Upload 37 files
Browse files- Text2Human/configs/index_pred_net.yml +84 -0
- Text2Human/configs/parsing_gen.yml +40 -0
- Text2Human/configs/parsing_token.yml +47 -0
- Text2Human/configs/sample_from_parsing.yml +93 -0
- Text2Human/configs/sample_from_pose.yml +107 -0
- Text2Human/configs/sampler.yml +83 -0
- Text2Human/configs/vqvae_bottom.yml +72 -0
- Text2Human/configs/vqvae_top.yml +53 -0
- models/__init__.py +42 -0
- models/archs/__init__.py +0 -0
- models/archs/__pycache__/__init__.cpython-38.pyc +0 -0
- models/archs/__pycache__/fcn_arch.cpython-38.pyc +0 -0
- models/archs/__pycache__/shape_attr_embedding_arch.cpython-38.pyc +0 -0
- models/archs/__pycache__/transformer_arch.cpython-38.pyc +0 -0
- models/archs/__pycache__/unet_arch.cpython-38.pyc +0 -0
- models/archs/__pycache__/vqgan_arch.cpython-38.pyc +0 -0
- models/archs/fcn_arch.py +418 -0
- models/archs/shape_attr_embedding_arch.py +35 -0
- models/archs/transformer_arch.py +273 -0
- models/archs/unet_arch.py +693 -0
- models/archs/vqgan_arch.py +1203 -0
- models/hierarchy_inference_model.py +363 -0
- models/hierarchy_vqgan_model.py +374 -0
- models/losses/__init__.py +0 -0
- models/losses/__pycache__/__init__.cpython-38.pyc +0 -0
- models/losses/__pycache__/accuracy.cpython-38.pyc +0 -0
- models/losses/__pycache__/cross_entropy_loss.cpython-38.pyc +0 -0
- models/losses/__pycache__/segmentation_loss.cpython-38.pyc +0 -0
- models/losses/__pycache__/vqgan_loss.cpython-38.pyc +0 -0
- models/losses/accuracy.py +46 -0
- models/losses/cross_entropy_loss.py +246 -0
- models/losses/segmentation_loss.py +25 -0
- models/losses/vqgan_loss.py +114 -0
- models/parsing_gen_model.py +220 -0
- models/sample_model.py +498 -0
- models/transformer_model.py +482 -0
- models/vqgan_model.py +551 -0
Text2Human/configs/index_pred_net.yml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: index_prediction_network
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
train_img_dir: ./datasets/train_images
|
10 |
+
test_img_dir: ./datasets/test_images
|
11 |
+
segm_dir: ./datasets/segm
|
12 |
+
pose_dir: ./datasets/densepose
|
13 |
+
train_ann_file: ./datasets/texture_ann/train
|
14 |
+
val_ann_file: ./datasets/texture_ann/val
|
15 |
+
test_ann_file: ./datasets/texture_ann/test
|
16 |
+
downsample_factor: 2
|
17 |
+
|
18 |
+
model_type: VQGANTextureAwareSpatialHierarchyInferenceModel
|
19 |
+
# network configs
|
20 |
+
embed_dim: 256
|
21 |
+
n_embed: 1024
|
22 |
+
codebook_spatial_size: 2
|
23 |
+
|
24 |
+
# bottom level vqvae
|
25 |
+
bot_n_embed: 512
|
26 |
+
bot_double_z: false
|
27 |
+
bot_z_channels: 256
|
28 |
+
bot_resolution: 512
|
29 |
+
bot_in_channels: 3
|
30 |
+
bot_out_ch: 3
|
31 |
+
bot_ch: 128
|
32 |
+
bot_ch_mult: [1, 1, 2, 4]
|
33 |
+
bot_num_res_blocks: 2
|
34 |
+
bot_attn_resolutions: [64]
|
35 |
+
bot_dropout: 0.0
|
36 |
+
bot_vae_path: ./pretrained_models/vqvae_bottom.pth
|
37 |
+
|
38 |
+
# top level vqgan
|
39 |
+
top_double_z: false
|
40 |
+
top_z_channels: 256
|
41 |
+
top_resolution: 512
|
42 |
+
top_in_channels: 3
|
43 |
+
top_out_ch: 3
|
44 |
+
top_ch: 128
|
45 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
46 |
+
top_num_res_blocks: 2
|
47 |
+
top_attn_resolutions: [32]
|
48 |
+
top_dropout: 0.0
|
49 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
50 |
+
|
51 |
+
# unet configs
|
52 |
+
encoder_in_channels: 256
|
53 |
+
fc_in_channels: 64
|
54 |
+
fc_in_index: 4
|
55 |
+
fc_channels: 64
|
56 |
+
fc_num_convs: 1
|
57 |
+
fc_concat_input: False
|
58 |
+
fc_dropout_ratio: 0.1
|
59 |
+
fc_num_classes: 512
|
60 |
+
fc_align_corners: False
|
61 |
+
|
62 |
+
disc_layers: 3
|
63 |
+
disc_weight_max: 1
|
64 |
+
disc_start_step: 30001
|
65 |
+
n_channels: 3
|
66 |
+
ndf: 64
|
67 |
+
nf: 128
|
68 |
+
perceptual_weight: 1.0
|
69 |
+
|
70 |
+
num_segm_classes: 24
|
71 |
+
|
72 |
+
# training configs
|
73 |
+
val_freq: 5
|
74 |
+
print_freq: 100
|
75 |
+
weight_decay: 0
|
76 |
+
manual_seed: 2021
|
77 |
+
num_epochs: 100
|
78 |
+
lr: !!float 1.0e-04
|
79 |
+
lr_decay: step
|
80 |
+
gamma: 1.0
|
81 |
+
step: 50
|
82 |
+
optimizer: Adam
|
83 |
+
loss_function: cross_entropy
|
84 |
+
|
Text2Human/configs/parsing_gen.yml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: parsing_generation
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 8
|
8 |
+
num_workers: 4
|
9 |
+
segm_dir: ./datasets/segm
|
10 |
+
pose_dir: ./datasets/densepose
|
11 |
+
train_ann_file: ./datasets/shape_ann/train_ann_file.txt
|
12 |
+
val_ann_file: ./datasets/shape_ann/val_ann_file.txt
|
13 |
+
test_ann_file: ./datasets/shape_ann/test_ann_file.txt
|
14 |
+
downsample_factor: 2
|
15 |
+
|
16 |
+
model_type: ParsingGenModel
|
17 |
+
# network configs
|
18 |
+
embedder_dim: 8
|
19 |
+
embedder_out_dim: 128
|
20 |
+
attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
|
21 |
+
encoder_in_channels: 1
|
22 |
+
fc_in_channels: 64
|
23 |
+
fc_in_index: 4
|
24 |
+
fc_channels: 64
|
25 |
+
fc_num_convs: 1
|
26 |
+
fc_concat_input: False
|
27 |
+
fc_dropout_ratio: 0.1
|
28 |
+
fc_num_classes: 24
|
29 |
+
fc_align_corners: False
|
30 |
+
|
31 |
+
# training configs
|
32 |
+
val_freq: 5
|
33 |
+
print_freq: 100
|
34 |
+
weight_decay: 0
|
35 |
+
manual_seed: 2021
|
36 |
+
num_epochs: 100
|
37 |
+
lr: !!float 1e-4
|
38 |
+
lr_decay: step
|
39 |
+
gamma: 0.1
|
40 |
+
step: 50
|
Text2Human/configs/parsing_token.yml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: parsing_tokenization
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
train_img_dir: ./datasets/train_images
|
10 |
+
test_img_dir: ./datasets/test_images
|
11 |
+
segm_dir: ./datasets/segm
|
12 |
+
pose_dir: ./datasets/densepose
|
13 |
+
train_ann_file: ./datasets/texture_ann/train
|
14 |
+
val_ann_file: ./datasets/texture_ann/val
|
15 |
+
test_ann_file: ./datasets/texture_ann/test
|
16 |
+
downsample_factor: 2
|
17 |
+
|
18 |
+
model_type: VQSegmentationModel
|
19 |
+
# network configs
|
20 |
+
embed_dim: 32
|
21 |
+
n_embed: 1024
|
22 |
+
image_key: "segmentation"
|
23 |
+
n_labels: 24
|
24 |
+
double_z: false
|
25 |
+
z_channels: 32
|
26 |
+
resolution: 512
|
27 |
+
in_channels: 24
|
28 |
+
out_ch: 24
|
29 |
+
ch: 64
|
30 |
+
ch_mult: [1, 1, 2, 2, 4]
|
31 |
+
num_res_blocks: 1
|
32 |
+
attn_resolutions: [16]
|
33 |
+
dropout: 0.0
|
34 |
+
|
35 |
+
num_segm_classes: 24
|
36 |
+
|
37 |
+
|
38 |
+
# training configs
|
39 |
+
val_freq: 5
|
40 |
+
print_freq: 100
|
41 |
+
weight_decay: 0
|
42 |
+
manual_seed: 2021
|
43 |
+
num_epochs: 100
|
44 |
+
lr: !!float 4.5e-05
|
45 |
+
lr_decay: step
|
46 |
+
gamma: 0.1
|
47 |
+
step: 50
|
Text2Human/configs/sample_from_parsing.yml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: sample_from_parsing
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
test_img_dir: ./datasets/test_images
|
10 |
+
segm_dir: ./datasets/segm
|
11 |
+
pose_dir: ./datasets/densepose
|
12 |
+
test_ann_file: ./datasets/texture_ann/test
|
13 |
+
downsample_factor: 2
|
14 |
+
|
15 |
+
model_type: SampleFromParsingModel
|
16 |
+
# network configs
|
17 |
+
embed_dim: 256
|
18 |
+
n_embed: 1024
|
19 |
+
codebook_spatial_size: 2
|
20 |
+
|
21 |
+
# bottom level vqvae
|
22 |
+
bot_n_embed: 512
|
23 |
+
bot_codebook_spatial_size: 2
|
24 |
+
bot_double_z: false
|
25 |
+
bot_z_channels: 256
|
26 |
+
bot_resolution: 512
|
27 |
+
bot_in_channels: 3
|
28 |
+
bot_out_ch: 3
|
29 |
+
bot_ch: 128
|
30 |
+
bot_ch_mult: [1, 1, 2, 4]
|
31 |
+
bot_num_res_blocks: 2
|
32 |
+
bot_attn_resolutions: [64]
|
33 |
+
bot_dropout: 0.0
|
34 |
+
bot_vae_path: ./pretrained_models/vqvae_bottom.pth
|
35 |
+
|
36 |
+
# top level vqgan
|
37 |
+
top_double_z: false
|
38 |
+
top_z_channels: 256
|
39 |
+
top_resolution: 512
|
40 |
+
top_in_channels: 3
|
41 |
+
top_out_ch: 3
|
42 |
+
top_ch: 128
|
43 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
44 |
+
top_num_res_blocks: 2
|
45 |
+
top_attn_resolutions: [32]
|
46 |
+
top_dropout: 0.0
|
47 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
48 |
+
|
49 |
+
# unet configs
|
50 |
+
index_pred_encoder_in_channels: 256
|
51 |
+
index_pred_fc_in_channels: 64
|
52 |
+
index_pred_fc_in_index: 4
|
53 |
+
index_pred_fc_channels: 64
|
54 |
+
index_pred_fc_num_convs: 1
|
55 |
+
index_pred_fc_concat_input: False
|
56 |
+
index_pred_fc_dropout_ratio: 0.1
|
57 |
+
index_pred_fc_num_classes: 512
|
58 |
+
index_pred_fc_align_corners: False
|
59 |
+
pretrained_index_network: ./pretrained_models/index_pred_net.pth
|
60 |
+
|
61 |
+
# segmentation tokenization
|
62 |
+
segm_double_z: false
|
63 |
+
segm_z_channels: 32
|
64 |
+
segm_resolution: 512
|
65 |
+
segm_in_channels: 24
|
66 |
+
segm_out_ch: 24
|
67 |
+
segm_ch: 64
|
68 |
+
segm_ch_mult: [1, 1, 2, 2, 4]
|
69 |
+
segm_num_res_blocks: 1
|
70 |
+
segm_attn_resolutions: [16]
|
71 |
+
segm_dropout: 0.0
|
72 |
+
segm_num_segm_classes: 24
|
73 |
+
segm_n_embed: 1024
|
74 |
+
segm_embed_dim: 32
|
75 |
+
segm_token_path: ./pretrained_models/parsing_token.pth
|
76 |
+
|
77 |
+
# sampler configs
|
78 |
+
codebook_size: 18432
|
79 |
+
segm_codebook_size: 1024
|
80 |
+
texture_codebook_size: 18
|
81 |
+
bert_n_emb: 512
|
82 |
+
bert_n_layers: 24
|
83 |
+
bert_n_head: 8
|
84 |
+
block_size: 512 # 32 x 16
|
85 |
+
latent_shape: [32, 16]
|
86 |
+
embd_pdrop: 0.0
|
87 |
+
resid_pdrop: 0.0
|
88 |
+
attn_pdrop: 0.0
|
89 |
+
num_head: 18
|
90 |
+
pretrained_sampler: ./pretrained_models/sampler.pth
|
91 |
+
|
92 |
+
manual_seed: 2021
|
93 |
+
sample_steps: 256
|
Text2Human/configs/sample_from_pose.yml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: sample_from_pose
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
pose_dir: ./datasets/densepose
|
10 |
+
texture_ann_file: ./datasets/texture_ann/test
|
11 |
+
shape_ann_path: ./datasets/shape_ann/test_ann_file.txt
|
12 |
+
downsample_factor: 2
|
13 |
+
|
14 |
+
model_type: SampleFromPoseModel
|
15 |
+
# network configs
|
16 |
+
embed_dim: 256
|
17 |
+
n_embed: 1024
|
18 |
+
codebook_spatial_size: 2
|
19 |
+
|
20 |
+
# bottom level vqgan
|
21 |
+
bot_n_embed: 512
|
22 |
+
bot_codebook_spatial_size: 2
|
23 |
+
bot_double_z: false
|
24 |
+
bot_z_channels: 256
|
25 |
+
bot_resolution: 512
|
26 |
+
bot_in_channels: 3
|
27 |
+
bot_out_ch: 3
|
28 |
+
bot_ch: 128
|
29 |
+
bot_ch_mult: [1, 1, 2, 4]
|
30 |
+
bot_num_res_blocks: 2
|
31 |
+
bot_attn_resolutions: [64]
|
32 |
+
bot_dropout: 0.0
|
33 |
+
bot_vae_path: ./pretrained_models/vqvae_bottom.pth
|
34 |
+
|
35 |
+
# top level vqgan
|
36 |
+
top_double_z: false
|
37 |
+
top_z_channels: 256
|
38 |
+
top_resolution: 512
|
39 |
+
top_in_channels: 3
|
40 |
+
top_out_ch: 3
|
41 |
+
top_ch: 128
|
42 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
43 |
+
top_num_res_blocks: 2
|
44 |
+
top_attn_resolutions: [32]
|
45 |
+
top_dropout: 0.0
|
46 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
47 |
+
|
48 |
+
# unet configs
|
49 |
+
index_pred_encoder_in_channels: 256
|
50 |
+
index_pred_fc_in_channels: 64
|
51 |
+
index_pred_fc_in_index: 4
|
52 |
+
index_pred_fc_channels: 64
|
53 |
+
index_pred_fc_num_convs: 1
|
54 |
+
index_pred_fc_concat_input: False
|
55 |
+
index_pred_fc_dropout_ratio: 0.1
|
56 |
+
index_pred_fc_num_classes: 512
|
57 |
+
index_pred_fc_align_corners: False
|
58 |
+
pretrained_index_network: ./pretrained_models/index_pred_net.pth
|
59 |
+
|
60 |
+
# segmentation tokenization
|
61 |
+
segm_double_z: false
|
62 |
+
segm_z_channels: 32
|
63 |
+
segm_resolution: 512
|
64 |
+
segm_in_channels: 24
|
65 |
+
segm_out_ch: 24
|
66 |
+
segm_ch: 64
|
67 |
+
segm_ch_mult: [1, 1, 2, 2, 4]
|
68 |
+
segm_num_res_blocks: 1
|
69 |
+
segm_attn_resolutions: [16]
|
70 |
+
segm_dropout: 0.0
|
71 |
+
segm_num_segm_classes: 24
|
72 |
+
segm_n_embed: 1024
|
73 |
+
segm_embed_dim: 32
|
74 |
+
segm_token_path: ./pretrained_models/parsing_token.pth
|
75 |
+
|
76 |
+
# sampler configs
|
77 |
+
codebook_size: 18432
|
78 |
+
segm_codebook_size: 1024
|
79 |
+
texture_codebook_size: 18
|
80 |
+
bert_n_emb: 512
|
81 |
+
bert_n_layers: 24
|
82 |
+
bert_n_head: 8
|
83 |
+
block_size: 512 # 32 x 16
|
84 |
+
latent_shape: [32, 16]
|
85 |
+
embd_pdrop: 0.0
|
86 |
+
resid_pdrop: 0.0
|
87 |
+
attn_pdrop: 0.0
|
88 |
+
num_head: 18
|
89 |
+
pretrained_sampler: ./pretrained_models/sampler.pth
|
90 |
+
|
91 |
+
# shape network configs
|
92 |
+
shape_embedder_dim: 8
|
93 |
+
shape_embedder_out_dim: 128
|
94 |
+
shape_attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
|
95 |
+
shape_encoder_in_channels: 1
|
96 |
+
shape_fc_in_channels: 64
|
97 |
+
shape_fc_in_index: 4
|
98 |
+
shape_fc_channels: 64
|
99 |
+
shape_fc_num_convs: 1
|
100 |
+
shape_fc_concat_input: False
|
101 |
+
shape_fc_dropout_ratio: 0.1
|
102 |
+
shape_fc_num_classes: 24
|
103 |
+
shape_fc_align_corners: False
|
104 |
+
pretrained_parsing_gen: ./pretrained_models/parsing_gen.pth
|
105 |
+
|
106 |
+
manual_seed: 2021
|
107 |
+
sample_steps: 256
|
Text2Human/configs/sampler.yml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: sampler
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 1
|
9 |
+
train_img_dir: ./datasets/train_images
|
10 |
+
test_img_dir: ./datasets/test_images
|
11 |
+
segm_dir: ./datasets/segm
|
12 |
+
pose_dir: ./datasets/densepose
|
13 |
+
train_ann_file: ./datasets/texture_ann/train
|
14 |
+
val_ann_file: ./datasets/texture_ann/val
|
15 |
+
test_ann_file: ./datasets/texture_ann/test
|
16 |
+
downsample_factor: 2
|
17 |
+
|
18 |
+
# pretrained models
|
19 |
+
img_ae_path: ./pretrained_models/vqvae_top.pth
|
20 |
+
segm_ae_path: ./pretrained_models/parsing_token.pth
|
21 |
+
|
22 |
+
model_type: TransformerTextureAwareModel
|
23 |
+
# network configs
|
24 |
+
|
25 |
+
# image autoencoder
|
26 |
+
img_embed_dim: 256
|
27 |
+
img_n_embed: 1024
|
28 |
+
img_double_z: false
|
29 |
+
img_z_channels: 256
|
30 |
+
img_resolution: 512
|
31 |
+
img_in_channels: 3
|
32 |
+
img_out_ch: 3
|
33 |
+
img_ch: 128
|
34 |
+
img_ch_mult: [1, 1, 2, 2, 4]
|
35 |
+
img_num_res_blocks: 2
|
36 |
+
img_attn_resolutions: [32]
|
37 |
+
img_dropout: 0.0
|
38 |
+
|
39 |
+
# segmentation tokenization
|
40 |
+
segm_double_z: false
|
41 |
+
segm_z_channels: 32
|
42 |
+
segm_resolution: 512
|
43 |
+
segm_in_channels: 24
|
44 |
+
segm_out_ch: 24
|
45 |
+
segm_ch: 64
|
46 |
+
segm_ch_mult: [1, 1, 2, 2, 4]
|
47 |
+
segm_num_res_blocks: 1
|
48 |
+
segm_attn_resolutions: [16]
|
49 |
+
segm_dropout: 0.0
|
50 |
+
segm_num_segm_classes: 24
|
51 |
+
segm_n_embed: 1024
|
52 |
+
segm_embed_dim: 32
|
53 |
+
|
54 |
+
# sampler configs
|
55 |
+
codebook_size: 18432
|
56 |
+
segm_codebook_size: 1024
|
57 |
+
texture_codebook_size: 18
|
58 |
+
bert_n_emb: 512
|
59 |
+
bert_n_layers: 24
|
60 |
+
bert_n_head: 8
|
61 |
+
block_size: 512 # 32 x 16
|
62 |
+
latent_shape: [32, 16]
|
63 |
+
embd_pdrop: 0.0
|
64 |
+
resid_pdrop: 0.0
|
65 |
+
attn_pdrop: 0.0
|
66 |
+
num_head: 18
|
67 |
+
|
68 |
+
# loss configs
|
69 |
+
loss_type: reweighted_elbo
|
70 |
+
mask_schedule: random
|
71 |
+
|
72 |
+
sample_steps: 256
|
73 |
+
|
74 |
+
# training configs
|
75 |
+
val_freq: 5
|
76 |
+
print_freq: 100
|
77 |
+
weight_decay: 0
|
78 |
+
manual_seed: 2021
|
79 |
+
num_epochs: 100
|
80 |
+
lr: !!float 1e-4
|
81 |
+
lr_decay: step
|
82 |
+
gamma: 1.0
|
83 |
+
step: 50
|
Text2Human/configs/vqvae_bottom.yml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: vqvae_bottom
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
train_img_dir: ./datasets/train_images
|
10 |
+
test_img_dir: ./datasets/test_images
|
11 |
+
segm_dir: ./datasets/segm
|
12 |
+
pose_dir: ./datasets/densepose
|
13 |
+
train_ann_file: ./datasets/texture_ann/train
|
14 |
+
val_ann_file: ./datasets/texture_ann/val
|
15 |
+
test_ann_file: ./datasets/texture_ann/test
|
16 |
+
downsample_factor: 2
|
17 |
+
|
18 |
+
model_type: HierarchyVQSpatialTextureAwareModel
|
19 |
+
# network configs
|
20 |
+
embed_dim: 256
|
21 |
+
n_embed: 1024
|
22 |
+
codebook_spatial_size: 2
|
23 |
+
|
24 |
+
# bottom level vqvae
|
25 |
+
bot_n_embed: 512
|
26 |
+
bot_double_z: false
|
27 |
+
bot_z_channels: 256
|
28 |
+
bot_resolution: 512
|
29 |
+
bot_in_channels: 3
|
30 |
+
bot_out_ch: 3
|
31 |
+
bot_ch: 128
|
32 |
+
bot_ch_mult: [1, 1, 2, 4]
|
33 |
+
bot_num_res_blocks: 2
|
34 |
+
bot_attn_resolutions: [64]
|
35 |
+
bot_dropout: 0.0
|
36 |
+
|
37 |
+
# top level vqgan
|
38 |
+
top_double_z: false
|
39 |
+
top_z_channels: 256
|
40 |
+
top_resolution: 512
|
41 |
+
top_in_channels: 3
|
42 |
+
top_out_ch: 3
|
43 |
+
top_ch: 128
|
44 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
45 |
+
top_num_res_blocks: 2
|
46 |
+
top_attn_resolutions: [32]
|
47 |
+
top_dropout: 0.0
|
48 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
49 |
+
|
50 |
+
fix_decoder: false
|
51 |
+
|
52 |
+
disc_layers: 3
|
53 |
+
disc_weight_max: 1
|
54 |
+
disc_start_step: 1
|
55 |
+
n_channels: 3
|
56 |
+
ndf: 64
|
57 |
+
nf: 128
|
58 |
+
perceptual_weight: 1.0
|
59 |
+
|
60 |
+
num_segm_classes: 24
|
61 |
+
|
62 |
+
# training configs
|
63 |
+
val_freq: 5
|
64 |
+
print_freq: 100
|
65 |
+
weight_decay: 0
|
66 |
+
manual_seed: 2021
|
67 |
+
num_epochs: 1000
|
68 |
+
lr: !!float 1.0e-04
|
69 |
+
lr_decay: step
|
70 |
+
gamma: 1.0
|
71 |
+
step: 50
|
72 |
+
|
Text2Human/configs/vqvae_top.yml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: vqvae_top
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
train_img_dir: ./datasets/train_images
|
10 |
+
test_img_dir: ./datasets/test_images
|
11 |
+
segm_dir: ./datasets/segm
|
12 |
+
pose_dir: ./datasets/densepose
|
13 |
+
train_ann_file: ./datasets/texture_ann/train
|
14 |
+
val_ann_file: ./datasets/texture_ann/val
|
15 |
+
test_ann_file: ./datasets/texture_ann/test
|
16 |
+
downsample_factor: 2
|
17 |
+
|
18 |
+
model_type: VQImageSegmTextureModel
|
19 |
+
# network configs
|
20 |
+
embed_dim: 256
|
21 |
+
n_embed: 1024
|
22 |
+
double_z: false
|
23 |
+
z_channels: 256
|
24 |
+
resolution: 512
|
25 |
+
in_channels: 3
|
26 |
+
out_ch: 3
|
27 |
+
ch: 128
|
28 |
+
ch_mult: [1, 1, 2, 2, 4]
|
29 |
+
num_res_blocks: 2
|
30 |
+
attn_resolutions: [32]
|
31 |
+
dropout: 0.0
|
32 |
+
|
33 |
+
disc_layers: 3
|
34 |
+
disc_weight_max: 0
|
35 |
+
disc_start_step: 3000000000000000000000000001
|
36 |
+
n_channels: 3
|
37 |
+
ndf: 64
|
38 |
+
nf: 128
|
39 |
+
perceptual_weight: 1.0
|
40 |
+
|
41 |
+
num_segm_classes: 24
|
42 |
+
|
43 |
+
|
44 |
+
# training configs
|
45 |
+
val_freq: 5
|
46 |
+
print_freq: 100
|
47 |
+
weight_decay: 0
|
48 |
+
manual_seed: 2021
|
49 |
+
num_epochs: 1000
|
50 |
+
lr: !!float 1.0e-04
|
51 |
+
lr_decay: step
|
52 |
+
gamma: 1.0
|
53 |
+
step: 50
|
models/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import importlib
|
3 |
+
import logging
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
# automatically scan and import model modules
|
7 |
+
# scan all the files under the 'models' folder and collect files ending with
|
8 |
+
# '_model.py'
|
9 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
10 |
+
model_filenames = [
|
11 |
+
osp.splitext(osp.basename(v))[0]
|
12 |
+
for v in glob.glob(f'{model_folder}/*_model.py')
|
13 |
+
]
|
14 |
+
# import all the model modules
|
15 |
+
_model_modules = [
|
16 |
+
importlib.import_module(f'models.{file_name}')
|
17 |
+
for file_name in model_filenames
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def create_model(opt):
|
22 |
+
"""Create model.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
opt (dict): Configuration. It constains:
|
26 |
+
model_type (str): Model type.
|
27 |
+
"""
|
28 |
+
model_type = opt['model_type']
|
29 |
+
|
30 |
+
# dynamically instantiation
|
31 |
+
for module in _model_modules:
|
32 |
+
model_cls = getattr(module, model_type, None)
|
33 |
+
if model_cls is not None:
|
34 |
+
break
|
35 |
+
if model_cls is None:
|
36 |
+
raise ValueError(f'Model {model_type} is not found.')
|
37 |
+
|
38 |
+
model = model_cls(opt)
|
39 |
+
|
40 |
+
logger = logging.getLogger('base')
|
41 |
+
logger.info(f'Model [{model.__class__.__name__}] is created.')
|
42 |
+
return model
|
models/archs/__init__.py
ADDED
File without changes
|
models/archs/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (126 Bytes). View file
|
|
models/archs/__pycache__/fcn_arch.cpython-38.pyc
ADDED
Binary file (10.5 kB). View file
|
|
models/archs/__pycache__/shape_attr_embedding_arch.cpython-38.pyc
ADDED
Binary file (1.33 kB). View file
|
|
models/archs/__pycache__/transformer_arch.cpython-38.pyc
ADDED
Binary file (7.61 kB). View file
|
|
models/archs/__pycache__/unet_arch.cpython-38.pyc
ADDED
Binary file (21.9 kB). View file
|
|
models/archs/__pycache__/vqgan_arch.cpython-38.pyc
ADDED
Binary file (24.5 kB). View file
|
|
models/archs/fcn_arch.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from mmcv.cnn import ConvModule, normal_init
|
4 |
+
from mmseg.ops import resize
|
5 |
+
|
6 |
+
|
7 |
+
class BaseDecodeHead(nn.Module):
|
8 |
+
"""Base class for BaseDecodeHead.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
in_channels (int|Sequence[int]): Input channels.
|
12 |
+
channels (int): Channels after modules, before conv_seg.
|
13 |
+
num_classes (int): Number of classes.
|
14 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
15 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
16 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
17 |
+
act_cfg (dict): Config of activation layers.
|
18 |
+
Default: dict(type='ReLU')
|
19 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
20 |
+
input_transform (str|None): Transformation type of input features.
|
21 |
+
Options: 'resize_concat', 'multiple_select', None.
|
22 |
+
'resize_concat': Multiple feature maps will be resize to the
|
23 |
+
same size as first one and than concat together.
|
24 |
+
Usually used in FCN head of HRNet.
|
25 |
+
'multiple_select': Multiple feature maps will be bundle into
|
26 |
+
a list and passed into decode head.
|
27 |
+
None: Only one select feature map is allowed.
|
28 |
+
Default: None.
|
29 |
+
loss_decode (dict): Config of decode loss.
|
30 |
+
Default: dict(type='CrossEntropyLoss').
|
31 |
+
ignore_index (int | None): The label index to be ignored. When using
|
32 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
33 |
+
sampler (dict|None): The config of segmentation map sampler.
|
34 |
+
Default: None.
|
35 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
36 |
+
Default: False.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
in_channels,
|
41 |
+
channels,
|
42 |
+
*,
|
43 |
+
num_classes,
|
44 |
+
dropout_ratio=0.1,
|
45 |
+
conv_cfg=None,
|
46 |
+
norm_cfg=dict(type='BN'),
|
47 |
+
act_cfg=dict(type='ReLU'),
|
48 |
+
in_index=-1,
|
49 |
+
input_transform=None,
|
50 |
+
ignore_index=255,
|
51 |
+
align_corners=False):
|
52 |
+
super(BaseDecodeHead, self).__init__()
|
53 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
54 |
+
self.channels = channels
|
55 |
+
self.num_classes = num_classes
|
56 |
+
self.dropout_ratio = dropout_ratio
|
57 |
+
self.conv_cfg = conv_cfg
|
58 |
+
self.norm_cfg = norm_cfg
|
59 |
+
self.act_cfg = act_cfg
|
60 |
+
self.in_index = in_index
|
61 |
+
|
62 |
+
self.ignore_index = ignore_index
|
63 |
+
self.align_corners = align_corners
|
64 |
+
|
65 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
66 |
+
if dropout_ratio > 0:
|
67 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
68 |
+
else:
|
69 |
+
self.dropout = None
|
70 |
+
|
71 |
+
def extra_repr(self):
|
72 |
+
"""Extra repr."""
|
73 |
+
s = f'input_transform={self.input_transform}, ' \
|
74 |
+
f'ignore_index={self.ignore_index}, ' \
|
75 |
+
f'align_corners={self.align_corners}'
|
76 |
+
return s
|
77 |
+
|
78 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
79 |
+
"""Check and initialize input transforms.
|
80 |
+
|
81 |
+
The in_channels, in_index and input_transform must match.
|
82 |
+
Specifically, when input_transform is None, only single feature map
|
83 |
+
will be selected. So in_channels and in_index must be of type int.
|
84 |
+
When input_transform
|
85 |
+
|
86 |
+
Args:
|
87 |
+
in_channels (int|Sequence[int]): Input channels.
|
88 |
+
in_index (int|Sequence[int]): Input feature index.
|
89 |
+
input_transform (str|None): Transformation type of input features.
|
90 |
+
Options: 'resize_concat', 'multiple_select', None.
|
91 |
+
'resize_concat': Multiple feature maps will be resize to the
|
92 |
+
same size as first one and than concat together.
|
93 |
+
Usually used in FCN head of HRNet.
|
94 |
+
'multiple_select': Multiple feature maps will be bundle into
|
95 |
+
a list and passed into decode head.
|
96 |
+
None: Only one select feature map is allowed.
|
97 |
+
"""
|
98 |
+
|
99 |
+
if input_transform is not None:
|
100 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
101 |
+
self.input_transform = input_transform
|
102 |
+
self.in_index = in_index
|
103 |
+
if input_transform is not None:
|
104 |
+
assert isinstance(in_channels, (list, tuple))
|
105 |
+
assert isinstance(in_index, (list, tuple))
|
106 |
+
assert len(in_channels) == len(in_index)
|
107 |
+
if input_transform == 'resize_concat':
|
108 |
+
self.in_channels = sum(in_channels)
|
109 |
+
else:
|
110 |
+
self.in_channels = in_channels
|
111 |
+
else:
|
112 |
+
assert isinstance(in_channels, int)
|
113 |
+
assert isinstance(in_index, int)
|
114 |
+
self.in_channels = in_channels
|
115 |
+
|
116 |
+
def init_weights(self):
|
117 |
+
"""Initialize weights of classification layer."""
|
118 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
119 |
+
|
120 |
+
def _transform_inputs(self, inputs):
|
121 |
+
"""Transform inputs for decoder.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
inputs (list[Tensor]): List of multi-level img features.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
Tensor: The transformed inputs
|
128 |
+
"""
|
129 |
+
|
130 |
+
if self.input_transform == 'resize_concat':
|
131 |
+
inputs = [inputs[i] for i in self.in_index]
|
132 |
+
upsampled_inputs = [
|
133 |
+
resize(
|
134 |
+
input=x,
|
135 |
+
size=inputs[0].shape[2:],
|
136 |
+
mode='bilinear',
|
137 |
+
align_corners=self.align_corners) for x in inputs
|
138 |
+
]
|
139 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
140 |
+
elif self.input_transform == 'multiple_select':
|
141 |
+
inputs = [inputs[i] for i in self.in_index]
|
142 |
+
else:
|
143 |
+
inputs = inputs[self.in_index]
|
144 |
+
|
145 |
+
return inputs
|
146 |
+
|
147 |
+
def forward(self, inputs):
|
148 |
+
"""Placeholder of forward function."""
|
149 |
+
pass
|
150 |
+
|
151 |
+
def cls_seg(self, feat):
|
152 |
+
"""Classify each pixel."""
|
153 |
+
if self.dropout is not None:
|
154 |
+
feat = self.dropout(feat)
|
155 |
+
output = self.conv_seg(feat)
|
156 |
+
return output
|
157 |
+
|
158 |
+
|
159 |
+
class FCNHead(BaseDecodeHead):
|
160 |
+
"""Fully Convolution Networks for Semantic Segmentation.
|
161 |
+
|
162 |
+
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
num_convs (int): Number of convs in the head. Default: 2.
|
166 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
167 |
+
concat_input (bool): Whether concat the input and output of convs
|
168 |
+
before classification layer.
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self,
|
172 |
+
num_convs=2,
|
173 |
+
kernel_size=3,
|
174 |
+
concat_input=True,
|
175 |
+
**kwargs):
|
176 |
+
assert num_convs >= 0
|
177 |
+
self.num_convs = num_convs
|
178 |
+
self.concat_input = concat_input
|
179 |
+
self.kernel_size = kernel_size
|
180 |
+
super(FCNHead, self).__init__(**kwargs)
|
181 |
+
if num_convs == 0:
|
182 |
+
assert self.in_channels == self.channels
|
183 |
+
|
184 |
+
convs = []
|
185 |
+
convs.append(
|
186 |
+
ConvModule(
|
187 |
+
self.in_channels,
|
188 |
+
self.channels,
|
189 |
+
kernel_size=kernel_size,
|
190 |
+
padding=kernel_size // 2,
|
191 |
+
conv_cfg=self.conv_cfg,
|
192 |
+
norm_cfg=self.norm_cfg,
|
193 |
+
act_cfg=self.act_cfg))
|
194 |
+
for i in range(num_convs - 1):
|
195 |
+
convs.append(
|
196 |
+
ConvModule(
|
197 |
+
self.channels,
|
198 |
+
self.channels,
|
199 |
+
kernel_size=kernel_size,
|
200 |
+
padding=kernel_size // 2,
|
201 |
+
conv_cfg=self.conv_cfg,
|
202 |
+
norm_cfg=self.norm_cfg,
|
203 |
+
act_cfg=self.act_cfg))
|
204 |
+
if num_convs == 0:
|
205 |
+
self.convs = nn.Identity()
|
206 |
+
else:
|
207 |
+
self.convs = nn.Sequential(*convs)
|
208 |
+
if self.concat_input:
|
209 |
+
self.conv_cat = ConvModule(
|
210 |
+
self.in_channels + self.channels,
|
211 |
+
self.channels,
|
212 |
+
kernel_size=kernel_size,
|
213 |
+
padding=kernel_size // 2,
|
214 |
+
conv_cfg=self.conv_cfg,
|
215 |
+
norm_cfg=self.norm_cfg,
|
216 |
+
act_cfg=self.act_cfg)
|
217 |
+
|
218 |
+
def forward(self, inputs):
|
219 |
+
"""Forward function."""
|
220 |
+
x = self._transform_inputs(inputs)
|
221 |
+
output = self.convs(x)
|
222 |
+
if self.concat_input:
|
223 |
+
output = self.conv_cat(torch.cat([x, output], dim=1))
|
224 |
+
output = self.cls_seg(output)
|
225 |
+
return output
|
226 |
+
|
227 |
+
|
228 |
+
class MultiHeadFCNHead(nn.Module):
|
229 |
+
"""Fully Convolution Networks for Semantic Segmentation.
|
230 |
+
|
231 |
+
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
num_convs (int): Number of convs in the head. Default: 2.
|
235 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
236 |
+
concat_input (bool): Whether concat the input and output of convs
|
237 |
+
before classification layer.
|
238 |
+
"""
|
239 |
+
|
240 |
+
def __init__(self,
|
241 |
+
in_channels,
|
242 |
+
channels,
|
243 |
+
*,
|
244 |
+
num_classes,
|
245 |
+
dropout_ratio=0.1,
|
246 |
+
conv_cfg=None,
|
247 |
+
norm_cfg=dict(type='BN'),
|
248 |
+
act_cfg=dict(type='ReLU'),
|
249 |
+
in_index=-1,
|
250 |
+
input_transform=None,
|
251 |
+
ignore_index=255,
|
252 |
+
align_corners=False,
|
253 |
+
num_convs=2,
|
254 |
+
kernel_size=3,
|
255 |
+
concat_input=True,
|
256 |
+
num_head=18,
|
257 |
+
**kwargs):
|
258 |
+
super(MultiHeadFCNHead, self).__init__()
|
259 |
+
assert num_convs >= 0
|
260 |
+
self.num_convs = num_convs
|
261 |
+
self.concat_input = concat_input
|
262 |
+
self.kernel_size = kernel_size
|
263 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
264 |
+
self.channels = channels
|
265 |
+
self.num_classes = num_classes
|
266 |
+
self.dropout_ratio = dropout_ratio
|
267 |
+
self.conv_cfg = conv_cfg
|
268 |
+
self.norm_cfg = norm_cfg
|
269 |
+
self.act_cfg = act_cfg
|
270 |
+
self.in_index = in_index
|
271 |
+
self.num_head = num_head
|
272 |
+
|
273 |
+
self.ignore_index = ignore_index
|
274 |
+
self.align_corners = align_corners
|
275 |
+
|
276 |
+
if dropout_ratio > 0:
|
277 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
278 |
+
|
279 |
+
conv_seg_head_list = []
|
280 |
+
for _ in range(self.num_head):
|
281 |
+
conv_seg_head_list.append(
|
282 |
+
nn.Conv2d(channels, num_classes, kernel_size=1))
|
283 |
+
|
284 |
+
self.conv_seg_head_list = nn.ModuleList(conv_seg_head_list)
|
285 |
+
|
286 |
+
self.init_weights()
|
287 |
+
|
288 |
+
if num_convs == 0:
|
289 |
+
assert self.in_channels == self.channels
|
290 |
+
|
291 |
+
convs_list = []
|
292 |
+
conv_cat_list = []
|
293 |
+
|
294 |
+
for _ in range(self.num_head):
|
295 |
+
convs = []
|
296 |
+
convs.append(
|
297 |
+
ConvModule(
|
298 |
+
self.in_channels,
|
299 |
+
self.channels,
|
300 |
+
kernel_size=kernel_size,
|
301 |
+
padding=kernel_size // 2,
|
302 |
+
conv_cfg=self.conv_cfg,
|
303 |
+
norm_cfg=self.norm_cfg,
|
304 |
+
act_cfg=self.act_cfg))
|
305 |
+
for _ in range(num_convs - 1):
|
306 |
+
convs.append(
|
307 |
+
ConvModule(
|
308 |
+
self.channels,
|
309 |
+
self.channels,
|
310 |
+
kernel_size=kernel_size,
|
311 |
+
padding=kernel_size // 2,
|
312 |
+
conv_cfg=self.conv_cfg,
|
313 |
+
norm_cfg=self.norm_cfg,
|
314 |
+
act_cfg=self.act_cfg))
|
315 |
+
if num_convs == 0:
|
316 |
+
convs_list.append(nn.Identity())
|
317 |
+
else:
|
318 |
+
convs_list.append(nn.Sequential(*convs))
|
319 |
+
if self.concat_input:
|
320 |
+
conv_cat_list.append(
|
321 |
+
ConvModule(
|
322 |
+
self.in_channels + self.channels,
|
323 |
+
self.channels,
|
324 |
+
kernel_size=kernel_size,
|
325 |
+
padding=kernel_size // 2,
|
326 |
+
conv_cfg=self.conv_cfg,
|
327 |
+
norm_cfg=self.norm_cfg,
|
328 |
+
act_cfg=self.act_cfg))
|
329 |
+
|
330 |
+
self.convs_list = nn.ModuleList(convs_list)
|
331 |
+
self.conv_cat_list = nn.ModuleList(conv_cat_list)
|
332 |
+
|
333 |
+
def forward(self, inputs):
|
334 |
+
"""Forward function."""
|
335 |
+
x = self._transform_inputs(inputs)
|
336 |
+
|
337 |
+
output_list = []
|
338 |
+
for head_idx in range(self.num_head):
|
339 |
+
output = self.convs_list[head_idx](x)
|
340 |
+
if self.concat_input:
|
341 |
+
output = self.conv_cat_list[head_idx](
|
342 |
+
torch.cat([x, output], dim=1))
|
343 |
+
if self.dropout is not None:
|
344 |
+
output = self.dropout(output)
|
345 |
+
output = self.conv_seg_head_list[head_idx](output)
|
346 |
+
output_list.append(output)
|
347 |
+
|
348 |
+
return output_list
|
349 |
+
|
350 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
351 |
+
"""Check and initialize input transforms.
|
352 |
+
|
353 |
+
The in_channels, in_index and input_transform must match.
|
354 |
+
Specifically, when input_transform is None, only single feature map
|
355 |
+
will be selected. So in_channels and in_index must be of type int.
|
356 |
+
When input_transform
|
357 |
+
|
358 |
+
Args:
|
359 |
+
in_channels (int|Sequence[int]): Input channels.
|
360 |
+
in_index (int|Sequence[int]): Input feature index.
|
361 |
+
input_transform (str|None): Transformation type of input features.
|
362 |
+
Options: 'resize_concat', 'multiple_select', None.
|
363 |
+
'resize_concat': Multiple feature maps will be resize to the
|
364 |
+
same size as first one and than concat together.
|
365 |
+
Usually used in FCN head of HRNet.
|
366 |
+
'multiple_select': Multiple feature maps will be bundle into
|
367 |
+
a list and passed into decode head.
|
368 |
+
None: Only one select feature map is allowed.
|
369 |
+
"""
|
370 |
+
|
371 |
+
if input_transform is not None:
|
372 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
373 |
+
self.input_transform = input_transform
|
374 |
+
self.in_index = in_index
|
375 |
+
if input_transform is not None:
|
376 |
+
assert isinstance(in_channels, (list, tuple))
|
377 |
+
assert isinstance(in_index, (list, tuple))
|
378 |
+
assert len(in_channels) == len(in_index)
|
379 |
+
if input_transform == 'resize_concat':
|
380 |
+
self.in_channels = sum(in_channels)
|
381 |
+
else:
|
382 |
+
self.in_channels = in_channels
|
383 |
+
else:
|
384 |
+
assert isinstance(in_channels, int)
|
385 |
+
assert isinstance(in_index, int)
|
386 |
+
self.in_channels = in_channels
|
387 |
+
|
388 |
+
def init_weights(self):
|
389 |
+
"""Initialize weights of classification layer."""
|
390 |
+
for conv_seg_head in self.conv_seg_head_list:
|
391 |
+
normal_init(conv_seg_head, mean=0, std=0.01)
|
392 |
+
|
393 |
+
def _transform_inputs(self, inputs):
|
394 |
+
"""Transform inputs for decoder.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
inputs (list[Tensor]): List of multi-level img features.
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
Tensor: The transformed inputs
|
401 |
+
"""
|
402 |
+
|
403 |
+
if self.input_transform == 'resize_concat':
|
404 |
+
inputs = [inputs[i] for i in self.in_index]
|
405 |
+
upsampled_inputs = [
|
406 |
+
resize(
|
407 |
+
input=x,
|
408 |
+
size=inputs[0].shape[2:],
|
409 |
+
mode='bilinear',
|
410 |
+
align_corners=self.align_corners) for x in inputs
|
411 |
+
]
|
412 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
413 |
+
elif self.input_transform == 'multiple_select':
|
414 |
+
inputs = [inputs[i] for i in self.in_index]
|
415 |
+
else:
|
416 |
+
inputs = inputs[self.in_index]
|
417 |
+
|
418 |
+
return inputs
|
models/archs/shape_attr_embedding_arch.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class ShapeAttrEmbedding(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, dim, out_dim, cls_num_list):
|
9 |
+
super(ShapeAttrEmbedding, self).__init__()
|
10 |
+
|
11 |
+
for idx, cls_num in enumerate(cls_num_list):
|
12 |
+
setattr(
|
13 |
+
self, f'attr_{idx}',
|
14 |
+
nn.Sequential(
|
15 |
+
nn.Linear(cls_num, dim), nn.LeakyReLU(),
|
16 |
+
nn.Linear(dim, dim)))
|
17 |
+
self.cls_num_list = cls_num_list
|
18 |
+
self.attr_num = len(cls_num_list)
|
19 |
+
self.fusion = nn.Sequential(
|
20 |
+
nn.Linear(dim * self.attr_num, out_dim), nn.LeakyReLU(),
|
21 |
+
nn.Linear(out_dim, out_dim))
|
22 |
+
|
23 |
+
def forward(self, attr):
|
24 |
+
attr_embedding_list = []
|
25 |
+
for idx in range(self.attr_num):
|
26 |
+
attr_embed_fc = getattr(self, f'attr_{idx}')
|
27 |
+
attr_embedding_list.append(
|
28 |
+
attr_embed_fc(
|
29 |
+
F.one_hot(
|
30 |
+
attr[:, idx],
|
31 |
+
num_classes=self.cls_num_list[idx]).to(torch.float32)))
|
32 |
+
attr_embedding = torch.cat(attr_embedding_list, dim=1)
|
33 |
+
attr_embedding = self.fusion(attr_embedding)
|
34 |
+
|
35 |
+
return attr_embedding
|
models/archs/transformer_arch.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class CausalSelfAttention(nn.Module):
|
10 |
+
"""
|
11 |
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
12 |
+
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
13 |
+
explicit implementation here to show that there is nothing too scary here.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, bert_n_emb, bert_n_head, attn_pdrop, resid_pdrop,
|
17 |
+
latent_shape, sampler):
|
18 |
+
super().__init__()
|
19 |
+
assert bert_n_emb % bert_n_head == 0
|
20 |
+
# key, query, value projections for all heads
|
21 |
+
self.key = nn.Linear(bert_n_emb, bert_n_emb)
|
22 |
+
self.query = nn.Linear(bert_n_emb, bert_n_emb)
|
23 |
+
self.value = nn.Linear(bert_n_emb, bert_n_emb)
|
24 |
+
# regularization
|
25 |
+
self.attn_drop = nn.Dropout(attn_pdrop)
|
26 |
+
self.resid_drop = nn.Dropout(resid_pdrop)
|
27 |
+
# output projection
|
28 |
+
self.proj = nn.Linear(bert_n_emb, bert_n_emb)
|
29 |
+
self.n_head = bert_n_head
|
30 |
+
self.causal = True if sampler == 'autoregressive' else False
|
31 |
+
if self.causal:
|
32 |
+
block_size = np.prod(latent_shape)
|
33 |
+
mask = torch.tril(torch.ones(block_size, block_size))
|
34 |
+
self.register_buffer("mask", mask.view(1, 1, block_size,
|
35 |
+
block_size))
|
36 |
+
|
37 |
+
def forward(self, x, layer_past=None):
|
38 |
+
B, T, C = x.size()
|
39 |
+
|
40 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
41 |
+
k = self.key(x).view(B, T, self.n_head,
|
42 |
+
C // self.n_head).transpose(1,
|
43 |
+
2) # (B, nh, T, hs)
|
44 |
+
q = self.query(x).view(B, T, self.n_head,
|
45 |
+
C // self.n_head).transpose(1,
|
46 |
+
2) # (B, nh, T, hs)
|
47 |
+
v = self.value(x).view(B, T, self.n_head,
|
48 |
+
C // self.n_head).transpose(1,
|
49 |
+
2) # (B, nh, T, hs)
|
50 |
+
|
51 |
+
present = torch.stack((k, v))
|
52 |
+
if self.causal and layer_past is not None:
|
53 |
+
past_key, past_value = layer_past
|
54 |
+
k = torch.cat((past_key, k), dim=-2)
|
55 |
+
v = torch.cat((past_value, v), dim=-2)
|
56 |
+
|
57 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
58 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
59 |
+
|
60 |
+
if self.causal and layer_past is None:
|
61 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
|
62 |
+
|
63 |
+
att = F.softmax(att, dim=-1)
|
64 |
+
att = self.attn_drop(att)
|
65 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
66 |
+
# re-assemble all head outputs side by side
|
67 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
68 |
+
|
69 |
+
# output projection
|
70 |
+
y = self.resid_drop(self.proj(y))
|
71 |
+
return y, present
|
72 |
+
|
73 |
+
|
74 |
+
class Block(nn.Module):
|
75 |
+
""" an unassuming Transformer block """
|
76 |
+
|
77 |
+
def __init__(self, bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
|
78 |
+
latent_shape, sampler):
|
79 |
+
super().__init__()
|
80 |
+
self.ln1 = nn.LayerNorm(bert_n_emb)
|
81 |
+
self.ln2 = nn.LayerNorm(bert_n_emb)
|
82 |
+
self.attn = CausalSelfAttention(bert_n_emb, bert_n_head, attn_pdrop,
|
83 |
+
resid_pdrop, latent_shape, sampler)
|
84 |
+
self.mlp = nn.Sequential(
|
85 |
+
nn.Linear(bert_n_emb, 4 * bert_n_emb),
|
86 |
+
nn.GELU(), # nice
|
87 |
+
nn.Linear(4 * bert_n_emb, bert_n_emb),
|
88 |
+
nn.Dropout(resid_pdrop),
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x, layer_past=None, return_present=False):
|
92 |
+
|
93 |
+
attn, present = self.attn(self.ln1(x), layer_past)
|
94 |
+
x = x + attn
|
95 |
+
x = x + self.mlp(self.ln2(x))
|
96 |
+
|
97 |
+
if layer_past is not None or return_present:
|
98 |
+
return x, present
|
99 |
+
return x
|
100 |
+
|
101 |
+
|
102 |
+
class Transformer(nn.Module):
|
103 |
+
""" the full GPT language model, with a context size of block_size """
|
104 |
+
|
105 |
+
def __init__(self,
|
106 |
+
codebook_size,
|
107 |
+
segm_codebook_size,
|
108 |
+
bert_n_emb,
|
109 |
+
bert_n_layers,
|
110 |
+
bert_n_head,
|
111 |
+
block_size,
|
112 |
+
latent_shape,
|
113 |
+
embd_pdrop,
|
114 |
+
resid_pdrop,
|
115 |
+
attn_pdrop,
|
116 |
+
sampler='absorbing'):
|
117 |
+
super().__init__()
|
118 |
+
|
119 |
+
self.vocab_size = codebook_size + 1
|
120 |
+
self.n_embd = bert_n_emb
|
121 |
+
self.block_size = block_size
|
122 |
+
self.n_layers = bert_n_layers
|
123 |
+
self.codebook_size = codebook_size
|
124 |
+
self.segm_codebook_size = segm_codebook_size
|
125 |
+
self.causal = sampler == 'autoregressive'
|
126 |
+
if self.causal:
|
127 |
+
self.vocab_size = codebook_size
|
128 |
+
|
129 |
+
self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
|
130 |
+
self.pos_emb = nn.Parameter(
|
131 |
+
torch.zeros(1, self.block_size, self.n_embd))
|
132 |
+
self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
|
133 |
+
self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
|
134 |
+
self.drop = nn.Dropout(embd_pdrop)
|
135 |
+
|
136 |
+
# transformer
|
137 |
+
self.blocks = nn.Sequential(*[
|
138 |
+
Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
|
139 |
+
latent_shape, sampler) for _ in range(self.n_layers)
|
140 |
+
])
|
141 |
+
# decoder head
|
142 |
+
self.ln_f = nn.LayerNorm(self.n_embd)
|
143 |
+
self.head = nn.Linear(self.n_embd, self.codebook_size, bias=False)
|
144 |
+
|
145 |
+
def get_block_size(self):
|
146 |
+
return self.block_size
|
147 |
+
|
148 |
+
def _init_weights(self, module):
|
149 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
150 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
151 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
152 |
+
module.bias.data.zero_()
|
153 |
+
elif isinstance(module, nn.LayerNorm):
|
154 |
+
module.bias.data.zero_()
|
155 |
+
module.weight.data.fill_(1.0)
|
156 |
+
|
157 |
+
def forward(self, idx, segm_tokens, t=None):
|
158 |
+
# each index maps to a (learnable) vector
|
159 |
+
token_embeddings = self.tok_emb(idx)
|
160 |
+
|
161 |
+
segm_embeddings = self.segm_emb(segm_tokens)
|
162 |
+
|
163 |
+
if self.causal:
|
164 |
+
token_embeddings = torch.cat((self.start_tok.repeat(
|
165 |
+
token_embeddings.size(0), 1, 1), token_embeddings),
|
166 |
+
dim=1)
|
167 |
+
|
168 |
+
t = token_embeddings.shape[1]
|
169 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
170 |
+
# each position maps to a (learnable) vector
|
171 |
+
|
172 |
+
position_embeddings = self.pos_emb[:, :t, :]
|
173 |
+
|
174 |
+
x = token_embeddings + position_embeddings + segm_embeddings
|
175 |
+
x = self.drop(x)
|
176 |
+
for block in self.blocks:
|
177 |
+
x = block(x)
|
178 |
+
x = self.ln_f(x)
|
179 |
+
logits = self.head(x)
|
180 |
+
|
181 |
+
return logits
|
182 |
+
|
183 |
+
|
184 |
+
class TransformerMultiHead(nn.Module):
|
185 |
+
""" the full GPT language model, with a context size of block_size """
|
186 |
+
|
187 |
+
def __init__(self,
|
188 |
+
codebook_size,
|
189 |
+
segm_codebook_size,
|
190 |
+
texture_codebook_size,
|
191 |
+
bert_n_emb,
|
192 |
+
bert_n_layers,
|
193 |
+
bert_n_head,
|
194 |
+
block_size,
|
195 |
+
latent_shape,
|
196 |
+
embd_pdrop,
|
197 |
+
resid_pdrop,
|
198 |
+
attn_pdrop,
|
199 |
+
num_head,
|
200 |
+
sampler='absorbing'):
|
201 |
+
super().__init__()
|
202 |
+
|
203 |
+
self.vocab_size = codebook_size + 1
|
204 |
+
self.n_embd = bert_n_emb
|
205 |
+
self.block_size = block_size
|
206 |
+
self.n_layers = bert_n_layers
|
207 |
+
self.codebook_size = codebook_size
|
208 |
+
self.segm_codebook_size = segm_codebook_size
|
209 |
+
self.texture_codebook_size = texture_codebook_size
|
210 |
+
self.causal = sampler == 'autoregressive'
|
211 |
+
if self.causal:
|
212 |
+
self.vocab_size = codebook_size
|
213 |
+
|
214 |
+
self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
|
215 |
+
self.pos_emb = nn.Parameter(
|
216 |
+
torch.zeros(1, self.block_size, self.n_embd))
|
217 |
+
self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
|
218 |
+
self.texture_emb = nn.Embedding(self.texture_codebook_size,
|
219 |
+
self.n_embd)
|
220 |
+
self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
|
221 |
+
self.drop = nn.Dropout(embd_pdrop)
|
222 |
+
|
223 |
+
# transformer
|
224 |
+
self.blocks = nn.Sequential(*[
|
225 |
+
Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
|
226 |
+
latent_shape, sampler) for _ in range(self.n_layers)
|
227 |
+
])
|
228 |
+
# decoder head
|
229 |
+
self.num_head = num_head
|
230 |
+
self.head_class_num = codebook_size // self.num_head
|
231 |
+
self.ln_f = nn.LayerNorm(self.n_embd)
|
232 |
+
self.head_list = nn.ModuleList([
|
233 |
+
nn.Linear(self.n_embd, self.head_class_num, bias=False)
|
234 |
+
for _ in range(self.num_head)
|
235 |
+
])
|
236 |
+
|
237 |
+
def get_block_size(self):
|
238 |
+
return self.block_size
|
239 |
+
|
240 |
+
def _init_weights(self, module):
|
241 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
242 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
243 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
244 |
+
module.bias.data.zero_()
|
245 |
+
elif isinstance(module, nn.LayerNorm):
|
246 |
+
module.bias.data.zero_()
|
247 |
+
module.weight.data.fill_(1.0)
|
248 |
+
|
249 |
+
def forward(self, idx, segm_tokens, texture_tokens, t=None):
|
250 |
+
# each index maps to a (learnable) vector
|
251 |
+
token_embeddings = self.tok_emb(idx)
|
252 |
+
segm_embeddings = self.segm_emb(segm_tokens)
|
253 |
+
texture_embeddings = self.texture_emb(texture_tokens)
|
254 |
+
|
255 |
+
if self.causal:
|
256 |
+
token_embeddings = torch.cat((self.start_tok.repeat(
|
257 |
+
token_embeddings.size(0), 1, 1), token_embeddings),
|
258 |
+
dim=1)
|
259 |
+
|
260 |
+
t = token_embeddings.shape[1]
|
261 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
262 |
+
# each position maps to a (learnable) vector
|
263 |
+
|
264 |
+
position_embeddings = self.pos_emb[:, :t, :]
|
265 |
+
|
266 |
+
x = token_embeddings + position_embeddings + segm_embeddings + texture_embeddings
|
267 |
+
x = self.drop(x)
|
268 |
+
for block in self.blocks:
|
269 |
+
x = block(x)
|
270 |
+
x = self.ln_f(x)
|
271 |
+
logits_list = [self.head_list[i](x) for i in range(self.num_head)]
|
272 |
+
|
273 |
+
return logits_list
|
models/archs/unet_arch.py
ADDED
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.utils.checkpoint as cp
|
4 |
+
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
5 |
+
build_norm_layer, build_upsample_layer, constant_init,
|
6 |
+
kaiming_init)
|
7 |
+
from mmcv.runner import load_checkpoint
|
8 |
+
from mmcv.utils.parrots_wrapper import _BatchNorm
|
9 |
+
from mmseg.utils import get_root_logger
|
10 |
+
|
11 |
+
|
12 |
+
class UpConvBlock(nn.Module):
|
13 |
+
"""Upsample convolution block in decoder for UNet.
|
14 |
+
|
15 |
+
This upsample convolution block consists of one upsample module
|
16 |
+
followed by one convolution block. The upsample module expands the
|
17 |
+
high-level low-resolution feature map and the convolution block fuses
|
18 |
+
the upsampled high-level low-resolution feature map and the low-level
|
19 |
+
high-resolution feature map from encoder.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
conv_block (nn.Sequential): Sequential of convolutional layers.
|
23 |
+
in_channels (int): Number of input channels of the high-level
|
24 |
+
skip_channels (int): Number of input channels of the low-level
|
25 |
+
high-resolution feature map from encoder.
|
26 |
+
out_channels (int): Number of output channels.
|
27 |
+
num_convs (int): Number of convolutional layers in the conv_block.
|
28 |
+
Default: 2.
|
29 |
+
stride (int): Stride of convolutional layer in conv_block. Default: 1.
|
30 |
+
dilation (int): Dilation rate of convolutional layer in conv_block.
|
31 |
+
Default: 1.
|
32 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
33 |
+
memory while slowing down the training speed. Default: False.
|
34 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
35 |
+
Default: None.
|
36 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
37 |
+
Default: dict(type='BN').
|
38 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
39 |
+
Default: dict(type='ReLU').
|
40 |
+
upsample_cfg (dict): The upsample config of the upsample module in
|
41 |
+
decoder. Default: dict(type='InterpConv'). If the size of
|
42 |
+
high-level feature map is the same as that of skip feature map
|
43 |
+
(low-level feature map from encoder), it does not need upsample the
|
44 |
+
high-level feature map and the upsample_cfg is None.
|
45 |
+
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
46 |
+
Default: None.
|
47 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self,
|
51 |
+
conv_block,
|
52 |
+
in_channels,
|
53 |
+
skip_channels,
|
54 |
+
out_channels,
|
55 |
+
num_convs=2,
|
56 |
+
stride=1,
|
57 |
+
dilation=1,
|
58 |
+
with_cp=False,
|
59 |
+
conv_cfg=None,
|
60 |
+
norm_cfg=dict(type='BN'),
|
61 |
+
act_cfg=dict(type='ReLU'),
|
62 |
+
upsample_cfg=dict(type='InterpConv'),
|
63 |
+
dcn=None,
|
64 |
+
plugins=None):
|
65 |
+
super(UpConvBlock, self).__init__()
|
66 |
+
assert dcn is None, 'Not implemented yet.'
|
67 |
+
assert plugins is None, 'Not implemented yet.'
|
68 |
+
|
69 |
+
self.conv_block = conv_block(
|
70 |
+
in_channels=2 * skip_channels,
|
71 |
+
out_channels=out_channels,
|
72 |
+
num_convs=num_convs,
|
73 |
+
stride=stride,
|
74 |
+
dilation=dilation,
|
75 |
+
with_cp=with_cp,
|
76 |
+
conv_cfg=conv_cfg,
|
77 |
+
norm_cfg=norm_cfg,
|
78 |
+
act_cfg=act_cfg,
|
79 |
+
dcn=None,
|
80 |
+
plugins=None)
|
81 |
+
if upsample_cfg is not None:
|
82 |
+
self.upsample = build_upsample_layer(
|
83 |
+
cfg=upsample_cfg,
|
84 |
+
in_channels=in_channels,
|
85 |
+
out_channels=skip_channels,
|
86 |
+
with_cp=with_cp,
|
87 |
+
norm_cfg=norm_cfg,
|
88 |
+
act_cfg=act_cfg)
|
89 |
+
else:
|
90 |
+
self.upsample = ConvModule(
|
91 |
+
in_channels,
|
92 |
+
skip_channels,
|
93 |
+
kernel_size=1,
|
94 |
+
stride=1,
|
95 |
+
padding=0,
|
96 |
+
conv_cfg=conv_cfg,
|
97 |
+
norm_cfg=norm_cfg,
|
98 |
+
act_cfg=act_cfg)
|
99 |
+
|
100 |
+
def forward(self, skip, x):
|
101 |
+
"""Forward function."""
|
102 |
+
|
103 |
+
x = self.upsample(x)
|
104 |
+
out = torch.cat([skip, x], dim=1)
|
105 |
+
out = self.conv_block(out)
|
106 |
+
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
class BasicConvBlock(nn.Module):
|
111 |
+
"""Basic convolutional block for UNet.
|
112 |
+
|
113 |
+
This module consists of several plain convolutional layers.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
in_channels (int): Number of input channels.
|
117 |
+
out_channels (int): Number of output channels.
|
118 |
+
num_convs (int): Number of convolutional layers. Default: 2.
|
119 |
+
stride (int): Whether use stride convolution to downsample
|
120 |
+
the input feature map. If stride=2, it only uses stride convolution
|
121 |
+
in the first convolutional layer to downsample the input feature
|
122 |
+
map. Options are 1 or 2. Default: 1.
|
123 |
+
dilation (int): Whether use dilated convolution to expand the
|
124 |
+
receptive field. Set dilation rate of each convolutional layer and
|
125 |
+
the dilation rate of the first convolutional layer is always 1.
|
126 |
+
Default: 1.
|
127 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
128 |
+
memory while slowing down the training speed. Default: False.
|
129 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
130 |
+
Default: None.
|
131 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
132 |
+
Default: dict(type='BN').
|
133 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
134 |
+
Default: dict(type='ReLU').
|
135 |
+
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
136 |
+
Default: None.
|
137 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self,
|
141 |
+
in_channels,
|
142 |
+
out_channels,
|
143 |
+
num_convs=2,
|
144 |
+
stride=1,
|
145 |
+
dilation=1,
|
146 |
+
with_cp=False,
|
147 |
+
conv_cfg=None,
|
148 |
+
norm_cfg=dict(type='BN'),
|
149 |
+
act_cfg=dict(type='ReLU'),
|
150 |
+
dcn=None,
|
151 |
+
plugins=None):
|
152 |
+
super(BasicConvBlock, self).__init__()
|
153 |
+
assert dcn is None, 'Not implemented yet.'
|
154 |
+
assert plugins is None, 'Not implemented yet.'
|
155 |
+
|
156 |
+
self.with_cp = with_cp
|
157 |
+
convs = []
|
158 |
+
for i in range(num_convs):
|
159 |
+
convs.append(
|
160 |
+
ConvModule(
|
161 |
+
in_channels=in_channels if i == 0 else out_channels,
|
162 |
+
out_channels=out_channels,
|
163 |
+
kernel_size=3,
|
164 |
+
stride=stride if i == 0 else 1,
|
165 |
+
dilation=1 if i == 0 else dilation,
|
166 |
+
padding=1 if i == 0 else dilation,
|
167 |
+
conv_cfg=conv_cfg,
|
168 |
+
norm_cfg=norm_cfg,
|
169 |
+
act_cfg=act_cfg))
|
170 |
+
|
171 |
+
self.convs = nn.Sequential(*convs)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
"""Forward function."""
|
175 |
+
|
176 |
+
if self.with_cp and x.requires_grad:
|
177 |
+
out = cp.checkpoint(self.convs, x)
|
178 |
+
else:
|
179 |
+
out = self.convs(x)
|
180 |
+
return out
|
181 |
+
|
182 |
+
|
183 |
+
class DeconvModule(nn.Module):
|
184 |
+
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
185 |
+
|
186 |
+
This module uses deconvolution to upsample feature map in the decoder
|
187 |
+
of UNet.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
in_channels (int): Number of input channels.
|
191 |
+
out_channels (int): Number of output channels.
|
192 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
193 |
+
memory while slowing down the training speed. Default: False.
|
194 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
195 |
+
Default: dict(type='BN').
|
196 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
197 |
+
Default: dict(type='ReLU').
|
198 |
+
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self,
|
202 |
+
in_channels,
|
203 |
+
out_channels,
|
204 |
+
with_cp=False,
|
205 |
+
norm_cfg=dict(type='BN'),
|
206 |
+
act_cfg=dict(type='ReLU'),
|
207 |
+
*,
|
208 |
+
kernel_size=4,
|
209 |
+
scale_factor=2):
|
210 |
+
super(DeconvModule, self).__init__()
|
211 |
+
|
212 |
+
assert (kernel_size - scale_factor >= 0) and\
|
213 |
+
(kernel_size - scale_factor) % 2 == 0,\
|
214 |
+
f'kernel_size should be greater than or equal to scale_factor '\
|
215 |
+
f'and (kernel_size - scale_factor) should be even numbers, '\
|
216 |
+
f'while the kernel size is {kernel_size} and scale_factor is '\
|
217 |
+
f'{scale_factor}.'
|
218 |
+
|
219 |
+
stride = scale_factor
|
220 |
+
padding = (kernel_size - scale_factor) // 2
|
221 |
+
self.with_cp = with_cp
|
222 |
+
deconv = nn.ConvTranspose2d(
|
223 |
+
in_channels,
|
224 |
+
out_channels,
|
225 |
+
kernel_size=kernel_size,
|
226 |
+
stride=stride,
|
227 |
+
padding=padding)
|
228 |
+
|
229 |
+
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
|
230 |
+
activate = build_activation_layer(act_cfg)
|
231 |
+
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
"""Forward function."""
|
235 |
+
|
236 |
+
if self.with_cp and x.requires_grad:
|
237 |
+
out = cp.checkpoint(self.deconv_upsamping, x)
|
238 |
+
else:
|
239 |
+
out = self.deconv_upsamping(x)
|
240 |
+
return out
|
241 |
+
|
242 |
+
|
243 |
+
@UPSAMPLE_LAYERS.register_module()
|
244 |
+
class InterpConv(nn.Module):
|
245 |
+
"""Interpolation upsample module in decoder for UNet.
|
246 |
+
|
247 |
+
This module uses interpolation to upsample feature map in the decoder
|
248 |
+
of UNet. It consists of one interpolation upsample layer and one
|
249 |
+
convolutional layer. It can be one interpolation upsample layer followed
|
250 |
+
by one convolutional layer (conv_first=False) or one convolutional layer
|
251 |
+
followed by one interpolation upsample layer (conv_first=True).
|
252 |
+
|
253 |
+
Args:
|
254 |
+
in_channels (int): Number of input channels.
|
255 |
+
out_channels (int): Number of output channels.
|
256 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
257 |
+
memory while slowing down the training speed. Default: False.
|
258 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
259 |
+
Default: dict(type='BN').
|
260 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
261 |
+
Default: dict(type='ReLU').
|
262 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
263 |
+
Default: None.
|
264 |
+
conv_first (bool): Whether convolutional layer or interpolation
|
265 |
+
upsample layer first. Default: False. It means interpolation
|
266 |
+
upsample layer followed by one convolutional layer.
|
267 |
+
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
268 |
+
stride (int): Stride of the convolutional layer. Default: 1.
|
269 |
+
padding (int): Padding of the convolutional layer. Default: 1.
|
270 |
+
upsampe_cfg (dict): Interpolation config of the upsample layer.
|
271 |
+
Default: dict(
|
272 |
+
scale_factor=2, mode='bilinear', align_corners=False).
|
273 |
+
"""
|
274 |
+
|
275 |
+
def __init__(self,
|
276 |
+
in_channels,
|
277 |
+
out_channels,
|
278 |
+
with_cp=False,
|
279 |
+
norm_cfg=dict(type='BN'),
|
280 |
+
act_cfg=dict(type='ReLU'),
|
281 |
+
*,
|
282 |
+
conv_cfg=None,
|
283 |
+
conv_first=False,
|
284 |
+
kernel_size=1,
|
285 |
+
stride=1,
|
286 |
+
padding=0,
|
287 |
+
upsampe_cfg=dict(
|
288 |
+
scale_factor=2, mode='bilinear', align_corners=False)):
|
289 |
+
super(InterpConv, self).__init__()
|
290 |
+
|
291 |
+
self.with_cp = with_cp
|
292 |
+
conv = ConvModule(
|
293 |
+
in_channels,
|
294 |
+
out_channels,
|
295 |
+
kernel_size=kernel_size,
|
296 |
+
stride=stride,
|
297 |
+
padding=padding,
|
298 |
+
conv_cfg=conv_cfg,
|
299 |
+
norm_cfg=norm_cfg,
|
300 |
+
act_cfg=act_cfg)
|
301 |
+
upsample = nn.Upsample(**upsampe_cfg)
|
302 |
+
if conv_first:
|
303 |
+
self.interp_upsample = nn.Sequential(conv, upsample)
|
304 |
+
else:
|
305 |
+
self.interp_upsample = nn.Sequential(upsample, conv)
|
306 |
+
|
307 |
+
def forward(self, x):
|
308 |
+
"""Forward function."""
|
309 |
+
|
310 |
+
if self.with_cp and x.requires_grad:
|
311 |
+
out = cp.checkpoint(self.interp_upsample, x)
|
312 |
+
else:
|
313 |
+
out = self.interp_upsample(x)
|
314 |
+
return out
|
315 |
+
|
316 |
+
|
317 |
+
class UNet(nn.Module):
|
318 |
+
"""UNet backbone.
|
319 |
+
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
320 |
+
https://arxiv.org/pdf/1505.04597.pdf
|
321 |
+
|
322 |
+
Args:
|
323 |
+
in_channels (int): Number of input image channels. Default" 3.
|
324 |
+
base_channels (int): Number of base channels of each stage.
|
325 |
+
The output channels of the first stage. Default: 64.
|
326 |
+
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
327 |
+
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
328 |
+
len(strides) is equal to num_stages. Normally the stride of the
|
329 |
+
first stage in encoder is 1. If strides[i]=2, it uses stride
|
330 |
+
convolution to downsample in the correspondence encoder stage.
|
331 |
+
Default: (1, 1, 1, 1, 1).
|
332 |
+
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
333 |
+
convolution block of the correspondence encoder stage.
|
334 |
+
Default: (2, 2, 2, 2, 2).
|
335 |
+
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
336 |
+
convolution block of the correspondence decoder stage.
|
337 |
+
Default: (2, 2, 2, 2).
|
338 |
+
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
339 |
+
feature map after the first stage of encoder
|
340 |
+
(stages: [1, num_stages)). If the correspondence encoder stage use
|
341 |
+
stride convolution (strides[i]=2), it will never use MaxPool to
|
342 |
+
downsample, even downsamples[i-1]=True.
|
343 |
+
Default: (True, True, True, True).
|
344 |
+
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
345 |
+
Default: (1, 1, 1, 1, 1).
|
346 |
+
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
347 |
+
Default: (1, 1, 1, 1).
|
348 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
349 |
+
memory while slowing down the training speed. Default: False.
|
350 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
351 |
+
Default: None.
|
352 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
353 |
+
Default: dict(type='BN').
|
354 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
355 |
+
Default: dict(type='ReLU').
|
356 |
+
upsample_cfg (dict): The upsample config of the upsample module in
|
357 |
+
decoder. Default: dict(type='InterpConv').
|
358 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
359 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
360 |
+
and its variants only. Default: False.
|
361 |
+
dcn (bool): Use deformable convolution in convolutional layer or not.
|
362 |
+
Default: None.
|
363 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
364 |
+
|
365 |
+
Notice:
|
366 |
+
The input image size should be devisible by the whole downsample rate
|
367 |
+
of the encoder. More detail of the whole downsample rate can be found
|
368 |
+
in UNet._check_input_devisible.
|
369 |
+
|
370 |
+
"""
|
371 |
+
|
372 |
+
def __init__(self,
|
373 |
+
in_channels=3,
|
374 |
+
base_channels=64,
|
375 |
+
num_stages=5,
|
376 |
+
strides=(1, 1, 1, 1, 1),
|
377 |
+
enc_num_convs=(2, 2, 2, 2, 2),
|
378 |
+
dec_num_convs=(2, 2, 2, 2),
|
379 |
+
downsamples=(True, True, True, True),
|
380 |
+
enc_dilations=(1, 1, 1, 1, 1),
|
381 |
+
dec_dilations=(1, 1, 1, 1),
|
382 |
+
with_cp=False,
|
383 |
+
conv_cfg=None,
|
384 |
+
norm_cfg=dict(type='BN'),
|
385 |
+
act_cfg=dict(type='ReLU'),
|
386 |
+
upsample_cfg=dict(type='InterpConv'),
|
387 |
+
norm_eval=False,
|
388 |
+
dcn=None,
|
389 |
+
plugins=None):
|
390 |
+
super(UNet, self).__init__()
|
391 |
+
assert dcn is None, 'Not implemented yet.'
|
392 |
+
assert plugins is None, 'Not implemented yet.'
|
393 |
+
assert len(strides) == num_stages, \
|
394 |
+
'The length of strides should be equal to num_stages, '\
|
395 |
+
f'while the strides is {strides}, the length of '\
|
396 |
+
f'strides is {len(strides)}, and the num_stages is '\
|
397 |
+
f'{num_stages}.'
|
398 |
+
assert len(enc_num_convs) == num_stages, \
|
399 |
+
'The length of enc_num_convs should be equal to num_stages, '\
|
400 |
+
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
401 |
+
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
402 |
+
f'{num_stages}.'
|
403 |
+
assert len(dec_num_convs) == (num_stages-1), \
|
404 |
+
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
405 |
+
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
406 |
+
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
407 |
+
f'{num_stages}.'
|
408 |
+
assert len(downsamples) == (num_stages-1), \
|
409 |
+
'The length of downsamples should be equal to (num_stages-1), '\
|
410 |
+
f'while the downsamples is {downsamples}, the length of '\
|
411 |
+
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
412 |
+
f'{num_stages}.'
|
413 |
+
assert len(enc_dilations) == num_stages, \
|
414 |
+
'The length of enc_dilations should be equal to num_stages, '\
|
415 |
+
f'while the enc_dilations is {enc_dilations}, the length of '\
|
416 |
+
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
417 |
+
f'{num_stages}.'
|
418 |
+
assert len(dec_dilations) == (num_stages-1), \
|
419 |
+
'The length of dec_dilations should be equal to (num_stages-1), '\
|
420 |
+
f'while the dec_dilations is {dec_dilations}, the length of '\
|
421 |
+
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
422 |
+
f'{num_stages}.'
|
423 |
+
self.num_stages = num_stages
|
424 |
+
self.strides = strides
|
425 |
+
self.downsamples = downsamples
|
426 |
+
self.norm_eval = norm_eval
|
427 |
+
|
428 |
+
self.encoder = nn.ModuleList()
|
429 |
+
self.decoder = nn.ModuleList()
|
430 |
+
|
431 |
+
for i in range(num_stages):
|
432 |
+
enc_conv_block = []
|
433 |
+
if i != 0:
|
434 |
+
if strides[i] == 1 and downsamples[i - 1]:
|
435 |
+
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
436 |
+
upsample = (strides[i] != 1 or downsamples[i - 1])
|
437 |
+
self.decoder.append(
|
438 |
+
UpConvBlock(
|
439 |
+
conv_block=BasicConvBlock,
|
440 |
+
in_channels=base_channels * 2**i,
|
441 |
+
skip_channels=base_channels * 2**(i - 1),
|
442 |
+
out_channels=base_channels * 2**(i - 1),
|
443 |
+
num_convs=dec_num_convs[i - 1],
|
444 |
+
stride=1,
|
445 |
+
dilation=dec_dilations[i - 1],
|
446 |
+
with_cp=with_cp,
|
447 |
+
conv_cfg=conv_cfg,
|
448 |
+
norm_cfg=norm_cfg,
|
449 |
+
act_cfg=act_cfg,
|
450 |
+
upsample_cfg=upsample_cfg if upsample else None,
|
451 |
+
dcn=None,
|
452 |
+
plugins=None))
|
453 |
+
|
454 |
+
enc_conv_block.append(
|
455 |
+
BasicConvBlock(
|
456 |
+
in_channels=in_channels,
|
457 |
+
out_channels=base_channels * 2**i,
|
458 |
+
num_convs=enc_num_convs[i],
|
459 |
+
stride=strides[i],
|
460 |
+
dilation=enc_dilations[i],
|
461 |
+
with_cp=with_cp,
|
462 |
+
conv_cfg=conv_cfg,
|
463 |
+
norm_cfg=norm_cfg,
|
464 |
+
act_cfg=act_cfg,
|
465 |
+
dcn=None,
|
466 |
+
plugins=None))
|
467 |
+
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
468 |
+
in_channels = base_channels * 2**i
|
469 |
+
|
470 |
+
def forward(self, x):
|
471 |
+
enc_outs = []
|
472 |
+
|
473 |
+
for enc in self.encoder:
|
474 |
+
x = enc(x)
|
475 |
+
enc_outs.append(x)
|
476 |
+
dec_outs = [x]
|
477 |
+
for i in reversed(range(len(self.decoder))):
|
478 |
+
x = self.decoder[i](enc_outs[i], x)
|
479 |
+
dec_outs.append(x)
|
480 |
+
|
481 |
+
return dec_outs
|
482 |
+
|
483 |
+
def init_weights(self, pretrained=None):
|
484 |
+
"""Initialize the weights in backbone.
|
485 |
+
|
486 |
+
Args:
|
487 |
+
pretrained (str, optional): Path to pre-trained weights.
|
488 |
+
Defaults to None.
|
489 |
+
"""
|
490 |
+
if isinstance(pretrained, str):
|
491 |
+
logger = get_root_logger()
|
492 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
493 |
+
elif pretrained is None:
|
494 |
+
for m in self.modules():
|
495 |
+
if isinstance(m, nn.Conv2d):
|
496 |
+
kaiming_init(m)
|
497 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
498 |
+
constant_init(m, 1)
|
499 |
+
else:
|
500 |
+
raise TypeError('pretrained must be a str or None')
|
501 |
+
|
502 |
+
|
503 |
+
class ShapeUNet(nn.Module):
|
504 |
+
"""ShapeUNet backbone with small modifications.
|
505 |
+
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
506 |
+
https://arxiv.org/pdf/1505.04597.pdf
|
507 |
+
|
508 |
+
Args:
|
509 |
+
in_channels (int): Number of input image channels. Default" 3.
|
510 |
+
base_channels (int): Number of base channels of each stage.
|
511 |
+
The output channels of the first stage. Default: 64.
|
512 |
+
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
513 |
+
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
514 |
+
len(strides) is equal to num_stages. Normally the stride of the
|
515 |
+
first stage in encoder is 1. If strides[i]=2, it uses stride
|
516 |
+
convolution to downsample in the correspondance encoder stage.
|
517 |
+
Default: (1, 1, 1, 1, 1).
|
518 |
+
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
519 |
+
convolution block of the correspondance encoder stage.
|
520 |
+
Default: (2, 2, 2, 2, 2).
|
521 |
+
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
522 |
+
convolution block of the correspondance decoder stage.
|
523 |
+
Default: (2, 2, 2, 2).
|
524 |
+
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
525 |
+
feature map after the first stage of encoder
|
526 |
+
(stages: [1, num_stages)). If the correspondance encoder stage use
|
527 |
+
stride convolution (strides[i]=2), it will never use MaxPool to
|
528 |
+
downsample, even downsamples[i-1]=True.
|
529 |
+
Default: (True, True, True, True).
|
530 |
+
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
531 |
+
Default: (1, 1, 1, 1, 1).
|
532 |
+
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
533 |
+
Default: (1, 1, 1, 1).
|
534 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
535 |
+
memory while slowing down the training speed. Default: False.
|
536 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
537 |
+
Default: None.
|
538 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
539 |
+
Default: dict(type='BN').
|
540 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
541 |
+
Default: dict(type='ReLU').
|
542 |
+
upsample_cfg (dict): The upsample config of the upsample module in
|
543 |
+
decoder. Default: dict(type='InterpConv').
|
544 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
545 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
546 |
+
and its variants only. Default: False.
|
547 |
+
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
548 |
+
Default: None.
|
549 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
550 |
+
|
551 |
+
Notice:
|
552 |
+
The input image size should be devisible by the whole downsample rate
|
553 |
+
of the encoder. More detail of the whole downsample rate can be found
|
554 |
+
in UNet._check_input_devisible.
|
555 |
+
|
556 |
+
"""
|
557 |
+
|
558 |
+
def __init__(self,
|
559 |
+
in_channels=3,
|
560 |
+
base_channels=64,
|
561 |
+
num_stages=5,
|
562 |
+
attr_embedding=128,
|
563 |
+
strides=(1, 1, 1, 1, 1),
|
564 |
+
enc_num_convs=(2, 2, 2, 2, 2),
|
565 |
+
dec_num_convs=(2, 2, 2, 2),
|
566 |
+
downsamples=(True, True, True, True),
|
567 |
+
enc_dilations=(1, 1, 1, 1, 1),
|
568 |
+
dec_dilations=(1, 1, 1, 1),
|
569 |
+
with_cp=False,
|
570 |
+
conv_cfg=None,
|
571 |
+
norm_cfg=dict(type='BN'),
|
572 |
+
act_cfg=dict(type='ReLU'),
|
573 |
+
upsample_cfg=dict(type='InterpConv'),
|
574 |
+
norm_eval=False,
|
575 |
+
dcn=None,
|
576 |
+
plugins=None):
|
577 |
+
super(ShapeUNet, self).__init__()
|
578 |
+
assert dcn is None, 'Not implemented yet.'
|
579 |
+
assert plugins is None, 'Not implemented yet.'
|
580 |
+
assert len(strides) == num_stages, \
|
581 |
+
'The length of strides should be equal to num_stages, '\
|
582 |
+
f'while the strides is {strides}, the length of '\
|
583 |
+
f'strides is {len(strides)}, and the num_stages is '\
|
584 |
+
f'{num_stages}.'
|
585 |
+
assert len(enc_num_convs) == num_stages, \
|
586 |
+
'The length of enc_num_convs should be equal to num_stages, '\
|
587 |
+
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
588 |
+
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
589 |
+
f'{num_stages}.'
|
590 |
+
assert len(dec_num_convs) == (num_stages-1), \
|
591 |
+
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
592 |
+
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
593 |
+
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
594 |
+
f'{num_stages}.'
|
595 |
+
assert len(downsamples) == (num_stages-1), \
|
596 |
+
'The length of downsamples should be equal to (num_stages-1), '\
|
597 |
+
f'while the downsamples is {downsamples}, the length of '\
|
598 |
+
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
599 |
+
f'{num_stages}.'
|
600 |
+
assert len(enc_dilations) == num_stages, \
|
601 |
+
'The length of enc_dilations should be equal to num_stages, '\
|
602 |
+
f'while the enc_dilations is {enc_dilations}, the length of '\
|
603 |
+
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
604 |
+
f'{num_stages}.'
|
605 |
+
assert len(dec_dilations) == (num_stages-1), \
|
606 |
+
'The length of dec_dilations should be equal to (num_stages-1), '\
|
607 |
+
f'while the dec_dilations is {dec_dilations}, the length of '\
|
608 |
+
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
609 |
+
f'{num_stages}.'
|
610 |
+
self.num_stages = num_stages
|
611 |
+
self.strides = strides
|
612 |
+
self.downsamples = downsamples
|
613 |
+
self.norm_eval = norm_eval
|
614 |
+
|
615 |
+
self.encoder = nn.ModuleList()
|
616 |
+
self.decoder = nn.ModuleList()
|
617 |
+
|
618 |
+
for i in range(num_stages):
|
619 |
+
enc_conv_block = []
|
620 |
+
if i != 0:
|
621 |
+
if strides[i] == 1 and downsamples[i - 1]:
|
622 |
+
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
623 |
+
upsample = (strides[i] != 1 or downsamples[i - 1])
|
624 |
+
self.decoder.append(
|
625 |
+
UpConvBlock(
|
626 |
+
conv_block=BasicConvBlock,
|
627 |
+
in_channels=base_channels * 2**i,
|
628 |
+
skip_channels=base_channels * 2**(i - 1),
|
629 |
+
out_channels=base_channels * 2**(i - 1),
|
630 |
+
num_convs=dec_num_convs[i - 1],
|
631 |
+
stride=1,
|
632 |
+
dilation=dec_dilations[i - 1],
|
633 |
+
with_cp=with_cp,
|
634 |
+
conv_cfg=conv_cfg,
|
635 |
+
norm_cfg=norm_cfg,
|
636 |
+
act_cfg=act_cfg,
|
637 |
+
upsample_cfg=upsample_cfg if upsample else None,
|
638 |
+
dcn=None,
|
639 |
+
plugins=None))
|
640 |
+
|
641 |
+
enc_conv_block.append(
|
642 |
+
BasicConvBlock(
|
643 |
+
in_channels=in_channels + attr_embedding,
|
644 |
+
out_channels=base_channels * 2**i,
|
645 |
+
num_convs=enc_num_convs[i],
|
646 |
+
stride=strides[i],
|
647 |
+
dilation=enc_dilations[i],
|
648 |
+
with_cp=with_cp,
|
649 |
+
conv_cfg=conv_cfg,
|
650 |
+
norm_cfg=norm_cfg,
|
651 |
+
act_cfg=act_cfg,
|
652 |
+
dcn=None,
|
653 |
+
plugins=None))
|
654 |
+
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
655 |
+
in_channels = base_channels * 2**i
|
656 |
+
|
657 |
+
def forward(self, x, attr_embedding):
|
658 |
+
enc_outs = []
|
659 |
+
Be, Ce = attr_embedding.size()
|
660 |
+
for enc in self.encoder:
|
661 |
+
_, _, H, W = x.size()
|
662 |
+
x = enc(
|
663 |
+
torch.cat([
|
664 |
+
x,
|
665 |
+
attr_embedding.view(Be, Ce, 1, 1).expand((Be, Ce, H, W))
|
666 |
+
],
|
667 |
+
dim=1))
|
668 |
+
enc_outs.append(x)
|
669 |
+
dec_outs = [x]
|
670 |
+
for i in reversed(range(len(self.decoder))):
|
671 |
+
x = self.decoder[i](enc_outs[i], x)
|
672 |
+
dec_outs.append(x)
|
673 |
+
|
674 |
+
return dec_outs
|
675 |
+
|
676 |
+
def init_weights(self, pretrained=None):
|
677 |
+
"""Initialize the weights in backbone.
|
678 |
+
|
679 |
+
Args:
|
680 |
+
pretrained (str, optional): Path to pre-trained weights.
|
681 |
+
Defaults to None.
|
682 |
+
"""
|
683 |
+
if isinstance(pretrained, str):
|
684 |
+
logger = get_root_logger()
|
685 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
686 |
+
elif pretrained is None:
|
687 |
+
for m in self.modules():
|
688 |
+
if isinstance(m, nn.Conv2d):
|
689 |
+
kaiming_init(m)
|
690 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
691 |
+
constant_init(m, 1)
|
692 |
+
else:
|
693 |
+
raise TypeError('pretrained must be a str or None')
|
models/archs/vqgan_arch.py
ADDED
@@ -0,0 +1,1203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_diffusion + derived encoder decoder
|
2 |
+
import math
|
3 |
+
from urllib.request import proxy_bypass
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
|
12 |
+
class VectorQuantizer(nn.Module):
|
13 |
+
"""
|
14 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
15 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
16 |
+
"""
|
17 |
+
|
18 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
19 |
+
# backwards compatibility we use the buggy version by default, but you can
|
20 |
+
# specify legacy=False to fix it.
|
21 |
+
def __init__(self,
|
22 |
+
n_e,
|
23 |
+
e_dim,
|
24 |
+
beta,
|
25 |
+
remap=None,
|
26 |
+
unknown_index="random",
|
27 |
+
sane_index_shape=False,
|
28 |
+
legacy=True):
|
29 |
+
super().__init__()
|
30 |
+
self.n_e = n_e
|
31 |
+
self.e_dim = e_dim
|
32 |
+
self.beta = beta
|
33 |
+
self.legacy = legacy
|
34 |
+
|
35 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
36 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
37 |
+
|
38 |
+
self.remap = remap
|
39 |
+
if self.remap is not None:
|
40 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
41 |
+
self.re_embed = self.used.shape[0]
|
42 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
43 |
+
if self.unknown_index == "extra":
|
44 |
+
self.unknown_index = self.re_embed
|
45 |
+
self.re_embed = self.re_embed + 1
|
46 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
47 |
+
f"Using {self.unknown_index} for unknown indices.")
|
48 |
+
else:
|
49 |
+
self.re_embed = n_e
|
50 |
+
|
51 |
+
self.sane_index_shape = sane_index_shape
|
52 |
+
|
53 |
+
def remap_to_used(self, inds):
|
54 |
+
ishape = inds.shape
|
55 |
+
assert len(ishape) > 1
|
56 |
+
inds = inds.reshape(ishape[0], -1)
|
57 |
+
used = self.used.to(inds)
|
58 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
59 |
+
new = match.argmax(-1)
|
60 |
+
unknown = match.sum(2) < 1
|
61 |
+
if self.unknown_index == "random":
|
62 |
+
new[unknown] = torch.randint(
|
63 |
+
0, self.re_embed,
|
64 |
+
size=new[unknown].shape).to(device=new.device)
|
65 |
+
else:
|
66 |
+
new[unknown] = self.unknown_index
|
67 |
+
return new.reshape(ishape)
|
68 |
+
|
69 |
+
def unmap_to_all(self, inds):
|
70 |
+
ishape = inds.shape
|
71 |
+
assert len(ishape) > 1
|
72 |
+
inds = inds.reshape(ishape[0], -1)
|
73 |
+
used = self.used.to(inds)
|
74 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
75 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
76 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
77 |
+
return back.reshape(ishape)
|
78 |
+
|
79 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
80 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
81 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
82 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
83 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
84 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
85 |
+
z_flattened = z.view(-1, self.e_dim)
|
86 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
87 |
+
|
88 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
89 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
90 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
91 |
+
|
92 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
93 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
94 |
+
perplexity = None
|
95 |
+
min_encodings = None
|
96 |
+
|
97 |
+
# compute loss for embedding
|
98 |
+
if not self.legacy:
|
99 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
100 |
+
torch.mean((z_q - z.detach()) ** 2)
|
101 |
+
else:
|
102 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
103 |
+
torch.mean((z_q - z.detach()) ** 2)
|
104 |
+
|
105 |
+
# preserve gradients
|
106 |
+
z_q = z + (z_q - z).detach()
|
107 |
+
|
108 |
+
# reshape back to match original input shape
|
109 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
110 |
+
|
111 |
+
if self.remap is not None:
|
112 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
113 |
+
z.shape[0], -1) # add batch axis
|
114 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
115 |
+
min_encoding_indices = min_encoding_indices.reshape(-1,
|
116 |
+
1) # flatten
|
117 |
+
|
118 |
+
if self.sane_index_shape:
|
119 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
120 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
121 |
+
|
122 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
123 |
+
|
124 |
+
def get_codebook_entry(self, indices, shape):
|
125 |
+
# shape specifying (batch, height, width, channel)
|
126 |
+
if self.remap is not None:
|
127 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
128 |
+
indices = self.unmap_to_all(indices)
|
129 |
+
indices = indices.reshape(-1) # flatten again
|
130 |
+
|
131 |
+
# get quantized latent vectors
|
132 |
+
z_q = self.embedding(indices)
|
133 |
+
|
134 |
+
if shape is not None:
|
135 |
+
z_q = z_q.view(shape)
|
136 |
+
# reshape back to match original input shape
|
137 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
138 |
+
|
139 |
+
return z_q
|
140 |
+
|
141 |
+
|
142 |
+
class VectorQuantizerTexture(nn.Module):
|
143 |
+
"""
|
144 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
145 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
146 |
+
"""
|
147 |
+
|
148 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
149 |
+
# backwards compatibility we use the buggy version by default, but you can
|
150 |
+
# specify legacy=False to fix it.
|
151 |
+
def __init__(self,
|
152 |
+
n_e,
|
153 |
+
e_dim,
|
154 |
+
beta,
|
155 |
+
remap=None,
|
156 |
+
unknown_index="random",
|
157 |
+
sane_index_shape=False,
|
158 |
+
legacy=True):
|
159 |
+
super().__init__()
|
160 |
+
self.n_e = n_e
|
161 |
+
self.e_dim = e_dim
|
162 |
+
self.beta = beta
|
163 |
+
self.legacy = legacy
|
164 |
+
|
165 |
+
# TODO: decide number of embeddings
|
166 |
+
self.embedding_list = nn.ModuleList(
|
167 |
+
[nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
|
168 |
+
for embedding in self.embedding_list:
|
169 |
+
embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
170 |
+
|
171 |
+
self.remap = remap
|
172 |
+
if self.remap is not None:
|
173 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
174 |
+
self.re_embed = self.used.shape[0]
|
175 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
176 |
+
if self.unknown_index == "extra":
|
177 |
+
self.unknown_index = self.re_embed
|
178 |
+
self.re_embed = self.re_embed + 1
|
179 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
180 |
+
f"Using {self.unknown_index} for unknown indices.")
|
181 |
+
else:
|
182 |
+
self.re_embed = n_e
|
183 |
+
|
184 |
+
self.sane_index_shape = sane_index_shape
|
185 |
+
|
186 |
+
def remap_to_used(self, inds):
|
187 |
+
ishape = inds.shape
|
188 |
+
assert len(ishape) > 1
|
189 |
+
inds = inds.reshape(ishape[0], -1)
|
190 |
+
used = self.used.to(inds)
|
191 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
192 |
+
new = match.argmax(-1)
|
193 |
+
unknown = match.sum(2) < 1
|
194 |
+
if self.unknown_index == "random":
|
195 |
+
new[unknown] = torch.randint(
|
196 |
+
0, self.re_embed,
|
197 |
+
size=new[unknown].shape).to(device=new.device)
|
198 |
+
else:
|
199 |
+
new[unknown] = self.unknown_index
|
200 |
+
return new.reshape(ishape)
|
201 |
+
|
202 |
+
def unmap_to_all(self, inds):
|
203 |
+
ishape = inds.shape
|
204 |
+
assert len(ishape) > 1
|
205 |
+
inds = inds.reshape(ishape[0], -1)
|
206 |
+
used = self.used.to(inds)
|
207 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
208 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
209 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
210 |
+
return back.reshape(ishape)
|
211 |
+
|
212 |
+
def forward(self,
|
213 |
+
z,
|
214 |
+
segm_map,
|
215 |
+
temp=None,
|
216 |
+
rescale_logits=False,
|
217 |
+
return_logits=False):
|
218 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
219 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
220 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
221 |
+
|
222 |
+
segm_map = F.interpolate(segm_map, size=z.size()[2:], mode='nearest')
|
223 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
224 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
225 |
+
z_flattened = z.view(-1, self.e_dim)
|
226 |
+
|
227 |
+
# flatten segm_map (b, h, w)
|
228 |
+
segm_map_flatten = segm_map.view(-1)
|
229 |
+
|
230 |
+
z_q = torch.zeros_like(z_flattened)
|
231 |
+
min_encoding_indices_list = []
|
232 |
+
min_encoding_indices_continual = torch.full(
|
233 |
+
segm_map_flatten.size(),
|
234 |
+
fill_value=-1,
|
235 |
+
dtype=torch.long,
|
236 |
+
device=segm_map_flatten.device)
|
237 |
+
for codebook_idx in range(18):
|
238 |
+
min_encoding_indices = torch.full(
|
239 |
+
segm_map_flatten.size(),
|
240 |
+
fill_value=-1,
|
241 |
+
dtype=torch.long,
|
242 |
+
device=segm_map_flatten.device)
|
243 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
244 |
+
z_selected = z_flattened[segm_map_flatten == codebook_idx]
|
245 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
246 |
+
d_selected = torch.sum(
|
247 |
+
z_selected**2, dim=1, keepdim=True) + torch.sum(
|
248 |
+
self.embedding_list[codebook_idx].weight**2,
|
249 |
+
dim=1) - 2 * torch.einsum(
|
250 |
+
'bd,dn->bn', z_selected,
|
251 |
+
rearrange(self.embedding_list[codebook_idx].weight,
|
252 |
+
'n d -> d n'))
|
253 |
+
min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
|
254 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
255 |
+
min_encoding_indices_selected)
|
256 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
257 |
+
min_encoding_indices[
|
258 |
+
segm_map_flatten ==
|
259 |
+
codebook_idx] = min_encoding_indices_selected
|
260 |
+
min_encoding_indices_continual[
|
261 |
+
segm_map_flatten ==
|
262 |
+
codebook_idx] = min_encoding_indices_selected + 1024 * codebook_idx
|
263 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
264 |
+
z.shape[0], z.shape[1], z.shape[2])
|
265 |
+
min_encoding_indices_list.append(min_encoding_indices)
|
266 |
+
|
267 |
+
min_encoding_indices_continual = min_encoding_indices_continual.reshape(
|
268 |
+
z.shape[0], z.shape[1], z.shape[2])
|
269 |
+
z_q = z_q.view(z.shape)
|
270 |
+
perplexity = None
|
271 |
+
|
272 |
+
# compute loss for embedding
|
273 |
+
if not self.legacy:
|
274 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
275 |
+
torch.mean((z_q - z.detach()) ** 2)
|
276 |
+
else:
|
277 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
278 |
+
torch.mean((z_q - z.detach()) ** 2)
|
279 |
+
|
280 |
+
# preserve gradients
|
281 |
+
z_q = z + (z_q - z).detach()
|
282 |
+
|
283 |
+
# reshape back to match original input shape
|
284 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
285 |
+
|
286 |
+
return z_q, loss, (perplexity, min_encoding_indices_continual,
|
287 |
+
min_encoding_indices_list)
|
288 |
+
|
289 |
+
def get_codebook_entry(self, indices_list, segm_map, shape):
|
290 |
+
# flatten segm_map (b, h, w)
|
291 |
+
segm_map = F.interpolate(
|
292 |
+
segm_map, size=(shape[1], shape[2]), mode='nearest')
|
293 |
+
segm_map_flatten = segm_map.view(-1)
|
294 |
+
|
295 |
+
z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
|
296 |
+
self.e_dim).to(segm_map.device)
|
297 |
+
for codebook_idx in range(18):
|
298 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
299 |
+
min_encoding_indices_selected = indices_list[
|
300 |
+
codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
|
301 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
302 |
+
min_encoding_indices_selected)
|
303 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
304 |
+
|
305 |
+
z_q = z_q.view(shape)
|
306 |
+
# reshape back to match original input shape
|
307 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
308 |
+
|
309 |
+
return z_q
|
310 |
+
|
311 |
+
|
312 |
+
def sample_patches(inputs, patch_size=3, stride=1):
|
313 |
+
"""Extract sliding local patches from an input feature tensor.
|
314 |
+
The sampled pathes are row-major.
|
315 |
+
Args:
|
316 |
+
inputs (Tensor): the input feature maps, shape: (n, c, h, w).
|
317 |
+
patch_size (int): the spatial size of sampled patches. Default: 3.
|
318 |
+
stride (int): the stride of sampling. Default: 1.
|
319 |
+
Returns:
|
320 |
+
patches (Tensor): extracted patches, shape: (n, c * patch_size *
|
321 |
+
patch_size, n_patches).
|
322 |
+
"""
|
323 |
+
|
324 |
+
patches = F.unfold(inputs, (patch_size, patch_size), stride=stride)
|
325 |
+
|
326 |
+
return patches
|
327 |
+
|
328 |
+
|
329 |
+
class VectorQuantizerSpatialTextureAware(nn.Module):
|
330 |
+
"""
|
331 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
332 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
333 |
+
"""
|
334 |
+
|
335 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
336 |
+
# backwards compatibility we use the buggy version by default, but you can
|
337 |
+
# specify legacy=False to fix it.
|
338 |
+
def __init__(self,
|
339 |
+
n_e,
|
340 |
+
e_dim,
|
341 |
+
beta,
|
342 |
+
spatial_size,
|
343 |
+
remap=None,
|
344 |
+
unknown_index="random",
|
345 |
+
sane_index_shape=False,
|
346 |
+
legacy=True):
|
347 |
+
super().__init__()
|
348 |
+
self.n_e = n_e
|
349 |
+
self.e_dim = e_dim * spatial_size * spatial_size
|
350 |
+
self.beta = beta
|
351 |
+
self.legacy = legacy
|
352 |
+
self.spatial_size = spatial_size
|
353 |
+
|
354 |
+
# TODO: decide number of embeddings
|
355 |
+
self.embedding_list = nn.ModuleList(
|
356 |
+
[nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
|
357 |
+
for embedding in self.embedding_list:
|
358 |
+
embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
359 |
+
|
360 |
+
self.remap = remap
|
361 |
+
if self.remap is not None:
|
362 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
363 |
+
self.re_embed = self.used.shape[0]
|
364 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
365 |
+
if self.unknown_index == "extra":
|
366 |
+
self.unknown_index = self.re_embed
|
367 |
+
self.re_embed = self.re_embed + 1
|
368 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
369 |
+
f"Using {self.unknown_index} for unknown indices.")
|
370 |
+
else:
|
371 |
+
self.re_embed = n_e
|
372 |
+
|
373 |
+
self.sane_index_shape = sane_index_shape
|
374 |
+
|
375 |
+
def forward(self,
|
376 |
+
z,
|
377 |
+
segm_map,
|
378 |
+
temp=None,
|
379 |
+
rescale_logits=False,
|
380 |
+
return_logits=False):
|
381 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
382 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
383 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
384 |
+
|
385 |
+
segm_map = F.interpolate(
|
386 |
+
segm_map,
|
387 |
+
size=(z.size(2) // self.spatial_size,
|
388 |
+
z.size(3) // self.spatial_size),
|
389 |
+
mode='nearest')
|
390 |
+
|
391 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
392 |
+
# z = rearrange(z, 'b c h w -> b h w c').contiguous() ?
|
393 |
+
z_patches = sample_patches(
|
394 |
+
z, patch_size=self.spatial_size,
|
395 |
+
stride=self.spatial_size).permute(0, 2, 1)
|
396 |
+
z_patches_flattened = z_patches.reshape(-1, self.e_dim)
|
397 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
398 |
+
|
399 |
+
# flatten segm_map (b, h, w)
|
400 |
+
segm_map_flatten = segm_map.view(-1)
|
401 |
+
|
402 |
+
z_q = torch.zeros_like(z_patches_flattened)
|
403 |
+
min_encoding_indices_list = []
|
404 |
+
min_encoding_indices_continual = torch.full(
|
405 |
+
segm_map_flatten.size(),
|
406 |
+
fill_value=-1,
|
407 |
+
dtype=torch.long,
|
408 |
+
device=segm_map_flatten.device)
|
409 |
+
|
410 |
+
for codebook_idx in range(18):
|
411 |
+
min_encoding_indices = torch.full(
|
412 |
+
segm_map_flatten.size(),
|
413 |
+
fill_value=-1,
|
414 |
+
dtype=torch.long,
|
415 |
+
device=segm_map_flatten.device)
|
416 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
417 |
+
z_selected = z_patches_flattened[segm_map_flatten ==
|
418 |
+
codebook_idx]
|
419 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
420 |
+
d_selected = torch.sum(
|
421 |
+
z_selected**2, dim=1, keepdim=True) + torch.sum(
|
422 |
+
self.embedding_list[codebook_idx].weight**2,
|
423 |
+
dim=1) - 2 * torch.einsum(
|
424 |
+
'bd,dn->bn', z_selected,
|
425 |
+
rearrange(self.embedding_list[codebook_idx].weight,
|
426 |
+
'n d -> d n'))
|
427 |
+
min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
|
428 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
429 |
+
min_encoding_indices_selected)
|
430 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
431 |
+
min_encoding_indices[
|
432 |
+
segm_map_flatten ==
|
433 |
+
codebook_idx] = min_encoding_indices_selected
|
434 |
+
min_encoding_indices_continual[
|
435 |
+
segm_map_flatten ==
|
436 |
+
codebook_idx] = min_encoding_indices_selected + self.n_e * codebook_idx
|
437 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
438 |
+
z_patches.shape[0], segm_map.shape[2], segm_map.shape[3])
|
439 |
+
min_encoding_indices_list.append(min_encoding_indices)
|
440 |
+
|
441 |
+
z_q = F.fold(
|
442 |
+
z_q.view(z_patches.shape).permute(0, 2, 1),
|
443 |
+
z.size()[2:],
|
444 |
+
kernel_size=(self.spatial_size, self.spatial_size),
|
445 |
+
stride=self.spatial_size)
|
446 |
+
|
447 |
+
perplexity = None
|
448 |
+
|
449 |
+
# compute loss for embedding
|
450 |
+
if not self.legacy:
|
451 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
452 |
+
torch.mean((z_q - z.detach()) ** 2)
|
453 |
+
else:
|
454 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
455 |
+
torch.mean((z_q - z.detach()) ** 2)
|
456 |
+
|
457 |
+
# preserve gradients
|
458 |
+
z_q = z + (z_q - z).detach()
|
459 |
+
|
460 |
+
return z_q, loss, (perplexity, min_encoding_indices_continual,
|
461 |
+
min_encoding_indices_list)
|
462 |
+
|
463 |
+
def get_codebook_entry(self, indices_list, segm_map, shape):
|
464 |
+
# flatten segm_map (b, h, w)
|
465 |
+
segm_map = F.interpolate(
|
466 |
+
segm_map, size=(shape[1], shape[2]), mode='nearest')
|
467 |
+
segm_map_flatten = segm_map.view(-1)
|
468 |
+
|
469 |
+
z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
|
470 |
+
self.e_dim).to(segm_map.device)
|
471 |
+
for codebook_idx in range(18):
|
472 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
473 |
+
min_encoding_indices_selected = indices_list[
|
474 |
+
codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
|
475 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
476 |
+
min_encoding_indices_selected)
|
477 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
478 |
+
|
479 |
+
z_q = F.fold(
|
480 |
+
z_q.view(((shape[0], shape[1] * shape[2],
|
481 |
+
self.e_dim))).permute(0, 2, 1),
|
482 |
+
(shape[1] * self.spatial_size, shape[2] * self.spatial_size),
|
483 |
+
kernel_size=(self.spatial_size, self.spatial_size),
|
484 |
+
stride=self.spatial_size)
|
485 |
+
|
486 |
+
return z_q
|
487 |
+
|
488 |
+
|
489 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
490 |
+
"""
|
491 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
492 |
+
From Fairseq.
|
493 |
+
Build sinusoidal embeddings.
|
494 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
495 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
496 |
+
"""
|
497 |
+
assert len(timesteps.shape) == 1
|
498 |
+
|
499 |
+
half_dim = embedding_dim // 2
|
500 |
+
emb = math.log(10000) / (half_dim - 1)
|
501 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
502 |
+
emb = emb.to(device=timesteps.device)
|
503 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
504 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
505 |
+
if embedding_dim % 2 == 1: # zero pad
|
506 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
507 |
+
return emb
|
508 |
+
|
509 |
+
|
510 |
+
def nonlinearity(x):
|
511 |
+
# swish
|
512 |
+
return x * torch.sigmoid(x)
|
513 |
+
|
514 |
+
|
515 |
+
def Normalize(in_channels):
|
516 |
+
return torch.nn.GroupNorm(
|
517 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
518 |
+
|
519 |
+
|
520 |
+
class Upsample(nn.Module):
|
521 |
+
|
522 |
+
def __init__(self, in_channels, with_conv):
|
523 |
+
super().__init__()
|
524 |
+
self.with_conv = with_conv
|
525 |
+
if self.with_conv:
|
526 |
+
self.conv = torch.nn.Conv2d(
|
527 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
528 |
+
|
529 |
+
def forward(self, x):
|
530 |
+
x = torch.nn.functional.interpolate(
|
531 |
+
x, scale_factor=2.0, mode="nearest")
|
532 |
+
if self.with_conv:
|
533 |
+
x = self.conv(x)
|
534 |
+
return x
|
535 |
+
|
536 |
+
|
537 |
+
class Downsample(nn.Module):
|
538 |
+
|
539 |
+
def __init__(self, in_channels, with_conv):
|
540 |
+
super().__init__()
|
541 |
+
self.with_conv = with_conv
|
542 |
+
if self.with_conv:
|
543 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
544 |
+
self.conv = torch.nn.Conv2d(
|
545 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
546 |
+
|
547 |
+
def forward(self, x):
|
548 |
+
if self.with_conv:
|
549 |
+
pad = (0, 1, 0, 1)
|
550 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
551 |
+
x = self.conv(x)
|
552 |
+
else:
|
553 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
554 |
+
return x
|
555 |
+
|
556 |
+
|
557 |
+
class ResnetBlock(nn.Module):
|
558 |
+
|
559 |
+
def __init__(self,
|
560 |
+
*,
|
561 |
+
in_channels,
|
562 |
+
out_channels=None,
|
563 |
+
conv_shortcut=False,
|
564 |
+
dropout,
|
565 |
+
temb_channels=512):
|
566 |
+
super().__init__()
|
567 |
+
self.in_channels = in_channels
|
568 |
+
out_channels = in_channels if out_channels is None else out_channels
|
569 |
+
self.out_channels = out_channels
|
570 |
+
self.use_conv_shortcut = conv_shortcut
|
571 |
+
|
572 |
+
self.norm1 = Normalize(in_channels)
|
573 |
+
self.conv1 = torch.nn.Conv2d(
|
574 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
575 |
+
if temb_channels > 0:
|
576 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
577 |
+
self.norm2 = Normalize(out_channels)
|
578 |
+
self.dropout = torch.nn.Dropout(dropout)
|
579 |
+
self.conv2 = torch.nn.Conv2d(
|
580 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
581 |
+
if self.in_channels != self.out_channels:
|
582 |
+
if self.use_conv_shortcut:
|
583 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
584 |
+
in_channels,
|
585 |
+
out_channels,
|
586 |
+
kernel_size=3,
|
587 |
+
stride=1,
|
588 |
+
padding=1)
|
589 |
+
else:
|
590 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
591 |
+
in_channels,
|
592 |
+
out_channels,
|
593 |
+
kernel_size=1,
|
594 |
+
stride=1,
|
595 |
+
padding=0)
|
596 |
+
|
597 |
+
def forward(self, x, temb):
|
598 |
+
h = x
|
599 |
+
h = self.norm1(h)
|
600 |
+
h = nonlinearity(h)
|
601 |
+
h = self.conv1(h)
|
602 |
+
|
603 |
+
if temb is not None:
|
604 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
605 |
+
|
606 |
+
h = self.norm2(h)
|
607 |
+
h = nonlinearity(h)
|
608 |
+
h = self.dropout(h)
|
609 |
+
h = self.conv2(h)
|
610 |
+
|
611 |
+
if self.in_channels != self.out_channels:
|
612 |
+
if self.use_conv_shortcut:
|
613 |
+
x = self.conv_shortcut(x)
|
614 |
+
else:
|
615 |
+
x = self.nin_shortcut(x)
|
616 |
+
|
617 |
+
return x + h
|
618 |
+
|
619 |
+
|
620 |
+
class AttnBlock(nn.Module):
|
621 |
+
|
622 |
+
def __init__(self, in_channels):
|
623 |
+
super().__init__()
|
624 |
+
self.in_channels = in_channels
|
625 |
+
|
626 |
+
self.norm = Normalize(in_channels)
|
627 |
+
self.q = torch.nn.Conv2d(
|
628 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
629 |
+
self.k = torch.nn.Conv2d(
|
630 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
631 |
+
self.v = torch.nn.Conv2d(
|
632 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
633 |
+
self.proj_out = torch.nn.Conv2d(
|
634 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
635 |
+
|
636 |
+
def forward(self, x):
|
637 |
+
h_ = x
|
638 |
+
h_ = self.norm(h_)
|
639 |
+
q = self.q(h_)
|
640 |
+
k = self.k(h_)
|
641 |
+
v = self.v(h_)
|
642 |
+
|
643 |
+
# compute attention
|
644 |
+
b, c, h, w = q.shape
|
645 |
+
q = q.reshape(b, c, h * w)
|
646 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
647 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
648 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
649 |
+
w_ = w_ * (int(c)**(-0.5))
|
650 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
651 |
+
|
652 |
+
# attend to values
|
653 |
+
v = v.reshape(b, c, h * w)
|
654 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
655 |
+
h_ = torch.bmm(
|
656 |
+
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
657 |
+
h_ = h_.reshape(b, c, h, w)
|
658 |
+
|
659 |
+
h_ = self.proj_out(h_)
|
660 |
+
|
661 |
+
return x + h_
|
662 |
+
|
663 |
+
|
664 |
+
class Model(nn.Module):
|
665 |
+
|
666 |
+
def __init__(self,
|
667 |
+
*,
|
668 |
+
ch,
|
669 |
+
out_ch,
|
670 |
+
ch_mult=(1, 2, 4, 8),
|
671 |
+
num_res_blocks,
|
672 |
+
attn_resolutions,
|
673 |
+
dropout=0.0,
|
674 |
+
resamp_with_conv=True,
|
675 |
+
in_channels,
|
676 |
+
resolution,
|
677 |
+
use_timestep=True):
|
678 |
+
super().__init__()
|
679 |
+
self.ch = ch
|
680 |
+
self.temb_ch = self.ch * 4
|
681 |
+
self.num_resolutions = len(ch_mult)
|
682 |
+
self.num_res_blocks = num_res_blocks
|
683 |
+
self.resolution = resolution
|
684 |
+
self.in_channels = in_channels
|
685 |
+
|
686 |
+
self.use_timestep = use_timestep
|
687 |
+
if self.use_timestep:
|
688 |
+
# timestep embedding
|
689 |
+
self.temb = nn.Module()
|
690 |
+
self.temb.dense = nn.ModuleList([
|
691 |
+
torch.nn.Linear(self.ch, self.temb_ch),
|
692 |
+
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
693 |
+
])
|
694 |
+
|
695 |
+
# downsampling
|
696 |
+
self.conv_in = torch.nn.Conv2d(
|
697 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
698 |
+
|
699 |
+
curr_res = resolution
|
700 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
701 |
+
self.down = nn.ModuleList()
|
702 |
+
for i_level in range(self.num_resolutions):
|
703 |
+
block = nn.ModuleList()
|
704 |
+
attn = nn.ModuleList()
|
705 |
+
block_in = ch * in_ch_mult[i_level]
|
706 |
+
block_out = ch * ch_mult[i_level]
|
707 |
+
for i_block in range(self.num_res_blocks):
|
708 |
+
block.append(
|
709 |
+
ResnetBlock(
|
710 |
+
in_channels=block_in,
|
711 |
+
out_channels=block_out,
|
712 |
+
temb_channels=self.temb_ch,
|
713 |
+
dropout=dropout))
|
714 |
+
block_in = block_out
|
715 |
+
if curr_res in attn_resolutions:
|
716 |
+
attn.append(AttnBlock(block_in))
|
717 |
+
down = nn.Module()
|
718 |
+
down.block = block
|
719 |
+
down.attn = attn
|
720 |
+
if i_level != self.num_resolutions - 1:
|
721 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
722 |
+
curr_res = curr_res // 2
|
723 |
+
self.down.append(down)
|
724 |
+
|
725 |
+
# middle
|
726 |
+
self.mid = nn.Module()
|
727 |
+
self.mid.block_1 = ResnetBlock(
|
728 |
+
in_channels=block_in,
|
729 |
+
out_channels=block_in,
|
730 |
+
temb_channels=self.temb_ch,
|
731 |
+
dropout=dropout)
|
732 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
733 |
+
self.mid.block_2 = ResnetBlock(
|
734 |
+
in_channels=block_in,
|
735 |
+
out_channels=block_in,
|
736 |
+
temb_channels=self.temb_ch,
|
737 |
+
dropout=dropout)
|
738 |
+
|
739 |
+
# upsampling
|
740 |
+
self.up = nn.ModuleList()
|
741 |
+
for i_level in reversed(range(self.num_resolutions)):
|
742 |
+
block = nn.ModuleList()
|
743 |
+
attn = nn.ModuleList()
|
744 |
+
block_out = ch * ch_mult[i_level]
|
745 |
+
skip_in = ch * ch_mult[i_level]
|
746 |
+
for i_block in range(self.num_res_blocks + 1):
|
747 |
+
if i_block == self.num_res_blocks:
|
748 |
+
skip_in = ch * in_ch_mult[i_level]
|
749 |
+
block.append(
|
750 |
+
ResnetBlock(
|
751 |
+
in_channels=block_in + skip_in,
|
752 |
+
out_channels=block_out,
|
753 |
+
temb_channels=self.temb_ch,
|
754 |
+
dropout=dropout))
|
755 |
+
block_in = block_out
|
756 |
+
if curr_res in attn_resolutions:
|
757 |
+
attn.append(AttnBlock(block_in))
|
758 |
+
up = nn.Module()
|
759 |
+
up.block = block
|
760 |
+
up.attn = attn
|
761 |
+
if i_level != 0:
|
762 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
763 |
+
curr_res = curr_res * 2
|
764 |
+
self.up.insert(0, up) # prepend to get consistent order
|
765 |
+
|
766 |
+
# end
|
767 |
+
self.norm_out = Normalize(block_in)
|
768 |
+
self.conv_out = torch.nn.Conv2d(
|
769 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
770 |
+
|
771 |
+
def forward(self, x, t=None):
|
772 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
773 |
+
|
774 |
+
if self.use_timestep:
|
775 |
+
# timestep embedding
|
776 |
+
assert t is not None
|
777 |
+
temb = get_timestep_embedding(t, self.ch)
|
778 |
+
temb = self.temb.dense[0](temb)
|
779 |
+
temb = nonlinearity(temb)
|
780 |
+
temb = self.temb.dense[1](temb)
|
781 |
+
else:
|
782 |
+
temb = None
|
783 |
+
|
784 |
+
# downsampling
|
785 |
+
hs = [self.conv_in(x)]
|
786 |
+
for i_level in range(self.num_resolutions):
|
787 |
+
for i_block in range(self.num_res_blocks):
|
788 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
789 |
+
if len(self.down[i_level].attn) > 0:
|
790 |
+
h = self.down[i_level].attn[i_block](h)
|
791 |
+
hs.append(h)
|
792 |
+
if i_level != self.num_resolutions - 1:
|
793 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
794 |
+
|
795 |
+
# middle
|
796 |
+
h = hs[-1]
|
797 |
+
h = self.mid.block_1(h, temb)
|
798 |
+
h = self.mid.attn_1(h)
|
799 |
+
h = self.mid.block_2(h, temb)
|
800 |
+
|
801 |
+
# upsampling
|
802 |
+
for i_level in reversed(range(self.num_resolutions)):
|
803 |
+
for i_block in range(self.num_res_blocks + 1):
|
804 |
+
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
|
805 |
+
dim=1), temb)
|
806 |
+
if len(self.up[i_level].attn) > 0:
|
807 |
+
h = self.up[i_level].attn[i_block](h)
|
808 |
+
if i_level != 0:
|
809 |
+
h = self.up[i_level].upsample(h)
|
810 |
+
|
811 |
+
# end
|
812 |
+
h = self.norm_out(h)
|
813 |
+
h = nonlinearity(h)
|
814 |
+
h = self.conv_out(h)
|
815 |
+
return h
|
816 |
+
|
817 |
+
|
818 |
+
class Encoder(nn.Module):
|
819 |
+
|
820 |
+
def __init__(self,
|
821 |
+
ch,
|
822 |
+
num_res_blocks,
|
823 |
+
attn_resolutions,
|
824 |
+
in_channels,
|
825 |
+
resolution,
|
826 |
+
z_channels,
|
827 |
+
ch_mult=(1, 2, 4, 8),
|
828 |
+
dropout=0.0,
|
829 |
+
resamp_with_conv=True,
|
830 |
+
double_z=True):
|
831 |
+
super().__init__()
|
832 |
+
self.ch = ch
|
833 |
+
self.temb_ch = 0
|
834 |
+
self.num_resolutions = len(ch_mult)
|
835 |
+
self.num_res_blocks = num_res_blocks
|
836 |
+
self.resolution = resolution
|
837 |
+
self.in_channels = in_channels
|
838 |
+
|
839 |
+
# downsampling
|
840 |
+
self.conv_in = torch.nn.Conv2d(
|
841 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
842 |
+
|
843 |
+
curr_res = resolution
|
844 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
845 |
+
self.down = nn.ModuleList()
|
846 |
+
for i_level in range(self.num_resolutions):
|
847 |
+
block = nn.ModuleList()
|
848 |
+
attn = nn.ModuleList()
|
849 |
+
block_in = ch * in_ch_mult[i_level]
|
850 |
+
block_out = ch * ch_mult[i_level]
|
851 |
+
for i_block in range(self.num_res_blocks):
|
852 |
+
block.append(
|
853 |
+
ResnetBlock(
|
854 |
+
in_channels=block_in,
|
855 |
+
out_channels=block_out,
|
856 |
+
temb_channels=self.temb_ch,
|
857 |
+
dropout=dropout))
|
858 |
+
block_in = block_out
|
859 |
+
if curr_res in attn_resolutions:
|
860 |
+
attn.append(AttnBlock(block_in))
|
861 |
+
down = nn.Module()
|
862 |
+
down.block = block
|
863 |
+
down.attn = attn
|
864 |
+
if i_level != self.num_resolutions - 1:
|
865 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
866 |
+
curr_res = curr_res // 2
|
867 |
+
self.down.append(down)
|
868 |
+
|
869 |
+
# middle
|
870 |
+
self.mid = nn.Module()
|
871 |
+
self.mid.block_1 = ResnetBlock(
|
872 |
+
in_channels=block_in,
|
873 |
+
out_channels=block_in,
|
874 |
+
temb_channels=self.temb_ch,
|
875 |
+
dropout=dropout)
|
876 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
877 |
+
self.mid.block_2 = ResnetBlock(
|
878 |
+
in_channels=block_in,
|
879 |
+
out_channels=block_in,
|
880 |
+
temb_channels=self.temb_ch,
|
881 |
+
dropout=dropout)
|
882 |
+
|
883 |
+
# end
|
884 |
+
self.norm_out = Normalize(block_in)
|
885 |
+
self.conv_out = torch.nn.Conv2d(
|
886 |
+
block_in,
|
887 |
+
2 * z_channels if double_z else z_channels,
|
888 |
+
kernel_size=3,
|
889 |
+
stride=1,
|
890 |
+
padding=1)
|
891 |
+
|
892 |
+
def forward(self, x):
|
893 |
+
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
894 |
+
|
895 |
+
# timestep embedding
|
896 |
+
temb = None
|
897 |
+
|
898 |
+
# downsampling
|
899 |
+
hs = [self.conv_in(x)]
|
900 |
+
for i_level in range(self.num_resolutions):
|
901 |
+
for i_block in range(self.num_res_blocks):
|
902 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
903 |
+
if len(self.down[i_level].attn) > 0:
|
904 |
+
h = self.down[i_level].attn[i_block](h)
|
905 |
+
hs.append(h)
|
906 |
+
if i_level != self.num_resolutions - 1:
|
907 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
908 |
+
|
909 |
+
# middle
|
910 |
+
h = hs[-1]
|
911 |
+
h = self.mid.block_1(h, temb)
|
912 |
+
h = self.mid.attn_1(h)
|
913 |
+
h = self.mid.block_2(h, temb)
|
914 |
+
|
915 |
+
# end
|
916 |
+
h = self.norm_out(h)
|
917 |
+
h = nonlinearity(h)
|
918 |
+
h = self.conv_out(h)
|
919 |
+
return h
|
920 |
+
|
921 |
+
|
922 |
+
class Decoder(nn.Module):
|
923 |
+
|
924 |
+
def __init__(self,
|
925 |
+
in_channels,
|
926 |
+
resolution,
|
927 |
+
z_channels,
|
928 |
+
ch,
|
929 |
+
out_ch,
|
930 |
+
num_res_blocks,
|
931 |
+
attn_resolutions,
|
932 |
+
ch_mult=(1, 2, 4, 8),
|
933 |
+
dropout=0.0,
|
934 |
+
resamp_with_conv=True,
|
935 |
+
give_pre_end=False):
|
936 |
+
super().__init__()
|
937 |
+
self.ch = ch
|
938 |
+
self.temb_ch = 0
|
939 |
+
self.num_resolutions = len(ch_mult)
|
940 |
+
self.num_res_blocks = num_res_blocks
|
941 |
+
self.resolution = resolution
|
942 |
+
self.in_channels = in_channels
|
943 |
+
self.give_pre_end = give_pre_end
|
944 |
+
|
945 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
946 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
947 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
948 |
+
curr_res = resolution // 2**(self.num_resolutions - 1)
|
949 |
+
self.z_shape = (1, z_channels, curr_res, curr_res // 2)
|
950 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
951 |
+
self.z_shape, np.prod(self.z_shape)))
|
952 |
+
|
953 |
+
# z to block_in
|
954 |
+
self.conv_in = torch.nn.Conv2d(
|
955 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
956 |
+
|
957 |
+
# middle
|
958 |
+
self.mid = nn.Module()
|
959 |
+
self.mid.block_1 = ResnetBlock(
|
960 |
+
in_channels=block_in,
|
961 |
+
out_channels=block_in,
|
962 |
+
temb_channels=self.temb_ch,
|
963 |
+
dropout=dropout)
|
964 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
965 |
+
self.mid.block_2 = ResnetBlock(
|
966 |
+
in_channels=block_in,
|
967 |
+
out_channels=block_in,
|
968 |
+
temb_channels=self.temb_ch,
|
969 |
+
dropout=dropout)
|
970 |
+
|
971 |
+
# upsampling
|
972 |
+
self.up = nn.ModuleList()
|
973 |
+
for i_level in reversed(range(self.num_resolutions)):
|
974 |
+
block = nn.ModuleList()
|
975 |
+
attn = nn.ModuleList()
|
976 |
+
block_out = ch * ch_mult[i_level]
|
977 |
+
for i_block in range(self.num_res_blocks + 1):
|
978 |
+
block.append(
|
979 |
+
ResnetBlock(
|
980 |
+
in_channels=block_in,
|
981 |
+
out_channels=block_out,
|
982 |
+
temb_channels=self.temb_ch,
|
983 |
+
dropout=dropout))
|
984 |
+
block_in = block_out
|
985 |
+
if curr_res in attn_resolutions:
|
986 |
+
attn.append(AttnBlock(block_in))
|
987 |
+
up = nn.Module()
|
988 |
+
up.block = block
|
989 |
+
up.attn = attn
|
990 |
+
if i_level != 0:
|
991 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
992 |
+
curr_res = curr_res * 2
|
993 |
+
self.up.insert(0, up) # prepend to get consistent order
|
994 |
+
|
995 |
+
# end
|
996 |
+
self.norm_out = Normalize(block_in)
|
997 |
+
self.conv_out = torch.nn.Conv2d(
|
998 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
999 |
+
|
1000 |
+
def forward(self, z, bot_h=None):
|
1001 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
1002 |
+
self.last_z_shape = z.shape
|
1003 |
+
|
1004 |
+
# timestep embedding
|
1005 |
+
temb = None
|
1006 |
+
|
1007 |
+
# z to block_in
|
1008 |
+
h = self.conv_in(z)
|
1009 |
+
|
1010 |
+
# middle
|
1011 |
+
h = self.mid.block_1(h, temb)
|
1012 |
+
h = self.mid.attn_1(h)
|
1013 |
+
h = self.mid.block_2(h, temb)
|
1014 |
+
|
1015 |
+
# upsampling
|
1016 |
+
for i_level in reversed(range(self.num_resolutions)):
|
1017 |
+
for i_block in range(self.num_res_blocks + 1):
|
1018 |
+
h = self.up[i_level].block[i_block](h, temb)
|
1019 |
+
if len(self.up[i_level].attn) > 0:
|
1020 |
+
h = self.up[i_level].attn[i_block](h)
|
1021 |
+
if i_level != 0:
|
1022 |
+
h = self.up[i_level].upsample(h)
|
1023 |
+
if i_level == 4 and bot_h is not None:
|
1024 |
+
h += bot_h
|
1025 |
+
|
1026 |
+
# end
|
1027 |
+
if self.give_pre_end:
|
1028 |
+
return h
|
1029 |
+
|
1030 |
+
h = self.norm_out(h)
|
1031 |
+
h = nonlinearity(h)
|
1032 |
+
h = self.conv_out(h)
|
1033 |
+
return h
|
1034 |
+
|
1035 |
+
def get_feature_top(self, z):
|
1036 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
1037 |
+
self.last_z_shape = z.shape
|
1038 |
+
|
1039 |
+
# timestep embedding
|
1040 |
+
temb = None
|
1041 |
+
|
1042 |
+
# z to block_in
|
1043 |
+
h = self.conv_in(z)
|
1044 |
+
|
1045 |
+
# middle
|
1046 |
+
h = self.mid.block_1(h, temb)
|
1047 |
+
h = self.mid.attn_1(h)
|
1048 |
+
h = self.mid.block_2(h, temb)
|
1049 |
+
|
1050 |
+
# upsampling
|
1051 |
+
for i_level in reversed(range(self.num_resolutions)):
|
1052 |
+
for i_block in range(self.num_res_blocks + 1):
|
1053 |
+
h = self.up[i_level].block[i_block](h, temb)
|
1054 |
+
if len(self.up[i_level].attn) > 0:
|
1055 |
+
h = self.up[i_level].attn[i_block](h)
|
1056 |
+
if i_level != 0:
|
1057 |
+
h = self.up[i_level].upsample(h)
|
1058 |
+
if i_level == 4:
|
1059 |
+
return h
|
1060 |
+
|
1061 |
+
def get_feature_middle(self, z, mid_h):
|
1062 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
1063 |
+
self.last_z_shape = z.shape
|
1064 |
+
|
1065 |
+
# timestep embedding
|
1066 |
+
temb = None
|
1067 |
+
|
1068 |
+
# z to block_in
|
1069 |
+
h = self.conv_in(z)
|
1070 |
+
|
1071 |
+
# middle
|
1072 |
+
h = self.mid.block_1(h, temb)
|
1073 |
+
h = self.mid.attn_1(h)
|
1074 |
+
h = self.mid.block_2(h, temb)
|
1075 |
+
|
1076 |
+
# upsampling
|
1077 |
+
for i_level in reversed(range(self.num_resolutions)):
|
1078 |
+
for i_block in range(self.num_res_blocks + 1):
|
1079 |
+
h = self.up[i_level].block[i_block](h, temb)
|
1080 |
+
if len(self.up[i_level].attn) > 0:
|
1081 |
+
h = self.up[i_level].attn[i_block](h)
|
1082 |
+
if i_level != 0:
|
1083 |
+
h = self.up[i_level].upsample(h)
|
1084 |
+
if i_level == 4:
|
1085 |
+
h += mid_h
|
1086 |
+
if i_level == 3:
|
1087 |
+
return h
|
1088 |
+
|
1089 |
+
|
1090 |
+
class DecoderRes(nn.Module):
|
1091 |
+
|
1092 |
+
def __init__(self,
|
1093 |
+
in_channels,
|
1094 |
+
resolution,
|
1095 |
+
z_channels,
|
1096 |
+
ch,
|
1097 |
+
num_res_blocks,
|
1098 |
+
ch_mult=(1, 2, 4, 8),
|
1099 |
+
dropout=0.0,
|
1100 |
+
give_pre_end=False):
|
1101 |
+
super().__init__()
|
1102 |
+
self.ch = ch
|
1103 |
+
self.temb_ch = 0
|
1104 |
+
self.num_resolutions = len(ch_mult)
|
1105 |
+
self.num_res_blocks = num_res_blocks
|
1106 |
+
self.resolution = resolution
|
1107 |
+
self.in_channels = in_channels
|
1108 |
+
self.give_pre_end = give_pre_end
|
1109 |
+
|
1110 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
1111 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
1112 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
1113 |
+
curr_res = resolution // 2**(self.num_resolutions - 1)
|
1114 |
+
self.z_shape = (1, z_channels, curr_res, curr_res // 2)
|
1115 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
1116 |
+
self.z_shape, np.prod(self.z_shape)))
|
1117 |
+
|
1118 |
+
# z to block_in
|
1119 |
+
self.conv_in = torch.nn.Conv2d(
|
1120 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
1121 |
+
|
1122 |
+
# middle
|
1123 |
+
self.mid = nn.Module()
|
1124 |
+
self.mid.block_1 = ResnetBlock(
|
1125 |
+
in_channels=block_in,
|
1126 |
+
out_channels=block_in,
|
1127 |
+
temb_channels=self.temb_ch,
|
1128 |
+
dropout=dropout)
|
1129 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
1130 |
+
self.mid.block_2 = ResnetBlock(
|
1131 |
+
in_channels=block_in,
|
1132 |
+
out_channels=block_in,
|
1133 |
+
temb_channels=self.temb_ch,
|
1134 |
+
dropout=dropout)
|
1135 |
+
|
1136 |
+
def forward(self, z):
|
1137 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
1138 |
+
self.last_z_shape = z.shape
|
1139 |
+
|
1140 |
+
# timestep embedding
|
1141 |
+
temb = None
|
1142 |
+
|
1143 |
+
# z to block_in
|
1144 |
+
h = self.conv_in(z)
|
1145 |
+
|
1146 |
+
# middle
|
1147 |
+
h = self.mid.block_1(h, temb)
|
1148 |
+
h = self.mid.attn_1(h)
|
1149 |
+
h = self.mid.block_2(h, temb)
|
1150 |
+
|
1151 |
+
return h
|
1152 |
+
|
1153 |
+
|
1154 |
+
# patch based discriminator
|
1155 |
+
class Discriminator(nn.Module):
|
1156 |
+
|
1157 |
+
def __init__(self, nc, ndf, n_layers=3):
|
1158 |
+
super().__init__()
|
1159 |
+
|
1160 |
+
layers = [
|
1161 |
+
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
|
1162 |
+
nn.LeakyReLU(0.2, True)
|
1163 |
+
]
|
1164 |
+
ndf_mult = 1
|
1165 |
+
ndf_mult_prev = 1
|
1166 |
+
for n in range(1,
|
1167 |
+
n_layers): # gradually increase the number of filters
|
1168 |
+
ndf_mult_prev = ndf_mult
|
1169 |
+
ndf_mult = min(2**n, 8)
|
1170 |
+
layers += [
|
1171 |
+
nn.Conv2d(
|
1172 |
+
ndf * ndf_mult_prev,
|
1173 |
+
ndf * ndf_mult,
|
1174 |
+
kernel_size=4,
|
1175 |
+
stride=2,
|
1176 |
+
padding=1,
|
1177 |
+
bias=False),
|
1178 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
1179 |
+
nn.LeakyReLU(0.2, True)
|
1180 |
+
]
|
1181 |
+
|
1182 |
+
ndf_mult_prev = ndf_mult
|
1183 |
+
ndf_mult = min(2**n_layers, 8)
|
1184 |
+
|
1185 |
+
layers += [
|
1186 |
+
nn.Conv2d(
|
1187 |
+
ndf * ndf_mult_prev,
|
1188 |
+
ndf * ndf_mult,
|
1189 |
+
kernel_size=4,
|
1190 |
+
stride=1,
|
1191 |
+
padding=1,
|
1192 |
+
bias=False),
|
1193 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
1194 |
+
nn.LeakyReLU(0.2, True)
|
1195 |
+
]
|
1196 |
+
|
1197 |
+
layers += [
|
1198 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
|
1199 |
+
] # output 1 channel prediction map
|
1200 |
+
self.main = nn.Sequential(*layers)
|
1201 |
+
|
1202 |
+
def forward(self, x):
|
1203 |
+
return self.main(x)
|
models/hierarchy_inference_model.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchvision.utils import save_image
|
8 |
+
|
9 |
+
from models.archs.fcn_arch import MultiHeadFCNHead
|
10 |
+
from models.archs.unet_arch import UNet
|
11 |
+
from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
|
12 |
+
VectorQuantizerSpatialTextureAware,
|
13 |
+
VectorQuantizerTexture)
|
14 |
+
from models.losses.accuracy import accuracy
|
15 |
+
from models.losses.cross_entropy_loss import CrossEntropyLoss
|
16 |
+
|
17 |
+
logger = logging.getLogger('base')
|
18 |
+
|
19 |
+
|
20 |
+
class VQGANTextureAwareSpatialHierarchyInferenceModel():
|
21 |
+
|
22 |
+
def __init__(self, opt):
|
23 |
+
self.opt = opt
|
24 |
+
self.device = torch.device('cuda')
|
25 |
+
self.is_train = opt['is_train']
|
26 |
+
|
27 |
+
self.top_encoder = Encoder(
|
28 |
+
ch=opt['top_ch'],
|
29 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
30 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
31 |
+
ch_mult=opt['top_ch_mult'],
|
32 |
+
in_channels=opt['top_in_channels'],
|
33 |
+
resolution=opt['top_resolution'],
|
34 |
+
z_channels=opt['top_z_channels'],
|
35 |
+
double_z=opt['top_double_z'],
|
36 |
+
dropout=opt['top_dropout']).to(self.device)
|
37 |
+
self.decoder = Decoder(
|
38 |
+
in_channels=opt['top_in_channels'],
|
39 |
+
resolution=opt['top_resolution'],
|
40 |
+
z_channels=opt['top_z_channels'],
|
41 |
+
ch=opt['top_ch'],
|
42 |
+
out_ch=opt['top_out_ch'],
|
43 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
44 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
45 |
+
ch_mult=opt['top_ch_mult'],
|
46 |
+
dropout=opt['top_dropout'],
|
47 |
+
resamp_with_conv=True,
|
48 |
+
give_pre_end=False).to(self.device)
|
49 |
+
self.top_quantize = VectorQuantizerTexture(
|
50 |
+
1024, opt['embed_dim'], beta=0.25).to(self.device)
|
51 |
+
self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
|
52 |
+
opt['embed_dim'],
|
53 |
+
1).to(self.device)
|
54 |
+
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
55 |
+
opt["top_z_channels"],
|
56 |
+
1).to(self.device)
|
57 |
+
self.load_top_pretrain_models()
|
58 |
+
|
59 |
+
self.bot_encoder = Encoder(
|
60 |
+
ch=opt['bot_ch'],
|
61 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
62 |
+
attn_resolutions=opt['bot_attn_resolutions'],
|
63 |
+
ch_mult=opt['bot_ch_mult'],
|
64 |
+
in_channels=opt['bot_in_channels'],
|
65 |
+
resolution=opt['bot_resolution'],
|
66 |
+
z_channels=opt['bot_z_channels'],
|
67 |
+
double_z=opt['bot_double_z'],
|
68 |
+
dropout=opt['bot_dropout']).to(self.device)
|
69 |
+
self.bot_decoder_res = DecoderRes(
|
70 |
+
in_channels=opt['bot_in_channels'],
|
71 |
+
resolution=opt['bot_resolution'],
|
72 |
+
z_channels=opt['bot_z_channels'],
|
73 |
+
ch=opt['bot_ch'],
|
74 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
75 |
+
ch_mult=opt['bot_ch_mult'],
|
76 |
+
dropout=opt['bot_dropout'],
|
77 |
+
give_pre_end=False).to(self.device)
|
78 |
+
self.bot_quantize = VectorQuantizerSpatialTextureAware(
|
79 |
+
opt['bot_n_embed'],
|
80 |
+
opt['embed_dim'],
|
81 |
+
beta=0.25,
|
82 |
+
spatial_size=opt['codebook_spatial_size']).to(self.device)
|
83 |
+
self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
|
84 |
+
opt['embed_dim'],
|
85 |
+
1).to(self.device)
|
86 |
+
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
87 |
+
opt["bot_z_channels"],
|
88 |
+
1).to(self.device)
|
89 |
+
|
90 |
+
self.load_bot_pretrain_network()
|
91 |
+
|
92 |
+
self.guidance_encoder = UNet(
|
93 |
+
in_channels=opt['encoder_in_channels']).to(self.device)
|
94 |
+
self.index_decoder = MultiHeadFCNHead(
|
95 |
+
in_channels=opt['fc_in_channels'],
|
96 |
+
in_index=opt['fc_in_index'],
|
97 |
+
channels=opt['fc_channels'],
|
98 |
+
num_convs=opt['fc_num_convs'],
|
99 |
+
concat_input=opt['fc_concat_input'],
|
100 |
+
dropout_ratio=opt['fc_dropout_ratio'],
|
101 |
+
num_classes=opt['fc_num_classes'],
|
102 |
+
align_corners=opt['fc_align_corners'],
|
103 |
+
num_head=18).to(self.device)
|
104 |
+
|
105 |
+
self.init_training_settings()
|
106 |
+
|
107 |
+
def init_training_settings(self):
|
108 |
+
optim_params = []
|
109 |
+
for v in self.guidance_encoder.parameters():
|
110 |
+
if v.requires_grad:
|
111 |
+
optim_params.append(v)
|
112 |
+
for v in self.index_decoder.parameters():
|
113 |
+
if v.requires_grad:
|
114 |
+
optim_params.append(v)
|
115 |
+
# set up optimizers
|
116 |
+
if self.opt['optimizer'] == 'Adam':
|
117 |
+
self.optimizer = torch.optim.Adam(
|
118 |
+
optim_params,
|
119 |
+
self.opt['lr'],
|
120 |
+
weight_decay=self.opt['weight_decay'])
|
121 |
+
elif self.opt['optimizer'] == 'SGD':
|
122 |
+
self.optimizer = torch.optim.SGD(
|
123 |
+
optim_params,
|
124 |
+
self.opt['lr'],
|
125 |
+
momentum=self.opt['momentum'],
|
126 |
+
weight_decay=self.opt['weight_decay'])
|
127 |
+
self.log_dict = OrderedDict()
|
128 |
+
if self.opt['loss_function'] == 'cross_entropy':
|
129 |
+
self.loss_func = CrossEntropyLoss().to(self.device)
|
130 |
+
|
131 |
+
def load_top_pretrain_models(self):
|
132 |
+
# load pretrained vqgan for segmentation mask
|
133 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
134 |
+
self.top_encoder.load_state_dict(
|
135 |
+
top_vae_checkpoint['encoder'], strict=True)
|
136 |
+
self.decoder.load_state_dict(
|
137 |
+
top_vae_checkpoint['decoder'], strict=True)
|
138 |
+
self.top_quantize.load_state_dict(
|
139 |
+
top_vae_checkpoint['quantize'], strict=True)
|
140 |
+
self.top_quant_conv.load_state_dict(
|
141 |
+
top_vae_checkpoint['quant_conv'], strict=True)
|
142 |
+
self.top_post_quant_conv.load_state_dict(
|
143 |
+
top_vae_checkpoint['post_quant_conv'], strict=True)
|
144 |
+
self.top_encoder.eval()
|
145 |
+
self.top_quantize.eval()
|
146 |
+
self.top_quant_conv.eval()
|
147 |
+
self.top_post_quant_conv.eval()
|
148 |
+
|
149 |
+
def load_bot_pretrain_network(self):
|
150 |
+
checkpoint = torch.load(self.opt['bot_vae_path'])
|
151 |
+
self.bot_encoder.load_state_dict(
|
152 |
+
checkpoint['bot_encoder'], strict=True)
|
153 |
+
self.bot_decoder_res.load_state_dict(
|
154 |
+
checkpoint['bot_decoder_res'], strict=True)
|
155 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
156 |
+
self.bot_quantize.load_state_dict(
|
157 |
+
checkpoint['bot_quantize'], strict=True)
|
158 |
+
self.bot_quant_conv.load_state_dict(
|
159 |
+
checkpoint['bot_quant_conv'], strict=True)
|
160 |
+
self.bot_post_quant_conv.load_state_dict(
|
161 |
+
checkpoint['bot_post_quant_conv'], strict=True)
|
162 |
+
|
163 |
+
self.bot_encoder.eval()
|
164 |
+
self.bot_decoder_res.eval()
|
165 |
+
self.decoder.eval()
|
166 |
+
self.bot_quantize.eval()
|
167 |
+
self.bot_quant_conv.eval()
|
168 |
+
self.bot_post_quant_conv.eval()
|
169 |
+
|
170 |
+
def top_encode(self, x, mask):
|
171 |
+
h = self.top_encoder(x)
|
172 |
+
h = self.top_quant_conv(h)
|
173 |
+
quant, _, _ = self.top_quantize(h, mask)
|
174 |
+
quant = self.top_post_quant_conv(quant)
|
175 |
+
|
176 |
+
return quant, quant
|
177 |
+
|
178 |
+
def feed_data(self, data):
|
179 |
+
self.image = data['image'].to(self.device)
|
180 |
+
self.texture_mask = data['texture_mask'].float().to(self.device)
|
181 |
+
self.get_gt_indices()
|
182 |
+
|
183 |
+
self.texture_tokens = F.interpolate(
|
184 |
+
self.texture_mask, size=(32, 16),
|
185 |
+
mode='nearest').view(self.image.size(0), -1).long()
|
186 |
+
|
187 |
+
def bot_encode(self, x, mask):
|
188 |
+
h = self.bot_encoder(x)
|
189 |
+
h = self.bot_quant_conv(h)
|
190 |
+
_, _, (_, _, indices_list) = self.bot_quantize(h, mask)
|
191 |
+
|
192 |
+
return indices_list
|
193 |
+
|
194 |
+
def get_gt_indices(self):
|
195 |
+
self.quant_t, self.feature_t = self.top_encode(self.image,
|
196 |
+
self.texture_mask)
|
197 |
+
self.gt_indices_list = self.bot_encode(self.image, self.texture_mask)
|
198 |
+
|
199 |
+
def index_to_image(self, index_bottom_list, texture_mask):
|
200 |
+
quant_b = self.bot_quantize.get_codebook_entry(
|
201 |
+
index_bottom_list, texture_mask,
|
202 |
+
(index_bottom_list[0].size(0), index_bottom_list[0].size(1),
|
203 |
+
index_bottom_list[0].size(2),
|
204 |
+
self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
|
205 |
+
quant_b = self.bot_post_quant_conv(quant_b)
|
206 |
+
bot_dec_res = self.bot_decoder_res(quant_b)
|
207 |
+
|
208 |
+
dec = self.decoder(self.quant_t, bot_h=bot_dec_res)
|
209 |
+
|
210 |
+
return dec
|
211 |
+
|
212 |
+
def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path):
|
213 |
+
rec_img = self.index_to_image(rec_img_index, texture_mask)
|
214 |
+
pred_img = self.index_to_image(pred_img_index, texture_mask)
|
215 |
+
|
216 |
+
base_img = self.decoder(self.quant_t)
|
217 |
+
img_cat = torch.cat([
|
218 |
+
self.image,
|
219 |
+
rec_img,
|
220 |
+
base_img,
|
221 |
+
pred_img,
|
222 |
+
], dim=3).detach()
|
223 |
+
img_cat = ((img_cat + 1) / 2)
|
224 |
+
img_cat = img_cat.clamp_(0, 1)
|
225 |
+
save_image(img_cat, save_path, nrow=1, padding=4)
|
226 |
+
|
227 |
+
def optimize_parameters(self):
|
228 |
+
self.guidance_encoder.train()
|
229 |
+
self.index_decoder.train()
|
230 |
+
|
231 |
+
self.feature_enc = self.guidance_encoder(self.feature_t)
|
232 |
+
self.memory_logits_list = self.index_decoder(self.feature_enc)
|
233 |
+
|
234 |
+
loss = 0
|
235 |
+
for i in range(18):
|
236 |
+
loss += self.loss_func(
|
237 |
+
self.memory_logits_list[i],
|
238 |
+
self.gt_indices_list[i],
|
239 |
+
ignore_index=-1)
|
240 |
+
|
241 |
+
self.optimizer.zero_grad()
|
242 |
+
loss.backward()
|
243 |
+
self.optimizer.step()
|
244 |
+
|
245 |
+
self.log_dict['loss_total'] = loss
|
246 |
+
|
247 |
+
def inference(self, data_loader, save_dir):
|
248 |
+
self.guidance_encoder.eval()
|
249 |
+
self.index_decoder.eval()
|
250 |
+
|
251 |
+
acc = 0
|
252 |
+
num = 0
|
253 |
+
|
254 |
+
for _, data in enumerate(data_loader):
|
255 |
+
self.feed_data(data)
|
256 |
+
img_name = data['img_name']
|
257 |
+
|
258 |
+
num += self.image.size(0)
|
259 |
+
|
260 |
+
texture_mask_flatten = self.texture_tokens.view(-1)
|
261 |
+
min_encodings_indices_list = [
|
262 |
+
torch.full(
|
263 |
+
texture_mask_flatten.size(),
|
264 |
+
fill_value=-1,
|
265 |
+
dtype=torch.long,
|
266 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
267 |
+
]
|
268 |
+
with torch.no_grad():
|
269 |
+
self.feature_enc = self.guidance_encoder(self.feature_t)
|
270 |
+
memory_logits_list = self.index_decoder(self.feature_enc)
|
271 |
+
# memory_indices_pred = memory_logits.argmax(dim=1)
|
272 |
+
batch_acc = 0
|
273 |
+
for codebook_idx, memory_logits in enumerate(memory_logits_list):
|
274 |
+
region_of_interest = texture_mask_flatten == codebook_idx
|
275 |
+
if torch.sum(region_of_interest) > 0:
|
276 |
+
memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
|
277 |
+
batch_acc += torch.sum(
|
278 |
+
memory_indices_pred[region_of_interest] ==
|
279 |
+
self.gt_indices_list[codebook_idx].view(
|
280 |
+
-1)[region_of_interest])
|
281 |
+
memory_indices_pred = memory_indices_pred
|
282 |
+
min_encodings_indices_list[codebook_idx][
|
283 |
+
region_of_interest] = memory_indices_pred[
|
284 |
+
region_of_interest]
|
285 |
+
min_encodings_indices_return_list = [
|
286 |
+
min_encodings_indices.view(self.gt_indices_list[0].size())
|
287 |
+
for min_encodings_indices in min_encodings_indices_list
|
288 |
+
]
|
289 |
+
batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel(
|
290 |
+
) * self.image.size(0)
|
291 |
+
acc += batch_acc
|
292 |
+
self.get_vis(min_encodings_indices_return_list,
|
293 |
+
self.gt_indices_list, self.texture_mask,
|
294 |
+
f'{save_dir}/{img_name[0]}')
|
295 |
+
|
296 |
+
self.guidance_encoder.train()
|
297 |
+
self.index_decoder.train()
|
298 |
+
return (acc / num).item()
|
299 |
+
|
300 |
+
def load_network(self):
|
301 |
+
checkpoint = torch.load(self.opt['pretrained_models'])
|
302 |
+
self.guidance_encoder.load_state_dict(
|
303 |
+
checkpoint['guidance_encoder'], strict=True)
|
304 |
+
self.guidance_encoder.eval()
|
305 |
+
|
306 |
+
self.index_decoder.load_state_dict(
|
307 |
+
checkpoint['index_decoder'], strict=True)
|
308 |
+
self.index_decoder.eval()
|
309 |
+
|
310 |
+
def save_network(self, save_path):
|
311 |
+
"""Save networks.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
net (nn.Module): Network to be saved.
|
315 |
+
net_label (str): Network label.
|
316 |
+
current_iter (int): Current iter number.
|
317 |
+
"""
|
318 |
+
|
319 |
+
save_dict = {}
|
320 |
+
save_dict['guidance_encoder'] = self.guidance_encoder.state_dict()
|
321 |
+
save_dict['index_decoder'] = self.index_decoder.state_dict()
|
322 |
+
|
323 |
+
torch.save(save_dict, save_path)
|
324 |
+
|
325 |
+
def update_learning_rate(self, epoch):
|
326 |
+
"""Update learning rate.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
current_iter (int): Current iteration.
|
330 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
331 |
+
Default: -1.
|
332 |
+
"""
|
333 |
+
lr = self.optimizer.param_groups[0]['lr']
|
334 |
+
|
335 |
+
if self.opt['lr_decay'] == 'step':
|
336 |
+
lr = self.opt['lr'] * (
|
337 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
338 |
+
elif self.opt['lr_decay'] == 'cos':
|
339 |
+
lr = self.opt['lr'] * (
|
340 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
341 |
+
elif self.opt['lr_decay'] == 'linear':
|
342 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
343 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
344 |
+
if epoch < self.opt['turning_point'] + 1:
|
345 |
+
# learning rate decay as 95%
|
346 |
+
# at the turning point (1 / 95% = 1.0526)
|
347 |
+
lr = self.opt['lr'] * (
|
348 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
349 |
+
else:
|
350 |
+
lr *= self.opt['gamma']
|
351 |
+
elif self.opt['lr_decay'] == 'schedule':
|
352 |
+
if epoch in self.opt['schedule']:
|
353 |
+
lr *= self.opt['gamma']
|
354 |
+
else:
|
355 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
356 |
+
# set learning rate
|
357 |
+
for param_group in self.optimizer.param_groups:
|
358 |
+
param_group['lr'] = lr
|
359 |
+
|
360 |
+
return lr
|
361 |
+
|
362 |
+
def get_current_log(self):
|
363 |
+
return self.log_dict
|
models/hierarchy_vqgan_model.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import sys
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
sys.path.append('..')
|
6 |
+
import lpips
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
|
11 |
+
from models.archs.vqgan_arch import (Decoder, DecoderRes, Discriminator,
|
12 |
+
Encoder,
|
13 |
+
VectorQuantizerSpatialTextureAware,
|
14 |
+
VectorQuantizerTexture)
|
15 |
+
from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
|
16 |
+
calculate_adaptive_weight, hinge_d_loss)
|
17 |
+
|
18 |
+
|
19 |
+
class HierarchyVQSpatialTextureAwareModel():
|
20 |
+
|
21 |
+
def __init__(self, opt):
|
22 |
+
self.opt = opt
|
23 |
+
self.device = torch.device('cuda')
|
24 |
+
self.top_encoder = Encoder(
|
25 |
+
ch=opt['top_ch'],
|
26 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
27 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
28 |
+
ch_mult=opt['top_ch_mult'],
|
29 |
+
in_channels=opt['top_in_channels'],
|
30 |
+
resolution=opt['top_resolution'],
|
31 |
+
z_channels=opt['top_z_channels'],
|
32 |
+
double_z=opt['top_double_z'],
|
33 |
+
dropout=opt['top_dropout']).to(self.device)
|
34 |
+
self.decoder = Decoder(
|
35 |
+
in_channels=opt['top_in_channels'],
|
36 |
+
resolution=opt['top_resolution'],
|
37 |
+
z_channels=opt['top_z_channels'],
|
38 |
+
ch=opt['top_ch'],
|
39 |
+
out_ch=opt['top_out_ch'],
|
40 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
41 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
42 |
+
ch_mult=opt['top_ch_mult'],
|
43 |
+
dropout=opt['top_dropout'],
|
44 |
+
resamp_with_conv=True,
|
45 |
+
give_pre_end=False).to(self.device)
|
46 |
+
self.top_quantize = VectorQuantizerTexture(
|
47 |
+
1024, opt['embed_dim'], beta=0.25).to(self.device)
|
48 |
+
self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
|
49 |
+
opt['embed_dim'],
|
50 |
+
1).to(self.device)
|
51 |
+
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
52 |
+
opt["top_z_channels"],
|
53 |
+
1).to(self.device)
|
54 |
+
self.load_top_pretrain_models()
|
55 |
+
|
56 |
+
self.bot_encoder = Encoder(
|
57 |
+
ch=opt['bot_ch'],
|
58 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
59 |
+
attn_resolutions=opt['bot_attn_resolutions'],
|
60 |
+
ch_mult=opt['bot_ch_mult'],
|
61 |
+
in_channels=opt['bot_in_channels'],
|
62 |
+
resolution=opt['bot_resolution'],
|
63 |
+
z_channels=opt['bot_z_channels'],
|
64 |
+
double_z=opt['bot_double_z'],
|
65 |
+
dropout=opt['bot_dropout']).to(self.device)
|
66 |
+
self.bot_decoder_res = DecoderRes(
|
67 |
+
in_channels=opt['bot_in_channels'],
|
68 |
+
resolution=opt['bot_resolution'],
|
69 |
+
z_channels=opt['bot_z_channels'],
|
70 |
+
ch=opt['bot_ch'],
|
71 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
72 |
+
ch_mult=opt['bot_ch_mult'],
|
73 |
+
dropout=opt['bot_dropout'],
|
74 |
+
give_pre_end=False).to(self.device)
|
75 |
+
self.bot_quantize = VectorQuantizerSpatialTextureAware(
|
76 |
+
opt['bot_n_embed'],
|
77 |
+
opt['embed_dim'],
|
78 |
+
beta=0.25,
|
79 |
+
spatial_size=opt['codebook_spatial_size']).to(self.device)
|
80 |
+
self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
|
81 |
+
opt['embed_dim'],
|
82 |
+
1).to(self.device)
|
83 |
+
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
84 |
+
opt["bot_z_channels"],
|
85 |
+
1).to(self.device)
|
86 |
+
|
87 |
+
self.disc = Discriminator(
|
88 |
+
opt['n_channels'], opt['ndf'],
|
89 |
+
n_layers=opt['disc_layers']).to(self.device)
|
90 |
+
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
|
91 |
+
self.perceptual_weight = opt['perceptual_weight']
|
92 |
+
self.disc_start_step = opt['disc_start_step']
|
93 |
+
self.disc_weight_max = opt['disc_weight_max']
|
94 |
+
self.diff_aug = opt['diff_aug']
|
95 |
+
self.policy = "color,translation"
|
96 |
+
|
97 |
+
self.load_discriminator_models()
|
98 |
+
|
99 |
+
self.disc.train()
|
100 |
+
|
101 |
+
self.fix_decoder = opt['fix_decoder']
|
102 |
+
|
103 |
+
self.init_training_settings()
|
104 |
+
|
105 |
+
def load_top_pretrain_models(self):
|
106 |
+
# load pretrained vqgan for segmentation mask
|
107 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
108 |
+
self.top_encoder.load_state_dict(
|
109 |
+
top_vae_checkpoint['encoder'], strict=True)
|
110 |
+
self.decoder.load_state_dict(
|
111 |
+
top_vae_checkpoint['decoder'], strict=True)
|
112 |
+
self.top_quantize.load_state_dict(
|
113 |
+
top_vae_checkpoint['quantize'], strict=True)
|
114 |
+
self.top_quant_conv.load_state_dict(
|
115 |
+
top_vae_checkpoint['quant_conv'], strict=True)
|
116 |
+
self.top_post_quant_conv.load_state_dict(
|
117 |
+
top_vae_checkpoint['post_quant_conv'], strict=True)
|
118 |
+
self.top_encoder.eval()
|
119 |
+
self.top_quantize.eval()
|
120 |
+
self.top_quant_conv.eval()
|
121 |
+
self.top_post_quant_conv.eval()
|
122 |
+
|
123 |
+
def init_training_settings(self):
|
124 |
+
self.log_dict = OrderedDict()
|
125 |
+
self.configure_optimizers()
|
126 |
+
|
127 |
+
def configure_optimizers(self):
|
128 |
+
optim_params = []
|
129 |
+
for v in self.bot_encoder.parameters():
|
130 |
+
if v.requires_grad:
|
131 |
+
optim_params.append(v)
|
132 |
+
for v in self.bot_decoder_res.parameters():
|
133 |
+
if v.requires_grad:
|
134 |
+
optim_params.append(v)
|
135 |
+
for v in self.bot_quantize.parameters():
|
136 |
+
if v.requires_grad:
|
137 |
+
optim_params.append(v)
|
138 |
+
for v in self.bot_quant_conv.parameters():
|
139 |
+
if v.requires_grad:
|
140 |
+
optim_params.append(v)
|
141 |
+
for v in self.bot_post_quant_conv.parameters():
|
142 |
+
if v.requires_grad:
|
143 |
+
optim_params.append(v)
|
144 |
+
if not self.fix_decoder:
|
145 |
+
for name, v in self.decoder.named_parameters():
|
146 |
+
if v.requires_grad:
|
147 |
+
if 'up.0' in name:
|
148 |
+
optim_params.append(v)
|
149 |
+
if 'up.1' in name:
|
150 |
+
optim_params.append(v)
|
151 |
+
if 'up.2' in name:
|
152 |
+
optim_params.append(v)
|
153 |
+
if 'up.3' in name:
|
154 |
+
optim_params.append(v)
|
155 |
+
|
156 |
+
self.optimizer = torch.optim.Adam(optim_params, lr=self.opt['lr'])
|
157 |
+
|
158 |
+
self.disc_optimizer = torch.optim.Adam(
|
159 |
+
self.disc.parameters(), lr=self.opt['lr'])
|
160 |
+
|
161 |
+
def load_discriminator_models(self):
|
162 |
+
# load pretrained vqgan for segmentation mask
|
163 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
164 |
+
self.disc.load_state_dict(
|
165 |
+
top_vae_checkpoint['discriminator'], strict=True)
|
166 |
+
|
167 |
+
def save_network(self, save_path):
|
168 |
+
"""Save networks.
|
169 |
+
"""
|
170 |
+
|
171 |
+
save_dict = {}
|
172 |
+
save_dict['bot_encoder'] = self.bot_encoder.state_dict()
|
173 |
+
save_dict['bot_decoder_res'] = self.bot_decoder_res.state_dict()
|
174 |
+
save_dict['decoder'] = self.decoder.state_dict()
|
175 |
+
save_dict['bot_quantize'] = self.bot_quantize.state_dict()
|
176 |
+
save_dict['bot_quant_conv'] = self.bot_quant_conv.state_dict()
|
177 |
+
save_dict['bot_post_quant_conv'] = self.bot_post_quant_conv.state_dict(
|
178 |
+
)
|
179 |
+
save_dict['discriminator'] = self.disc.state_dict()
|
180 |
+
torch.save(save_dict, save_path)
|
181 |
+
|
182 |
+
def load_network(self):
|
183 |
+
checkpoint = torch.load(self.opt['pretrained_models'])
|
184 |
+
self.bot_encoder.load_state_dict(
|
185 |
+
checkpoint['bot_encoder'], strict=True)
|
186 |
+
self.bot_decoder_res.load_state_dict(
|
187 |
+
checkpoint['bot_decoder_res'], strict=True)
|
188 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
189 |
+
self.bot_quantize.load_state_dict(
|
190 |
+
checkpoint['bot_quantize'], strict=True)
|
191 |
+
self.bot_quant_conv.load_state_dict(
|
192 |
+
checkpoint['bot_quant_conv'], strict=True)
|
193 |
+
self.bot_post_quant_conv.load_state_dict(
|
194 |
+
checkpoint['bot_post_quant_conv'], strict=True)
|
195 |
+
|
196 |
+
def optimize_parameters(self, data, step):
|
197 |
+
self.bot_encoder.train()
|
198 |
+
self.bot_decoder_res.train()
|
199 |
+
if not self.fix_decoder:
|
200 |
+
self.decoder.train()
|
201 |
+
self.bot_quantize.train()
|
202 |
+
self.bot_quant_conv.train()
|
203 |
+
self.bot_post_quant_conv.train()
|
204 |
+
|
205 |
+
loss, d_loss = self.training_step(data, step)
|
206 |
+
self.optimizer.zero_grad()
|
207 |
+
loss.backward()
|
208 |
+
self.optimizer.step()
|
209 |
+
|
210 |
+
if step > self.disc_start_step:
|
211 |
+
self.disc_optimizer.zero_grad()
|
212 |
+
d_loss.backward()
|
213 |
+
self.disc_optimizer.step()
|
214 |
+
|
215 |
+
def top_encode(self, x, mask):
|
216 |
+
h = self.top_encoder(x)
|
217 |
+
h = self.top_quant_conv(h)
|
218 |
+
quant, _, _ = self.top_quantize(h, mask)
|
219 |
+
quant = self.top_post_quant_conv(quant)
|
220 |
+
return quant
|
221 |
+
|
222 |
+
def bot_encode(self, x, mask):
|
223 |
+
h = self.bot_encoder(x)
|
224 |
+
h = self.bot_quant_conv(h)
|
225 |
+
quant, emb_loss, info = self.bot_quantize(h, mask)
|
226 |
+
quant = self.bot_post_quant_conv(quant)
|
227 |
+
bot_dec_res = self.bot_decoder_res(quant)
|
228 |
+
return bot_dec_res, emb_loss, info
|
229 |
+
|
230 |
+
def decode(self, quant_top, bot_dec_res):
|
231 |
+
dec = self.decoder(quant_top, bot_h=bot_dec_res)
|
232 |
+
return dec
|
233 |
+
|
234 |
+
def forward_step(self, input, mask):
|
235 |
+
with torch.no_grad():
|
236 |
+
quant_top = self.top_encode(input, mask)
|
237 |
+
bot_dec_res, diff, _ = self.bot_encode(input, mask)
|
238 |
+
dec = self.decode(quant_top, bot_dec_res)
|
239 |
+
return dec, diff
|
240 |
+
|
241 |
+
def feed_data(self, data):
|
242 |
+
x = data['image'].float().to(self.device)
|
243 |
+
mask = data['texture_mask'].float().to(self.device)
|
244 |
+
|
245 |
+
return x, mask
|
246 |
+
|
247 |
+
def training_step(self, data, step):
|
248 |
+
x, mask = self.feed_data(data)
|
249 |
+
xrec, codebook_loss = self.forward_step(x, mask)
|
250 |
+
|
251 |
+
# get recon/perceptual loss
|
252 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
253 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
254 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
255 |
+
nll_loss = torch.mean(nll_loss)
|
256 |
+
|
257 |
+
# augment for input to discriminator
|
258 |
+
if self.diff_aug:
|
259 |
+
xrec = DiffAugment(xrec, policy=self.policy)
|
260 |
+
|
261 |
+
# update generator
|
262 |
+
logits_fake = self.disc(xrec)
|
263 |
+
g_loss = -torch.mean(logits_fake)
|
264 |
+
last_layer = self.decoder.conv_out.weight
|
265 |
+
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
|
266 |
+
self.disc_weight_max)
|
267 |
+
d_weight *= adopt_weight(1, step, self.disc_start_step)
|
268 |
+
loss = nll_loss + d_weight * g_loss + codebook_loss
|
269 |
+
|
270 |
+
self.log_dict["loss"] = loss
|
271 |
+
self.log_dict["l1"] = recon_loss.mean().item()
|
272 |
+
self.log_dict["perceptual"] = p_loss.mean().item()
|
273 |
+
self.log_dict["nll_loss"] = nll_loss.item()
|
274 |
+
self.log_dict["g_loss"] = g_loss.item()
|
275 |
+
self.log_dict["d_weight"] = d_weight
|
276 |
+
self.log_dict["codebook_loss"] = codebook_loss.item()
|
277 |
+
|
278 |
+
if step > self.disc_start_step:
|
279 |
+
if self.diff_aug:
|
280 |
+
logits_real = self.disc(
|
281 |
+
DiffAugment(x.contiguous().detach(), policy=self.policy))
|
282 |
+
else:
|
283 |
+
logits_real = self.disc(x.contiguous().detach())
|
284 |
+
logits_fake = self.disc(xrec.contiguous().detach(
|
285 |
+
)) # detach so that generator isn"t also updated
|
286 |
+
d_loss = hinge_d_loss(logits_real, logits_fake)
|
287 |
+
self.log_dict["d_loss"] = d_loss
|
288 |
+
else:
|
289 |
+
d_loss = None
|
290 |
+
|
291 |
+
return loss, d_loss
|
292 |
+
|
293 |
+
@torch.no_grad()
|
294 |
+
def inference(self, data_loader, save_dir):
|
295 |
+
self.bot_encoder.eval()
|
296 |
+
self.bot_decoder_res.eval()
|
297 |
+
self.decoder.eval()
|
298 |
+
self.bot_quantize.eval()
|
299 |
+
self.bot_quant_conv.eval()
|
300 |
+
self.bot_post_quant_conv.eval()
|
301 |
+
|
302 |
+
loss_total = 0
|
303 |
+
num = 0
|
304 |
+
|
305 |
+
for _, data in enumerate(data_loader):
|
306 |
+
img_name = data['img_name'][0]
|
307 |
+
x, mask = self.feed_data(data)
|
308 |
+
xrec, _ = self.forward_step(x, mask)
|
309 |
+
|
310 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
311 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
312 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
313 |
+
nll_loss = torch.mean(nll_loss)
|
314 |
+
loss_total += nll_loss
|
315 |
+
|
316 |
+
num += x.size(0)
|
317 |
+
|
318 |
+
if x.shape[1] > 3:
|
319 |
+
# colorize with random projection
|
320 |
+
assert xrec.shape[1] > 3
|
321 |
+
# convert logits to indices
|
322 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
323 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
324 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
325 |
+
x = self.to_rgb(x)
|
326 |
+
xrec = self.to_rgb(xrec)
|
327 |
+
|
328 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
329 |
+
img_cat = ((img_cat + 1) / 2)
|
330 |
+
img_cat = img_cat.clamp_(0, 1)
|
331 |
+
save_image(
|
332 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
333 |
+
|
334 |
+
return (loss_total / num).item()
|
335 |
+
|
336 |
+
def get_current_log(self):
|
337 |
+
return self.log_dict
|
338 |
+
|
339 |
+
def update_learning_rate(self, epoch):
|
340 |
+
"""Update learning rate.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
current_iter (int): Current iteration.
|
344 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
345 |
+
Default: -1.
|
346 |
+
"""
|
347 |
+
lr = self.optimizer.param_groups[0]['lr']
|
348 |
+
|
349 |
+
if self.opt['lr_decay'] == 'step':
|
350 |
+
lr = self.opt['lr'] * (
|
351 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
352 |
+
elif self.opt['lr_decay'] == 'cos':
|
353 |
+
lr = self.opt['lr'] * (
|
354 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
355 |
+
elif self.opt['lr_decay'] == 'linear':
|
356 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
357 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
358 |
+
if epoch < self.opt['turning_point'] + 1:
|
359 |
+
# learning rate decay as 95%
|
360 |
+
# at the turning point (1 / 95% = 1.0526)
|
361 |
+
lr = self.opt['lr'] * (
|
362 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
363 |
+
else:
|
364 |
+
lr *= self.opt['gamma']
|
365 |
+
elif self.opt['lr_decay'] == 'schedule':
|
366 |
+
if epoch in self.opt['schedule']:
|
367 |
+
lr *= self.opt['gamma']
|
368 |
+
else:
|
369 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
370 |
+
# set learning rate
|
371 |
+
for param_group in self.optimizer.param_groups:
|
372 |
+
param_group['lr'] = lr
|
373 |
+
|
374 |
+
return lr
|
models/losses/__init__.py
ADDED
File without changes
|
models/losses/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (127 Bytes). View file
|
|
models/losses/__pycache__/accuracy.cpython-38.pyc
ADDED
Binary file (2.02 kB). View file
|
|
models/losses/__pycache__/cross_entropy_loss.cpython-38.pyc
ADDED
Binary file (6.76 kB). View file
|
|
models/losses/__pycache__/segmentation_loss.cpython-38.pyc
ADDED
Binary file (1.3 kB). View file
|
|
models/losses/__pycache__/vqgan_loss.cpython-38.pyc
ADDED
Binary file (3.65 kB). View file
|
|
models/losses/accuracy.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def accuracy(pred, target, topk=1, thresh=None):
|
2 |
+
"""Calculate accuracy according to the prediction and target.
|
3 |
+
|
4 |
+
Args:
|
5 |
+
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
|
6 |
+
target (torch.Tensor): The target of each prediction, shape (N, , ...)
|
7 |
+
topk (int | tuple[int], optional): If the predictions in ``topk``
|
8 |
+
matches the target, the predictions will be regarded as
|
9 |
+
correct ones. Defaults to 1.
|
10 |
+
thresh (float, optional): If not None, predictions with scores under
|
11 |
+
this threshold are considered incorrect. Default to None.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
float | tuple[float]: If the input ``topk`` is a single integer,
|
15 |
+
the function will return a single float as accuracy. If
|
16 |
+
``topk`` is a tuple containing multiple integers, the
|
17 |
+
function will return a tuple containing accuracies of
|
18 |
+
each ``topk`` number.
|
19 |
+
"""
|
20 |
+
assert isinstance(topk, (int, tuple))
|
21 |
+
if isinstance(topk, int):
|
22 |
+
topk = (topk, )
|
23 |
+
return_single = True
|
24 |
+
else:
|
25 |
+
return_single = False
|
26 |
+
|
27 |
+
maxk = max(topk)
|
28 |
+
if pred.size(0) == 0:
|
29 |
+
accu = [pred.new_tensor(0.) for i in range(len(topk))]
|
30 |
+
return accu[0] if return_single else accu
|
31 |
+
assert pred.ndim == target.ndim + 1
|
32 |
+
assert pred.size(0) == target.size(0)
|
33 |
+
assert maxk <= pred.size(1), \
|
34 |
+
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
|
35 |
+
pred_value, pred_label = pred.topk(maxk, dim=1)
|
36 |
+
# transpose to shape (maxk, N, ...)
|
37 |
+
pred_label = pred_label.transpose(0, 1)
|
38 |
+
correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
|
39 |
+
if thresh is not None:
|
40 |
+
# Only prediction values larger than thresh are counted as correct
|
41 |
+
correct = correct & (pred_value > thresh).t()
|
42 |
+
res = []
|
43 |
+
for k in topk:
|
44 |
+
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
45 |
+
res.append(correct_k.mul_(100.0 / target.numel()))
|
46 |
+
return res[0] if return_single else res
|
models/losses/cross_entropy_loss.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def reduce_loss(loss, reduction):
|
7 |
+
"""Reduce loss as specified.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
loss (Tensor): Elementwise loss tensor.
|
11 |
+
reduction (str): Options are "none", "mean" and "sum".
|
12 |
+
|
13 |
+
Return:
|
14 |
+
Tensor: Reduced loss tensor.
|
15 |
+
"""
|
16 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
17 |
+
# none: 0, elementwise_mean:1, sum: 2
|
18 |
+
if reduction_enum == 0:
|
19 |
+
return loss
|
20 |
+
elif reduction_enum == 1:
|
21 |
+
return loss.mean()
|
22 |
+
elif reduction_enum == 2:
|
23 |
+
return loss.sum()
|
24 |
+
|
25 |
+
|
26 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
|
27 |
+
"""Apply element-wise weight and reduce loss.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
loss (Tensor): Element-wise loss.
|
31 |
+
weight (Tensor): Element-wise weights.
|
32 |
+
reduction (str): Same as built-in losses of PyTorch.
|
33 |
+
avg_factor (float): Avarage factor when computing the mean of losses.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Tensor: Processed loss values.
|
37 |
+
"""
|
38 |
+
# if weight is specified, apply element-wise weight
|
39 |
+
if weight is not None:
|
40 |
+
assert weight.dim() == loss.dim()
|
41 |
+
if weight.dim() > 1:
|
42 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
43 |
+
loss = loss * weight
|
44 |
+
|
45 |
+
# if avg_factor is not specified, just reduce the loss
|
46 |
+
if avg_factor is None:
|
47 |
+
loss = reduce_loss(loss, reduction)
|
48 |
+
else:
|
49 |
+
# if reduction is mean, then average the loss by avg_factor
|
50 |
+
if reduction == 'mean':
|
51 |
+
loss = loss.sum() / avg_factor
|
52 |
+
# if reduction is 'none', then do nothing, otherwise raise an error
|
53 |
+
elif reduction != 'none':
|
54 |
+
raise ValueError('avg_factor can not be used with reduction="sum"')
|
55 |
+
return loss
|
56 |
+
|
57 |
+
|
58 |
+
def cross_entropy(pred,
|
59 |
+
label,
|
60 |
+
weight=None,
|
61 |
+
class_weight=None,
|
62 |
+
reduction='mean',
|
63 |
+
avg_factor=None,
|
64 |
+
ignore_index=-100):
|
65 |
+
"""The wrapper function for :func:`F.cross_entropy`"""
|
66 |
+
# class_weight is a manual rescaling weight given to each class.
|
67 |
+
# If given, has to be a Tensor of size C element-wise losses
|
68 |
+
loss = F.cross_entropy(
|
69 |
+
pred,
|
70 |
+
label,
|
71 |
+
weight=class_weight,
|
72 |
+
reduction='none',
|
73 |
+
ignore_index=ignore_index)
|
74 |
+
|
75 |
+
# apply weights and do the reduction
|
76 |
+
if weight is not None:
|
77 |
+
weight = weight.float()
|
78 |
+
loss = weight_reduce_loss(
|
79 |
+
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
80 |
+
|
81 |
+
return loss
|
82 |
+
|
83 |
+
|
84 |
+
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
85 |
+
"""Expand onehot labels to match the size of prediction."""
|
86 |
+
bin_labels = labels.new_zeros(target_shape)
|
87 |
+
valid_mask = (labels >= 0) & (labels != ignore_index)
|
88 |
+
inds = torch.nonzero(valid_mask, as_tuple=True)
|
89 |
+
|
90 |
+
if inds[0].numel() > 0:
|
91 |
+
if labels.dim() == 3:
|
92 |
+
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
93 |
+
else:
|
94 |
+
bin_labels[inds[0], labels[valid_mask]] = 1
|
95 |
+
|
96 |
+
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
97 |
+
if label_weights is None:
|
98 |
+
bin_label_weights = valid_mask
|
99 |
+
else:
|
100 |
+
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
101 |
+
bin_label_weights *= valid_mask
|
102 |
+
|
103 |
+
return bin_labels, bin_label_weights
|
104 |
+
|
105 |
+
|
106 |
+
def binary_cross_entropy(pred,
|
107 |
+
label,
|
108 |
+
weight=None,
|
109 |
+
reduction='mean',
|
110 |
+
avg_factor=None,
|
111 |
+
class_weight=None,
|
112 |
+
ignore_index=255):
|
113 |
+
"""Calculate the binary CrossEntropy loss.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
pred (torch.Tensor): The prediction with shape (N, 1).
|
117 |
+
label (torch.Tensor): The learning label of the prediction.
|
118 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
119 |
+
reduction (str, optional): The method used to reduce the loss.
|
120 |
+
Options are "none", "mean" and "sum".
|
121 |
+
avg_factor (int, optional): Average factor that is used to average
|
122 |
+
the loss. Defaults to None.
|
123 |
+
class_weight (list[float], optional): The weight for each class.
|
124 |
+
ignore_index (int | None): The label index to be ignored. Default: 255
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
torch.Tensor: The calculated loss
|
128 |
+
"""
|
129 |
+
if pred.dim() != label.dim():
|
130 |
+
assert (pred.dim() == 2 and label.dim() == 1) or (
|
131 |
+
pred.dim() == 4 and label.dim() == 3), \
|
132 |
+
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
133 |
+
'H, W], label shape [N, H, W] are supported'
|
134 |
+
label, weight = _expand_onehot_labels(label, weight, pred.shape,
|
135 |
+
ignore_index)
|
136 |
+
|
137 |
+
# weighted element-wise losses
|
138 |
+
if weight is not None:
|
139 |
+
weight = weight.float()
|
140 |
+
loss = F.binary_cross_entropy_with_logits(
|
141 |
+
pred, label.float(), pos_weight=class_weight, reduction='none')
|
142 |
+
# do the reduction for the weighted loss
|
143 |
+
loss = weight_reduce_loss(
|
144 |
+
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
145 |
+
|
146 |
+
return loss
|
147 |
+
|
148 |
+
|
149 |
+
def mask_cross_entropy(pred,
|
150 |
+
target,
|
151 |
+
label,
|
152 |
+
reduction='mean',
|
153 |
+
avg_factor=None,
|
154 |
+
class_weight=None,
|
155 |
+
ignore_index=None):
|
156 |
+
"""Calculate the CrossEntropy loss for masks.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
160 |
+
of classes.
|
161 |
+
target (torch.Tensor): The learning label of the prediction.
|
162 |
+
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
163 |
+
corresponding object. This will be used to select the mask in the
|
164 |
+
of the class which the object belongs to when the mask prediction
|
165 |
+
if not class-agnostic.
|
166 |
+
reduction (str, optional): The method used to reduce the loss.
|
167 |
+
Options are "none", "mean" and "sum".
|
168 |
+
avg_factor (int, optional): Average factor that is used to average
|
169 |
+
the loss. Defaults to None.
|
170 |
+
class_weight (list[float], optional): The weight for each class.
|
171 |
+
ignore_index (None): Placeholder, to be consistent with other loss.
|
172 |
+
Default: None.
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
torch.Tensor: The calculated loss
|
176 |
+
"""
|
177 |
+
assert ignore_index is None, 'BCE loss does not support ignore_index'
|
178 |
+
# TODO: handle these two reserved arguments
|
179 |
+
assert reduction == 'mean' and avg_factor is None
|
180 |
+
num_rois = pred.size()[0]
|
181 |
+
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
182 |
+
pred_slice = pred[inds, label].squeeze(1)
|
183 |
+
return F.binary_cross_entropy_with_logits(
|
184 |
+
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
185 |
+
|
186 |
+
|
187 |
+
class CrossEntropyLoss(nn.Module):
|
188 |
+
"""CrossEntropyLoss.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
192 |
+
of softmax. Defaults to False.
|
193 |
+
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
194 |
+
Defaults to False.
|
195 |
+
reduction (str, optional): . Defaults to 'mean'.
|
196 |
+
Options are "none", "mean" and "sum".
|
197 |
+
class_weight (list[float], optional): Weight of each class.
|
198 |
+
Defaults to None.
|
199 |
+
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(self,
|
203 |
+
use_sigmoid=False,
|
204 |
+
use_mask=False,
|
205 |
+
reduction='mean',
|
206 |
+
class_weight=None,
|
207 |
+
loss_weight=1.0):
|
208 |
+
super(CrossEntropyLoss, self).__init__()
|
209 |
+
assert (use_sigmoid is False) or (use_mask is False)
|
210 |
+
self.use_sigmoid = use_sigmoid
|
211 |
+
self.use_mask = use_mask
|
212 |
+
self.reduction = reduction
|
213 |
+
self.loss_weight = loss_weight
|
214 |
+
self.class_weight = class_weight
|
215 |
+
|
216 |
+
if self.use_sigmoid:
|
217 |
+
self.cls_criterion = binary_cross_entropy
|
218 |
+
elif self.use_mask:
|
219 |
+
self.cls_criterion = mask_cross_entropy
|
220 |
+
else:
|
221 |
+
self.cls_criterion = cross_entropy
|
222 |
+
|
223 |
+
def forward(self,
|
224 |
+
cls_score,
|
225 |
+
label,
|
226 |
+
weight=None,
|
227 |
+
avg_factor=None,
|
228 |
+
reduction_override=None,
|
229 |
+
**kwargs):
|
230 |
+
"""Forward function."""
|
231 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
232 |
+
reduction = (
|
233 |
+
reduction_override if reduction_override else self.reduction)
|
234 |
+
if self.class_weight is not None:
|
235 |
+
class_weight = cls_score.new_tensor(self.class_weight)
|
236 |
+
else:
|
237 |
+
class_weight = None
|
238 |
+
loss_cls = self.loss_weight * self.cls_criterion(
|
239 |
+
cls_score,
|
240 |
+
label,
|
241 |
+
weight,
|
242 |
+
class_weight=class_weight,
|
243 |
+
reduction=reduction,
|
244 |
+
avg_factor=avg_factor,
|
245 |
+
**kwargs)
|
246 |
+
return loss_cls
|
models/losses/segmentation_loss.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class BCELoss(nn.Module):
|
6 |
+
|
7 |
+
def forward(self, prediction, target):
|
8 |
+
loss = F.binary_cross_entropy_with_logits(prediction, target)
|
9 |
+
return loss, {}
|
10 |
+
|
11 |
+
|
12 |
+
class BCELossWithQuant(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, codebook_weight=1.):
|
15 |
+
super().__init__()
|
16 |
+
self.codebook_weight = codebook_weight
|
17 |
+
|
18 |
+
def forward(self, qloss, target, prediction, split):
|
19 |
+
bce_loss = F.binary_cross_entropy_with_logits(prediction, target)
|
20 |
+
loss = bce_loss + self.codebook_weight * qloss
|
21 |
+
return loss, {
|
22 |
+
"{}/total_loss".format(split): loss.clone().detach().mean(),
|
23 |
+
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
24 |
+
"{}/quant_loss".format(split): qloss.detach().mean()
|
25 |
+
}
|
models/losses/vqgan_loss.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max):
|
6 |
+
recon_grads = torch.autograd.grad(
|
7 |
+
recon_loss, last_layer, retain_graph=True)[0]
|
8 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
9 |
+
|
10 |
+
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
|
11 |
+
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
|
12 |
+
return d_weight
|
13 |
+
|
14 |
+
|
15 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
16 |
+
if global_step < threshold:
|
17 |
+
weight = value
|
18 |
+
return weight
|
19 |
+
|
20 |
+
|
21 |
+
@torch.jit.script
|
22 |
+
def hinge_d_loss(logits_real, logits_fake):
|
23 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
24 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
25 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
26 |
+
return d_loss
|
27 |
+
|
28 |
+
|
29 |
+
def DiffAugment(x, policy='', channels_first=True):
|
30 |
+
if policy:
|
31 |
+
if not channels_first:
|
32 |
+
x = x.permute(0, 3, 1, 2)
|
33 |
+
for p in policy.split(','):
|
34 |
+
for f in AUGMENT_FNS[p]:
|
35 |
+
x = f(x)
|
36 |
+
if not channels_first:
|
37 |
+
x = x.permute(0, 2, 3, 1)
|
38 |
+
x = x.contiguous()
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
def rand_brightness(x):
|
43 |
+
x = x + (
|
44 |
+
torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
def rand_saturation(x):
|
49 |
+
x_mean = x.mean(dim=1, keepdim=True)
|
50 |
+
x = (x - x_mean) * (torch.rand(
|
51 |
+
x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
def rand_contrast(x):
|
56 |
+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
57 |
+
x = (x - x_mean) * (torch.rand(
|
58 |
+
x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
def rand_translation(x, ratio=0.125):
|
63 |
+
shift_x, shift_y = int(x.size(2) * ratio +
|
64 |
+
0.5), int(x.size(3) * ratio + 0.5)
|
65 |
+
translation_x = torch.randint(
|
66 |
+
-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
67 |
+
translation_y = torch.randint(
|
68 |
+
-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
69 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
70 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
71 |
+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
72 |
+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
73 |
+
)
|
74 |
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
75 |
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
76 |
+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
77 |
+
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x,
|
78 |
+
grid_y].permute(0, 3, 1, 2)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
def rand_cutout(x, ratio=0.5):
|
83 |
+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
84 |
+
offset_x = torch.randint(
|
85 |
+
0,
|
86 |
+
x.size(2) + (1 - cutout_size[0] % 2),
|
87 |
+
size=[x.size(0), 1, 1],
|
88 |
+
device=x.device)
|
89 |
+
offset_y = torch.randint(
|
90 |
+
0,
|
91 |
+
x.size(3) + (1 - cutout_size[1] % 2),
|
92 |
+
size=[x.size(0), 1, 1],
|
93 |
+
device=x.device)
|
94 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
95 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
96 |
+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
97 |
+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
98 |
+
)
|
99 |
+
grid_x = torch.clamp(
|
100 |
+
grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
101 |
+
grid_y = torch.clamp(
|
102 |
+
grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
103 |
+
mask = torch.ones(
|
104 |
+
x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
105 |
+
mask[grid_batch, grid_x, grid_y] = 0
|
106 |
+
x = x * mask.unsqueeze(1)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
AUGMENT_FNS = {
|
111 |
+
'color': [rand_brightness, rand_saturation, rand_contrast],
|
112 |
+
'translation': [rand_translation],
|
113 |
+
'cutout': [rand_cutout],
|
114 |
+
}
|
models/parsing_gen_model.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import mmcv
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torchvision.utils import save_image
|
9 |
+
|
10 |
+
from models.archs.fcn_arch import FCNHead
|
11 |
+
from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
|
12 |
+
from models.archs.unet_arch import ShapeUNet
|
13 |
+
from models.losses.accuracy import accuracy
|
14 |
+
from models.losses.cross_entropy_loss import CrossEntropyLoss
|
15 |
+
|
16 |
+
logger = logging.getLogger('base')
|
17 |
+
|
18 |
+
|
19 |
+
class ParsingGenModel():
|
20 |
+
"""Paring Generation model.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
self.opt = opt
|
25 |
+
self.device = torch.device('cuda')
|
26 |
+
self.is_train = opt['is_train']
|
27 |
+
|
28 |
+
self.attr_embedder = ShapeAttrEmbedding(
|
29 |
+
dim=opt['embedder_dim'],
|
30 |
+
out_dim=opt['embedder_out_dim'],
|
31 |
+
cls_num_list=opt['attr_class_num']).to(self.device)
|
32 |
+
self.parsing_encoder = ShapeUNet(
|
33 |
+
in_channels=opt['encoder_in_channels']).to(self.device)
|
34 |
+
self.parsing_decoder = FCNHead(
|
35 |
+
in_channels=opt['fc_in_channels'],
|
36 |
+
in_index=opt['fc_in_index'],
|
37 |
+
channels=opt['fc_channels'],
|
38 |
+
num_convs=opt['fc_num_convs'],
|
39 |
+
concat_input=opt['fc_concat_input'],
|
40 |
+
dropout_ratio=opt['fc_dropout_ratio'],
|
41 |
+
num_classes=opt['fc_num_classes'],
|
42 |
+
align_corners=opt['fc_align_corners'],
|
43 |
+
).to(self.device)
|
44 |
+
|
45 |
+
self.init_training_settings()
|
46 |
+
|
47 |
+
self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
|
48 |
+
[250, 235, 215], [255, 250, 205], [211, 211, 211],
|
49 |
+
[70, 130, 180], [127, 255, 212], [0, 100, 0],
|
50 |
+
[50, 205, 50], [255, 255, 0], [245, 222, 179],
|
51 |
+
[255, 140, 0], [255, 0, 0], [16, 78, 139],
|
52 |
+
[144, 238, 144], [50, 205, 174], [50, 155, 250],
|
53 |
+
[160, 140, 88], [213, 140, 88], [90, 140, 90],
|
54 |
+
[185, 210, 205], [130, 165, 180], [225, 141, 151]]
|
55 |
+
|
56 |
+
def init_training_settings(self):
|
57 |
+
optim_params = []
|
58 |
+
for v in self.attr_embedder.parameters():
|
59 |
+
if v.requires_grad:
|
60 |
+
optim_params.append(v)
|
61 |
+
for v in self.parsing_encoder.parameters():
|
62 |
+
if v.requires_grad:
|
63 |
+
optim_params.append(v)
|
64 |
+
for v in self.parsing_decoder.parameters():
|
65 |
+
if v.requires_grad:
|
66 |
+
optim_params.append(v)
|
67 |
+
# set up optimizers
|
68 |
+
self.optimizer = torch.optim.Adam(
|
69 |
+
optim_params,
|
70 |
+
self.opt['lr'],
|
71 |
+
weight_decay=self.opt['weight_decay'])
|
72 |
+
self.log_dict = OrderedDict()
|
73 |
+
self.entropy_loss = CrossEntropyLoss().to(self.device)
|
74 |
+
|
75 |
+
def feed_data(self, data):
|
76 |
+
self.pose = data['densepose'].to(self.device)
|
77 |
+
self.attr = data['attr'].to(self.device)
|
78 |
+
self.segm = data['segm'].to(self.device)
|
79 |
+
|
80 |
+
def optimize_parameters(self):
|
81 |
+
self.attr_embedder.train()
|
82 |
+
self.parsing_encoder.train()
|
83 |
+
self.parsing_decoder.train()
|
84 |
+
|
85 |
+
self.attr_embedding = self.attr_embedder(self.attr)
|
86 |
+
self.pose_enc = self.parsing_encoder(self.pose, self.attr_embedding)
|
87 |
+
self.seg_logits = self.parsing_decoder(self.pose_enc)
|
88 |
+
|
89 |
+
loss = self.entropy_loss(self.seg_logits, self.segm)
|
90 |
+
|
91 |
+
self.optimizer.zero_grad()
|
92 |
+
loss.backward()
|
93 |
+
self.optimizer.step()
|
94 |
+
|
95 |
+
self.log_dict['loss_total'] = loss
|
96 |
+
|
97 |
+
def get_vis(self, save_path):
|
98 |
+
img_cat = torch.cat([
|
99 |
+
self.pose,
|
100 |
+
self.segm,
|
101 |
+
], dim=3).detach()
|
102 |
+
img_cat = ((img_cat + 1) / 2)
|
103 |
+
|
104 |
+
img_cat = img_cat.clamp_(0, 1)
|
105 |
+
|
106 |
+
save_image(img_cat, save_path, nrow=1, padding=4)
|
107 |
+
|
108 |
+
def inference(self, data_loader, save_dir):
|
109 |
+
self.attr_embedder.eval()
|
110 |
+
self.parsing_encoder.eval()
|
111 |
+
self.parsing_decoder.eval()
|
112 |
+
|
113 |
+
acc = 0
|
114 |
+
num = 0
|
115 |
+
|
116 |
+
for _, data in enumerate(data_loader):
|
117 |
+
pose = data['densepose'].to(self.device)
|
118 |
+
attr = data['attr'].to(self.device)
|
119 |
+
segm = data['segm'].to(self.device)
|
120 |
+
img_name = data['img_name']
|
121 |
+
|
122 |
+
num += pose.size(0)
|
123 |
+
with torch.no_grad():
|
124 |
+
attr_embedding = self.attr_embedder(attr)
|
125 |
+
pose_enc = self.parsing_encoder(pose, attr_embedding)
|
126 |
+
seg_logits = self.parsing_decoder(pose_enc)
|
127 |
+
seg_pred = seg_logits.argmax(dim=1)
|
128 |
+
acc += accuracy(seg_logits, segm)
|
129 |
+
palette_label = self.palette_result(segm.cpu().numpy())
|
130 |
+
palette_pred = self.palette_result(seg_pred.cpu().numpy())
|
131 |
+
pose_numpy = ((pose[0] + 1) / 2. * 255.).expand(
|
132 |
+
3,
|
133 |
+
pose[0].size(1),
|
134 |
+
pose[0].size(2),
|
135 |
+
).cpu().numpy().clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
|
136 |
+
concat_result = np.concatenate(
|
137 |
+
(pose_numpy, palette_pred, palette_label), axis=1)
|
138 |
+
mmcv.imwrite(concat_result, f'{save_dir}/{img_name[0]}')
|
139 |
+
|
140 |
+
self.attr_embedder.train()
|
141 |
+
self.parsing_encoder.train()
|
142 |
+
self.parsing_decoder.train()
|
143 |
+
return (acc / num).item()
|
144 |
+
|
145 |
+
def get_current_log(self):
|
146 |
+
return self.log_dict
|
147 |
+
|
148 |
+
def update_learning_rate(self, epoch):
|
149 |
+
"""Update learning rate.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
current_iter (int): Current iteration.
|
153 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
154 |
+
Default: -1.
|
155 |
+
"""
|
156 |
+
lr = self.optimizer.param_groups[0]['lr']
|
157 |
+
|
158 |
+
if self.opt['lr_decay'] == 'step':
|
159 |
+
lr = self.opt['lr'] * (
|
160 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
161 |
+
elif self.opt['lr_decay'] == 'cos':
|
162 |
+
lr = self.opt['lr'] * (
|
163 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
164 |
+
elif self.opt['lr_decay'] == 'linear':
|
165 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
166 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
167 |
+
if epoch < self.opt['turning_point'] + 1:
|
168 |
+
# learning rate decay as 95%
|
169 |
+
# at the turning point (1 / 95% = 1.0526)
|
170 |
+
lr = self.opt['lr'] * (
|
171 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
172 |
+
else:
|
173 |
+
lr *= self.opt['gamma']
|
174 |
+
elif self.opt['lr_decay'] == 'schedule':
|
175 |
+
if epoch in self.opt['schedule']:
|
176 |
+
lr *= self.opt['gamma']
|
177 |
+
else:
|
178 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
179 |
+
# set learning rate
|
180 |
+
for param_group in self.optimizer.param_groups:
|
181 |
+
param_group['lr'] = lr
|
182 |
+
|
183 |
+
return lr
|
184 |
+
|
185 |
+
def save_network(self, save_path):
|
186 |
+
"""Save networks.
|
187 |
+
"""
|
188 |
+
|
189 |
+
save_dict = {}
|
190 |
+
save_dict['embedder'] = self.attr_embedder.state_dict()
|
191 |
+
save_dict['encoder'] = self.parsing_encoder.state_dict()
|
192 |
+
save_dict['decoder'] = self.parsing_decoder.state_dict()
|
193 |
+
|
194 |
+
torch.save(save_dict, save_path)
|
195 |
+
|
196 |
+
def load_network(self):
|
197 |
+
checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
|
198 |
+
|
199 |
+
self.attr_embedder.load_state_dict(checkpoint['embedder'], strict=True)
|
200 |
+
self.attr_embedder.eval()
|
201 |
+
|
202 |
+
self.parsing_encoder.load_state_dict(
|
203 |
+
checkpoint['encoder'], strict=True)
|
204 |
+
self.parsing_encoder.eval()
|
205 |
+
|
206 |
+
self.parsing_decoder.load_state_dict(
|
207 |
+
checkpoint['decoder'], strict=True)
|
208 |
+
self.parsing_decoder.eval()
|
209 |
+
|
210 |
+
def palette_result(self, result):
|
211 |
+
seg = result[0]
|
212 |
+
palette = np.array(self.palette)
|
213 |
+
assert palette.shape[1] == 3
|
214 |
+
assert len(palette.shape) == 2
|
215 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
216 |
+
for label, color in enumerate(palette):
|
217 |
+
color_seg[seg == label, :] = color
|
218 |
+
# convert to BGR
|
219 |
+
color_seg = color_seg[..., ::-1]
|
220 |
+
return color_seg
|
models/sample_model.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.distributions as dists
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchvision.utils import save_image
|
8 |
+
|
9 |
+
from models.archs.fcn_arch import FCNHead, MultiHeadFCNHead
|
10 |
+
from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
|
11 |
+
from models.archs.transformer_arch import TransformerMultiHead
|
12 |
+
from models.archs.unet_arch import ShapeUNet, UNet
|
13 |
+
from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
|
14 |
+
VectorQuantizer,
|
15 |
+
VectorQuantizerSpatialTextureAware,
|
16 |
+
VectorQuantizerTexture)
|
17 |
+
|
18 |
+
logger = logging.getLogger('base')
|
19 |
+
|
20 |
+
|
21 |
+
class BaseSampleModel():
|
22 |
+
"""Base Model"""
|
23 |
+
|
24 |
+
def __init__(self, opt):
|
25 |
+
self.opt = opt
|
26 |
+
self.device = torch.device('cuda')
|
27 |
+
|
28 |
+
# hierarchical VQVAE
|
29 |
+
self.decoder = Decoder(
|
30 |
+
in_channels=opt['top_in_channels'],
|
31 |
+
resolution=opt['top_resolution'],
|
32 |
+
z_channels=opt['top_z_channels'],
|
33 |
+
ch=opt['top_ch'],
|
34 |
+
out_ch=opt['top_out_ch'],
|
35 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
36 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
37 |
+
ch_mult=opt['top_ch_mult'],
|
38 |
+
dropout=opt['top_dropout'],
|
39 |
+
resamp_with_conv=True,
|
40 |
+
give_pre_end=False).to(self.device)
|
41 |
+
self.top_quantize = VectorQuantizerTexture(
|
42 |
+
1024, opt['embed_dim'], beta=0.25).to(self.device)
|
43 |
+
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
44 |
+
opt["top_z_channels"],
|
45 |
+
1).to(self.device)
|
46 |
+
self.load_top_pretrain_models()
|
47 |
+
|
48 |
+
self.bot_decoder_res = DecoderRes(
|
49 |
+
in_channels=opt['bot_in_channels'],
|
50 |
+
resolution=opt['bot_resolution'],
|
51 |
+
z_channels=opt['bot_z_channels'],
|
52 |
+
ch=opt['bot_ch'],
|
53 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
54 |
+
ch_mult=opt['bot_ch_mult'],
|
55 |
+
dropout=opt['bot_dropout'],
|
56 |
+
give_pre_end=False).to(self.device)
|
57 |
+
self.bot_quantize = VectorQuantizerSpatialTextureAware(
|
58 |
+
opt['bot_n_embed'],
|
59 |
+
opt['embed_dim'],
|
60 |
+
beta=0.25,
|
61 |
+
spatial_size=opt['bot_codebook_spatial_size']).to(self.device)
|
62 |
+
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
63 |
+
opt["bot_z_channels"],
|
64 |
+
1).to(self.device)
|
65 |
+
self.load_bot_pretrain_network()
|
66 |
+
|
67 |
+
# top -> bot prediction
|
68 |
+
self.index_pred_guidance_encoder = UNet(
|
69 |
+
in_channels=opt['index_pred_encoder_in_channels']).to(self.device)
|
70 |
+
self.index_pred_decoder = MultiHeadFCNHead(
|
71 |
+
in_channels=opt['index_pred_fc_in_channels'],
|
72 |
+
in_index=opt['index_pred_fc_in_index'],
|
73 |
+
channels=opt['index_pred_fc_channels'],
|
74 |
+
num_convs=opt['index_pred_fc_num_convs'],
|
75 |
+
concat_input=opt['index_pred_fc_concat_input'],
|
76 |
+
dropout_ratio=opt['index_pred_fc_dropout_ratio'],
|
77 |
+
num_classes=opt['index_pred_fc_num_classes'],
|
78 |
+
align_corners=opt['index_pred_fc_align_corners'],
|
79 |
+
num_head=18).to(self.device)
|
80 |
+
self.load_index_pred_network()
|
81 |
+
|
82 |
+
# VAE for segmentation mask
|
83 |
+
self.segm_encoder = Encoder(
|
84 |
+
ch=opt['segm_ch'],
|
85 |
+
num_res_blocks=opt['segm_num_res_blocks'],
|
86 |
+
attn_resolutions=opt['segm_attn_resolutions'],
|
87 |
+
ch_mult=opt['segm_ch_mult'],
|
88 |
+
in_channels=opt['segm_in_channels'],
|
89 |
+
resolution=opt['segm_resolution'],
|
90 |
+
z_channels=opt['segm_z_channels'],
|
91 |
+
double_z=opt['segm_double_z'],
|
92 |
+
dropout=opt['segm_dropout']).to(self.device)
|
93 |
+
self.segm_quantizer = VectorQuantizer(
|
94 |
+
opt['segm_n_embed'],
|
95 |
+
opt['segm_embed_dim'],
|
96 |
+
beta=0.25,
|
97 |
+
sane_index_shape=True).to(self.device)
|
98 |
+
self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
|
99 |
+
opt['segm_embed_dim'],
|
100 |
+
1).to(self.device)
|
101 |
+
self.load_pretrained_segm_token()
|
102 |
+
|
103 |
+
# define sampler
|
104 |
+
self.sampler_fn = TransformerMultiHead(
|
105 |
+
codebook_size=opt['codebook_size'],
|
106 |
+
segm_codebook_size=opt['segm_codebook_size'],
|
107 |
+
texture_codebook_size=opt['texture_codebook_size'],
|
108 |
+
bert_n_emb=opt['bert_n_emb'],
|
109 |
+
bert_n_layers=opt['bert_n_layers'],
|
110 |
+
bert_n_head=opt['bert_n_head'],
|
111 |
+
block_size=opt['block_size'],
|
112 |
+
latent_shape=opt['latent_shape'],
|
113 |
+
embd_pdrop=opt['embd_pdrop'],
|
114 |
+
resid_pdrop=opt['resid_pdrop'],
|
115 |
+
attn_pdrop=opt['attn_pdrop'],
|
116 |
+
num_head=opt['num_head']).to(self.device)
|
117 |
+
self.load_sampler_pretrained_network()
|
118 |
+
|
119 |
+
self.shape = tuple(opt['latent_shape'])
|
120 |
+
|
121 |
+
self.mask_id = opt['codebook_size']
|
122 |
+
self.sample_steps = opt['sample_steps']
|
123 |
+
|
124 |
+
def load_top_pretrain_models(self):
|
125 |
+
# load pretrained vqgan
|
126 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
127 |
+
|
128 |
+
self.decoder.load_state_dict(
|
129 |
+
top_vae_checkpoint['decoder'], strict=True)
|
130 |
+
self.top_quantize.load_state_dict(
|
131 |
+
top_vae_checkpoint['quantize'], strict=True)
|
132 |
+
self.top_post_quant_conv.load_state_dict(
|
133 |
+
top_vae_checkpoint['post_quant_conv'], strict=True)
|
134 |
+
|
135 |
+
self.decoder.eval()
|
136 |
+
self.top_quantize.eval()
|
137 |
+
self.top_post_quant_conv.eval()
|
138 |
+
|
139 |
+
def load_bot_pretrain_network(self):
|
140 |
+
checkpoint = torch.load(self.opt['bot_vae_path'])
|
141 |
+
self.bot_decoder_res.load_state_dict(
|
142 |
+
checkpoint['bot_decoder_res'], strict=True)
|
143 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
144 |
+
self.bot_quantize.load_state_dict(
|
145 |
+
checkpoint['bot_quantize'], strict=True)
|
146 |
+
self.bot_post_quant_conv.load_state_dict(
|
147 |
+
checkpoint['bot_post_quant_conv'], strict=True)
|
148 |
+
|
149 |
+
self.bot_decoder_res.eval()
|
150 |
+
self.decoder.eval()
|
151 |
+
self.bot_quantize.eval()
|
152 |
+
self.bot_post_quant_conv.eval()
|
153 |
+
|
154 |
+
def load_pretrained_segm_token(self):
|
155 |
+
# load pretrained vqgan for segmentation mask
|
156 |
+
segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
|
157 |
+
self.segm_encoder.load_state_dict(
|
158 |
+
segm_token_checkpoint['encoder'], strict=True)
|
159 |
+
self.segm_quantizer.load_state_dict(
|
160 |
+
segm_token_checkpoint['quantize'], strict=True)
|
161 |
+
self.segm_quant_conv.load_state_dict(
|
162 |
+
segm_token_checkpoint['quant_conv'], strict=True)
|
163 |
+
|
164 |
+
self.segm_encoder.eval()
|
165 |
+
self.segm_quantizer.eval()
|
166 |
+
self.segm_quant_conv.eval()
|
167 |
+
|
168 |
+
def load_index_pred_network(self):
|
169 |
+
checkpoint = torch.load(self.opt['pretrained_index_network'])
|
170 |
+
self.index_pred_guidance_encoder.load_state_dict(
|
171 |
+
checkpoint['guidance_encoder'], strict=True)
|
172 |
+
self.index_pred_decoder.load_state_dict(
|
173 |
+
checkpoint['index_decoder'], strict=True)
|
174 |
+
|
175 |
+
self.index_pred_guidance_encoder.eval()
|
176 |
+
self.index_pred_decoder.eval()
|
177 |
+
|
178 |
+
def load_sampler_pretrained_network(self):
|
179 |
+
checkpoint = torch.load(self.opt['pretrained_sampler'])
|
180 |
+
self.sampler_fn.load_state_dict(checkpoint, strict=True)
|
181 |
+
self.sampler_fn.eval()
|
182 |
+
|
183 |
+
def bot_index_prediction(self, feature_top, texture_mask):
|
184 |
+
self.index_pred_guidance_encoder.eval()
|
185 |
+
self.index_pred_decoder.eval()
|
186 |
+
|
187 |
+
texture_mask_flatten = F.interpolate(
|
188 |
+
texture_mask, (32, 16), mode='nearest').view(-1).long()
|
189 |
+
|
190 |
+
min_encodings_indices_list = [
|
191 |
+
torch.full(
|
192 |
+
texture_mask_flatten.size(),
|
193 |
+
fill_value=-1,
|
194 |
+
dtype=torch.long,
|
195 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
196 |
+
]
|
197 |
+
with torch.no_grad():
|
198 |
+
feature_enc = self.index_pred_guidance_encoder(feature_top)
|
199 |
+
memory_logits_list = self.index_pred_decoder(feature_enc)
|
200 |
+
for codebook_idx, memory_logits in enumerate(memory_logits_list):
|
201 |
+
region_of_interest = texture_mask_flatten == codebook_idx
|
202 |
+
if torch.sum(region_of_interest) > 0:
|
203 |
+
memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
|
204 |
+
memory_indices_pred = memory_indices_pred
|
205 |
+
min_encodings_indices_list[codebook_idx][
|
206 |
+
region_of_interest] = memory_indices_pred[
|
207 |
+
region_of_interest]
|
208 |
+
min_encodings_indices_return_list = [
|
209 |
+
min_encodings_indices.view((1, 32, 16))
|
210 |
+
for min_encodings_indices in min_encodings_indices_list
|
211 |
+
]
|
212 |
+
|
213 |
+
return min_encodings_indices_return_list
|
214 |
+
|
215 |
+
def sample_and_refine(self, save_dir=None, img_name=None):
|
216 |
+
# sample 32x16 features indices
|
217 |
+
sampled_top_indices_list = self.sample_fn(
|
218 |
+
temp=1, sample_steps=self.sample_steps)
|
219 |
+
|
220 |
+
for sample_idx in range(self.batch_size):
|
221 |
+
sample_indices = [
|
222 |
+
sampled_indices_cur[sample_idx:sample_idx + 1]
|
223 |
+
for sampled_indices_cur in sampled_top_indices_list
|
224 |
+
]
|
225 |
+
top_quant = self.top_quantize.get_codebook_entry(
|
226 |
+
sample_indices, self.texture_mask[sample_idx:sample_idx + 1],
|
227 |
+
(sample_indices[0].size(0), self.shape[0], self.shape[1],
|
228 |
+
self.opt["top_z_channels"]))
|
229 |
+
|
230 |
+
top_quant = self.top_post_quant_conv(top_quant)
|
231 |
+
|
232 |
+
bot_indices_list = self.bot_index_prediction(
|
233 |
+
top_quant, self.texture_mask[sample_idx:sample_idx + 1])
|
234 |
+
|
235 |
+
quant_bot = self.bot_quantize.get_codebook_entry(
|
236 |
+
bot_indices_list, self.texture_mask[sample_idx:sample_idx + 1],
|
237 |
+
(bot_indices_list[0].size(0), bot_indices_list[0].size(1),
|
238 |
+
bot_indices_list[0].size(2),
|
239 |
+
self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
|
240 |
+
quant_bot = self.bot_post_quant_conv(quant_bot)
|
241 |
+
bot_dec_res = self.bot_decoder_res(quant_bot)
|
242 |
+
|
243 |
+
dec = self.decoder(top_quant, bot_h=bot_dec_res)
|
244 |
+
|
245 |
+
dec = ((dec + 1) / 2)
|
246 |
+
dec = dec.clamp_(0, 1)
|
247 |
+
if save_dir is None and img_name is None:
|
248 |
+
return dec
|
249 |
+
else:
|
250 |
+
save_image(
|
251 |
+
dec,
|
252 |
+
f'{save_dir}/{img_name[sample_idx]}',
|
253 |
+
nrow=1,
|
254 |
+
padding=4)
|
255 |
+
|
256 |
+
def sample_fn(self, temp=1.0, sample_steps=None):
|
257 |
+
self.sampler_fn.eval()
|
258 |
+
|
259 |
+
x_t = torch.ones((self.batch_size, np.prod(self.shape)),
|
260 |
+
device=self.device).long() * self.mask_id
|
261 |
+
unmasked = torch.zeros_like(x_t, device=self.device).bool()
|
262 |
+
sample_steps = list(range(1, sample_steps + 1))
|
263 |
+
|
264 |
+
texture_tokens = F.interpolate(
|
265 |
+
self.texture_mask, (32, 16),
|
266 |
+
mode='nearest').view(self.batch_size, -1).long()
|
267 |
+
|
268 |
+
texture_mask_flatten = texture_tokens.view(-1)
|
269 |
+
|
270 |
+
# min_encodings_indices_list would be used to visualize the image
|
271 |
+
min_encodings_indices_list = [
|
272 |
+
torch.full(
|
273 |
+
texture_mask_flatten.size(),
|
274 |
+
fill_value=-1,
|
275 |
+
dtype=torch.long,
|
276 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
277 |
+
]
|
278 |
+
|
279 |
+
for t in reversed(sample_steps):
|
280 |
+
t = torch.full((self.batch_size, ),
|
281 |
+
t,
|
282 |
+
device=self.device,
|
283 |
+
dtype=torch.long)
|
284 |
+
|
285 |
+
# where to unmask
|
286 |
+
changes = torch.rand(
|
287 |
+
x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
|
288 |
+
# don't unmask somewhere already unmasked
|
289 |
+
changes = torch.bitwise_xor(changes,
|
290 |
+
torch.bitwise_and(changes, unmasked))
|
291 |
+
# update mask with changes
|
292 |
+
unmasked = torch.bitwise_or(unmasked, changes)
|
293 |
+
|
294 |
+
x_0_logits_list = self.sampler_fn(
|
295 |
+
x_t, self.segm_tokens, texture_tokens, t=t)
|
296 |
+
|
297 |
+
changes_flatten = changes.view(-1)
|
298 |
+
ori_shape = x_t.shape # [b, h*w]
|
299 |
+
x_t = x_t.view(-1) # [b*h*w]
|
300 |
+
for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
|
301 |
+
if torch.sum(texture_mask_flatten[changes_flatten] ==
|
302 |
+
codebook_idx) > 0:
|
303 |
+
# scale by temperature
|
304 |
+
x_0_logits = x_0_logits / temp
|
305 |
+
x_0_dist = dists.Categorical(logits=x_0_logits)
|
306 |
+
x_0_hat = x_0_dist.sample().long()
|
307 |
+
x_0_hat = x_0_hat.view(-1)
|
308 |
+
|
309 |
+
# only replace the changed indices with corresponding codebook_idx
|
310 |
+
changes_segm = torch.bitwise_and(
|
311 |
+
changes_flatten, texture_mask_flatten == codebook_idx)
|
312 |
+
|
313 |
+
# x_t would be the input to the transformer, so the index range should be continual one
|
314 |
+
x_t[changes_segm] = x_0_hat[
|
315 |
+
changes_segm] + 1024 * codebook_idx
|
316 |
+
min_encodings_indices_list[codebook_idx][
|
317 |
+
changes_segm] = x_0_hat[changes_segm]
|
318 |
+
|
319 |
+
x_t = x_t.view(ori_shape) # [b, h*w]
|
320 |
+
|
321 |
+
min_encodings_indices_return_list = [
|
322 |
+
min_encodings_indices.view(ori_shape)
|
323 |
+
for min_encodings_indices in min_encodings_indices_list
|
324 |
+
]
|
325 |
+
|
326 |
+
self.sampler_fn.train()
|
327 |
+
|
328 |
+
return min_encodings_indices_return_list
|
329 |
+
|
330 |
+
@torch.no_grad()
|
331 |
+
def get_quantized_segm(self, segm):
|
332 |
+
segm_one_hot = F.one_hot(
|
333 |
+
segm.squeeze(1).long(),
|
334 |
+
num_classes=self.opt['segm_num_segm_classes']).permute(
|
335 |
+
0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
336 |
+
encoded_segm_mask = self.segm_encoder(segm_one_hot)
|
337 |
+
encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
|
338 |
+
_, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
|
339 |
+
|
340 |
+
return segm_tokens
|
341 |
+
|
342 |
+
|
343 |
+
class SampleFromParsingModel(BaseSampleModel):
|
344 |
+
"""SampleFromParsing model.
|
345 |
+
"""
|
346 |
+
|
347 |
+
def feed_data(self, data):
|
348 |
+
self.segm = data['segm'].to(self.device)
|
349 |
+
self.texture_mask = data['texture_mask'].to(self.device)
|
350 |
+
self.batch_size = self.segm.size(0)
|
351 |
+
|
352 |
+
self.segm_tokens = self.get_quantized_segm(self.segm)
|
353 |
+
self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
|
354 |
+
|
355 |
+
def inference(self, data_loader, save_dir):
|
356 |
+
for _, data in enumerate(data_loader):
|
357 |
+
img_name = data['img_name']
|
358 |
+
self.feed_data(data)
|
359 |
+
with torch.no_grad():
|
360 |
+
self.sample_and_refine(save_dir, img_name)
|
361 |
+
|
362 |
+
|
363 |
+
class SampleFromPoseModel(BaseSampleModel):
|
364 |
+
"""SampleFromPose model.
|
365 |
+
"""
|
366 |
+
|
367 |
+
def __init__(self, opt):
|
368 |
+
super().__init__(opt)
|
369 |
+
# pose-to-parsing
|
370 |
+
self.shape_attr_embedder = ShapeAttrEmbedding(
|
371 |
+
dim=opt['shape_embedder_dim'],
|
372 |
+
out_dim=opt['shape_embedder_out_dim'],
|
373 |
+
cls_num_list=opt['shape_attr_class_num']).to(self.device)
|
374 |
+
self.shape_parsing_encoder = ShapeUNet(
|
375 |
+
in_channels=opt['shape_encoder_in_channels']).to(self.device)
|
376 |
+
self.shape_parsing_decoder = FCNHead(
|
377 |
+
in_channels=opt['shape_fc_in_channels'],
|
378 |
+
in_index=opt['shape_fc_in_index'],
|
379 |
+
channels=opt['shape_fc_channels'],
|
380 |
+
num_convs=opt['shape_fc_num_convs'],
|
381 |
+
concat_input=opt['shape_fc_concat_input'],
|
382 |
+
dropout_ratio=opt['shape_fc_dropout_ratio'],
|
383 |
+
num_classes=opt['shape_fc_num_classes'],
|
384 |
+
align_corners=opt['shape_fc_align_corners'],
|
385 |
+
).to(self.device)
|
386 |
+
self.load_shape_generation_models()
|
387 |
+
|
388 |
+
self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
|
389 |
+
[250, 235, 215], [255, 250, 205], [211, 211, 211],
|
390 |
+
[70, 130, 180], [127, 255, 212], [0, 100, 0],
|
391 |
+
[50, 205, 50], [255, 255, 0], [245, 222, 179],
|
392 |
+
[255, 140, 0], [255, 0, 0], [16, 78, 139],
|
393 |
+
[144, 238, 144], [50, 205, 174], [50, 155, 250],
|
394 |
+
[160, 140, 88], [213, 140, 88], [90, 140, 90],
|
395 |
+
[185, 210, 205], [130, 165, 180], [225, 141, 151]]
|
396 |
+
|
397 |
+
def load_shape_generation_models(self):
|
398 |
+
checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
|
399 |
+
|
400 |
+
self.shape_attr_embedder.load_state_dict(
|
401 |
+
checkpoint['embedder'], strict=True)
|
402 |
+
self.shape_attr_embedder.eval()
|
403 |
+
|
404 |
+
self.shape_parsing_encoder.load_state_dict(
|
405 |
+
checkpoint['encoder'], strict=True)
|
406 |
+
self.shape_parsing_encoder.eval()
|
407 |
+
|
408 |
+
self.shape_parsing_decoder.load_state_dict(
|
409 |
+
checkpoint['decoder'], strict=True)
|
410 |
+
self.shape_parsing_decoder.eval()
|
411 |
+
|
412 |
+
def feed_data(self, data):
|
413 |
+
self.pose = data['densepose'].to(self.device)
|
414 |
+
self.batch_size = self.pose.size(0)
|
415 |
+
|
416 |
+
self.shape_attr = data['shape_attr'].to(self.device)
|
417 |
+
self.upper_fused_attr = data['upper_fused_attr'].to(self.device)
|
418 |
+
self.lower_fused_attr = data['lower_fused_attr'].to(self.device)
|
419 |
+
self.outer_fused_attr = data['outer_fused_attr'].to(self.device)
|
420 |
+
|
421 |
+
def inference(self, data_loader, save_dir):
|
422 |
+
for _, data in enumerate(data_loader):
|
423 |
+
img_name = data['img_name']
|
424 |
+
self.feed_data(data)
|
425 |
+
with torch.no_grad():
|
426 |
+
self.generate_parsing_map()
|
427 |
+
self.generate_quantized_segm()
|
428 |
+
self.generate_texture_map()
|
429 |
+
self.sample_and_refine(save_dir, img_name)
|
430 |
+
|
431 |
+
def generate_parsing_map(self):
|
432 |
+
with torch.no_grad():
|
433 |
+
attr_embedding = self.shape_attr_embedder(self.shape_attr)
|
434 |
+
pose_enc = self.shape_parsing_encoder(self.pose, attr_embedding)
|
435 |
+
seg_logits = self.shape_parsing_decoder(pose_enc)
|
436 |
+
self.segm = seg_logits.argmax(dim=1)
|
437 |
+
self.segm = self.segm.unsqueeze(1)
|
438 |
+
|
439 |
+
def generate_quantized_segm(self):
|
440 |
+
self.segm_tokens = self.get_quantized_segm(self.segm)
|
441 |
+
self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
|
442 |
+
|
443 |
+
def generate_texture_map(self):
|
444 |
+
upper_cls = [1., 4.]
|
445 |
+
lower_cls = [3., 5., 21.]
|
446 |
+
outer_cls = [2.]
|
447 |
+
|
448 |
+
mask_batch = []
|
449 |
+
for idx in range(self.batch_size):
|
450 |
+
mask = torch.zeros_like(self.segm[idx])
|
451 |
+
upper_fused_attr = self.upper_fused_attr[idx]
|
452 |
+
lower_fused_attr = self.lower_fused_attr[idx]
|
453 |
+
outer_fused_attr = self.outer_fused_attr[idx]
|
454 |
+
if upper_fused_attr != 17:
|
455 |
+
for cls in upper_cls:
|
456 |
+
mask[self.segm[idx] == cls] = upper_fused_attr + 1
|
457 |
+
|
458 |
+
if lower_fused_attr != 17:
|
459 |
+
for cls in lower_cls:
|
460 |
+
mask[self.segm[idx] == cls] = lower_fused_attr + 1
|
461 |
+
|
462 |
+
if outer_fused_attr != 17:
|
463 |
+
for cls in outer_cls:
|
464 |
+
mask[self.segm[idx] == cls] = outer_fused_attr + 1
|
465 |
+
|
466 |
+
mask_batch.append(mask)
|
467 |
+
self.texture_mask = torch.stack(mask_batch, dim=0).to(torch.float32)
|
468 |
+
|
469 |
+
def feed_pose_data(self, pose_img):
|
470 |
+
# for ui demo
|
471 |
+
|
472 |
+
self.pose = pose_img.to(self.device)
|
473 |
+
self.batch_size = self.pose.size(0)
|
474 |
+
|
475 |
+
def feed_shape_attributes(self, shape_attr):
|
476 |
+
# for ui demo
|
477 |
+
|
478 |
+
self.shape_attr = shape_attr.to(self.device)
|
479 |
+
|
480 |
+
def feed_texture_attributes(self, texture_attr):
|
481 |
+
# for ui demo
|
482 |
+
|
483 |
+
self.upper_fused_attr = texture_attr[0].unsqueeze(0).to(self.device)
|
484 |
+
self.lower_fused_attr = texture_attr[1].unsqueeze(0).to(self.device)
|
485 |
+
self.outer_fused_attr = texture_attr[2].unsqueeze(0).to(self.device)
|
486 |
+
|
487 |
+
def palette_result(self, result):
|
488 |
+
|
489 |
+
seg = result[0]
|
490 |
+
palette = np.array(self.palette)
|
491 |
+
assert palette.shape[1] == 3
|
492 |
+
assert len(palette.shape) == 2
|
493 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
494 |
+
for label, color in enumerate(palette):
|
495 |
+
color_seg[seg == label, :] = color
|
496 |
+
# convert to BGR
|
497 |
+
# color_seg = color_seg[..., ::-1]
|
498 |
+
return color_seg
|
models/transformer_model.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.distributions as dists
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
|
11 |
+
from models.archs.transformer_arch import TransformerMultiHead
|
12 |
+
from models.archs.vqgan_arch import (Decoder, Encoder, VectorQuantizer,
|
13 |
+
VectorQuantizerTexture)
|
14 |
+
|
15 |
+
logger = logging.getLogger('base')
|
16 |
+
|
17 |
+
|
18 |
+
class TransformerTextureAwareModel():
|
19 |
+
"""Texture-Aware Diffusion based Transformer model.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, opt):
|
23 |
+
self.opt = opt
|
24 |
+
self.device = torch.device('cuda')
|
25 |
+
self.is_train = opt['is_train']
|
26 |
+
|
27 |
+
# VQVAE for image
|
28 |
+
self.img_encoder = Encoder(
|
29 |
+
ch=opt['img_ch'],
|
30 |
+
num_res_blocks=opt['img_num_res_blocks'],
|
31 |
+
attn_resolutions=opt['img_attn_resolutions'],
|
32 |
+
ch_mult=opt['img_ch_mult'],
|
33 |
+
in_channels=opt['img_in_channels'],
|
34 |
+
resolution=opt['img_resolution'],
|
35 |
+
z_channels=opt['img_z_channels'],
|
36 |
+
double_z=opt['img_double_z'],
|
37 |
+
dropout=opt['img_dropout']).to(self.device)
|
38 |
+
self.img_decoder = Decoder(
|
39 |
+
in_channels=opt['img_in_channels'],
|
40 |
+
resolution=opt['img_resolution'],
|
41 |
+
z_channels=opt['img_z_channels'],
|
42 |
+
ch=opt['img_ch'],
|
43 |
+
out_ch=opt['img_out_ch'],
|
44 |
+
num_res_blocks=opt['img_num_res_blocks'],
|
45 |
+
attn_resolutions=opt['img_attn_resolutions'],
|
46 |
+
ch_mult=opt['img_ch_mult'],
|
47 |
+
dropout=opt['img_dropout'],
|
48 |
+
resamp_with_conv=True,
|
49 |
+
give_pre_end=False).to(self.device)
|
50 |
+
self.img_quantizer = VectorQuantizerTexture(
|
51 |
+
opt['img_n_embed'], opt['img_embed_dim'],
|
52 |
+
beta=0.25).to(self.device)
|
53 |
+
self.img_quant_conv = torch.nn.Conv2d(opt["img_z_channels"],
|
54 |
+
opt['img_embed_dim'],
|
55 |
+
1).to(self.device)
|
56 |
+
self.img_post_quant_conv = torch.nn.Conv2d(opt['img_embed_dim'],
|
57 |
+
opt["img_z_channels"],
|
58 |
+
1).to(self.device)
|
59 |
+
self.load_pretrained_image_vae()
|
60 |
+
|
61 |
+
# VAE for segmentation mask
|
62 |
+
self.segm_encoder = Encoder(
|
63 |
+
ch=opt['segm_ch'],
|
64 |
+
num_res_blocks=opt['segm_num_res_blocks'],
|
65 |
+
attn_resolutions=opt['segm_attn_resolutions'],
|
66 |
+
ch_mult=opt['segm_ch_mult'],
|
67 |
+
in_channels=opt['segm_in_channels'],
|
68 |
+
resolution=opt['segm_resolution'],
|
69 |
+
z_channels=opt['segm_z_channels'],
|
70 |
+
double_z=opt['segm_double_z'],
|
71 |
+
dropout=opt['segm_dropout']).to(self.device)
|
72 |
+
self.segm_quantizer = VectorQuantizer(
|
73 |
+
opt['segm_n_embed'],
|
74 |
+
opt['segm_embed_dim'],
|
75 |
+
beta=0.25,
|
76 |
+
sane_index_shape=True).to(self.device)
|
77 |
+
self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
|
78 |
+
opt['segm_embed_dim'],
|
79 |
+
1).to(self.device)
|
80 |
+
self.load_pretrained_segm_vae()
|
81 |
+
|
82 |
+
# define sampler
|
83 |
+
self._denoise_fn = TransformerMultiHead(
|
84 |
+
codebook_size=opt['codebook_size'],
|
85 |
+
segm_codebook_size=opt['segm_codebook_size'],
|
86 |
+
texture_codebook_size=opt['texture_codebook_size'],
|
87 |
+
bert_n_emb=opt['bert_n_emb'],
|
88 |
+
bert_n_layers=opt['bert_n_layers'],
|
89 |
+
bert_n_head=opt['bert_n_head'],
|
90 |
+
block_size=opt['block_size'],
|
91 |
+
latent_shape=opt['latent_shape'],
|
92 |
+
embd_pdrop=opt['embd_pdrop'],
|
93 |
+
resid_pdrop=opt['resid_pdrop'],
|
94 |
+
attn_pdrop=opt['attn_pdrop'],
|
95 |
+
num_head=opt['num_head']).to(self.device)
|
96 |
+
|
97 |
+
self.num_classes = opt['codebook_size']
|
98 |
+
self.shape = tuple(opt['latent_shape'])
|
99 |
+
self.num_timesteps = 1000
|
100 |
+
|
101 |
+
self.mask_id = opt['codebook_size']
|
102 |
+
self.loss_type = opt['loss_type']
|
103 |
+
self.mask_schedule = opt['mask_schedule']
|
104 |
+
|
105 |
+
self.sample_steps = opt['sample_steps']
|
106 |
+
|
107 |
+
self.init_training_settings()
|
108 |
+
|
109 |
+
def load_pretrained_image_vae(self):
|
110 |
+
# load pretrained vqgan for segmentation mask
|
111 |
+
img_ae_checkpoint = torch.load(self.opt['img_ae_path'])
|
112 |
+
self.img_encoder.load_state_dict(
|
113 |
+
img_ae_checkpoint['encoder'], strict=True)
|
114 |
+
self.img_decoder.load_state_dict(
|
115 |
+
img_ae_checkpoint['decoder'], strict=True)
|
116 |
+
self.img_quantizer.load_state_dict(
|
117 |
+
img_ae_checkpoint['quantize'], strict=True)
|
118 |
+
self.img_quant_conv.load_state_dict(
|
119 |
+
img_ae_checkpoint['quant_conv'], strict=True)
|
120 |
+
self.img_post_quant_conv.load_state_dict(
|
121 |
+
img_ae_checkpoint['post_quant_conv'], strict=True)
|
122 |
+
self.img_encoder.eval()
|
123 |
+
self.img_decoder.eval()
|
124 |
+
self.img_quantizer.eval()
|
125 |
+
self.img_quant_conv.eval()
|
126 |
+
self.img_post_quant_conv.eval()
|
127 |
+
|
128 |
+
def load_pretrained_segm_vae(self):
|
129 |
+
# load pretrained vqgan for segmentation mask
|
130 |
+
segm_ae_checkpoint = torch.load(self.opt['segm_ae_path'])
|
131 |
+
self.segm_encoder.load_state_dict(
|
132 |
+
segm_ae_checkpoint['encoder'], strict=True)
|
133 |
+
self.segm_quantizer.load_state_dict(
|
134 |
+
segm_ae_checkpoint['quantize'], strict=True)
|
135 |
+
self.segm_quant_conv.load_state_dict(
|
136 |
+
segm_ae_checkpoint['quant_conv'], strict=True)
|
137 |
+
self.segm_encoder.eval()
|
138 |
+
self.segm_quantizer.eval()
|
139 |
+
self.segm_quant_conv.eval()
|
140 |
+
|
141 |
+
def init_training_settings(self):
|
142 |
+
optim_params = []
|
143 |
+
for v in self._denoise_fn.parameters():
|
144 |
+
if v.requires_grad:
|
145 |
+
optim_params.append(v)
|
146 |
+
# set up optimizer
|
147 |
+
self.optimizer = torch.optim.Adam(
|
148 |
+
optim_params,
|
149 |
+
self.opt['lr'],
|
150 |
+
weight_decay=self.opt['weight_decay'])
|
151 |
+
self.log_dict = OrderedDict()
|
152 |
+
|
153 |
+
@torch.no_grad()
|
154 |
+
def get_quantized_img(self, image, texture_mask):
|
155 |
+
encoded_img = self.img_encoder(image)
|
156 |
+
encoded_img = self.img_quant_conv(encoded_img)
|
157 |
+
|
158 |
+
# img_tokens_input is the continual index for the input of transformer
|
159 |
+
# img_tokens_gt_list is the index for 18 texture-aware codebooks respectively
|
160 |
+
_, _, [_, img_tokens_input, img_tokens_gt_list
|
161 |
+
] = self.img_quantizer(encoded_img, texture_mask)
|
162 |
+
|
163 |
+
# reshape the tokens
|
164 |
+
b = image.size(0)
|
165 |
+
img_tokens_input = img_tokens_input.view(b, -1)
|
166 |
+
img_tokens_gt_return_list = [
|
167 |
+
img_tokens_gt.view(b, -1) for img_tokens_gt in img_tokens_gt_list
|
168 |
+
]
|
169 |
+
|
170 |
+
return img_tokens_input, img_tokens_gt_return_list
|
171 |
+
|
172 |
+
@torch.no_grad()
|
173 |
+
def decode(self, quant):
|
174 |
+
quant = self.img_post_quant_conv(quant)
|
175 |
+
dec = self.img_decoder(quant)
|
176 |
+
return dec
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def decode_image_indices(self, indices_list, texture_mask):
|
180 |
+
quant = self.img_quantizer.get_codebook_entry(
|
181 |
+
indices_list, texture_mask,
|
182 |
+
(indices_list[0].size(0), self.shape[0], self.shape[1],
|
183 |
+
self.opt["img_z_channels"]))
|
184 |
+
dec = self.decode(quant)
|
185 |
+
|
186 |
+
return dec
|
187 |
+
|
188 |
+
def sample_time(self, b, device, method='uniform'):
|
189 |
+
if method == 'importance':
|
190 |
+
if not (self.Lt_count > 10).all():
|
191 |
+
return self.sample_time(b, device, method='uniform')
|
192 |
+
|
193 |
+
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
|
194 |
+
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
|
195 |
+
pt_all = Lt_sqrt / Lt_sqrt.sum()
|
196 |
+
|
197 |
+
t = torch.multinomial(pt_all, num_samples=b, replacement=True)
|
198 |
+
|
199 |
+
pt = pt_all.gather(dim=0, index=t)
|
200 |
+
|
201 |
+
return t, pt
|
202 |
+
|
203 |
+
elif method == 'uniform':
|
204 |
+
t = torch.randint(
|
205 |
+
1, self.num_timesteps + 1, (b, ), device=device).long()
|
206 |
+
pt = torch.ones_like(t).float() / self.num_timesteps
|
207 |
+
return t, pt
|
208 |
+
|
209 |
+
else:
|
210 |
+
raise ValueError
|
211 |
+
|
212 |
+
def q_sample(self, x_0, x_0_gt_list, t):
|
213 |
+
# samples q(x_t | x_0)
|
214 |
+
# randomly set token to mask with probability t/T
|
215 |
+
# x_t, x_0_ignore = x_0.clone(), x_0.clone()
|
216 |
+
x_t = x_0.clone()
|
217 |
+
|
218 |
+
mask = torch.rand_like(x_t.float()) < (
|
219 |
+
t.float().unsqueeze(-1) / self.num_timesteps)
|
220 |
+
x_t[mask] = self.mask_id
|
221 |
+
# x_0_ignore[torch.bitwise_not(mask)] = -1
|
222 |
+
|
223 |
+
# for every gt token list, we also need to do the mask
|
224 |
+
x_0_gt_ignore_list = []
|
225 |
+
for x_0_gt in x_0_gt_list:
|
226 |
+
x_0_gt_ignore = x_0_gt.clone()
|
227 |
+
x_0_gt_ignore[torch.bitwise_not(mask)] = -1
|
228 |
+
x_0_gt_ignore_list.append(x_0_gt_ignore)
|
229 |
+
|
230 |
+
return x_t, x_0_gt_ignore_list, mask
|
231 |
+
|
232 |
+
def _train_loss(self, x_0, x_0_gt_list):
|
233 |
+
b, device = x_0.size(0), x_0.device
|
234 |
+
|
235 |
+
# choose what time steps to compute loss at
|
236 |
+
t, pt = self.sample_time(b, device, 'uniform')
|
237 |
+
|
238 |
+
# make x noisy and denoise
|
239 |
+
if self.mask_schedule == 'random':
|
240 |
+
x_t, x_0_gt_ignore_list, mask = self.q_sample(
|
241 |
+
x_0=x_0, x_0_gt_list=x_0_gt_list, t=t)
|
242 |
+
else:
|
243 |
+
raise NotImplementedError
|
244 |
+
|
245 |
+
# sample p(x_0 | x_t)
|
246 |
+
x_0_hat_logits_list = self._denoise_fn(
|
247 |
+
x_t, self.segm_tokens, self.texture_tokens, t=t)
|
248 |
+
|
249 |
+
# Always compute ELBO for comparison purposes
|
250 |
+
cross_entropy_loss = 0
|
251 |
+
for x_0_hat_logits, x_0_gt_ignore in zip(x_0_hat_logits_list,
|
252 |
+
x_0_gt_ignore_list):
|
253 |
+
cross_entropy_loss += F.cross_entropy(
|
254 |
+
x_0_hat_logits.permute(0, 2, 1),
|
255 |
+
x_0_gt_ignore,
|
256 |
+
ignore_index=-1,
|
257 |
+
reduction='none').sum(1)
|
258 |
+
vb_loss = cross_entropy_loss / t
|
259 |
+
vb_loss = vb_loss / pt
|
260 |
+
vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel())
|
261 |
+
if self.loss_type == 'elbo':
|
262 |
+
loss = vb_loss
|
263 |
+
elif self.loss_type == 'mlm':
|
264 |
+
denom = mask.float().sum(1)
|
265 |
+
denom[denom == 0] = 1 # prevent divide by 0 errors.
|
266 |
+
loss = cross_entropy_loss / denom
|
267 |
+
elif self.loss_type == 'reweighted_elbo':
|
268 |
+
weight = (1 - (t / self.num_timesteps))
|
269 |
+
loss = weight * cross_entropy_loss
|
270 |
+
loss = loss / (math.log(2) * x_0.shape[1:].numel())
|
271 |
+
else:
|
272 |
+
raise ValueError
|
273 |
+
|
274 |
+
return loss.mean(), vb_loss.mean()
|
275 |
+
|
276 |
+
def feed_data(self, data):
|
277 |
+
self.image = data['image'].to(self.device)
|
278 |
+
self.segm = data['segm'].to(self.device)
|
279 |
+
self.texture_mask = data['texture_mask'].to(self.device)
|
280 |
+
self.input_indices, self.gt_indices_list = self.get_quantized_img(
|
281 |
+
self.image, self.texture_mask)
|
282 |
+
|
283 |
+
self.texture_tokens = F.interpolate(
|
284 |
+
self.texture_mask, size=self.shape,
|
285 |
+
mode='nearest').view(self.image.size(0), -1).long()
|
286 |
+
|
287 |
+
self.segm_tokens = self.get_quantized_segm(self.segm)
|
288 |
+
self.segm_tokens = self.segm_tokens.view(self.image.size(0), -1)
|
289 |
+
|
290 |
+
def optimize_parameters(self):
|
291 |
+
self._denoise_fn.train()
|
292 |
+
|
293 |
+
loss, vb_loss = self._train_loss(self.input_indices,
|
294 |
+
self.gt_indices_list)
|
295 |
+
|
296 |
+
self.optimizer.zero_grad()
|
297 |
+
loss.backward()
|
298 |
+
self.optimizer.step()
|
299 |
+
|
300 |
+
self.log_dict['loss'] = loss
|
301 |
+
self.log_dict['vb_loss'] = vb_loss
|
302 |
+
|
303 |
+
self._denoise_fn.eval()
|
304 |
+
|
305 |
+
@torch.no_grad()
|
306 |
+
def get_quantized_segm(self, segm):
|
307 |
+
segm_one_hot = F.one_hot(
|
308 |
+
segm.squeeze(1).long(),
|
309 |
+
num_classes=self.opt['segm_num_segm_classes']).permute(
|
310 |
+
0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
311 |
+
encoded_segm_mask = self.segm_encoder(segm_one_hot)
|
312 |
+
encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
|
313 |
+
_, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
|
314 |
+
|
315 |
+
return segm_tokens
|
316 |
+
|
317 |
+
def sample_fn(self, temp=1.0, sample_steps=None):
|
318 |
+
self._denoise_fn.eval()
|
319 |
+
|
320 |
+
b, device = self.image.size(0), 'cuda'
|
321 |
+
x_t = torch.ones(
|
322 |
+
(b, np.prod(self.shape)), device=device).long() * self.mask_id
|
323 |
+
unmasked = torch.zeros_like(x_t, device=device).bool()
|
324 |
+
sample_steps = list(range(1, sample_steps + 1))
|
325 |
+
|
326 |
+
texture_mask_flatten = self.texture_tokens.view(-1)
|
327 |
+
|
328 |
+
# min_encodings_indices_list would be used to visualize the image
|
329 |
+
min_encodings_indices_list = [
|
330 |
+
torch.full(
|
331 |
+
texture_mask_flatten.size(),
|
332 |
+
fill_value=-1,
|
333 |
+
dtype=torch.long,
|
334 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
335 |
+
]
|
336 |
+
|
337 |
+
for t in reversed(sample_steps):
|
338 |
+
print(f'Sample timestep {t:4d}', end='\r')
|
339 |
+
t = torch.full((b, ), t, device=device, dtype=torch.long)
|
340 |
+
|
341 |
+
# where to unmask
|
342 |
+
changes = torch.rand(
|
343 |
+
x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
|
344 |
+
# don't unmask somewhere already unmasked
|
345 |
+
changes = torch.bitwise_xor(changes,
|
346 |
+
torch.bitwise_and(changes, unmasked))
|
347 |
+
# update mask with changes
|
348 |
+
unmasked = torch.bitwise_or(unmasked, changes)
|
349 |
+
|
350 |
+
x_0_logits_list = self._denoise_fn(
|
351 |
+
x_t, self.segm_tokens, self.texture_tokens, t=t)
|
352 |
+
|
353 |
+
changes_flatten = changes.view(-1)
|
354 |
+
ori_shape = x_t.shape # [b, h*w]
|
355 |
+
x_t = x_t.view(-1) # [b*h*w]
|
356 |
+
for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
|
357 |
+
if torch.sum(texture_mask_flatten[changes_flatten] ==
|
358 |
+
codebook_idx) > 0:
|
359 |
+
# scale by temperature
|
360 |
+
x_0_logits = x_0_logits / temp
|
361 |
+
x_0_dist = dists.Categorical(logits=x_0_logits)
|
362 |
+
x_0_hat = x_0_dist.sample().long()
|
363 |
+
x_0_hat = x_0_hat.view(-1)
|
364 |
+
|
365 |
+
# only replace the changed indices with corresponding codebook_idx
|
366 |
+
changes_segm = torch.bitwise_and(
|
367 |
+
changes_flatten, texture_mask_flatten == codebook_idx)
|
368 |
+
|
369 |
+
# x_t would be the input to the transformer, so the index range should be continual one
|
370 |
+
x_t[changes_segm] = x_0_hat[
|
371 |
+
changes_segm] + 1024 * codebook_idx
|
372 |
+
min_encodings_indices_list[codebook_idx][
|
373 |
+
changes_segm] = x_0_hat[changes_segm]
|
374 |
+
|
375 |
+
x_t = x_t.view(ori_shape) # [b, h*w]
|
376 |
+
|
377 |
+
min_encodings_indices_return_list = [
|
378 |
+
min_encodings_indices.view(ori_shape)
|
379 |
+
for min_encodings_indices in min_encodings_indices_list
|
380 |
+
]
|
381 |
+
|
382 |
+
self._denoise_fn.train()
|
383 |
+
|
384 |
+
return min_encodings_indices_return_list
|
385 |
+
|
386 |
+
def get_vis(self, image, gt_indices, predicted_indices, texture_mask,
|
387 |
+
save_path):
|
388 |
+
# original image
|
389 |
+
ori_img = self.decode_image_indices(gt_indices, texture_mask)
|
390 |
+
# pred image
|
391 |
+
pred_img = self.decode_image_indices(predicted_indices, texture_mask)
|
392 |
+
img_cat = torch.cat([
|
393 |
+
image,
|
394 |
+
ori_img,
|
395 |
+
pred_img,
|
396 |
+
], dim=3).detach()
|
397 |
+
img_cat = ((img_cat + 1) / 2)
|
398 |
+
img_cat = img_cat.clamp_(0, 1)
|
399 |
+
save_image(img_cat, save_path, nrow=1, padding=4)
|
400 |
+
|
401 |
+
def inference(self, data_loader, save_dir):
|
402 |
+
self._denoise_fn.eval()
|
403 |
+
|
404 |
+
for _, data in enumerate(data_loader):
|
405 |
+
img_name = data['img_name']
|
406 |
+
self.feed_data(data)
|
407 |
+
b = self.image.size(0)
|
408 |
+
with torch.no_grad():
|
409 |
+
sampled_indices_list = self.sample_fn(
|
410 |
+
temp=1, sample_steps=self.sample_steps)
|
411 |
+
for idx in range(b):
|
412 |
+
self.get_vis(self.image[idx:idx + 1], [
|
413 |
+
gt_indices[idx:idx + 1]
|
414 |
+
for gt_indices in self.gt_indices_list
|
415 |
+
], [
|
416 |
+
sampled_indices[idx:idx + 1]
|
417 |
+
for sampled_indices in sampled_indices_list
|
418 |
+
], self.texture_mask[idx:idx + 1],
|
419 |
+
f'{save_dir}/{img_name[idx]}')
|
420 |
+
|
421 |
+
self._denoise_fn.train()
|
422 |
+
|
423 |
+
def get_current_log(self):
|
424 |
+
return self.log_dict
|
425 |
+
|
426 |
+
def update_learning_rate(self, epoch, iters=None):
|
427 |
+
"""Update learning rate.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
current_iter (int): Current iteration.
|
431 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
432 |
+
Default: -1.
|
433 |
+
"""
|
434 |
+
lr = self.optimizer.param_groups[0]['lr']
|
435 |
+
|
436 |
+
if self.opt['lr_decay'] == 'step':
|
437 |
+
lr = self.opt['lr'] * (
|
438 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
439 |
+
elif self.opt['lr_decay'] == 'cos':
|
440 |
+
lr = self.opt['lr'] * (
|
441 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
442 |
+
elif self.opt['lr_decay'] == 'linear':
|
443 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
444 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
445 |
+
if epoch < self.opt['turning_point'] + 1:
|
446 |
+
# learning rate decay as 95%
|
447 |
+
# at the turning point (1 / 95% = 1.0526)
|
448 |
+
lr = self.opt['lr'] * (
|
449 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
450 |
+
else:
|
451 |
+
lr *= self.opt['gamma']
|
452 |
+
elif self.opt['lr_decay'] == 'schedule':
|
453 |
+
if epoch in self.opt['schedule']:
|
454 |
+
lr *= self.opt['gamma']
|
455 |
+
elif self.opt['lr_decay'] == 'warm_up':
|
456 |
+
if iters <= self.opt['warmup_iters']:
|
457 |
+
lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters']
|
458 |
+
else:
|
459 |
+
lr = self.opt['lr']
|
460 |
+
else:
|
461 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
462 |
+
# set learning rate
|
463 |
+
for param_group in self.optimizer.param_groups:
|
464 |
+
param_group['lr'] = lr
|
465 |
+
|
466 |
+
return lr
|
467 |
+
|
468 |
+
def save_network(self, net, save_path):
|
469 |
+
"""Save networks.
|
470 |
+
|
471 |
+
Args:
|
472 |
+
net (nn.Module): Network to be saved.
|
473 |
+
net_label (str): Network label.
|
474 |
+
current_iter (int): Current iter number.
|
475 |
+
"""
|
476 |
+
state_dict = net.state_dict()
|
477 |
+
torch.save(state_dict, save_path)
|
478 |
+
|
479 |
+
def load_network(self):
|
480 |
+
checkpoint = torch.load(self.opt['pretrained_sampler'])
|
481 |
+
self._denoise_fn.load_state_dict(checkpoint, strict=True)
|
482 |
+
self._denoise_fn.eval()
|
models/vqgan_model.py
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import sys
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
sys.path.append('..')
|
6 |
+
import lpips
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
|
11 |
+
from models.archs.vqgan_arch import (Decoder, Discriminator, Encoder,
|
12 |
+
VectorQuantizer, VectorQuantizerTexture)
|
13 |
+
from models.losses.segmentation_loss import BCELossWithQuant
|
14 |
+
from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
|
15 |
+
calculate_adaptive_weight, hinge_d_loss)
|
16 |
+
|
17 |
+
|
18 |
+
class VQModel():
|
19 |
+
|
20 |
+
def __init__(self, opt):
|
21 |
+
super().__init__()
|
22 |
+
self.opt = opt
|
23 |
+
self.device = torch.device('cuda')
|
24 |
+
self.encoder = Encoder(
|
25 |
+
ch=opt['ch'],
|
26 |
+
num_res_blocks=opt['num_res_blocks'],
|
27 |
+
attn_resolutions=opt['attn_resolutions'],
|
28 |
+
ch_mult=opt['ch_mult'],
|
29 |
+
in_channels=opt['in_channels'],
|
30 |
+
resolution=opt['resolution'],
|
31 |
+
z_channels=opt['z_channels'],
|
32 |
+
double_z=opt['double_z'],
|
33 |
+
dropout=opt['dropout']).to(self.device)
|
34 |
+
self.decoder = Decoder(
|
35 |
+
in_channels=opt['in_channels'],
|
36 |
+
resolution=opt['resolution'],
|
37 |
+
z_channels=opt['z_channels'],
|
38 |
+
ch=opt['ch'],
|
39 |
+
out_ch=opt['out_ch'],
|
40 |
+
num_res_blocks=opt['num_res_blocks'],
|
41 |
+
attn_resolutions=opt['attn_resolutions'],
|
42 |
+
ch_mult=opt['ch_mult'],
|
43 |
+
dropout=opt['dropout'],
|
44 |
+
resamp_with_conv=True,
|
45 |
+
give_pre_end=False).to(self.device)
|
46 |
+
self.quantize = VectorQuantizer(
|
47 |
+
opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
|
48 |
+
self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
|
49 |
+
1).to(self.device)
|
50 |
+
self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
51 |
+
opt["z_channels"],
|
52 |
+
1).to(self.device)
|
53 |
+
|
54 |
+
def init_training_settings(self):
|
55 |
+
self.loss = BCELossWithQuant()
|
56 |
+
self.log_dict = OrderedDict()
|
57 |
+
self.configure_optimizers()
|
58 |
+
|
59 |
+
def save_network(self, save_path):
|
60 |
+
"""Save networks.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
net (nn.Module): Network to be saved.
|
64 |
+
net_label (str): Network label.
|
65 |
+
current_iter (int): Current iter number.
|
66 |
+
"""
|
67 |
+
|
68 |
+
save_dict = {}
|
69 |
+
save_dict['encoder'] = self.encoder.state_dict()
|
70 |
+
save_dict['decoder'] = self.decoder.state_dict()
|
71 |
+
save_dict['quantize'] = self.quantize.state_dict()
|
72 |
+
save_dict['quant_conv'] = self.quant_conv.state_dict()
|
73 |
+
save_dict['post_quant_conv'] = self.post_quant_conv.state_dict()
|
74 |
+
save_dict['discriminator'] = self.disc.state_dict()
|
75 |
+
torch.save(save_dict, save_path)
|
76 |
+
|
77 |
+
def load_network(self):
|
78 |
+
checkpoint = torch.load(self.opt['pretrained_models'])
|
79 |
+
self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
|
80 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
81 |
+
self.quantize.load_state_dict(checkpoint['quantize'], strict=True)
|
82 |
+
self.quant_conv.load_state_dict(checkpoint['quant_conv'], strict=True)
|
83 |
+
self.post_quant_conv.load_state_dict(
|
84 |
+
checkpoint['post_quant_conv'], strict=True)
|
85 |
+
|
86 |
+
def optimize_parameters(self, data, current_iter):
|
87 |
+
self.encoder.train()
|
88 |
+
self.decoder.train()
|
89 |
+
self.quantize.train()
|
90 |
+
self.quant_conv.train()
|
91 |
+
self.post_quant_conv.train()
|
92 |
+
|
93 |
+
loss = self.training_step(data)
|
94 |
+
self.optimizer.zero_grad()
|
95 |
+
loss.backward()
|
96 |
+
self.optimizer.step()
|
97 |
+
|
98 |
+
def encode(self, x):
|
99 |
+
h = self.encoder(x)
|
100 |
+
h = self.quant_conv(h)
|
101 |
+
quant, emb_loss, info = self.quantize(h)
|
102 |
+
return quant, emb_loss, info
|
103 |
+
|
104 |
+
def decode(self, quant):
|
105 |
+
quant = self.post_quant_conv(quant)
|
106 |
+
dec = self.decoder(quant)
|
107 |
+
return dec
|
108 |
+
|
109 |
+
def decode_code(self, code_b):
|
110 |
+
quant_b = self.quantize.embed_code(code_b)
|
111 |
+
dec = self.decode(quant_b)
|
112 |
+
return dec
|
113 |
+
|
114 |
+
def forward_step(self, input):
|
115 |
+
quant, diff, _ = self.encode(input)
|
116 |
+
dec = self.decode(quant)
|
117 |
+
return dec, diff
|
118 |
+
|
119 |
+
def feed_data(self, data):
|
120 |
+
x = data['segm']
|
121 |
+
x = F.one_hot(x, num_classes=self.opt['num_segm_classes'])
|
122 |
+
|
123 |
+
if len(x.shape) == 3:
|
124 |
+
x = x[..., None]
|
125 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
126 |
+
return x.float().to(self.device)
|
127 |
+
|
128 |
+
def get_current_log(self):
|
129 |
+
return self.log_dict
|
130 |
+
|
131 |
+
def update_learning_rate(self, epoch):
|
132 |
+
"""Update learning rate.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
current_iter (int): Current iteration.
|
136 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
137 |
+
Default: -1.
|
138 |
+
"""
|
139 |
+
lr = self.optimizer.param_groups[0]['lr']
|
140 |
+
|
141 |
+
if self.opt['lr_decay'] == 'step':
|
142 |
+
lr = self.opt['lr'] * (
|
143 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
144 |
+
elif self.opt['lr_decay'] == 'cos':
|
145 |
+
lr = self.opt['lr'] * (
|
146 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
147 |
+
elif self.opt['lr_decay'] == 'linear':
|
148 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
149 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
150 |
+
if epoch < self.opt['turning_point'] + 1:
|
151 |
+
# learning rate decay as 95%
|
152 |
+
# at the turning point (1 / 95% = 1.0526)
|
153 |
+
lr = self.opt['lr'] * (
|
154 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
155 |
+
else:
|
156 |
+
lr *= self.opt['gamma']
|
157 |
+
elif self.opt['lr_decay'] == 'schedule':
|
158 |
+
if epoch in self.opt['schedule']:
|
159 |
+
lr *= self.opt['gamma']
|
160 |
+
else:
|
161 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
162 |
+
# set learning rate
|
163 |
+
for param_group in self.optimizer.param_groups:
|
164 |
+
param_group['lr'] = lr
|
165 |
+
|
166 |
+
return lr
|
167 |
+
|
168 |
+
|
169 |
+
class VQSegmentationModel(VQModel):
|
170 |
+
|
171 |
+
def __init__(self, opt):
|
172 |
+
super().__init__(opt)
|
173 |
+
self.colorize = torch.randn(3, opt['num_segm_classes'], 1,
|
174 |
+
1).to(self.device)
|
175 |
+
|
176 |
+
self.init_training_settings()
|
177 |
+
|
178 |
+
def configure_optimizers(self):
|
179 |
+
self.optimizer = torch.optim.Adam(
|
180 |
+
list(self.encoder.parameters()) + list(self.decoder.parameters()) +
|
181 |
+
list(self.quantize.parameters()) +
|
182 |
+
list(self.quant_conv.parameters()) +
|
183 |
+
list(self.post_quant_conv.parameters()),
|
184 |
+
lr=self.opt['lr'],
|
185 |
+
betas=(0.5, 0.9))
|
186 |
+
|
187 |
+
def training_step(self, data):
|
188 |
+
x = self.feed_data(data)
|
189 |
+
xrec, qloss = self.forward_step(x)
|
190 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
|
191 |
+
self.log_dict.update(log_dict_ae)
|
192 |
+
return aeloss
|
193 |
+
|
194 |
+
def to_rgb(self, x):
|
195 |
+
x = F.conv2d(x, weight=self.colorize)
|
196 |
+
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
197 |
+
return x
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
def inference(self, data_loader, save_dir):
|
201 |
+
self.encoder.eval()
|
202 |
+
self.decoder.eval()
|
203 |
+
self.quantize.eval()
|
204 |
+
self.quant_conv.eval()
|
205 |
+
self.post_quant_conv.eval()
|
206 |
+
|
207 |
+
loss_total = 0
|
208 |
+
loss_bce = 0
|
209 |
+
loss_quant = 0
|
210 |
+
num = 0
|
211 |
+
|
212 |
+
for _, data in enumerate(data_loader):
|
213 |
+
img_name = data['img_name'][0]
|
214 |
+
x = self.feed_data(data)
|
215 |
+
xrec, qloss = self.forward_step(x)
|
216 |
+
_, log_dict_ae = self.loss(qloss, x, xrec, split="val")
|
217 |
+
|
218 |
+
loss_total += log_dict_ae['val/total_loss']
|
219 |
+
loss_bce += log_dict_ae['val/bce_loss']
|
220 |
+
loss_quant += log_dict_ae['val/quant_loss']
|
221 |
+
|
222 |
+
num += x.size(0)
|
223 |
+
|
224 |
+
if x.shape[1] > 3:
|
225 |
+
# colorize with random projection
|
226 |
+
assert xrec.shape[1] > 3
|
227 |
+
# convert logits to indices
|
228 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
229 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
230 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
231 |
+
x = self.to_rgb(x)
|
232 |
+
xrec = self.to_rgb(xrec)
|
233 |
+
|
234 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
235 |
+
img_cat = ((img_cat + 1) / 2)
|
236 |
+
img_cat = img_cat.clamp_(0, 1)
|
237 |
+
save_image(
|
238 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
239 |
+
|
240 |
+
return (loss_total / num).item(), (loss_bce /
|
241 |
+
num).item(), (loss_quant /
|
242 |
+
num).item()
|
243 |
+
|
244 |
+
|
245 |
+
class VQImageModel(VQModel):
|
246 |
+
|
247 |
+
def __init__(self, opt):
|
248 |
+
super().__init__(opt)
|
249 |
+
self.disc = Discriminator(
|
250 |
+
opt['n_channels'], opt['ndf'],
|
251 |
+
n_layers=opt['disc_layers']).to(self.device)
|
252 |
+
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
|
253 |
+
self.perceptual_weight = opt['perceptual_weight']
|
254 |
+
self.disc_start_step = opt['disc_start_step']
|
255 |
+
self.disc_weight_max = opt['disc_weight_max']
|
256 |
+
self.diff_aug = opt['diff_aug']
|
257 |
+
self.policy = "color,translation"
|
258 |
+
|
259 |
+
self.disc.train()
|
260 |
+
|
261 |
+
self.init_training_settings()
|
262 |
+
|
263 |
+
def feed_data(self, data):
|
264 |
+
x = data['image']
|
265 |
+
|
266 |
+
return x.float().to(self.device)
|
267 |
+
|
268 |
+
def init_training_settings(self):
|
269 |
+
self.log_dict = OrderedDict()
|
270 |
+
self.configure_optimizers()
|
271 |
+
|
272 |
+
def configure_optimizers(self):
|
273 |
+
self.optimizer = torch.optim.Adam(
|
274 |
+
list(self.encoder.parameters()) + list(self.decoder.parameters()) +
|
275 |
+
list(self.quantize.parameters()) +
|
276 |
+
list(self.quant_conv.parameters()) +
|
277 |
+
list(self.post_quant_conv.parameters()),
|
278 |
+
lr=self.opt['lr'])
|
279 |
+
|
280 |
+
self.disc_optimizer = torch.optim.Adam(
|
281 |
+
self.disc.parameters(), lr=self.opt['lr'])
|
282 |
+
|
283 |
+
def training_step(self, data, step):
|
284 |
+
x = self.feed_data(data)
|
285 |
+
xrec, codebook_loss = self.forward_step(x)
|
286 |
+
|
287 |
+
# get recon/perceptual loss
|
288 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
289 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
290 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
291 |
+
nll_loss = torch.mean(nll_loss)
|
292 |
+
|
293 |
+
# augment for input to discriminator
|
294 |
+
if self.diff_aug:
|
295 |
+
xrec = DiffAugment(xrec, policy=self.policy)
|
296 |
+
|
297 |
+
# update generator
|
298 |
+
logits_fake = self.disc(xrec)
|
299 |
+
g_loss = -torch.mean(logits_fake)
|
300 |
+
last_layer = self.decoder.conv_out.weight
|
301 |
+
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
|
302 |
+
self.disc_weight_max)
|
303 |
+
d_weight *= adopt_weight(1, step, self.disc_start_step)
|
304 |
+
loss = nll_loss + d_weight * g_loss + codebook_loss
|
305 |
+
|
306 |
+
self.log_dict["loss"] = loss
|
307 |
+
self.log_dict["l1"] = recon_loss.mean().item()
|
308 |
+
self.log_dict["perceptual"] = p_loss.mean().item()
|
309 |
+
self.log_dict["nll_loss"] = nll_loss.item()
|
310 |
+
self.log_dict["g_loss"] = g_loss.item()
|
311 |
+
self.log_dict["d_weight"] = d_weight
|
312 |
+
self.log_dict["codebook_loss"] = codebook_loss.item()
|
313 |
+
|
314 |
+
if step > self.disc_start_step:
|
315 |
+
if self.diff_aug:
|
316 |
+
logits_real = self.disc(
|
317 |
+
DiffAugment(x.contiguous().detach(), policy=self.policy))
|
318 |
+
else:
|
319 |
+
logits_real = self.disc(x.contiguous().detach())
|
320 |
+
logits_fake = self.disc(xrec.contiguous().detach(
|
321 |
+
)) # detach so that generator isn"t also updated
|
322 |
+
d_loss = hinge_d_loss(logits_real, logits_fake)
|
323 |
+
self.log_dict["d_loss"] = d_loss
|
324 |
+
else:
|
325 |
+
d_loss = None
|
326 |
+
|
327 |
+
return loss, d_loss
|
328 |
+
|
329 |
+
def optimize_parameters(self, data, step):
|
330 |
+
self.encoder.train()
|
331 |
+
self.decoder.train()
|
332 |
+
self.quantize.train()
|
333 |
+
self.quant_conv.train()
|
334 |
+
self.post_quant_conv.train()
|
335 |
+
|
336 |
+
loss, d_loss = self.training_step(data, step)
|
337 |
+
self.optimizer.zero_grad()
|
338 |
+
loss.backward()
|
339 |
+
self.optimizer.step()
|
340 |
+
|
341 |
+
if step > self.disc_start_step:
|
342 |
+
self.disc_optimizer.zero_grad()
|
343 |
+
d_loss.backward()
|
344 |
+
self.disc_optimizer.step()
|
345 |
+
|
346 |
+
@torch.no_grad()
|
347 |
+
def inference(self, data_loader, save_dir):
|
348 |
+
self.encoder.eval()
|
349 |
+
self.decoder.eval()
|
350 |
+
self.quantize.eval()
|
351 |
+
self.quant_conv.eval()
|
352 |
+
self.post_quant_conv.eval()
|
353 |
+
|
354 |
+
loss_total = 0
|
355 |
+
num = 0
|
356 |
+
|
357 |
+
for _, data in enumerate(data_loader):
|
358 |
+
img_name = data['img_name'][0]
|
359 |
+
x = self.feed_data(data)
|
360 |
+
xrec, _ = self.forward_step(x)
|
361 |
+
|
362 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
363 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
364 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
365 |
+
nll_loss = torch.mean(nll_loss)
|
366 |
+
loss_total += nll_loss
|
367 |
+
|
368 |
+
num += x.size(0)
|
369 |
+
|
370 |
+
if x.shape[1] > 3:
|
371 |
+
# colorize with random projection
|
372 |
+
assert xrec.shape[1] > 3
|
373 |
+
# convert logits to indices
|
374 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
375 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
376 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
377 |
+
x = self.to_rgb(x)
|
378 |
+
xrec = self.to_rgb(xrec)
|
379 |
+
|
380 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
381 |
+
img_cat = ((img_cat + 1) / 2)
|
382 |
+
img_cat = img_cat.clamp_(0, 1)
|
383 |
+
save_image(
|
384 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
385 |
+
|
386 |
+
return (loss_total / num).item()
|
387 |
+
|
388 |
+
|
389 |
+
class VQImageSegmTextureModel(VQImageModel):
|
390 |
+
|
391 |
+
def __init__(self, opt):
|
392 |
+
self.opt = opt
|
393 |
+
self.device = torch.device('cuda')
|
394 |
+
self.encoder = Encoder(
|
395 |
+
ch=opt['ch'],
|
396 |
+
num_res_blocks=opt['num_res_blocks'],
|
397 |
+
attn_resolutions=opt['attn_resolutions'],
|
398 |
+
ch_mult=opt['ch_mult'],
|
399 |
+
in_channels=opt['in_channels'],
|
400 |
+
resolution=opt['resolution'],
|
401 |
+
z_channels=opt['z_channels'],
|
402 |
+
double_z=opt['double_z'],
|
403 |
+
dropout=opt['dropout']).to(self.device)
|
404 |
+
self.decoder = Decoder(
|
405 |
+
in_channels=opt['in_channels'],
|
406 |
+
resolution=opt['resolution'],
|
407 |
+
z_channels=opt['z_channels'],
|
408 |
+
ch=opt['ch'],
|
409 |
+
out_ch=opt['out_ch'],
|
410 |
+
num_res_blocks=opt['num_res_blocks'],
|
411 |
+
attn_resolutions=opt['attn_resolutions'],
|
412 |
+
ch_mult=opt['ch_mult'],
|
413 |
+
dropout=opt['dropout'],
|
414 |
+
resamp_with_conv=True,
|
415 |
+
give_pre_end=False).to(self.device)
|
416 |
+
self.quantize = VectorQuantizerTexture(
|
417 |
+
opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
|
418 |
+
self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
|
419 |
+
1).to(self.device)
|
420 |
+
self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
421 |
+
opt["z_channels"],
|
422 |
+
1).to(self.device)
|
423 |
+
|
424 |
+
self.disc = Discriminator(
|
425 |
+
opt['n_channels'], opt['ndf'],
|
426 |
+
n_layers=opt['disc_layers']).to(self.device)
|
427 |
+
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
|
428 |
+
self.perceptual_weight = opt['perceptual_weight']
|
429 |
+
self.disc_start_step = opt['disc_start_step']
|
430 |
+
self.disc_weight_max = opt['disc_weight_max']
|
431 |
+
self.diff_aug = opt['diff_aug']
|
432 |
+
self.policy = "color,translation"
|
433 |
+
|
434 |
+
self.disc.train()
|
435 |
+
|
436 |
+
self.init_training_settings()
|
437 |
+
|
438 |
+
def feed_data(self, data):
|
439 |
+
x = data['image'].float().to(self.device)
|
440 |
+
mask = data['texture_mask'].float().to(self.device)
|
441 |
+
|
442 |
+
return x, mask
|
443 |
+
|
444 |
+
def training_step(self, data, step):
|
445 |
+
x, mask = self.feed_data(data)
|
446 |
+
xrec, codebook_loss = self.forward_step(x, mask)
|
447 |
+
|
448 |
+
# get recon/perceptual loss
|
449 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
450 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
451 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
452 |
+
nll_loss = torch.mean(nll_loss)
|
453 |
+
|
454 |
+
# augment for input to discriminator
|
455 |
+
if self.diff_aug:
|
456 |
+
xrec = DiffAugment(xrec, policy=self.policy)
|
457 |
+
|
458 |
+
# update generator
|
459 |
+
logits_fake = self.disc(xrec)
|
460 |
+
g_loss = -torch.mean(logits_fake)
|
461 |
+
last_layer = self.decoder.conv_out.weight
|
462 |
+
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
|
463 |
+
self.disc_weight_max)
|
464 |
+
d_weight *= adopt_weight(1, step, self.disc_start_step)
|
465 |
+
loss = nll_loss + d_weight * g_loss + codebook_loss
|
466 |
+
|
467 |
+
self.log_dict["loss"] = loss
|
468 |
+
self.log_dict["l1"] = recon_loss.mean().item()
|
469 |
+
self.log_dict["perceptual"] = p_loss.mean().item()
|
470 |
+
self.log_dict["nll_loss"] = nll_loss.item()
|
471 |
+
self.log_dict["g_loss"] = g_loss.item()
|
472 |
+
self.log_dict["d_weight"] = d_weight
|
473 |
+
self.log_dict["codebook_loss"] = codebook_loss.item()
|
474 |
+
|
475 |
+
if step > self.disc_start_step:
|
476 |
+
if self.diff_aug:
|
477 |
+
logits_real = self.disc(
|
478 |
+
DiffAugment(x.contiguous().detach(), policy=self.policy))
|
479 |
+
else:
|
480 |
+
logits_real = self.disc(x.contiguous().detach())
|
481 |
+
logits_fake = self.disc(xrec.contiguous().detach(
|
482 |
+
)) # detach so that generator isn"t also updated
|
483 |
+
d_loss = hinge_d_loss(logits_real, logits_fake)
|
484 |
+
self.log_dict["d_loss"] = d_loss
|
485 |
+
else:
|
486 |
+
d_loss = None
|
487 |
+
|
488 |
+
return loss, d_loss
|
489 |
+
|
490 |
+
@torch.no_grad()
|
491 |
+
def inference(self, data_loader, save_dir):
|
492 |
+
self.encoder.eval()
|
493 |
+
self.decoder.eval()
|
494 |
+
self.quantize.eval()
|
495 |
+
self.quant_conv.eval()
|
496 |
+
self.post_quant_conv.eval()
|
497 |
+
|
498 |
+
loss_total = 0
|
499 |
+
num = 0
|
500 |
+
|
501 |
+
for _, data in enumerate(data_loader):
|
502 |
+
img_name = data['img_name'][0]
|
503 |
+
x, mask = self.feed_data(data)
|
504 |
+
xrec, _ = self.forward_step(x, mask)
|
505 |
+
|
506 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
507 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
508 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
509 |
+
nll_loss = torch.mean(nll_loss)
|
510 |
+
loss_total += nll_loss
|
511 |
+
|
512 |
+
num += x.size(0)
|
513 |
+
|
514 |
+
if x.shape[1] > 3:
|
515 |
+
# colorize with random projection
|
516 |
+
assert xrec.shape[1] > 3
|
517 |
+
# convert logits to indices
|
518 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
519 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
520 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
521 |
+
x = self.to_rgb(x)
|
522 |
+
xrec = self.to_rgb(xrec)
|
523 |
+
|
524 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
525 |
+
img_cat = ((img_cat + 1) / 2)
|
526 |
+
img_cat = img_cat.clamp_(0, 1)
|
527 |
+
save_image(
|
528 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
529 |
+
|
530 |
+
return (loss_total / num).item()
|
531 |
+
|
532 |
+
def encode(self, x, mask):
|
533 |
+
h = self.encoder(x)
|
534 |
+
h = self.quant_conv(h)
|
535 |
+
quant, emb_loss, info = self.quantize(h, mask)
|
536 |
+
return quant, emb_loss, info
|
537 |
+
|
538 |
+
def decode(self, quant):
|
539 |
+
quant = self.post_quant_conv(quant)
|
540 |
+
dec = self.decoder(quant)
|
541 |
+
return dec
|
542 |
+
|
543 |
+
def decode_code(self, code_b):
|
544 |
+
quant_b = self.quantize.embed_code(code_b)
|
545 |
+
dec = self.decode(quant_b)
|
546 |
+
return dec
|
547 |
+
|
548 |
+
def forward_step(self, input, mask):
|
549 |
+
quant, diff, _ = self.encode(input, mask)
|
550 |
+
dec = self.decode(quant)
|
551 |
+
return dec, diff
|