yumingj commited on
Commit
d34627b
1 Parent(s): 097b7e6

init commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +5 -5
  2. Text2Human/.gitignore +9 -0
  3. Text2Human/LICENSE +21 -0
  4. Text2Human/README.md +255 -0
  5. Text2Human/configs/index_pred_net.yml +84 -0
  6. Text2Human/configs/parsing_gen.yml +40 -0
  7. Text2Human/configs/parsing_token.yml +47 -0
  8. Text2Human/configs/sample_from_parsing.yml +93 -0
  9. Text2Human/configs/sample_from_pose.yml +107 -0
  10. Text2Human/configs/sampler.yml +83 -0
  11. Text2Human/configs/vqvae_bottom.yml +72 -0
  12. Text2Human/configs/vqvae_top.yml +53 -0
  13. Text2Human/data/__init__.py +0 -0
  14. Text2Human/data/mask_dataset.py +59 -0
  15. Text2Human/data/parsing_generation_segm_attr_dataset.py +80 -0
  16. Text2Human/data/pose_attr_dataset.py +109 -0
  17. Text2Human/data/segm_attr_dataset.py +167 -0
  18. Text2Human/environment/text2human_env.yaml +114 -0
  19. Text2Human/models/__init__.py +42 -0
  20. Text2Human/models/archs/__init__.py +0 -0
  21. Text2Human/models/archs/fcn_arch.py +418 -0
  22. Text2Human/models/archs/shape_attr_embedding_arch.py +35 -0
  23. Text2Human/models/archs/transformer_arch.py +273 -0
  24. Text2Human/models/archs/unet_arch.py +693 -0
  25. Text2Human/models/archs/vqgan_arch.py +1203 -0
  26. Text2Human/models/hierarchy_inference_model.py +363 -0
  27. Text2Human/models/hierarchy_vqgan_model.py +374 -0
  28. Text2Human/models/losses/__init__.py +0 -0
  29. Text2Human/models/losses/accuracy.py +46 -0
  30. Text2Human/models/losses/cross_entropy_loss.py +246 -0
  31. Text2Human/models/losses/segmentation_loss.py +25 -0
  32. Text2Human/models/losses/vqgan_loss.py +114 -0
  33. Text2Human/models/parsing_gen_model.py +220 -0
  34. Text2Human/models/sample_model.py +500 -0
  35. Text2Human/models/transformer_model.py +482 -0
  36. Text2Human/models/vqgan_model.py +551 -0
  37. Text2Human/sample_from_parsing.py +53 -0
  38. Text2Human/sample_from_pose.py +52 -0
  39. Text2Human/train_index_prediction.py +133 -0
  40. Text2Human/train_parsing_gen.py +136 -0
  41. Text2Human/train_parsing_token.py +122 -0
  42. Text2Human/train_sampler.py +122 -0
  43. Text2Human/train_vqvae.py +132 -0
  44. Text2Human/ui/__init__.py +0 -0
  45. Text2Human/ui/mouse_event.py +129 -0
  46. Text2Human/ui/ui.py +313 -0
  47. Text2Human/ui_demo.py +285 -0
  48. Text2Human/ui_util/__init__.py +0 -0
  49. Text2Human/ui_util/config.py +25 -0
  50. Text2Human/utils/__init__.py +0 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
  title: Text2Human
3
- emoji: 🌖
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.1.3
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
1
  ---
2
  title: Text2Human
3
+ emoji: 🏃
4
+ colorFrom: purple
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.0.17
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
Text2Human/.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .cache/
3
+ datasets/*
4
+ experiments/*
5
+ tb_logger/*
6
+ results/*
7
+ *.png
8
+ *.txt
9
+ *.pth
Text2Human/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Yuming Jiang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Text2Human/README.md ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text2Human - Official PyTorch Implementation
2
+
3
+ <!-- <img src="./doc_images/overview.jpg" width="96%" height="96%"> -->
4
+
5
+ This repository provides the official PyTorch implementation for the following paper:
6
+
7
+ **Text2Human: Text-Driven Controllable Human Image Generation**</br>
8
+ [Yuming Jiang](https://yumingj.github.io/), [Shuai Yang](https://williamyang1991.github.io/), [Haonan Qiu](http://haonanqiu.com/), [Wayne Wu](https://dblp.org/pid/50/8731.html), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) and [Ziwei Liu](https://liuziwei7.github.io/)</br>
9
+ In ACM Transactions on Graphics (Proceedings of SIGGRAPH), 2022.
10
+
11
+ From [MMLab@NTU](https://www.mmlab-ntu.com/index.html) affliated with S-Lab, Nanyang Technological University and SenseTime Research.
12
+
13
+ <table>
14
+ <tr>
15
+ <td><img src="assets/1.png" width="100%"/></td>
16
+ <td><img src="assets/2.png" width="100%"/></td>
17
+ <td><img src="assets/3.png" width="100%"/></td>
18
+ <td><img src="assets/4.png" width="100%"/></td>
19
+ </tr>
20
+ <tr>
21
+ <td align='center' width='24%'>The lady wears a short-sleeve T-shirt with pure color pattern, and a short and denim skirt.</td>
22
+ <td align='center' width='24%'>The man wears a long and floral shirt, and long pants with the pure color pattern.</td>
23
+ <td align='center' width='24%'>A lady is wearing a sleeveless pure-color shirt and long jeans</td>
24
+ <td align='center' width='24%'>The man wears a short-sleeve T-shirt with the pure color pattern and a short pants with the pure color pattern.</td>
25
+ <tr>
26
+ </table>
27
+
28
+ [**[Project Page]**](https://yumingj.github.io/projects/Text2Human.html) | [**[Paper]**](https://arxiv.org/pdf/2205.15996.pdf) | [**[Dataset]**](https://github.com/yumingj/DeepFashion-MultiModal) | [**[Demo Video]**](https://youtu.be/yKh4VORA_E0)
29
+
30
+
31
+ ## Updates
32
+
33
+ - [05/2022] Paper and demo video are released.
34
+ - [05/2022] Code is released.
35
+ - [05/2022] This website is created.
36
+
37
+ ## Installation
38
+ **Clone this repo:**
39
+ ```bash
40
+ git clone https://github.com/yumingj/Text2Human.git
41
+ cd Text2Human
42
+ ```
43
+ **Dependencies:**
44
+
45
+ All dependencies for defining the environment are provided in `environment/text2human_env.yaml`.
46
+ We recommend using [Anaconda](https://docs.anaconda.com/anaconda/install/) to manage the python environment:
47
+ ```bash
48
+ conda env create -f ./environment/text2human_env.yaml
49
+ conda activate text2human
50
+ conda install -c huggingface tokenizers=0.9.4
51
+ conda install -c huggingface transformers=4.0.0
52
+ conda install -c conda-forge sentence-transformers=2.0.0
53
+ ```
54
+
55
+ If it doesn't work, you may need to install the following packages on your own:
56
+ - Python 3.6
57
+ - PyTorch 1.7.1
58
+ - CUDA 10.1
59
+ - [sentence-transformers](https://huggingface.co/sentence-transformers) 2.0.0
60
+ - [tokenizers](https://pypi.org/project/tokenizers/) 0.9.4
61
+ - [transformers](https://huggingface.co/docs/transformers/installation) 4.0.0
62
+
63
+ ## (1) Dataset Preparation
64
+
65
+ In this work, we contribute a large-scale high-quality dataset with rich multi-modal annotations named [DeepFashion-MultiModal](https://github.com/yumingj/DeepFashion-MultiModal) Dataset.
66
+ Here we pre-processed the raw annotations of the original dataset for the task of text-driven controllable human image generation. The pre-processing pipeline consists of:
67
+ - align the human body in the center of the images according to the human pose
68
+ - fuse the clothing color and clothing fabric annotations into one texture annotation
69
+ - do some annotation cleaning and image filtering
70
+ - split the whole dataset into the training set and testing set
71
+
72
+ You can download our processed dataset from this [Google Drive](https://drive.google.com/file/d/1KIoFfRZNQVn6RV_wTxG2wZmY8f2T_84B/view?usp=sharing). If you want to access the raw annotations, please refer to the [DeepFashion-MultiModal](https://github.com/yumingj/DeepFashion-MultiModal) Dataset.
73
+
74
+ After downloading the dataset, unzip the file and put them under the dataset folder with the following structure:
75
+ ```
76
+ ./datasets
77
+ ├── train_images
78
+ ├── xxx.png
79
+ ...
80
+ ├── xxx.png
81
+ └── xxx.png
82
+ ├── test_images
83
+ % the same structure as in train_images
84
+ ├── densepose
85
+ % the same structure as in train_images
86
+ ├── segm
87
+ % the same structure as in train_images
88
+ ├── shape_ann
89
+ ├── test_ann_file.txt
90
+ ├── train_ann_file.txt
91
+ └── val_ann_file.txt
92
+ └── texture_ann
93
+ ├── test
94
+ ├── lower_fused.txt
95
+ ├── outer_fused.txt
96
+ └── upper_fused.txt
97
+ ├── train
98
+ % the same files as in test
99
+ └── val
100
+ % the same files as in test
101
+ ```
102
+
103
+ ## (2) Sampling
104
+
105
+ ### Inference Notebook
106
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a></br>
107
+ Coming soon.
108
+
109
+
110
+ ### Pretrained Models
111
+
112
+ Pretrained models can be downloaded from this [Google Drive](https://drive.google.com/file/d/1VyI8_AbPwAUaZJPaPba8zxsFIWumlDen/view?usp=sharing). Unzip the file and put them under the dataset folder with the following structure:
113
+ ```
114
+ pretrained_models
115
+ ├── index_pred_net.pth
116
+ ├── parsing_gen.pth
117
+ ├── parsing_token.pth
118
+ ├── sampler.pth
119
+ ├── vqvae_bottom.pth
120
+ └── vqvae_top.pth
121
+ ```
122
+
123
+ ### Generation from Paring Maps
124
+ You can generate images from given parsing maps and pre-defined texture annotations:
125
+ ```python
126
+ python sample_from_parsing.py -opt ./configs/sample_from_parsing.yml
127
+ ```
128
+ The results are saved in the folder `./results/sampling_from_parsing`.
129
+
130
+ ### Generation from Poses
131
+ You can generate images from given human poses and pre-defined clothing shape and texture annotations:
132
+ ```python
133
+ python sample_from_pose.py -opt ./configs/sample_from_pose.yml
134
+ ```
135
+
136
+ **Remarks**: The above two scripts generate images without language interactions. If you want to generate images using texts, you can use the notebook or our user interface.
137
+
138
+ ### User Interface
139
+
140
+ ```python
141
+ python ui_demo.py
142
+ ```
143
+ <img src="./assets/ui.png" width="100%">
144
+
145
+ The descriptions for shapes should follow the following format:
146
+ ```
147
+ <gender>, <sleeve length>, <length of lower clothing>, <outer clothing type>, <other accessories1>, ...
148
+
149
+ Note: The outer clothing type and accessories can be omitted.
150
+
151
+ Examples:
152
+ man, sleeveless T-shirt, long pants
153
+ woman, short-sleeve T-shirt, short jeans
154
+ ```
155
+
156
+ The descriptions for textures should follow the following format:
157
+ ```
158
+ <upper clothing texture>, <lower clothing texture>, <outer clothing texture>
159
+
160
+ Note: Currently, we only support 5 types of textures, i.e., pure color, stripe/spline, plaid/lattice,
161
+ floral, denim. Your inputs should be restricted to these textures.
162
+ ```
163
+
164
+ ## (3) Training Text2Human
165
+
166
+ ### Stage I: Pose to Parsing
167
+ Train the parsing generation network. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/1MNyFLGqIQcOMg_HhgwCmKqdwfQSjeg_6/view?usp=sharing).
168
+ ```python
169
+ python train_parsing_gen.py -opt ./configs/parsing_gen.yml
170
+ ```
171
+
172
+ ### Stage II: Parsing to Human
173
+
174
+ **Step 1: Train the top level of the hierarchical VQVAE.**
175
+ We provide our pretrained model [here](https://drive.google.com/file/d/1TwypUg85gPFJtMwBLUjVS66FKR3oaTz8/view?usp=sharing). This model is trained by:
176
+ ```python
177
+ python train_vqvae.py -opt ./configs/vqvae_top.yml
178
+ ```
179
+
180
+ **Step 2: Train the bottom level of the hierarchical VQVAE.**
181
+ We provide our pretrained model [here](https://drive.google.com/file/d/15hzbY-RG-ILgzUqqGC0qMzlS4OayPdRH/view?usp=sharing). This model is trained by:
182
+ ```python
183
+ python train_vqvae.py -opt ./configs/vqvae_bottom.yml
184
+ ```
185
+
186
+ **Stage 3 & 4: Train the sampler with mixture-of-experts.** To train the sampler, we first need to train a model to tokenize the parsing maps. You can access our pretrained parsing maps [here](https://drive.google.com/file/d/1GLHoOeCP6sMao1-R63ahJMJF7-J00uir/view?usp=sharing).
187
+ ```python
188
+ python train_parsing_token.py -opt ./configs/parsing_token.yml
189
+ ```
190
+
191
+ With the parsing tokenization model, the sampler is trained by:
192
+ ```python
193
+ python train_sampler.py -opt ./configs/sampler.yml
194
+ ```
195
+ Our pretrained sampler is provided [here](https://drive.google.com/file/d/1OQO_kG2fK7eKiG1VJH1OL782X71UQAmS/view?usp=sharing).
196
+
197
+ **Stage 5: Train the index prediction network.**
198
+ We provide our pretrained index prediction network [here](https://drive.google.com/file/d/1rqhkQD-JGd7YBeIfDvMV-vjfbNHpIhYm/view?usp=sharing). It is trained by:
199
+ ```python
200
+ python train_index_prediction.py -opt ./configs/index_pred_net.yml
201
+ ```
202
+
203
+
204
+ **Remarks**: In the config files, we use the path to our models as the required pretrained models. If you want to train the models from scratch, please replace the path to your own one. We set the numbers of the training epochs as large numbers and you can choose the best epoch for each model. For your reference, our pretrained parsing generation network is trained for 50 epochs, top-level VQVAE is trained for 135 epochs, bottom-level VQVAE is trained for 70 epochs, parsing tokenization network is trained for 20 epochs, sampler is trained for 95 epochs, and the index prediction network is trained for 70 epochs.
205
+
206
+ ## (4) Results
207
+
208
+ Please visit our [Project Page](https://yumingj.github.io/projects/Text2Human.html#results) to view more results.</br>
209
+ You can select the attribtues to customize the desired human images.
210
+ [<img src="./assets/results.png" width="90%">
211
+ ](https://yumingj.github.io/projects/Text2Human.html#results)
212
+
213
+ ## DeepFashion-MultiModal Dataset
214
+
215
+ <img src="./assets/dataset_logo.png" width="90%">
216
+
217
+ In this work, we also propose **DeepFashion-MultiModal**, a large-scale high-quality human dataset with rich multi-modal annotations. It has the following properties:
218
+ 1. It contains 44,096 high-resolution human images, including 12,701 full body human images.
219
+ 2. For each full body images, we **manually annotate** the human parsing labels of 24 classes.
220
+ 3. For each full body images, we **manually annotate** the keypoints.
221
+ 4. We extract DensePose for each human image.
222
+ 5. Each image is **manually annotated** with attributes for both clothes shapes and textures.
223
+ 6. We provide a textual description for each image.
224
+
225
+ <img src="./assets/dataset_overview.png" width="100%">
226
+
227
+ Please refer to [this repo](https://github.com/yumingj/DeepFashion-MultiModal) for more details about our proposed dataset.
228
+
229
+ ## TODO List
230
+
231
+ - [ ] Release 1024x512 version of Text2Human.
232
+ - [ ] Train the Text2Human using [SHHQ dataset](https://stylegan-human.github.io/).
233
+
234
+ ## Citation
235
+
236
+ If you find this work useful for your research, please consider citing our paper:
237
+
238
+ ```bibtex
239
+ @article{jiang2022text2human,
240
+ title={Text2Human: Text-Driven Controllable Human Image Generation},
241
+ author={Jiang, Yuming and Yang, Shuai and Qiu, Haonan and Wu, Wayne and Loy, Chen Change and Liu, Ziwei},
242
+ journal={ACM Transactions on Graphics (TOG)},
243
+ volume={41},
244
+ number={4},
245
+ articleno={162},
246
+ pages={1--11},
247
+ year={2022},
248
+ publisher={ACM New York, NY, USA},
249
+ doi={10.1145/3528223.3530104},
250
+ }
251
+ ```
252
+
253
+ ## Acknowledgments
254
+
255
+ Part of the code is borrowed from [unleashing-transformers](https://github.com/samb-t/unleashing-transformers), [taming-transformers](https://github.com/CompVis/taming-transformers) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation).
Text2Human/configs/index_pred_net.yml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: index_prediction_network
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ train_img_dir: ./datasets/train_images
10
+ test_img_dir: ./datasets/test_images
11
+ segm_dir: ./datasets/segm
12
+ pose_dir: ./datasets/densepose
13
+ train_ann_file: ./datasets/texture_ann/train
14
+ val_ann_file: ./datasets/texture_ann/val
15
+ test_ann_file: ./datasets/texture_ann/test
16
+ downsample_factor: 2
17
+
18
+ model_type: VQGANTextureAwareSpatialHierarchyInferenceModel
19
+ # network configs
20
+ embed_dim: 256
21
+ n_embed: 1024
22
+ codebook_spatial_size: 2
23
+
24
+ # bottom level vqvae
25
+ bot_n_embed: 512
26
+ bot_double_z: false
27
+ bot_z_channels: 256
28
+ bot_resolution: 512
29
+ bot_in_channels: 3
30
+ bot_out_ch: 3
31
+ bot_ch: 128
32
+ bot_ch_mult: [1, 1, 2, 4]
33
+ bot_num_res_blocks: 2
34
+ bot_attn_resolutions: [64]
35
+ bot_dropout: 0.0
36
+ bot_vae_path: ./pretrained_models/vqvae_bottom.pth
37
+
38
+ # top level vqgan
39
+ top_double_z: false
40
+ top_z_channels: 256
41
+ top_resolution: 512
42
+ top_in_channels: 3
43
+ top_out_ch: 3
44
+ top_ch: 128
45
+ top_ch_mult: [1, 1, 2, 2, 4]
46
+ top_num_res_blocks: 2
47
+ top_attn_resolutions: [32]
48
+ top_dropout: 0.0
49
+ top_vae_path: ./pretrained_models/vqvae_top.pth
50
+
51
+ # unet configs
52
+ encoder_in_channels: 256
53
+ fc_in_channels: 64
54
+ fc_in_index: 4
55
+ fc_channels: 64
56
+ fc_num_convs: 1
57
+ fc_concat_input: False
58
+ fc_dropout_ratio: 0.1
59
+ fc_num_classes: 512
60
+ fc_align_corners: False
61
+
62
+ disc_layers: 3
63
+ disc_weight_max: 1
64
+ disc_start_step: 30001
65
+ n_channels: 3
66
+ ndf: 64
67
+ nf: 128
68
+ perceptual_weight: 1.0
69
+
70
+ num_segm_classes: 24
71
+
72
+ # training configs
73
+ val_freq: 5
74
+ print_freq: 100
75
+ weight_decay: 0
76
+ manual_seed: 2021
77
+ num_epochs: 100
78
+ lr: !!float 1.0e-04
79
+ lr_decay: step
80
+ gamma: 1.0
81
+ step: 50
82
+ optimizer: Adam
83
+ loss_function: cross_entropy
84
+
Text2Human/configs/parsing_gen.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: parsing_generation
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 8
8
+ num_workers: 4
9
+ segm_dir: ./datasets/segm
10
+ pose_dir: ./datasets/densepose
11
+ train_ann_file: ./datasets/shape_ann/train_ann_file.txt
12
+ val_ann_file: ./datasets/shape_ann/val_ann_file.txt
13
+ test_ann_file: ./datasets/shape_ann/test_ann_file.txt
14
+ downsample_factor: 2
15
+
16
+ model_type: ParsingGenModel
17
+ # network configs
18
+ embedder_dim: 8
19
+ embedder_out_dim: 128
20
+ attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
21
+ encoder_in_channels: 1
22
+ fc_in_channels: 64
23
+ fc_in_index: 4
24
+ fc_channels: 64
25
+ fc_num_convs: 1
26
+ fc_concat_input: False
27
+ fc_dropout_ratio: 0.1
28
+ fc_num_classes: 24
29
+ fc_align_corners: False
30
+
31
+ # training configs
32
+ val_freq: 5
33
+ print_freq: 100
34
+ weight_decay: 0
35
+ manual_seed: 2021
36
+ num_epochs: 100
37
+ lr: !!float 1e-4
38
+ lr_decay: step
39
+ gamma: 0.1
40
+ step: 50
Text2Human/configs/parsing_token.yml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: parsing_tokenization
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ train_img_dir: ./datasets/train_images
10
+ test_img_dir: ./datasets/test_images
11
+ segm_dir: ./datasets/segm
12
+ pose_dir: ./datasets/densepose
13
+ train_ann_file: ./datasets/texture_ann/train
14
+ val_ann_file: ./datasets/texture_ann/val
15
+ test_ann_file: ./datasets/texture_ann/test
16
+ downsample_factor: 2
17
+
18
+ model_type: VQSegmentationModel
19
+ # network configs
20
+ embed_dim: 32
21
+ n_embed: 1024
22
+ image_key: "segmentation"
23
+ n_labels: 24
24
+ double_z: false
25
+ z_channels: 32
26
+ resolution: 512
27
+ in_channels: 24
28
+ out_ch: 24
29
+ ch: 64
30
+ ch_mult: [1, 1, 2, 2, 4]
31
+ num_res_blocks: 1
32
+ attn_resolutions: [16]
33
+ dropout: 0.0
34
+
35
+ num_segm_classes: 24
36
+
37
+
38
+ # training configs
39
+ val_freq: 5
40
+ print_freq: 100
41
+ weight_decay: 0
42
+ manual_seed: 2021
43
+ num_epochs: 100
44
+ lr: !!float 4.5e-05
45
+ lr_decay: step
46
+ gamma: 0.1
47
+ step: 50
Text2Human/configs/sample_from_parsing.yml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sample_from_parsing
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ test_img_dir: ./datasets/test_images
10
+ segm_dir: ./datasets/segm
11
+ pose_dir: ./datasets/densepose
12
+ test_ann_file: ./datasets/texture_ann/test
13
+ downsample_factor: 2
14
+
15
+ model_type: SampleFromParsingModel
16
+ # network configs
17
+ embed_dim: 256
18
+ n_embed: 1024
19
+ codebook_spatial_size: 2
20
+
21
+ # bottom level vqvae
22
+ bot_n_embed: 512
23
+ bot_codebook_spatial_size: 2
24
+ bot_double_z: false
25
+ bot_z_channels: 256
26
+ bot_resolution: 512
27
+ bot_in_channels: 3
28
+ bot_out_ch: 3
29
+ bot_ch: 128
30
+ bot_ch_mult: [1, 1, 2, 4]
31
+ bot_num_res_blocks: 2
32
+ bot_attn_resolutions: [64]
33
+ bot_dropout: 0.0
34
+ bot_vae_path: ./pretrained_models/vqvae_bottom.pth
35
+
36
+ # top level vqgan
37
+ top_double_z: false
38
+ top_z_channels: 256
39
+ top_resolution: 512
40
+ top_in_channels: 3
41
+ top_out_ch: 3
42
+ top_ch: 128
43
+ top_ch_mult: [1, 1, 2, 2, 4]
44
+ top_num_res_blocks: 2
45
+ top_attn_resolutions: [32]
46
+ top_dropout: 0.0
47
+ top_vae_path: ./pretrained_models/vqvae_top.pth
48
+
49
+ # unet configs
50
+ index_pred_encoder_in_channels: 256
51
+ index_pred_fc_in_channels: 64
52
+ index_pred_fc_in_index: 4
53
+ index_pred_fc_channels: 64
54
+ index_pred_fc_num_convs: 1
55
+ index_pred_fc_concat_input: False
56
+ index_pred_fc_dropout_ratio: 0.1
57
+ index_pred_fc_num_classes: 512
58
+ index_pred_fc_align_corners: False
59
+ pretrained_index_network: ./pretrained_models/index_pred_net.pth
60
+
61
+ # segmentation tokenization
62
+ segm_double_z: false
63
+ segm_z_channels: 32
64
+ segm_resolution: 512
65
+ segm_in_channels: 24
66
+ segm_out_ch: 24
67
+ segm_ch: 64
68
+ segm_ch_mult: [1, 1, 2, 2, 4]
69
+ segm_num_res_blocks: 1
70
+ segm_attn_resolutions: [16]
71
+ segm_dropout: 0.0
72
+ segm_num_segm_classes: 24
73
+ segm_n_embed: 1024
74
+ segm_embed_dim: 32
75
+ segm_token_path: ./pretrained_models/parsing_token.pth
76
+
77
+ # sampler configs
78
+ codebook_size: 18432
79
+ segm_codebook_size: 1024
80
+ texture_codebook_size: 18
81
+ bert_n_emb: 512
82
+ bert_n_layers: 24
83
+ bert_n_head: 8
84
+ block_size: 512 # 32 x 16
85
+ latent_shape: [32, 16]
86
+ embd_pdrop: 0.0
87
+ resid_pdrop: 0.0
88
+ attn_pdrop: 0.0
89
+ num_head: 18
90
+ pretrained_sampler: ./pretrained_models/sampler.pth
91
+
92
+ manual_seed: 2021
93
+ sample_steps: 256
Text2Human/configs/sample_from_pose.yml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sample_from_pose
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ pose_dir: ./datasets/densepose
10
+ texture_ann_file: ./datasets/texture_ann/test
11
+ shape_ann_path: ./datasets/shape_ann/test_ann_file.txt
12
+ downsample_factor: 2
13
+
14
+ model_type: SampleFromPoseModel
15
+ # network configs
16
+ embed_dim: 256
17
+ n_embed: 1024
18
+ codebook_spatial_size: 2
19
+
20
+ # bottom level vqgan
21
+ bot_n_embed: 512
22
+ bot_codebook_spatial_size: 2
23
+ bot_double_z: false
24
+ bot_z_channels: 256
25
+ bot_resolution: 512
26
+ bot_in_channels: 3
27
+ bot_out_ch: 3
28
+ bot_ch: 128
29
+ bot_ch_mult: [1, 1, 2, 4]
30
+ bot_num_res_blocks: 2
31
+ bot_attn_resolutions: [64]
32
+ bot_dropout: 0.0
33
+ bot_vae_path: ./pretrained_models/vqvae_bottom.pth
34
+
35
+ # top level vqgan
36
+ top_double_z: false
37
+ top_z_channels: 256
38
+ top_resolution: 512
39
+ top_in_channels: 3
40
+ top_out_ch: 3
41
+ top_ch: 128
42
+ top_ch_mult: [1, 1, 2, 2, 4]
43
+ top_num_res_blocks: 2
44
+ top_attn_resolutions: [32]
45
+ top_dropout: 0.0
46
+ top_vae_path: ./pretrained_models/vqvae_top.pth
47
+
48
+ # unet configs
49
+ index_pred_encoder_in_channels: 256
50
+ index_pred_fc_in_channels: 64
51
+ index_pred_fc_in_index: 4
52
+ index_pred_fc_channels: 64
53
+ index_pred_fc_num_convs: 1
54
+ index_pred_fc_concat_input: False
55
+ index_pred_fc_dropout_ratio: 0.1
56
+ index_pred_fc_num_classes: 512
57
+ index_pred_fc_align_corners: False
58
+ pretrained_index_network: ./pretrained_models/index_pred_net.pth
59
+
60
+ # segmentation tokenization
61
+ segm_double_z: false
62
+ segm_z_channels: 32
63
+ segm_resolution: 512
64
+ segm_in_channels: 24
65
+ segm_out_ch: 24
66
+ segm_ch: 64
67
+ segm_ch_mult: [1, 1, 2, 2, 4]
68
+ segm_num_res_blocks: 1
69
+ segm_attn_resolutions: [16]
70
+ segm_dropout: 0.0
71
+ segm_num_segm_classes: 24
72
+ segm_n_embed: 1024
73
+ segm_embed_dim: 32
74
+ segm_token_path: ./pretrained_models/parsing_token.pth
75
+
76
+ # sampler configs
77
+ codebook_size: 18432
78
+ segm_codebook_size: 1024
79
+ texture_codebook_size: 18
80
+ bert_n_emb: 512
81
+ bert_n_layers: 24
82
+ bert_n_head: 8
83
+ block_size: 512 # 32 x 16
84
+ latent_shape: [32, 16]
85
+ embd_pdrop: 0.0
86
+ resid_pdrop: 0.0
87
+ attn_pdrop: 0.0
88
+ num_head: 18
89
+ pretrained_sampler: ./pretrained_models/sampler.pth
90
+
91
+ # shape network configs
92
+ shape_embedder_dim: 8
93
+ shape_embedder_out_dim: 128
94
+ shape_attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
95
+ shape_encoder_in_channels: 1
96
+ shape_fc_in_channels: 64
97
+ shape_fc_in_index: 4
98
+ shape_fc_channels: 64
99
+ shape_fc_num_convs: 1
100
+ shape_fc_concat_input: False
101
+ shape_fc_dropout_ratio: 0.1
102
+ shape_fc_num_classes: 24
103
+ shape_fc_align_corners: False
104
+ pretrained_parsing_gen: ./pretrained_models/parsing_gen.pth
105
+
106
+ manual_seed: 2021
107
+ sample_steps: 256
Text2Human/configs/sampler.yml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sampler
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 1
9
+ train_img_dir: ./datasets/train_images
10
+ test_img_dir: ./datasets/test_images
11
+ segm_dir: ./datasets/segm
12
+ pose_dir: ./datasets/densepose
13
+ train_ann_file: ./datasets/texture_ann/train
14
+ val_ann_file: ./datasets/texture_ann/val
15
+ test_ann_file: ./datasets/texture_ann/test
16
+ downsample_factor: 2
17
+
18
+ # pretrained models
19
+ img_ae_path: ./pretrained_models/vqvae_top.pth
20
+ segm_ae_path: ./pretrained_models/parsing_token.pth
21
+
22
+ model_type: TransformerTextureAwareModel
23
+ # network configs
24
+
25
+ # image autoencoder
26
+ img_embed_dim: 256
27
+ img_n_embed: 1024
28
+ img_double_z: false
29
+ img_z_channels: 256
30
+ img_resolution: 512
31
+ img_in_channels: 3
32
+ img_out_ch: 3
33
+ img_ch: 128
34
+ img_ch_mult: [1, 1, 2, 2, 4]
35
+ img_num_res_blocks: 2
36
+ img_attn_resolutions: [32]
37
+ img_dropout: 0.0
38
+
39
+ # segmentation tokenization
40
+ segm_double_z: false
41
+ segm_z_channels: 32
42
+ segm_resolution: 512
43
+ segm_in_channels: 24
44
+ segm_out_ch: 24
45
+ segm_ch: 64
46
+ segm_ch_mult: [1, 1, 2, 2, 4]
47
+ segm_num_res_blocks: 1
48
+ segm_attn_resolutions: [16]
49
+ segm_dropout: 0.0
50
+ segm_num_segm_classes: 24
51
+ segm_n_embed: 1024
52
+ segm_embed_dim: 32
53
+
54
+ # sampler configs
55
+ codebook_size: 18432
56
+ segm_codebook_size: 1024
57
+ texture_codebook_size: 18
58
+ bert_n_emb: 512
59
+ bert_n_layers: 24
60
+ bert_n_head: 8
61
+ block_size: 512 # 32 x 16
62
+ latent_shape: [32, 16]
63
+ embd_pdrop: 0.0
64
+ resid_pdrop: 0.0
65
+ attn_pdrop: 0.0
66
+ num_head: 18
67
+
68
+ # loss configs
69
+ loss_type: reweighted_elbo
70
+ mask_schedule: random
71
+
72
+ sample_steps: 256
73
+
74
+ # training configs
75
+ val_freq: 5
76
+ print_freq: 100
77
+ weight_decay: 0
78
+ manual_seed: 2021
79
+ num_epochs: 100
80
+ lr: !!float 1e-4
81
+ lr_decay: step
82
+ gamma: 1.0
83
+ step: 50
Text2Human/configs/vqvae_bottom.yml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: vqvae_bottom
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ train_img_dir: ./datasets/train_images
10
+ test_img_dir: ./datasets/test_images
11
+ segm_dir: ./datasets/segm
12
+ pose_dir: ./datasets/densepose
13
+ train_ann_file: ./datasets/texture_ann/train
14
+ val_ann_file: ./datasets/texture_ann/val
15
+ test_ann_file: ./datasets/texture_ann/test
16
+ downsample_factor: 2
17
+
18
+ model_type: HierarchyVQSpatialTextureAwareModel
19
+ # network configs
20
+ embed_dim: 256
21
+ n_embed: 1024
22
+ codebook_spatial_size: 2
23
+
24
+ # bottom level vqvae
25
+ bot_n_embed: 512
26
+ bot_double_z: false
27
+ bot_z_channels: 256
28
+ bot_resolution: 512
29
+ bot_in_channels: 3
30
+ bot_out_ch: 3
31
+ bot_ch: 128
32
+ bot_ch_mult: [1, 1, 2, 4]
33
+ bot_num_res_blocks: 2
34
+ bot_attn_resolutions: [64]
35
+ bot_dropout: 0.0
36
+
37
+ # top level vqgan
38
+ top_double_z: false
39
+ top_z_channels: 256
40
+ top_resolution: 512
41
+ top_in_channels: 3
42
+ top_out_ch: 3
43
+ top_ch: 128
44
+ top_ch_mult: [1, 1, 2, 2, 4]
45
+ top_num_res_blocks: 2
46
+ top_attn_resolutions: [32]
47
+ top_dropout: 0.0
48
+ top_vae_path: ./pretrained_models/vqvae_top.pth
49
+
50
+ fix_decoder: false
51
+
52
+ disc_layers: 3
53
+ disc_weight_max: 1
54
+ disc_start_step: 1
55
+ n_channels: 3
56
+ ndf: 64
57
+ nf: 128
58
+ perceptual_weight: 1.0
59
+
60
+ num_segm_classes: 24
61
+
62
+ # training configs
63
+ val_freq: 5
64
+ print_freq: 100
65
+ weight_decay: 0
66
+ manual_seed: 2021
67
+ num_epochs: 1000
68
+ lr: !!float 1.0e-04
69
+ lr_decay: step
70
+ gamma: 1.0
71
+ step: 50
72
+
Text2Human/configs/vqvae_top.yml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: vqvae_top
2
+ use_tb_logger: true
3
+ set_CUDA_VISIBLE_DEVICES: ~
4
+ gpu_ids: [3]
5
+
6
+ # dataset configs
7
+ batch_size: 4
8
+ num_workers: 4
9
+ train_img_dir: ./datasets/train_images
10
+ test_img_dir: ./datasets/test_images
11
+ segm_dir: ./datasets/segm
12
+ pose_dir: ./datasets/densepose
13
+ train_ann_file: ./datasets/texture_ann/train
14
+ val_ann_file: ./datasets/texture_ann/val
15
+ test_ann_file: ./datasets/texture_ann/test
16
+ downsample_factor: 2
17
+
18
+ model_type: VQImageSegmTextureModel
19
+ # network configs
20
+ embed_dim: 256
21
+ n_embed: 1024
22
+ double_z: false
23
+ z_channels: 256
24
+ resolution: 512
25
+ in_channels: 3
26
+ out_ch: 3
27
+ ch: 128
28
+ ch_mult: [1, 1, 2, 2, 4]
29
+ num_res_blocks: 2
30
+ attn_resolutions: [32]
31
+ dropout: 0.0
32
+
33
+ disc_layers: 3
34
+ disc_weight_max: 0
35
+ disc_start_step: 3000000000000000000000000001
36
+ n_channels: 3
37
+ ndf: 64
38
+ nf: 128
39
+ perceptual_weight: 1.0
40
+
41
+ num_segm_classes: 24
42
+
43
+
44
+ # training configs
45
+ val_freq: 5
46
+ print_freq: 100
47
+ weight_decay: 0
48
+ manual_seed: 2021
49
+ num_epochs: 1000
50
+ lr: !!float 1.0e-04
51
+ lr_decay: step
52
+ gamma: 1.0
53
+ step: 50
Text2Human/data/__init__.py ADDED
File without changes
Text2Human/data/mask_dataset.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+
10
+
11
+ class MaskDataset(data.Dataset):
12
+
13
+ def __init__(self, segm_dir, ann_dir, downsample_factor=2, xflip=False):
14
+
15
+ self._segm_path = segm_dir
16
+ self._image_fnames = []
17
+
18
+ self.downsample_factor = downsample_factor
19
+ self.xflip = xflip
20
+
21
+ # load attributes
22
+ assert os.path.exists(f'{ann_dir}/upper_fused.txt')
23
+ for idx, row in enumerate(
24
+ open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')):
25
+ annotations = row.split()
26
+ self._image_fnames.append(annotations[0])
27
+
28
+ def _open_file(self, path_prefix, fname):
29
+ return open(os.path.join(path_prefix, fname), 'rb')
30
+
31
+ def _load_segm(self, raw_idx):
32
+ fname = self._image_fnames[raw_idx]
33
+ fname = f'{fname[:-4]}_segm.png'
34
+ with self._open_file(self._segm_path, fname) as f:
35
+ segm = Image.open(f)
36
+ if self.downsample_factor != 1:
37
+ width, height = segm.size
38
+ width = width // self.downsample_factor
39
+ height = height // self.downsample_factor
40
+ segm = segm.resize(
41
+ size=(width, height), resample=Image.NEAREST)
42
+ segm = np.array(segm)
43
+ # segm = segm[:, :, np.newaxis].transpose(2, 0, 1)
44
+ return segm.astype(np.float32)
45
+
46
+ def __getitem__(self, index):
47
+ segm = self._load_segm(index)
48
+
49
+ if self.xflip and random.random() > 0.5:
50
+ segm = segm[:, ::-1].copy()
51
+
52
+ segm = torch.from_numpy(segm).long()
53
+
54
+ return_dict = {'segm': segm, 'img_name': self._image_fnames[index]}
55
+
56
+ return return_dict
57
+
58
+ def __len__(self):
59
+ return len(self._image_fnames)
Text2Human/data/parsing_generation_segm_attr_dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data as data
7
+ from PIL import Image
8
+
9
+
10
+ class ParsingGenerationDeepFashionAttrSegmDataset(data.Dataset):
11
+
12
+ def __init__(self, segm_dir, pose_dir, ann_file, downsample_factor=2):
13
+ self._densepose_path = pose_dir
14
+ self._segm_path = segm_dir
15
+ self._image_fnames = []
16
+ self.attrs = []
17
+
18
+ self.downsample_factor = downsample_factor
19
+
20
+ # training, ground-truth available
21
+ assert os.path.exists(ann_file)
22
+ for row in open(os.path.join(ann_file), 'r'):
23
+ annotations = row.split()
24
+ self._image_fnames.append(annotations[0])
25
+ self.attrs.append([int(i) for i in annotations[1:]])
26
+
27
+ def _open_file(self, path_prefix, fname):
28
+ return open(os.path.join(path_prefix, fname), 'rb')
29
+
30
+ def _load_densepose(self, raw_idx):
31
+ fname = self._image_fnames[raw_idx]
32
+ fname = f'{fname[:-4]}_densepose.png'
33
+ with self._open_file(self._densepose_path, fname) as f:
34
+ densepose = Image.open(f)
35
+ if self.downsample_factor != 1:
36
+ width, height = densepose.size
37
+ width = width // self.downsample_factor
38
+ height = height // self.downsample_factor
39
+ densepose = densepose.resize(
40
+ size=(width, height), resample=Image.NEAREST)
41
+ # channel-wise IUV order, [3, H, W]
42
+ densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
43
+ return densepose.astype(np.float32)
44
+
45
+ def _load_segm(self, raw_idx):
46
+ fname = self._image_fnames[raw_idx]
47
+ fname = f'{fname[:-4]}_segm.png'
48
+ with self._open_file(self._segm_path, fname) as f:
49
+ segm = Image.open(f)
50
+ if self.downsample_factor != 1:
51
+ width, height = segm.size
52
+ width = width // self.downsample_factor
53
+ height = height // self.downsample_factor
54
+ segm = segm.resize(
55
+ size=(width, height), resample=Image.NEAREST)
56
+ segm = np.array(segm)
57
+ return segm.astype(np.float32)
58
+
59
+ def __getitem__(self, index):
60
+ pose = self._load_densepose(index)
61
+ segm = self._load_segm(index)
62
+ attr = self.attrs[index]
63
+
64
+ pose = torch.from_numpy(pose)
65
+ segm = torch.LongTensor(segm)
66
+ attr = torch.LongTensor(attr)
67
+
68
+ pose = pose / 12. - 1
69
+
70
+ return_dict = {
71
+ 'densepose': pose,
72
+ 'segm': segm,
73
+ 'attr': attr,
74
+ 'img_name': self._image_fnames[index]
75
+ }
76
+
77
+ return return_dict
78
+
79
+ def __len__(self):
80
+ return len(self._image_fnames)
Text2Human/data/pose_attr_dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+
10
+
11
+ class DeepFashionAttrPoseDataset(data.Dataset):
12
+
13
+ def __init__(self,
14
+ pose_dir,
15
+ texture_ann_dir,
16
+ shape_ann_path,
17
+ downsample_factor=2,
18
+ xflip=False):
19
+ self._densepose_path = pose_dir
20
+ self._image_fnames_target = []
21
+ self._image_fnames = []
22
+ self.upper_fused_attrs = []
23
+ self.lower_fused_attrs = []
24
+ self.outer_fused_attrs = []
25
+ self.shape_attrs = []
26
+
27
+ self.downsample_factor = downsample_factor
28
+ self.xflip = xflip
29
+
30
+ # load attributes
31
+ assert os.path.exists(f'{texture_ann_dir}/upper_fused.txt')
32
+ for idx, row in enumerate(
33
+ open(os.path.join(f'{texture_ann_dir}/upper_fused.txt'), 'r')):
34
+ annotations = row.split()
35
+ self._image_fnames_target.append(annotations[0])
36
+ self._image_fnames.append(f'{annotations[0].split(".")[0]}.png')
37
+ self.upper_fused_attrs.append(int(annotations[1]))
38
+
39
+ assert len(self._image_fnames_target) == len(self.upper_fused_attrs)
40
+
41
+ assert os.path.exists(f'{texture_ann_dir}/lower_fused.txt')
42
+ for idx, row in enumerate(
43
+ open(os.path.join(f'{texture_ann_dir}/lower_fused.txt'), 'r')):
44
+ annotations = row.split()
45
+ assert self._image_fnames_target[idx] == annotations[0]
46
+ self.lower_fused_attrs.append(int(annotations[1]))
47
+
48
+ assert len(self._image_fnames_target) == len(self.lower_fused_attrs)
49
+
50
+ assert os.path.exists(f'{texture_ann_dir}/outer_fused.txt')
51
+ for idx, row in enumerate(
52
+ open(os.path.join(f'{texture_ann_dir}/outer_fused.txt'), 'r')):
53
+ annotations = row.split()
54
+ assert self._image_fnames_target[idx] == annotations[0]
55
+ self.outer_fused_attrs.append(int(annotations[1]))
56
+
57
+ assert len(self._image_fnames_target) == len(self.outer_fused_attrs)
58
+
59
+ assert os.path.exists(shape_ann_path)
60
+ for idx, row in enumerate(open(os.path.join(shape_ann_path), 'r')):
61
+ annotations = row.split()
62
+ assert self._image_fnames_target[idx] == annotations[0]
63
+ self.shape_attrs.append([int(i) for i in annotations[1:]])
64
+
65
+ def _open_file(self, path_prefix, fname):
66
+ return open(os.path.join(path_prefix, fname), 'rb')
67
+
68
+ def _load_densepose(self, raw_idx):
69
+ fname = self._image_fnames[raw_idx]
70
+ fname = f'{fname[:-4]}_densepose.png'
71
+ with self._open_file(self._densepose_path, fname) as f:
72
+ densepose = Image.open(f)
73
+ if self.downsample_factor != 1:
74
+ width, height = densepose.size
75
+ width = width // self.downsample_factor
76
+ height = height // self.downsample_factor
77
+ densepose = densepose.resize(
78
+ size=(width, height), resample=Image.NEAREST)
79
+ # channel-wise IUV order, [3, H, W]
80
+ densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
81
+ return densepose.astype(np.float32)
82
+
83
+ def __getitem__(self, index):
84
+ pose = self._load_densepose(index)
85
+ shape_attr = self.shape_attrs[index]
86
+ shape_attr = torch.LongTensor(shape_attr)
87
+
88
+ if self.xflip and random.random() > 0.5:
89
+ pose = pose[:, :, ::-1].copy()
90
+
91
+ upper_fused_attr = self.upper_fused_attrs[index]
92
+ lower_fused_attr = self.lower_fused_attrs[index]
93
+ outer_fused_attr = self.outer_fused_attrs[index]
94
+
95
+ pose = pose / 12. - 1
96
+
97
+ return_dict = {
98
+ 'densepose': pose,
99
+ 'img_name': self._image_fnames_target[index],
100
+ 'shape_attr': shape_attr,
101
+ 'upper_fused_attr': upper_fused_attr,
102
+ 'lower_fused_attr': lower_fused_attr,
103
+ 'outer_fused_attr': outer_fused_attr,
104
+ }
105
+
106
+ return return_dict
107
+
108
+ def __len__(self):
109
+ return len(self._image_fnames)
Text2Human/data/segm_attr_dataset.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+
10
+
11
+ class DeepFashionAttrSegmDataset(data.Dataset):
12
+
13
+ def __init__(self,
14
+ img_dir,
15
+ segm_dir,
16
+ pose_dir,
17
+ ann_dir,
18
+ downsample_factor=2,
19
+ xflip=False):
20
+ self._img_path = img_dir
21
+ self._densepose_path = pose_dir
22
+ self._segm_path = segm_dir
23
+ self._image_fnames = []
24
+ self.upper_fused_attrs = []
25
+ self.lower_fused_attrs = []
26
+ self.outer_fused_attrs = []
27
+
28
+ self.downsample_factor = downsample_factor
29
+ self.xflip = xflip
30
+
31
+ # load attributes
32
+ assert os.path.exists(f'{ann_dir}/upper_fused.txt')
33
+ for idx, row in enumerate(
34
+ open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')):
35
+ annotations = row.split()
36
+ self._image_fnames.append(annotations[0])
37
+ # assert self._image_fnames[idx] == annotations[0]
38
+ self.upper_fused_attrs.append(int(annotations[1]))
39
+
40
+ assert len(self._image_fnames) == len(self.upper_fused_attrs)
41
+
42
+ assert os.path.exists(f'{ann_dir}/lower_fused.txt')
43
+ for idx, row in enumerate(
44
+ open(os.path.join(f'{ann_dir}/lower_fused.txt'), 'r')):
45
+ annotations = row.split()
46
+ assert self._image_fnames[idx] == annotations[0]
47
+ self.lower_fused_attrs.append(int(annotations[1]))
48
+
49
+ assert len(self._image_fnames) == len(self.lower_fused_attrs)
50
+
51
+ assert os.path.exists(f'{ann_dir}/outer_fused.txt')
52
+ for idx, row in enumerate(
53
+ open(os.path.join(f'{ann_dir}/outer_fused.txt'), 'r')):
54
+ annotations = row.split()
55
+ assert self._image_fnames[idx] == annotations[0]
56
+ self.outer_fused_attrs.append(int(annotations[1]))
57
+
58
+ assert len(self._image_fnames) == len(self.outer_fused_attrs)
59
+
60
+ # remove the overlapping item between upper cls and lower cls
61
+ # cls 21 can appear with upper clothes
62
+ # cls 4 can appear with lower clothes
63
+ self.upper_cls = [1., 4.]
64
+ self.lower_cls = [3., 5., 21.]
65
+ self.outer_cls = [2.]
66
+ self.other_cls = [
67
+ 11., 18., 7., 8., 9., 10., 12., 16., 17., 19., 20., 22., 23., 15.,
68
+ 14., 13., 0., 6.
69
+ ]
70
+
71
+ def _open_file(self, path_prefix, fname):
72
+ return open(os.path.join(path_prefix, fname), 'rb')
73
+
74
+ def _load_raw_image(self, raw_idx):
75
+ fname = self._image_fnames[raw_idx]
76
+ with self._open_file(self._img_path, fname) as f:
77
+ image = Image.open(f)
78
+ if self.downsample_factor != 1:
79
+ width, height = image.size
80
+ width = width // self.downsample_factor
81
+ height = height // self.downsample_factor
82
+ image = image.resize(
83
+ size=(width, height), resample=Image.LANCZOS)
84
+ image = np.array(image)
85
+ if image.ndim == 2:
86
+ image = image[:, :, np.newaxis] # HW => HWC
87
+ image = image.transpose(2, 0, 1) # HWC => CHW
88
+ return image
89
+
90
+ def _load_densepose(self, raw_idx):
91
+ fname = self._image_fnames[raw_idx]
92
+ fname = f'{fname[:-4]}_densepose.png'
93
+ with self._open_file(self._densepose_path, fname) as f:
94
+ densepose = Image.open(f)
95
+ if self.downsample_factor != 1:
96
+ width, height = densepose.size
97
+ width = width // self.downsample_factor
98
+ height = height // self.downsample_factor
99
+ densepose = densepose.resize(
100
+ size=(width, height), resample=Image.NEAREST)
101
+ # channel-wise IUV order, [3, H, W]
102
+ densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
103
+ return densepose.astype(np.float32)
104
+
105
+ def _load_segm(self, raw_idx):
106
+ fname = self._image_fnames[raw_idx]
107
+ fname = f'{fname[:-4]}_segm.png'
108
+ with self._open_file(self._segm_path, fname) as f:
109
+ segm = Image.open(f)
110
+ if self.downsample_factor != 1:
111
+ width, height = segm.size
112
+ width = width // self.downsample_factor
113
+ height = height // self.downsample_factor
114
+ segm = segm.resize(
115
+ size=(width, height), resample=Image.NEAREST)
116
+ segm = np.array(segm)
117
+ segm = segm[:, :, np.newaxis].transpose(2, 0, 1)
118
+ return segm.astype(np.float32)
119
+
120
+ def __getitem__(self, index):
121
+ image = self._load_raw_image(index)
122
+ pose = self._load_densepose(index)
123
+ segm = self._load_segm(index)
124
+
125
+ if self.xflip and random.random() > 0.5:
126
+ assert image.ndim == 3 # CHW
127
+ image = image[:, :, ::-1].copy()
128
+ pose = pose[:, :, ::-1].copy()
129
+ segm = segm[:, :, ::-1].copy()
130
+
131
+ image = torch.from_numpy(image)
132
+ segm = torch.from_numpy(segm)
133
+
134
+ upper_fused_attr = self.upper_fused_attrs[index]
135
+ lower_fused_attr = self.lower_fused_attrs[index]
136
+ outer_fused_attr = self.outer_fused_attrs[index]
137
+
138
+ # mask 0: denotes the common codebook,
139
+ # mask (attr + 1): denotes the texture-specific codebook
140
+ mask = torch.zeros_like(segm)
141
+ if upper_fused_attr != 17:
142
+ for cls in self.upper_cls:
143
+ mask[segm == cls] = upper_fused_attr + 1
144
+
145
+ if lower_fused_attr != 17:
146
+ for cls in self.lower_cls:
147
+ mask[segm == cls] = lower_fused_attr + 1
148
+
149
+ if outer_fused_attr != 17:
150
+ for cls in self.outer_cls:
151
+ mask[segm == cls] = outer_fused_attr + 1
152
+
153
+ pose = pose / 12. - 1
154
+ image = image / 127.5 - 1
155
+
156
+ return_dict = {
157
+ 'image': image,
158
+ 'densepose': pose,
159
+ 'segm': segm,
160
+ 'texture_mask': mask,
161
+ 'img_name': self._image_fnames[index]
162
+ }
163
+
164
+ return return_dict
165
+
166
+ def __len__(self):
167
+ return len(self._image_fnames)
Text2Human/environment/text2human_env.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: text2human
2
+ channels:
3
+ - pytorch
4
+ - anaconda
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - astroid=2.5=py36h06a4308_1
9
+ - blas=1.0=mkl
10
+ - brotlipy=0.7.0=py36h7b6447c_1000
11
+ - ca-certificates=2021.10.26=h06a4308_2
12
+ - certifi=2021.5.30=py36h06a4308_0
13
+ - cffi=1.14.3=py36he30daa8_0
14
+ - chardet=3.0.4=py36_1003
15
+ - click=8.0.3=pyhd3eb1b0_0
16
+ - cryptography=3.1.1=py36h1ba5d50_0
17
+ - cudatoolkit=10.1.243=h6bb024c_0
18
+ - dataclasses=0.8=pyh4f3eec9_6
19
+ - dbus=1.13.18=hb2f20db_0
20
+ - expat=2.2.10=he6710b0_2
21
+ - filelock=3.4.0=pyhd3eb1b0_0
22
+ - fontconfig=2.13.0=h9420a91_0
23
+ - freetype=2.10.4=h5ab3b9f_0
24
+ - glib=2.56.2=hd408876_0
25
+ - gst-plugins-base=1.14.0=hbbd80ab_1
26
+ - gstreamer=1.14.0=hb453b48_1
27
+ - icu=58.2=he6710b0_3
28
+ - idna=2.10=py_0
29
+ - importlib-metadata=4.8.1=py36h06a4308_0
30
+ - importlib_metadata=4.8.1=hd3eb1b0_0
31
+ - intel-openmp=2020.2=254
32
+ - isort=5.7.0=pyhd3eb1b0_0
33
+ - joblib=1.0.1=pyhd3eb1b0_0
34
+ - jpeg=9b=habf39ab_1
35
+ - lazy-object-proxy=1.5.2=py36h27cfd23_0
36
+ - lcms2=2.11=h396b838_0
37
+ - ld_impl_linux-64=2.33.1=h53a641e_7
38
+ - libffi=3.3=he6710b0_2
39
+ - libgcc-ng=9.1.0=hdf63c60_0
40
+ - libpng=1.6.37=hbc83047_0
41
+ - libprotobuf=3.17.2=h4ff587b_1
42
+ - libstdcxx-ng=9.1.0=hdf63c60_0
43
+ - libtiff=4.2.0=h3942068_0
44
+ - libuuid=1.0.3=h1bed415_2
45
+ - libuv=1.40.0=h7b6447c_0
46
+ - libwebp-base=1.2.0=h27cfd23_0
47
+ - libxcb=1.14=h7b6447c_0
48
+ - libxml2=2.9.10=hb55368b_3
49
+ - lz4-c=1.9.3=h2531618_0
50
+ - mccabe=0.6.1=py36_1
51
+ - mkl=2020.2=256
52
+ - mkl-service=2.3.0=py36he8ac12f_0
53
+ - mkl_fft=1.3.0=py36h54f3939_0
54
+ - mkl_random=1.1.1=py36h0573a6f_0
55
+ - ncurses=6.2=he6710b0_1
56
+ - ninja=1.10.2=h5e70eb0_2
57
+ - numpy=1.19.2=py36h54aff64_0
58
+ - numpy-base=1.19.2=py36hfa32c7d_0
59
+ - olefile=0.46=py36_0
60
+ - openssl=1.1.1m=h7f8727e_0
61
+ - packaging=21.3=pyhd3eb1b0_0
62
+ - pcre=8.44=he6710b0_0
63
+ - pillow=8.1.2=py36he98fc37_0
64
+ - pip=21.0.1=py36h06a4308_0
65
+ - protobuf=3.17.2=py36h295c915_0
66
+ - pycparser=2.20=py_2
67
+ - pylint=2.7.2=py36h06a4308_1
68
+ - pyopenssl=19.1.0=py_1
69
+ - pyqt=5.9.2=py36h05f1152_2
70
+ - pysocks=1.7.1=py36_0
71
+ - python=3.6.13=hdb3f193_0
72
+ - pytorch=1.7.1=py3.6_cuda10.1.243_cudnn7.6.3_0
73
+ - qt=5.9.7=h5867ecd_1
74
+ - readline=8.1=h27cfd23_0
75
+ - regex=2021.8.3=py36h7f8727e_0
76
+ - requests=2.24.0=py_0
77
+ - setuptools=52.0.0=py36h06a4308_0
78
+ - sip=4.19.8=py36hf484d3e_0
79
+ - six=1.15.0=py36h06a4308_0
80
+ - sqlite=3.35.2=hdfb4753_0
81
+ - tk=8.6.10=hbc83047_0
82
+ - toml=0.10.2=pyhd3eb1b0_0
83
+ - torchvision=0.8.2=py36_cu101
84
+ - tqdm=4.62.3=pyhd3eb1b0_1
85
+ - typed-ast=1.4.2=py36h27cfd23_1
86
+ - typing-extensions=3.10.0.2=hd3eb1b0_0
87
+ - typing_extensions=3.10.0.2=pyh06a4308_0
88
+ - urllib3=1.25.11=py_0
89
+ - wheel=0.36.2=pyhd3eb1b0_0
90
+ - wrapt=1.12.1=py36h7b6447c_1
91
+ - xz=5.2.5=h7b6447c_0
92
+ - yaml=0.2.5=h7b6447c_0
93
+ - zipp=3.6.0=pyhd3eb1b0_0
94
+ - zlib=1.2.11=h7b6447c_3
95
+ - zstd=1.4.5=h9ceee32_0
96
+ - pip:
97
+ - addict==2.4.0
98
+ - cycler==0.11.0
99
+ - einops==0.4.0
100
+ - kiwisolver==1.3.1
101
+ - matplotlib==3.3.4
102
+ - mmcv-full==1.2.1
103
+ - mmsegmentation==0.9.0
104
+ - nltk==3.6.7
105
+ - opencv-python==4.5.5.62
106
+ - pyparsing==3.0.7
107
+ - python-dateutil==2.8.2
108
+ - pyyaml==6.0
109
+ - scikit-learn==0.24.2
110
+ - scipy==1.5.4
111
+ - sentencepiece==0.1.96
112
+ - terminaltables==3.1.10
113
+ - threadpoolctl==3.0.0
114
+ - yapf==0.32.0
Text2Human/models/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import importlib
3
+ import logging
4
+ import os.path as osp
5
+
6
+ # automatically scan and import model modules
7
+ # scan all the files under the 'models' folder and collect files ending with
8
+ # '_model.py'
9
+ model_folder = osp.dirname(osp.abspath(__file__))
10
+ model_filenames = [
11
+ osp.splitext(osp.basename(v))[0]
12
+ for v in glob.glob(f'{model_folder}/*_model.py')
13
+ ]
14
+ # import all the model modules
15
+ _model_modules = [
16
+ importlib.import_module(f'models.{file_name}')
17
+ for file_name in model_filenames
18
+ ]
19
+
20
+
21
+ def create_model(opt):
22
+ """Create model.
23
+
24
+ Args:
25
+ opt (dict): Configuration. It constains:
26
+ model_type (str): Model type.
27
+ """
28
+ model_type = opt['model_type']
29
+
30
+ # dynamically instantiation
31
+ for module in _model_modules:
32
+ model_cls = getattr(module, model_type, None)
33
+ if model_cls is not None:
34
+ break
35
+ if model_cls is None:
36
+ raise ValueError(f'Model {model_type} is not found.')
37
+
38
+ model = model_cls(opt)
39
+
40
+ logger = logging.getLogger('base')
41
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
42
+ return model
Text2Human/models/archs/__init__.py ADDED
File without changes
Text2Human/models/archs/fcn_arch.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from mmcv.cnn import ConvModule, normal_init
4
+ from mmseg.ops import resize
5
+
6
+
7
+ class BaseDecodeHead(nn.Module):
8
+ """Base class for BaseDecodeHead.
9
+
10
+ Args:
11
+ in_channels (int|Sequence[int]): Input channels.
12
+ channels (int): Channels after modules, before conv_seg.
13
+ num_classes (int): Number of classes.
14
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
15
+ conv_cfg (dict|None): Config of conv layers. Default: None.
16
+ norm_cfg (dict|None): Config of norm layers. Default: None.
17
+ act_cfg (dict): Config of activation layers.
18
+ Default: dict(type='ReLU')
19
+ in_index (int|Sequence[int]): Input feature index. Default: -1
20
+ input_transform (str|None): Transformation type of input features.
21
+ Options: 'resize_concat', 'multiple_select', None.
22
+ 'resize_concat': Multiple feature maps will be resize to the
23
+ same size as first one and than concat together.
24
+ Usually used in FCN head of HRNet.
25
+ 'multiple_select': Multiple feature maps will be bundle into
26
+ a list and passed into decode head.
27
+ None: Only one select feature map is allowed.
28
+ Default: None.
29
+ loss_decode (dict): Config of decode loss.
30
+ Default: dict(type='CrossEntropyLoss').
31
+ ignore_index (int | None): The label index to be ignored. When using
32
+ masked BCE loss, ignore_index should be set to None. Default: 255
33
+ sampler (dict|None): The config of segmentation map sampler.
34
+ Default: None.
35
+ align_corners (bool): align_corners argument of F.interpolate.
36
+ Default: False.
37
+ """
38
+
39
+ def __init__(self,
40
+ in_channels,
41
+ channels,
42
+ *,
43
+ num_classes,
44
+ dropout_ratio=0.1,
45
+ conv_cfg=None,
46
+ norm_cfg=dict(type='BN'),
47
+ act_cfg=dict(type='ReLU'),
48
+ in_index=-1,
49
+ input_transform=None,
50
+ ignore_index=255,
51
+ align_corners=False):
52
+ super(BaseDecodeHead, self).__init__()
53
+ self._init_inputs(in_channels, in_index, input_transform)
54
+ self.channels = channels
55
+ self.num_classes = num_classes
56
+ self.dropout_ratio = dropout_ratio
57
+ self.conv_cfg = conv_cfg
58
+ self.norm_cfg = norm_cfg
59
+ self.act_cfg = act_cfg
60
+ self.in_index = in_index
61
+
62
+ self.ignore_index = ignore_index
63
+ self.align_corners = align_corners
64
+
65
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
66
+ if dropout_ratio > 0:
67
+ self.dropout = nn.Dropout2d(dropout_ratio)
68
+ else:
69
+ self.dropout = None
70
+
71
+ def extra_repr(self):
72
+ """Extra repr."""
73
+ s = f'input_transform={self.input_transform}, ' \
74
+ f'ignore_index={self.ignore_index}, ' \
75
+ f'align_corners={self.align_corners}'
76
+ return s
77
+
78
+ def _init_inputs(self, in_channels, in_index, input_transform):
79
+ """Check and initialize input transforms.
80
+
81
+ The in_channels, in_index and input_transform must match.
82
+ Specifically, when input_transform is None, only single feature map
83
+ will be selected. So in_channels and in_index must be of type int.
84
+ When input_transform
85
+
86
+ Args:
87
+ in_channels (int|Sequence[int]): Input channels.
88
+ in_index (int|Sequence[int]): Input feature index.
89
+ input_transform (str|None): Transformation type of input features.
90
+ Options: 'resize_concat', 'multiple_select', None.
91
+ 'resize_concat': Multiple feature maps will be resize to the
92
+ same size as first one and than concat together.
93
+ Usually used in FCN head of HRNet.
94
+ 'multiple_select': Multiple feature maps will be bundle into
95
+ a list and passed into decode head.
96
+ None: Only one select feature map is allowed.
97
+ """
98
+
99
+ if input_transform is not None:
100
+ assert input_transform in ['resize_concat', 'multiple_select']
101
+ self.input_transform = input_transform
102
+ self.in_index = in_index
103
+ if input_transform is not None:
104
+ assert isinstance(in_channels, (list, tuple))
105
+ assert isinstance(in_index, (list, tuple))
106
+ assert len(in_channels) == len(in_index)
107
+ if input_transform == 'resize_concat':
108
+ self.in_channels = sum(in_channels)
109
+ else:
110
+ self.in_channels = in_channels
111
+ else:
112
+ assert isinstance(in_channels, int)
113
+ assert isinstance(in_index, int)
114
+ self.in_channels = in_channels
115
+
116
+ def init_weights(self):
117
+ """Initialize weights of classification layer."""
118
+ normal_init(self.conv_seg, mean=0, std=0.01)
119
+
120
+ def _transform_inputs(self, inputs):
121
+ """Transform inputs for decoder.
122
+
123
+ Args:
124
+ inputs (list[Tensor]): List of multi-level img features.
125
+
126
+ Returns:
127
+ Tensor: The transformed inputs
128
+ """
129
+
130
+ if self.input_transform == 'resize_concat':
131
+ inputs = [inputs[i] for i in self.in_index]
132
+ upsampled_inputs = [
133
+ resize(
134
+ input=x,
135
+ size=inputs[0].shape[2:],
136
+ mode='bilinear',
137
+ align_corners=self.align_corners) for x in inputs
138
+ ]
139
+ inputs = torch.cat(upsampled_inputs, dim=1)
140
+ elif self.input_transform == 'multiple_select':
141
+ inputs = [inputs[i] for i in self.in_index]
142
+ else:
143
+ inputs = inputs[self.in_index]
144
+
145
+ return inputs
146
+
147
+ def forward(self, inputs):
148
+ """Placeholder of forward function."""
149
+ pass
150
+
151
+ def cls_seg(self, feat):
152
+ """Classify each pixel."""
153
+ if self.dropout is not None:
154
+ feat = self.dropout(feat)
155
+ output = self.conv_seg(feat)
156
+ return output
157
+
158
+
159
+ class FCNHead(BaseDecodeHead):
160
+ """Fully Convolution Networks for Semantic Segmentation.
161
+
162
+ This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
163
+
164
+ Args:
165
+ num_convs (int): Number of convs in the head. Default: 2.
166
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
167
+ concat_input (bool): Whether concat the input and output of convs
168
+ before classification layer.
169
+ """
170
+
171
+ def __init__(self,
172
+ num_convs=2,
173
+ kernel_size=3,
174
+ concat_input=True,
175
+ **kwargs):
176
+ assert num_convs >= 0
177
+ self.num_convs = num_convs
178
+ self.concat_input = concat_input
179
+ self.kernel_size = kernel_size
180
+ super(FCNHead, self).__init__(**kwargs)
181
+ if num_convs == 0:
182
+ assert self.in_channels == self.channels
183
+
184
+ convs = []
185
+ convs.append(
186
+ ConvModule(
187
+ self.in_channels,
188
+ self.channels,
189
+ kernel_size=kernel_size,
190
+ padding=kernel_size // 2,
191
+ conv_cfg=self.conv_cfg,
192
+ norm_cfg=self.norm_cfg,
193
+ act_cfg=self.act_cfg))
194
+ for i in range(num_convs - 1):
195
+ convs.append(
196
+ ConvModule(
197
+ self.channels,
198
+ self.channels,
199
+ kernel_size=kernel_size,
200
+ padding=kernel_size // 2,
201
+ conv_cfg=self.conv_cfg,
202
+ norm_cfg=self.norm_cfg,
203
+ act_cfg=self.act_cfg))
204
+ if num_convs == 0:
205
+ self.convs = nn.Identity()
206
+ else:
207
+ self.convs = nn.Sequential(*convs)
208
+ if self.concat_input:
209
+ self.conv_cat = ConvModule(
210
+ self.in_channels + self.channels,
211
+ self.channels,
212
+ kernel_size=kernel_size,
213
+ padding=kernel_size // 2,
214
+ conv_cfg=self.conv_cfg,
215
+ norm_cfg=self.norm_cfg,
216
+ act_cfg=self.act_cfg)
217
+
218
+ def forward(self, inputs):
219
+ """Forward function."""
220
+ x = self._transform_inputs(inputs)
221
+ output = self.convs(x)
222
+ if self.concat_input:
223
+ output = self.conv_cat(torch.cat([x, output], dim=1))
224
+ output = self.cls_seg(output)
225
+ return output
226
+
227
+
228
+ class MultiHeadFCNHead(nn.Module):
229
+ """Fully Convolution Networks for Semantic Segmentation.
230
+
231
+ This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
232
+
233
+ Args:
234
+ num_convs (int): Number of convs in the head. Default: 2.
235
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
236
+ concat_input (bool): Whether concat the input and output of convs
237
+ before classification layer.
238
+ """
239
+
240
+ def __init__(self,
241
+ in_channels,
242
+ channels,
243
+ *,
244
+ num_classes,
245
+ dropout_ratio=0.1,
246
+ conv_cfg=None,
247
+ norm_cfg=dict(type='BN'),
248
+ act_cfg=dict(type='ReLU'),
249
+ in_index=-1,
250
+ input_transform=None,
251
+ ignore_index=255,
252
+ align_corners=False,
253
+ num_convs=2,
254
+ kernel_size=3,
255
+ concat_input=True,
256
+ num_head=18,
257
+ **kwargs):
258
+ super(MultiHeadFCNHead, self).__init__()
259
+ assert num_convs >= 0
260
+ self.num_convs = num_convs
261
+ self.concat_input = concat_input
262
+ self.kernel_size = kernel_size
263
+ self._init_inputs(in_channels, in_index, input_transform)
264
+ self.channels = channels
265
+ self.num_classes = num_classes
266
+ self.dropout_ratio = dropout_ratio
267
+ self.conv_cfg = conv_cfg
268
+ self.norm_cfg = norm_cfg
269
+ self.act_cfg = act_cfg
270
+ self.in_index = in_index
271
+ self.num_head = num_head
272
+
273
+ self.ignore_index = ignore_index
274
+ self.align_corners = align_corners
275
+
276
+ if dropout_ratio > 0:
277
+ self.dropout = nn.Dropout2d(dropout_ratio)
278
+
279
+ conv_seg_head_list = []
280
+ for _ in range(self.num_head):
281
+ conv_seg_head_list.append(
282
+ nn.Conv2d(channels, num_classes, kernel_size=1))
283
+
284
+ self.conv_seg_head_list = nn.ModuleList(conv_seg_head_list)
285
+
286
+ self.init_weights()
287
+
288
+ if num_convs == 0:
289
+ assert self.in_channels == self.channels
290
+
291
+ convs_list = []
292
+ conv_cat_list = []
293
+
294
+ for _ in range(self.num_head):
295
+ convs = []
296
+ convs.append(
297
+ ConvModule(
298
+ self.in_channels,
299
+ self.channels,
300
+ kernel_size=kernel_size,
301
+ padding=kernel_size // 2,
302
+ conv_cfg=self.conv_cfg,
303
+ norm_cfg=self.norm_cfg,
304
+ act_cfg=self.act_cfg))
305
+ for _ in range(num_convs - 1):
306
+ convs.append(
307
+ ConvModule(
308
+ self.channels,
309
+ self.channels,
310
+ kernel_size=kernel_size,
311
+ padding=kernel_size // 2,
312
+ conv_cfg=self.conv_cfg,
313
+ norm_cfg=self.norm_cfg,
314
+ act_cfg=self.act_cfg))
315
+ if num_convs == 0:
316
+ convs_list.append(nn.Identity())
317
+ else:
318
+ convs_list.append(nn.Sequential(*convs))
319
+ if self.concat_input:
320
+ conv_cat_list.append(
321
+ ConvModule(
322
+ self.in_channels + self.channels,
323
+ self.channels,
324
+ kernel_size=kernel_size,
325
+ padding=kernel_size // 2,
326
+ conv_cfg=self.conv_cfg,
327
+ norm_cfg=self.norm_cfg,
328
+ act_cfg=self.act_cfg))
329
+
330
+ self.convs_list = nn.ModuleList(convs_list)
331
+ self.conv_cat_list = nn.ModuleList(conv_cat_list)
332
+
333
+ def forward(self, inputs):
334
+ """Forward function."""
335
+ x = self._transform_inputs(inputs)
336
+
337
+ output_list = []
338
+ for head_idx in range(self.num_head):
339
+ output = self.convs_list[head_idx](x)
340
+ if self.concat_input:
341
+ output = self.conv_cat_list[head_idx](
342
+ torch.cat([x, output], dim=1))
343
+ if self.dropout is not None:
344
+ output = self.dropout(output)
345
+ output = self.conv_seg_head_list[head_idx](output)
346
+ output_list.append(output)
347
+
348
+ return output_list
349
+
350
+ def _init_inputs(self, in_channels, in_index, input_transform):
351
+ """Check and initialize input transforms.
352
+
353
+ The in_channels, in_index and input_transform must match.
354
+ Specifically, when input_transform is None, only single feature map
355
+ will be selected. So in_channels and in_index must be of type int.
356
+ When input_transform
357
+
358
+ Args:
359
+ in_channels (int|Sequence[int]): Input channels.
360
+ in_index (int|Sequence[int]): Input feature index.
361
+ input_transform (str|None): Transformation type of input features.
362
+ Options: 'resize_concat', 'multiple_select', None.
363
+ 'resize_concat': Multiple feature maps will be resize to the
364
+ same size as first one and than concat together.
365
+ Usually used in FCN head of HRNet.
366
+ 'multiple_select': Multiple feature maps will be bundle into
367
+ a list and passed into decode head.
368
+ None: Only one select feature map is allowed.
369
+ """
370
+
371
+ if input_transform is not None:
372
+ assert input_transform in ['resize_concat', 'multiple_select']
373
+ self.input_transform = input_transform
374
+ self.in_index = in_index
375
+ if input_transform is not None:
376
+ assert isinstance(in_channels, (list, tuple))
377
+ assert isinstance(in_index, (list, tuple))
378
+ assert len(in_channels) == len(in_index)
379
+ if input_transform == 'resize_concat':
380
+ self.in_channels = sum(in_channels)
381
+ else:
382
+ self.in_channels = in_channels
383
+ else:
384
+ assert isinstance(in_channels, int)
385
+ assert isinstance(in_index, int)
386
+ self.in_channels = in_channels
387
+
388
+ def init_weights(self):
389
+ """Initialize weights of classification layer."""
390
+ for conv_seg_head in self.conv_seg_head_list:
391
+ normal_init(conv_seg_head, mean=0, std=0.01)
392
+
393
+ def _transform_inputs(self, inputs):
394
+ """Transform inputs for decoder.
395
+
396
+ Args:
397
+ inputs (list[Tensor]): List of multi-level img features.
398
+
399
+ Returns:
400
+ Tensor: The transformed inputs
401
+ """
402
+
403
+ if self.input_transform == 'resize_concat':
404
+ inputs = [inputs[i] for i in self.in_index]
405
+ upsampled_inputs = [
406
+ resize(
407
+ input=x,
408
+ size=inputs[0].shape[2:],
409
+ mode='bilinear',
410
+ align_corners=self.align_corners) for x in inputs
411
+ ]
412
+ inputs = torch.cat(upsampled_inputs, dim=1)
413
+ elif self.input_transform == 'multiple_select':
414
+ inputs = [inputs[i] for i in self.in_index]
415
+ else:
416
+ inputs = inputs[self.in_index]
417
+
418
+ return inputs
Text2Human/models/archs/shape_attr_embedding_arch.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+
6
+ class ShapeAttrEmbedding(nn.Module):
7
+
8
+ def __init__(self, dim, out_dim, cls_num_list):
9
+ super(ShapeAttrEmbedding, self).__init__()
10
+
11
+ for idx, cls_num in enumerate(cls_num_list):
12
+ setattr(
13
+ self, f'attr_{idx}',
14
+ nn.Sequential(
15
+ nn.Linear(cls_num, dim), nn.LeakyReLU(),
16
+ nn.Linear(dim, dim)))
17
+ self.cls_num_list = cls_num_list
18
+ self.attr_num = len(cls_num_list)
19
+ self.fusion = nn.Sequential(
20
+ nn.Linear(dim * self.attr_num, out_dim), nn.LeakyReLU(),
21
+ nn.Linear(out_dim, out_dim))
22
+
23
+ def forward(self, attr):
24
+ attr_embedding_list = []
25
+ for idx in range(self.attr_num):
26
+ attr_embed_fc = getattr(self, f'attr_{idx}')
27
+ attr_embedding_list.append(
28
+ attr_embed_fc(
29
+ F.one_hot(
30
+ attr[:, idx],
31
+ num_classes=self.cls_num_list[idx]).to(torch.float32)))
32
+ attr_embedding = torch.cat(attr_embedding_list, dim=1)
33
+ attr_embedding = self.fusion(attr_embedding)
34
+
35
+ return attr_embedding
Text2Human/models/archs/transformer_arch.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class CausalSelfAttention(nn.Module):
10
+ """
11
+ A vanilla multi-head masked self-attention layer with a projection at the end.
12
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
13
+ explicit implementation here to show that there is nothing too scary here.
14
+ """
15
+
16
+ def __init__(self, bert_n_emb, bert_n_head, attn_pdrop, resid_pdrop,
17
+ latent_shape, sampler):
18
+ super().__init__()
19
+ assert bert_n_emb % bert_n_head == 0
20
+ # key, query, value projections for all heads
21
+ self.key = nn.Linear(bert_n_emb, bert_n_emb)
22
+ self.query = nn.Linear(bert_n_emb, bert_n_emb)
23
+ self.value = nn.Linear(bert_n_emb, bert_n_emb)
24
+ # regularization
25
+ self.attn_drop = nn.Dropout(attn_pdrop)
26
+ self.resid_drop = nn.Dropout(resid_pdrop)
27
+ # output projection
28
+ self.proj = nn.Linear(bert_n_emb, bert_n_emb)
29
+ self.n_head = bert_n_head
30
+ self.causal = True if sampler == 'autoregressive' else False
31
+ if self.causal:
32
+ block_size = np.prod(latent_shape)
33
+ mask = torch.tril(torch.ones(block_size, block_size))
34
+ self.register_buffer("mask", mask.view(1, 1, block_size,
35
+ block_size))
36
+
37
+ def forward(self, x, layer_past=None):
38
+ B, T, C = x.size()
39
+
40
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
41
+ k = self.key(x).view(B, T, self.n_head,
42
+ C // self.n_head).transpose(1,
43
+ 2) # (B, nh, T, hs)
44
+ q = self.query(x).view(B, T, self.n_head,
45
+ C // self.n_head).transpose(1,
46
+ 2) # (B, nh, T, hs)
47
+ v = self.value(x).view(B, T, self.n_head,
48
+ C // self.n_head).transpose(1,
49
+ 2) # (B, nh, T, hs)
50
+
51
+ present = torch.stack((k, v))
52
+ if self.causal and layer_past is not None:
53
+ past_key, past_value = layer_past
54
+ k = torch.cat((past_key, k), dim=-2)
55
+ v = torch.cat((past_value, v), dim=-2)
56
+
57
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
58
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
59
+
60
+ if self.causal and layer_past is None:
61
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
62
+
63
+ att = F.softmax(att, dim=-1)
64
+ att = self.attn_drop(att)
65
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
66
+ # re-assemble all head outputs side by side
67
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
68
+
69
+ # output projection
70
+ y = self.resid_drop(self.proj(y))
71
+ return y, present
72
+
73
+
74
+ class Block(nn.Module):
75
+ """ an unassuming Transformer block """
76
+
77
+ def __init__(self, bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
78
+ latent_shape, sampler):
79
+ super().__init__()
80
+ self.ln1 = nn.LayerNorm(bert_n_emb)
81
+ self.ln2 = nn.LayerNorm(bert_n_emb)
82
+ self.attn = CausalSelfAttention(bert_n_emb, bert_n_head, attn_pdrop,
83
+ resid_pdrop, latent_shape, sampler)
84
+ self.mlp = nn.Sequential(
85
+ nn.Linear(bert_n_emb, 4 * bert_n_emb),
86
+ nn.GELU(), # nice
87
+ nn.Linear(4 * bert_n_emb, bert_n_emb),
88
+ nn.Dropout(resid_pdrop),
89
+ )
90
+
91
+ def forward(self, x, layer_past=None, return_present=False):
92
+
93
+ attn, present = self.attn(self.ln1(x), layer_past)
94
+ x = x + attn
95
+ x = x + self.mlp(self.ln2(x))
96
+
97
+ if layer_past is not None or return_present:
98
+ return x, present
99
+ return x
100
+
101
+
102
+ class Transformer(nn.Module):
103
+ """ the full GPT language model, with a context size of block_size """
104
+
105
+ def __init__(self,
106
+ codebook_size,
107
+ segm_codebook_size,
108
+ bert_n_emb,
109
+ bert_n_layers,
110
+ bert_n_head,
111
+ block_size,
112
+ latent_shape,
113
+ embd_pdrop,
114
+ resid_pdrop,
115
+ attn_pdrop,
116
+ sampler='absorbing'):
117
+ super().__init__()
118
+
119
+ self.vocab_size = codebook_size + 1
120
+ self.n_embd = bert_n_emb
121
+ self.block_size = block_size
122
+ self.n_layers = bert_n_layers
123
+ self.codebook_size = codebook_size
124
+ self.segm_codebook_size = segm_codebook_size
125
+ self.causal = sampler == 'autoregressive'
126
+ if self.causal:
127
+ self.vocab_size = codebook_size
128
+
129
+ self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
130
+ self.pos_emb = nn.Parameter(
131
+ torch.zeros(1, self.block_size, self.n_embd))
132
+ self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
133
+ self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
134
+ self.drop = nn.Dropout(embd_pdrop)
135
+
136
+ # transformer
137
+ self.blocks = nn.Sequential(*[
138
+ Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
139
+ latent_shape, sampler) for _ in range(self.n_layers)
140
+ ])
141
+ # decoder head
142
+ self.ln_f = nn.LayerNorm(self.n_embd)
143
+ self.head = nn.Linear(self.n_embd, self.codebook_size, bias=False)
144
+
145
+ def get_block_size(self):
146
+ return self.block_size
147
+
148
+ def _init_weights(self, module):
149
+ if isinstance(module, (nn.Linear, nn.Embedding)):
150
+ module.weight.data.normal_(mean=0.0, std=0.02)
151
+ if isinstance(module, nn.Linear) and module.bias is not None:
152
+ module.bias.data.zero_()
153
+ elif isinstance(module, nn.LayerNorm):
154
+ module.bias.data.zero_()
155
+ module.weight.data.fill_(1.0)
156
+
157
+ def forward(self, idx, segm_tokens, t=None):
158
+ # each index maps to a (learnable) vector
159
+ token_embeddings = self.tok_emb(idx)
160
+
161
+ segm_embeddings = self.segm_emb(segm_tokens)
162
+
163
+ if self.causal:
164
+ token_embeddings = torch.cat((self.start_tok.repeat(
165
+ token_embeddings.size(0), 1, 1), token_embeddings),
166
+ dim=1)
167
+
168
+ t = token_embeddings.shape[1]
169
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
170
+ # each position maps to a (learnable) vector
171
+
172
+ position_embeddings = self.pos_emb[:, :t, :]
173
+
174
+ x = token_embeddings + position_embeddings + segm_embeddings
175
+ x = self.drop(x)
176
+ for block in self.blocks:
177
+ x = block(x)
178
+ x = self.ln_f(x)
179
+ logits = self.head(x)
180
+
181
+ return logits
182
+
183
+
184
+ class TransformerMultiHead(nn.Module):
185
+ """ the full GPT language model, with a context size of block_size """
186
+
187
+ def __init__(self,
188
+ codebook_size,
189
+ segm_codebook_size,
190
+ texture_codebook_size,
191
+ bert_n_emb,
192
+ bert_n_layers,
193
+ bert_n_head,
194
+ block_size,
195
+ latent_shape,
196
+ embd_pdrop,
197
+ resid_pdrop,
198
+ attn_pdrop,
199
+ num_head,
200
+ sampler='absorbing'):
201
+ super().__init__()
202
+
203
+ self.vocab_size = codebook_size + 1
204
+ self.n_embd = bert_n_emb
205
+ self.block_size = block_size
206
+ self.n_layers = bert_n_layers
207
+ self.codebook_size = codebook_size
208
+ self.segm_codebook_size = segm_codebook_size
209
+ self.texture_codebook_size = texture_codebook_size
210
+ self.causal = sampler == 'autoregressive'
211
+ if self.causal:
212
+ self.vocab_size = codebook_size
213
+
214
+ self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
215
+ self.pos_emb = nn.Parameter(
216
+ torch.zeros(1, self.block_size, self.n_embd))
217
+ self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
218
+ self.texture_emb = nn.Embedding(self.texture_codebook_size,
219
+ self.n_embd)
220
+ self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
221
+ self.drop = nn.Dropout(embd_pdrop)
222
+
223
+ # transformer
224
+ self.blocks = nn.Sequential(*[
225
+ Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
226
+ latent_shape, sampler) for _ in range(self.n_layers)
227
+ ])
228
+ # decoder head
229
+ self.num_head = num_head
230
+ self.head_class_num = codebook_size // self.num_head
231
+ self.ln_f = nn.LayerNorm(self.n_embd)
232
+ self.head_list = nn.ModuleList([
233
+ nn.Linear(self.n_embd, self.head_class_num, bias=False)
234
+ for _ in range(self.num_head)
235
+ ])
236
+
237
+ def get_block_size(self):
238
+ return self.block_size
239
+
240
+ def _init_weights(self, module):
241
+ if isinstance(module, (nn.Linear, nn.Embedding)):
242
+ module.weight.data.normal_(mean=0.0, std=0.02)
243
+ if isinstance(module, nn.Linear) and module.bias is not None:
244
+ module.bias.data.zero_()
245
+ elif isinstance(module, nn.LayerNorm):
246
+ module.bias.data.zero_()
247
+ module.weight.data.fill_(1.0)
248
+
249
+ def forward(self, idx, segm_tokens, texture_tokens, t=None):
250
+ # each index maps to a (learnable) vector
251
+ token_embeddings = self.tok_emb(idx)
252
+ segm_embeddings = self.segm_emb(segm_tokens)
253
+ texture_embeddings = self.texture_emb(texture_tokens)
254
+
255
+ if self.causal:
256
+ token_embeddings = torch.cat((self.start_tok.repeat(
257
+ token_embeddings.size(0), 1, 1), token_embeddings),
258
+ dim=1)
259
+
260
+ t = token_embeddings.shape[1]
261
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
262
+ # each position maps to a (learnable) vector
263
+
264
+ position_embeddings = self.pos_emb[:, :t, :]
265
+
266
+ x = token_embeddings + position_embeddings + segm_embeddings + texture_embeddings
267
+ x = self.drop(x)
268
+ for block in self.blocks:
269
+ x = block(x)
270
+ x = self.ln_f(x)
271
+ logits_list = [self.head_list[i](x) for i in range(self.num_head)]
272
+
273
+ return logits_list
Text2Human/models/archs/unet_arch.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint as cp
4
+ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
5
+ build_norm_layer, build_upsample_layer, constant_init,
6
+ kaiming_init)
7
+ from mmcv.runner import load_checkpoint
8
+ from mmcv.utils.parrots_wrapper import _BatchNorm
9
+ from mmseg.utils import get_root_logger
10
+
11
+
12
+ class UpConvBlock(nn.Module):
13
+ """Upsample convolution block in decoder for UNet.
14
+
15
+ This upsample convolution block consists of one upsample module
16
+ followed by one convolution block. The upsample module expands the
17
+ high-level low-resolution feature map and the convolution block fuses
18
+ the upsampled high-level low-resolution feature map and the low-level
19
+ high-resolution feature map from encoder.
20
+
21
+ Args:
22
+ conv_block (nn.Sequential): Sequential of convolutional layers.
23
+ in_channels (int): Number of input channels of the high-level
24
+ skip_channels (int): Number of input channels of the low-level
25
+ high-resolution feature map from encoder.
26
+ out_channels (int): Number of output channels.
27
+ num_convs (int): Number of convolutional layers in the conv_block.
28
+ Default: 2.
29
+ stride (int): Stride of convolutional layer in conv_block. Default: 1.
30
+ dilation (int): Dilation rate of convolutional layer in conv_block.
31
+ Default: 1.
32
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
33
+ memory while slowing down the training speed. Default: False.
34
+ conv_cfg (dict | None): Config dict for convolution layer.
35
+ Default: None.
36
+ norm_cfg (dict | None): Config dict for normalization layer.
37
+ Default: dict(type='BN').
38
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
39
+ Default: dict(type='ReLU').
40
+ upsample_cfg (dict): The upsample config of the upsample module in
41
+ decoder. Default: dict(type='InterpConv'). If the size of
42
+ high-level feature map is the same as that of skip feature map
43
+ (low-level feature map from encoder), it does not need upsample the
44
+ high-level feature map and the upsample_cfg is None.
45
+ dcn (bool): Use deformable convoluton in convolutional layer or not.
46
+ Default: None.
47
+ plugins (dict): plugins for convolutional layers. Default: None.
48
+ """
49
+
50
+ def __init__(self,
51
+ conv_block,
52
+ in_channels,
53
+ skip_channels,
54
+ out_channels,
55
+ num_convs=2,
56
+ stride=1,
57
+ dilation=1,
58
+ with_cp=False,
59
+ conv_cfg=None,
60
+ norm_cfg=dict(type='BN'),
61
+ act_cfg=dict(type='ReLU'),
62
+ upsample_cfg=dict(type='InterpConv'),
63
+ dcn=None,
64
+ plugins=None):
65
+ super(UpConvBlock, self).__init__()
66
+ assert dcn is None, 'Not implemented yet.'
67
+ assert plugins is None, 'Not implemented yet.'
68
+
69
+ self.conv_block = conv_block(
70
+ in_channels=2 * skip_channels,
71
+ out_channels=out_channels,
72
+ num_convs=num_convs,
73
+ stride=stride,
74
+ dilation=dilation,
75
+ with_cp=with_cp,
76
+ conv_cfg=conv_cfg,
77
+ norm_cfg=norm_cfg,
78
+ act_cfg=act_cfg,
79
+ dcn=None,
80
+ plugins=None)
81
+ if upsample_cfg is not None:
82
+ self.upsample = build_upsample_layer(
83
+ cfg=upsample_cfg,
84
+ in_channels=in_channels,
85
+ out_channels=skip_channels,
86
+ with_cp=with_cp,
87
+ norm_cfg=norm_cfg,
88
+ act_cfg=act_cfg)
89
+ else:
90
+ self.upsample = ConvModule(
91
+ in_channels,
92
+ skip_channels,
93
+ kernel_size=1,
94
+ stride=1,
95
+ padding=0,
96
+ conv_cfg=conv_cfg,
97
+ norm_cfg=norm_cfg,
98
+ act_cfg=act_cfg)
99
+
100
+ def forward(self, skip, x):
101
+ """Forward function."""
102
+
103
+ x = self.upsample(x)
104
+ out = torch.cat([skip, x], dim=1)
105
+ out = self.conv_block(out)
106
+
107
+ return out
108
+
109
+
110
+ class BasicConvBlock(nn.Module):
111
+ """Basic convolutional block for UNet.
112
+
113
+ This module consists of several plain convolutional layers.
114
+
115
+ Args:
116
+ in_channels (int): Number of input channels.
117
+ out_channels (int): Number of output channels.
118
+ num_convs (int): Number of convolutional layers. Default: 2.
119
+ stride (int): Whether use stride convolution to downsample
120
+ the input feature map. If stride=2, it only uses stride convolution
121
+ in the first convolutional layer to downsample the input feature
122
+ map. Options are 1 or 2. Default: 1.
123
+ dilation (int): Whether use dilated convolution to expand the
124
+ receptive field. Set dilation rate of each convolutional layer and
125
+ the dilation rate of the first convolutional layer is always 1.
126
+ Default: 1.
127
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
128
+ memory while slowing down the training speed. Default: False.
129
+ conv_cfg (dict | None): Config dict for convolution layer.
130
+ Default: None.
131
+ norm_cfg (dict | None): Config dict for normalization layer.
132
+ Default: dict(type='BN').
133
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
134
+ Default: dict(type='ReLU').
135
+ dcn (bool): Use deformable convoluton in convolutional layer or not.
136
+ Default: None.
137
+ plugins (dict): plugins for convolutional layers. Default: None.
138
+ """
139
+
140
+ def __init__(self,
141
+ in_channels,
142
+ out_channels,
143
+ num_convs=2,
144
+ stride=1,
145
+ dilation=1,
146
+ with_cp=False,
147
+ conv_cfg=None,
148
+ norm_cfg=dict(type='BN'),
149
+ act_cfg=dict(type='ReLU'),
150
+ dcn=None,
151
+ plugins=None):
152
+ super(BasicConvBlock, self).__init__()
153
+ assert dcn is None, 'Not implemented yet.'
154
+ assert plugins is None, 'Not implemented yet.'
155
+
156
+ self.with_cp = with_cp
157
+ convs = []
158
+ for i in range(num_convs):
159
+ convs.append(
160
+ ConvModule(
161
+ in_channels=in_channels if i == 0 else out_channels,
162
+ out_channels=out_channels,
163
+ kernel_size=3,
164
+ stride=stride if i == 0 else 1,
165
+ dilation=1 if i == 0 else dilation,
166
+ padding=1 if i == 0 else dilation,
167
+ conv_cfg=conv_cfg,
168
+ norm_cfg=norm_cfg,
169
+ act_cfg=act_cfg))
170
+
171
+ self.convs = nn.Sequential(*convs)
172
+
173
+ def forward(self, x):
174
+ """Forward function."""
175
+
176
+ if self.with_cp and x.requires_grad:
177
+ out = cp.checkpoint(self.convs, x)
178
+ else:
179
+ out = self.convs(x)
180
+ return out
181
+
182
+
183
+ class DeconvModule(nn.Module):
184
+ """Deconvolution upsample module in decoder for UNet (2X upsample).
185
+
186
+ This module uses deconvolution to upsample feature map in the decoder
187
+ of UNet.
188
+
189
+ Args:
190
+ in_channels (int): Number of input channels.
191
+ out_channels (int): Number of output channels.
192
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
193
+ memory while slowing down the training speed. Default: False.
194
+ norm_cfg (dict | None): Config dict for normalization layer.
195
+ Default: dict(type='BN').
196
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
197
+ Default: dict(type='ReLU').
198
+ kernel_size (int): Kernel size of the convolutional layer. Default: 4.
199
+ """
200
+
201
+ def __init__(self,
202
+ in_channels,
203
+ out_channels,
204
+ with_cp=False,
205
+ norm_cfg=dict(type='BN'),
206
+ act_cfg=dict(type='ReLU'),
207
+ *,
208
+ kernel_size=4,
209
+ scale_factor=2):
210
+ super(DeconvModule, self).__init__()
211
+
212
+ assert (kernel_size - scale_factor >= 0) and\
213
+ (kernel_size - scale_factor) % 2 == 0,\
214
+ f'kernel_size should be greater than or equal to scale_factor '\
215
+ f'and (kernel_size - scale_factor) should be even numbers, '\
216
+ f'while the kernel size is {kernel_size} and scale_factor is '\
217
+ f'{scale_factor}.'
218
+
219
+ stride = scale_factor
220
+ padding = (kernel_size - scale_factor) // 2
221
+ self.with_cp = with_cp
222
+ deconv = nn.ConvTranspose2d(
223
+ in_channels,
224
+ out_channels,
225
+ kernel_size=kernel_size,
226
+ stride=stride,
227
+ padding=padding)
228
+
229
+ norm_name, norm = build_norm_layer(norm_cfg, out_channels)
230
+ activate = build_activation_layer(act_cfg)
231
+ self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
232
+
233
+ def forward(self, x):
234
+ """Forward function."""
235
+
236
+ if self.with_cp and x.requires_grad:
237
+ out = cp.checkpoint(self.deconv_upsamping, x)
238
+ else:
239
+ out = self.deconv_upsamping(x)
240
+ return out
241
+
242
+
243
+ @UPSAMPLE_LAYERS.register_module()
244
+ class InterpConv(nn.Module):
245
+ """Interpolation upsample module in decoder for UNet.
246
+
247
+ This module uses interpolation to upsample feature map in the decoder
248
+ of UNet. It consists of one interpolation upsample layer and one
249
+ convolutional layer. It can be one interpolation upsample layer followed
250
+ by one convolutional layer (conv_first=False) or one convolutional layer
251
+ followed by one interpolation upsample layer (conv_first=True).
252
+
253
+ Args:
254
+ in_channels (int): Number of input channels.
255
+ out_channels (int): Number of output channels.
256
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
257
+ memory while slowing down the training speed. Default: False.
258
+ norm_cfg (dict | None): Config dict for normalization layer.
259
+ Default: dict(type='BN').
260
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
261
+ Default: dict(type='ReLU').
262
+ conv_cfg (dict | None): Config dict for convolution layer.
263
+ Default: None.
264
+ conv_first (bool): Whether convolutional layer or interpolation
265
+ upsample layer first. Default: False. It means interpolation
266
+ upsample layer followed by one convolutional layer.
267
+ kernel_size (int): Kernel size of the convolutional layer. Default: 1.
268
+ stride (int): Stride of the convolutional layer. Default: 1.
269
+ padding (int): Padding of the convolutional layer. Default: 1.
270
+ upsampe_cfg (dict): Interpolation config of the upsample layer.
271
+ Default: dict(
272
+ scale_factor=2, mode='bilinear', align_corners=False).
273
+ """
274
+
275
+ def __init__(self,
276
+ in_channels,
277
+ out_channels,
278
+ with_cp=False,
279
+ norm_cfg=dict(type='BN'),
280
+ act_cfg=dict(type='ReLU'),
281
+ *,
282
+ conv_cfg=None,
283
+ conv_first=False,
284
+ kernel_size=1,
285
+ stride=1,
286
+ padding=0,
287
+ upsampe_cfg=dict(
288
+ scale_factor=2, mode='bilinear', align_corners=False)):
289
+ super(InterpConv, self).__init__()
290
+
291
+ self.with_cp = with_cp
292
+ conv = ConvModule(
293
+ in_channels,
294
+ out_channels,
295
+ kernel_size=kernel_size,
296
+ stride=stride,
297
+ padding=padding,
298
+ conv_cfg=conv_cfg,
299
+ norm_cfg=norm_cfg,
300
+ act_cfg=act_cfg)
301
+ upsample = nn.Upsample(**upsampe_cfg)
302
+ if conv_first:
303
+ self.interp_upsample = nn.Sequential(conv, upsample)
304
+ else:
305
+ self.interp_upsample = nn.Sequential(upsample, conv)
306
+
307
+ def forward(self, x):
308
+ """Forward function."""
309
+
310
+ if self.with_cp and x.requires_grad:
311
+ out = cp.checkpoint(self.interp_upsample, x)
312
+ else:
313
+ out = self.interp_upsample(x)
314
+ return out
315
+
316
+
317
+ class UNet(nn.Module):
318
+ """UNet backbone.
319
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
320
+ https://arxiv.org/pdf/1505.04597.pdf
321
+
322
+ Args:
323
+ in_channels (int): Number of input image channels. Default" 3.
324
+ base_channels (int): Number of base channels of each stage.
325
+ The output channels of the first stage. Default: 64.
326
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
327
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
328
+ len(strides) is equal to num_stages. Normally the stride of the
329
+ first stage in encoder is 1. If strides[i]=2, it uses stride
330
+ convolution to downsample in the correspondence encoder stage.
331
+ Default: (1, 1, 1, 1, 1).
332
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
333
+ convolution block of the correspondence encoder stage.
334
+ Default: (2, 2, 2, 2, 2).
335
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
336
+ convolution block of the correspondence decoder stage.
337
+ Default: (2, 2, 2, 2).
338
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
339
+ feature map after the first stage of encoder
340
+ (stages: [1, num_stages)). If the correspondence encoder stage use
341
+ stride convolution (strides[i]=2), it will never use MaxPool to
342
+ downsample, even downsamples[i-1]=True.
343
+ Default: (True, True, True, True).
344
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
345
+ Default: (1, 1, 1, 1, 1).
346
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
347
+ Default: (1, 1, 1, 1).
348
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
349
+ memory while slowing down the training speed. Default: False.
350
+ conv_cfg (dict | None): Config dict for convolution layer.
351
+ Default: None.
352
+ norm_cfg (dict | None): Config dict for normalization layer.
353
+ Default: dict(type='BN').
354
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
355
+ Default: dict(type='ReLU').
356
+ upsample_cfg (dict): The upsample config of the upsample module in
357
+ decoder. Default: dict(type='InterpConv').
358
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
359
+ freeze running stats (mean and var). Note: Effect on Batch Norm
360
+ and its variants only. Default: False.
361
+ dcn (bool): Use deformable convolution in convolutional layer or not.
362
+ Default: None.
363
+ plugins (dict): plugins for convolutional layers. Default: None.
364
+
365
+ Notice:
366
+ The input image size should be devisible by the whole downsample rate
367
+ of the encoder. More detail of the whole downsample rate can be found
368
+ in UNet._check_input_devisible.
369
+
370
+ """
371
+
372
+ def __init__(self,
373
+ in_channels=3,
374
+ base_channels=64,
375
+ num_stages=5,
376
+ strides=(1, 1, 1, 1, 1),
377
+ enc_num_convs=(2, 2, 2, 2, 2),
378
+ dec_num_convs=(2, 2, 2, 2),
379
+ downsamples=(True, True, True, True),
380
+ enc_dilations=(1, 1, 1, 1, 1),
381
+ dec_dilations=(1, 1, 1, 1),
382
+ with_cp=False,
383
+ conv_cfg=None,
384
+ norm_cfg=dict(type='BN'),
385
+ act_cfg=dict(type='ReLU'),
386
+ upsample_cfg=dict(type='InterpConv'),
387
+ norm_eval=False,
388
+ dcn=None,
389
+ plugins=None):
390
+ super(UNet, self).__init__()
391
+ assert dcn is None, 'Not implemented yet.'
392
+ assert plugins is None, 'Not implemented yet.'
393
+ assert len(strides) == num_stages, \
394
+ 'The length of strides should be equal to num_stages, '\
395
+ f'while the strides is {strides}, the length of '\
396
+ f'strides is {len(strides)}, and the num_stages is '\
397
+ f'{num_stages}.'
398
+ assert len(enc_num_convs) == num_stages, \
399
+ 'The length of enc_num_convs should be equal to num_stages, '\
400
+ f'while the enc_num_convs is {enc_num_convs}, the length of '\
401
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
402
+ f'{num_stages}.'
403
+ assert len(dec_num_convs) == (num_stages-1), \
404
+ 'The length of dec_num_convs should be equal to (num_stages-1), '\
405
+ f'while the dec_num_convs is {dec_num_convs}, the length of '\
406
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
407
+ f'{num_stages}.'
408
+ assert len(downsamples) == (num_stages-1), \
409
+ 'The length of downsamples should be equal to (num_stages-1), '\
410
+ f'while the downsamples is {downsamples}, the length of '\
411
+ f'downsamples is {len(downsamples)}, and the num_stages is '\
412
+ f'{num_stages}.'
413
+ assert len(enc_dilations) == num_stages, \
414
+ 'The length of enc_dilations should be equal to num_stages, '\
415
+ f'while the enc_dilations is {enc_dilations}, the length of '\
416
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
417
+ f'{num_stages}.'
418
+ assert len(dec_dilations) == (num_stages-1), \
419
+ 'The length of dec_dilations should be equal to (num_stages-1), '\
420
+ f'while the dec_dilations is {dec_dilations}, the length of '\
421
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
422
+ f'{num_stages}.'
423
+ self.num_stages = num_stages
424
+ self.strides = strides
425
+ self.downsamples = downsamples
426
+ self.norm_eval = norm_eval
427
+
428
+ self.encoder = nn.ModuleList()
429
+ self.decoder = nn.ModuleList()
430
+
431
+ for i in range(num_stages):
432
+ enc_conv_block = []
433
+ if i != 0:
434
+ if strides[i] == 1 and downsamples[i - 1]:
435
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
436
+ upsample = (strides[i] != 1 or downsamples[i - 1])
437
+ self.decoder.append(
438
+ UpConvBlock(
439
+ conv_block=BasicConvBlock,
440
+ in_channels=base_channels * 2**i,
441
+ skip_channels=base_channels * 2**(i - 1),
442
+ out_channels=base_channels * 2**(i - 1),
443
+ num_convs=dec_num_convs[i - 1],
444
+ stride=1,
445
+ dilation=dec_dilations[i - 1],
446
+ with_cp=with_cp,
447
+ conv_cfg=conv_cfg,
448
+ norm_cfg=norm_cfg,
449
+ act_cfg=act_cfg,
450
+ upsample_cfg=upsample_cfg if upsample else None,
451
+ dcn=None,
452
+ plugins=None))
453
+
454
+ enc_conv_block.append(
455
+ BasicConvBlock(
456
+ in_channels=in_channels,
457
+ out_channels=base_channels * 2**i,
458
+ num_convs=enc_num_convs[i],
459
+ stride=strides[i],
460
+ dilation=enc_dilations[i],
461
+ with_cp=with_cp,
462
+ conv_cfg=conv_cfg,
463
+ norm_cfg=norm_cfg,
464
+ act_cfg=act_cfg,
465
+ dcn=None,
466
+ plugins=None))
467
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
468
+ in_channels = base_channels * 2**i
469
+
470
+ def forward(self, x):
471
+ enc_outs = []
472
+
473
+ for enc in self.encoder:
474
+ x = enc(x)
475
+ enc_outs.append(x)
476
+ dec_outs = [x]
477
+ for i in reversed(range(len(self.decoder))):
478
+ x = self.decoder[i](enc_outs[i], x)
479
+ dec_outs.append(x)
480
+
481
+ return dec_outs
482
+
483
+ def init_weights(self, pretrained=None):
484
+ """Initialize the weights in backbone.
485
+
486
+ Args:
487
+ pretrained (str, optional): Path to pre-trained weights.
488
+ Defaults to None.
489
+ """
490
+ if isinstance(pretrained, str):
491
+ logger = get_root_logger()
492
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
493
+ elif pretrained is None:
494
+ for m in self.modules():
495
+ if isinstance(m, nn.Conv2d):
496
+ kaiming_init(m)
497
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
498
+ constant_init(m, 1)
499
+ else:
500
+ raise TypeError('pretrained must be a str or None')
501
+
502
+
503
+ class ShapeUNet(nn.Module):
504
+ """ShapeUNet backbone with small modifications.
505
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
506
+ https://arxiv.org/pdf/1505.04597.pdf
507
+
508
+ Args:
509
+ in_channels (int): Number of input image channels. Default" 3.
510
+ base_channels (int): Number of base channels of each stage.
511
+ The output channels of the first stage. Default: 64.
512
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
513
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
514
+ len(strides) is equal to num_stages. Normally the stride of the
515
+ first stage in encoder is 1. If strides[i]=2, it uses stride
516
+ convolution to downsample in the correspondance encoder stage.
517
+ Default: (1, 1, 1, 1, 1).
518
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
519
+ convolution block of the correspondance encoder stage.
520
+ Default: (2, 2, 2, 2, 2).
521
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
522
+ convolution block of the correspondance decoder stage.
523
+ Default: (2, 2, 2, 2).
524
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
525
+ feature map after the first stage of encoder
526
+ (stages: [1, num_stages)). If the correspondance encoder stage use
527
+ stride convolution (strides[i]=2), it will never use MaxPool to
528
+ downsample, even downsamples[i-1]=True.
529
+ Default: (True, True, True, True).
530
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
531
+ Default: (1, 1, 1, 1, 1).
532
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
533
+ Default: (1, 1, 1, 1).
534
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
535
+ memory while slowing down the training speed. Default: False.
536
+ conv_cfg (dict | None): Config dict for convolution layer.
537
+ Default: None.
538
+ norm_cfg (dict | None): Config dict for normalization layer.
539
+ Default: dict(type='BN').
540
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
541
+ Default: dict(type='ReLU').
542
+ upsample_cfg (dict): The upsample config of the upsample module in
543
+ decoder. Default: dict(type='InterpConv').
544
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
545
+ freeze running stats (mean and var). Note: Effect on Batch Norm
546
+ and its variants only. Default: False.
547
+ dcn (bool): Use deformable convoluton in convolutional layer or not.
548
+ Default: None.
549
+ plugins (dict): plugins for convolutional layers. Default: None.
550
+
551
+ Notice:
552
+ The input image size should be devisible by the whole downsample rate
553
+ of the encoder. More detail of the whole downsample rate can be found
554
+ in UNet._check_input_devisible.
555
+
556
+ """
557
+
558
+ def __init__(self,
559
+ in_channels=3,
560
+ base_channels=64,
561
+ num_stages=5,
562
+ attr_embedding=128,
563
+ strides=(1, 1, 1, 1, 1),
564
+ enc_num_convs=(2, 2, 2, 2, 2),
565
+ dec_num_convs=(2, 2, 2, 2),
566
+ downsamples=(True, True, True, True),
567
+ enc_dilations=(1, 1, 1, 1, 1),
568
+ dec_dilations=(1, 1, 1, 1),
569
+ with_cp=False,
570
+ conv_cfg=None,
571
+ norm_cfg=dict(type='BN'),
572
+ act_cfg=dict(type='ReLU'),
573
+ upsample_cfg=dict(type='InterpConv'),
574
+ norm_eval=False,
575
+ dcn=None,
576
+ plugins=None):
577
+ super(ShapeUNet, self).__init__()
578
+ assert dcn is None, 'Not implemented yet.'
579
+ assert plugins is None, 'Not implemented yet.'
580
+ assert len(strides) == num_stages, \
581
+ 'The length of strides should be equal to num_stages, '\
582
+ f'while the strides is {strides}, the length of '\
583
+ f'strides is {len(strides)}, and the num_stages is '\
584
+ f'{num_stages}.'
585
+ assert len(enc_num_convs) == num_stages, \
586
+ 'The length of enc_num_convs should be equal to num_stages, '\
587
+ f'while the enc_num_convs is {enc_num_convs}, the length of '\
588
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
589
+ f'{num_stages}.'
590
+ assert len(dec_num_convs) == (num_stages-1), \
591
+ 'The length of dec_num_convs should be equal to (num_stages-1), '\
592
+ f'while the dec_num_convs is {dec_num_convs}, the length of '\
593
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
594
+ f'{num_stages}.'
595
+ assert len(downsamples) == (num_stages-1), \
596
+ 'The length of downsamples should be equal to (num_stages-1), '\
597
+ f'while the downsamples is {downsamples}, the length of '\
598
+ f'downsamples is {len(downsamples)}, and the num_stages is '\
599
+ f'{num_stages}.'
600
+ assert len(enc_dilations) == num_stages, \
601
+ 'The length of enc_dilations should be equal to num_stages, '\
602
+ f'while the enc_dilations is {enc_dilations}, the length of '\
603
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
604
+ f'{num_stages}.'
605
+ assert len(dec_dilations) == (num_stages-1), \
606
+ 'The length of dec_dilations should be equal to (num_stages-1), '\
607
+ f'while the dec_dilations is {dec_dilations}, the length of '\
608
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
609
+ f'{num_stages}.'
610
+ self.num_stages = num_stages
611
+ self.strides = strides
612
+ self.downsamples = downsamples
613
+ self.norm_eval = norm_eval
614
+
615
+ self.encoder = nn.ModuleList()
616
+ self.decoder = nn.ModuleList()
617
+
618
+ for i in range(num_stages):
619
+ enc_conv_block = []
620
+ if i != 0:
621
+ if strides[i] == 1 and downsamples[i - 1]:
622
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
623
+ upsample = (strides[i] != 1 or downsamples[i - 1])
624
+ self.decoder.append(
625
+ UpConvBlock(
626
+ conv_block=BasicConvBlock,
627
+ in_channels=base_channels * 2**i,
628
+ skip_channels=base_channels * 2**(i - 1),
629
+ out_channels=base_channels * 2**(i - 1),
630
+ num_convs=dec_num_convs[i - 1],
631
+ stride=1,
632
+ dilation=dec_dilations[i - 1],
633
+ with_cp=with_cp,
634
+ conv_cfg=conv_cfg,
635
+ norm_cfg=norm_cfg,
636
+ act_cfg=act_cfg,
637
+ upsample_cfg=upsample_cfg if upsample else None,
638
+ dcn=None,
639
+ plugins=None))
640
+
641
+ enc_conv_block.append(
642
+ BasicConvBlock(
643
+ in_channels=in_channels + attr_embedding,
644
+ out_channels=base_channels * 2**i,
645
+ num_convs=enc_num_convs[i],
646
+ stride=strides[i],
647
+ dilation=enc_dilations[i],
648
+ with_cp=with_cp,
649
+ conv_cfg=conv_cfg,
650
+ norm_cfg=norm_cfg,
651
+ act_cfg=act_cfg,
652
+ dcn=None,
653
+ plugins=None))
654
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
655
+ in_channels = base_channels * 2**i
656
+
657
+ def forward(self, x, attr_embedding):
658
+ enc_outs = []
659
+ Be, Ce = attr_embedding.size()
660
+ for enc in self.encoder:
661
+ _, _, H, W = x.size()
662
+ x = enc(
663
+ torch.cat([
664
+ x,
665
+ attr_embedding.view(Be, Ce, 1, 1).expand((Be, Ce, H, W))
666
+ ],
667
+ dim=1))
668
+ enc_outs.append(x)
669
+ dec_outs = [x]
670
+ for i in reversed(range(len(self.decoder))):
671
+ x = self.decoder[i](enc_outs[i], x)
672
+ dec_outs.append(x)
673
+
674
+ return dec_outs
675
+
676
+ def init_weights(self, pretrained=None):
677
+ """Initialize the weights in backbone.
678
+
679
+ Args:
680
+ pretrained (str, optional): Path to pre-trained weights.
681
+ Defaults to None.
682
+ """
683
+ if isinstance(pretrained, str):
684
+ logger = get_root_logger()
685
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
686
+ elif pretrained is None:
687
+ for m in self.modules():
688
+ if isinstance(m, nn.Conv2d):
689
+ kaiming_init(m)
690
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
691
+ constant_init(m, 1)
692
+ else:
693
+ raise TypeError('pretrained must be a str or None')
Text2Human/models/archs/vqgan_arch.py ADDED
@@ -0,0 +1,1203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ from urllib.request import proxy_bypass
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+
12
+ class VectorQuantizer(nn.Module):
13
+ """
14
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
15
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
16
+ """
17
+
18
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
19
+ # backwards compatibility we use the buggy version by default, but you can
20
+ # specify legacy=False to fix it.
21
+ def __init__(self,
22
+ n_e,
23
+ e_dim,
24
+ beta,
25
+ remap=None,
26
+ unknown_index="random",
27
+ sane_index_shape=False,
28
+ legacy=True):
29
+ super().__init__()
30
+ self.n_e = n_e
31
+ self.e_dim = e_dim
32
+ self.beta = beta
33
+ self.legacy = legacy
34
+
35
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
36
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
37
+
38
+ self.remap = remap
39
+ if self.remap is not None:
40
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
41
+ self.re_embed = self.used.shape[0]
42
+ self.unknown_index = unknown_index # "random" or "extra" or integer
43
+ if self.unknown_index == "extra":
44
+ self.unknown_index = self.re_embed
45
+ self.re_embed = self.re_embed + 1
46
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
47
+ f"Using {self.unknown_index} for unknown indices.")
48
+ else:
49
+ self.re_embed = n_e
50
+
51
+ self.sane_index_shape = sane_index_shape
52
+
53
+ def remap_to_used(self, inds):
54
+ ishape = inds.shape
55
+ assert len(ishape) > 1
56
+ inds = inds.reshape(ishape[0], -1)
57
+ used = self.used.to(inds)
58
+ match = (inds[:, :, None] == used[None, None, ...]).long()
59
+ new = match.argmax(-1)
60
+ unknown = match.sum(2) < 1
61
+ if self.unknown_index == "random":
62
+ new[unknown] = torch.randint(
63
+ 0, self.re_embed,
64
+ size=new[unknown].shape).to(device=new.device)
65
+ else:
66
+ new[unknown] = self.unknown_index
67
+ return new.reshape(ishape)
68
+
69
+ def unmap_to_all(self, inds):
70
+ ishape = inds.shape
71
+ assert len(ishape) > 1
72
+ inds = inds.reshape(ishape[0], -1)
73
+ used = self.used.to(inds)
74
+ if self.re_embed > self.used.shape[0]: # extra token
75
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
76
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
77
+ return back.reshape(ishape)
78
+
79
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
80
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
81
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
82
+ assert return_logits == False, "Only for interface compatible with Gumbel"
83
+ # reshape z -> (batch, height, width, channel) and flatten
84
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
85
+ z_flattened = z.view(-1, self.e_dim)
86
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
87
+
88
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
89
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
90
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
91
+
92
+ min_encoding_indices = torch.argmin(d, dim=1)
93
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
94
+ perplexity = None
95
+ min_encodings = None
96
+
97
+ # compute loss for embedding
98
+ if not self.legacy:
99
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
100
+ torch.mean((z_q - z.detach()) ** 2)
101
+ else:
102
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
103
+ torch.mean((z_q - z.detach()) ** 2)
104
+
105
+ # preserve gradients
106
+ z_q = z + (z_q - z).detach()
107
+
108
+ # reshape back to match original input shape
109
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
110
+
111
+ if self.remap is not None:
112
+ min_encoding_indices = min_encoding_indices.reshape(
113
+ z.shape[0], -1) # add batch axis
114
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
115
+ min_encoding_indices = min_encoding_indices.reshape(-1,
116
+ 1) # flatten
117
+
118
+ if self.sane_index_shape:
119
+ min_encoding_indices = min_encoding_indices.reshape(
120
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
121
+
122
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
123
+
124
+ def get_codebook_entry(self, indices, shape):
125
+ # shape specifying (batch, height, width, channel)
126
+ if self.remap is not None:
127
+ indices = indices.reshape(shape[0], -1) # add batch axis
128
+ indices = self.unmap_to_all(indices)
129
+ indices = indices.reshape(-1) # flatten again
130
+
131
+ # get quantized latent vectors
132
+ z_q = self.embedding(indices)
133
+
134
+ if shape is not None:
135
+ z_q = z_q.view(shape)
136
+ # reshape back to match original input shape
137
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
138
+
139
+ return z_q
140
+
141
+
142
+ class VectorQuantizerTexture(nn.Module):
143
+ """
144
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
145
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
146
+ """
147
+
148
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
149
+ # backwards compatibility we use the buggy version by default, but you can
150
+ # specify legacy=False to fix it.
151
+ def __init__(self,
152
+ n_e,
153
+ e_dim,
154
+ beta,
155
+ remap=None,
156
+ unknown_index="random",
157
+ sane_index_shape=False,
158
+ legacy=True):
159
+ super().__init__()
160
+ self.n_e = n_e
161
+ self.e_dim = e_dim
162
+ self.beta = beta
163
+ self.legacy = legacy
164
+
165
+ # TODO: decide number of embeddings
166
+ self.embedding_list = nn.ModuleList(
167
+ [nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
168
+ for embedding in self.embedding_list:
169
+ embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
170
+
171
+ self.remap = remap
172
+ if self.remap is not None:
173
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
174
+ self.re_embed = self.used.shape[0]
175
+ self.unknown_index = unknown_index # "random" or "extra" or integer
176
+ if self.unknown_index == "extra":
177
+ self.unknown_index = self.re_embed
178
+ self.re_embed = self.re_embed + 1
179
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
180
+ f"Using {self.unknown_index} for unknown indices.")
181
+ else:
182
+ self.re_embed = n_e
183
+
184
+ self.sane_index_shape = sane_index_shape
185
+
186
+ def remap_to_used(self, inds):
187
+ ishape = inds.shape
188
+ assert len(ishape) > 1
189
+ inds = inds.reshape(ishape[0], -1)
190
+ used = self.used.to(inds)
191
+ match = (inds[:, :, None] == used[None, None, ...]).long()
192
+ new = match.argmax(-1)
193
+ unknown = match.sum(2) < 1
194
+ if self.unknown_index == "random":
195
+ new[unknown] = torch.randint(
196
+ 0, self.re_embed,
197
+ size=new[unknown].shape).to(device=new.device)
198
+ else:
199
+ new[unknown] = self.unknown_index
200
+ return new.reshape(ishape)
201
+
202
+ def unmap_to_all(self, inds):
203
+ ishape = inds.shape
204
+ assert len(ishape) > 1
205
+ inds = inds.reshape(ishape[0], -1)
206
+ used = self.used.to(inds)
207
+ if self.re_embed > self.used.shape[0]: # extra token
208
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
209
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
210
+ return back.reshape(ishape)
211
+
212
+ def forward(self,
213
+ z,
214
+ segm_map,
215
+ temp=None,
216
+ rescale_logits=False,
217
+ return_logits=False):
218
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
219
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
220
+ assert return_logits == False, "Only for interface compatible with Gumbel"
221
+
222
+ segm_map = F.interpolate(segm_map, size=z.size()[2:], mode='nearest')
223
+ # reshape z -> (batch, height, width, channel) and flatten
224
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
225
+ z_flattened = z.view(-1, self.e_dim)
226
+
227
+ # flatten segm_map (b, h, w)
228
+ segm_map_flatten = segm_map.view(-1)
229
+
230
+ z_q = torch.zeros_like(z_flattened)
231
+ min_encoding_indices_list = []
232
+ min_encoding_indices_continual = torch.full(
233
+ segm_map_flatten.size(),
234
+ fill_value=-1,
235
+ dtype=torch.long,
236
+ device=segm_map_flatten.device)
237
+ for codebook_idx in range(18):
238
+ min_encoding_indices = torch.full(
239
+ segm_map_flatten.size(),
240
+ fill_value=-1,
241
+ dtype=torch.long,
242
+ device=segm_map_flatten.device)
243
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
244
+ z_selected = z_flattened[segm_map_flatten == codebook_idx]
245
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
246
+ d_selected = torch.sum(
247
+ z_selected**2, dim=1, keepdim=True) + torch.sum(
248
+ self.embedding_list[codebook_idx].weight**2,
249
+ dim=1) - 2 * torch.einsum(
250
+ 'bd,dn->bn', z_selected,
251
+ rearrange(self.embedding_list[codebook_idx].weight,
252
+ 'n d -> d n'))
253
+ min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
254
+ z_q_selected = self.embedding_list[codebook_idx](
255
+ min_encoding_indices_selected)
256
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
257
+ min_encoding_indices[
258
+ segm_map_flatten ==
259
+ codebook_idx] = min_encoding_indices_selected
260
+ min_encoding_indices_continual[
261
+ segm_map_flatten ==
262
+ codebook_idx] = min_encoding_indices_selected + 1024 * codebook_idx
263
+ min_encoding_indices = min_encoding_indices.reshape(
264
+ z.shape[0], z.shape[1], z.shape[2])
265
+ min_encoding_indices_list.append(min_encoding_indices)
266
+
267
+ min_encoding_indices_continual = min_encoding_indices_continual.reshape(
268
+ z.shape[0], z.shape[1], z.shape[2])
269
+ z_q = z_q.view(z.shape)
270
+ perplexity = None
271
+
272
+ # compute loss for embedding
273
+ if not self.legacy:
274
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
275
+ torch.mean((z_q - z.detach()) ** 2)
276
+ else:
277
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
278
+ torch.mean((z_q - z.detach()) ** 2)
279
+
280
+ # preserve gradients
281
+ z_q = z + (z_q - z).detach()
282
+
283
+ # reshape back to match original input shape
284
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
285
+
286
+ return z_q, loss, (perplexity, min_encoding_indices_continual,
287
+ min_encoding_indices_list)
288
+
289
+ def get_codebook_entry(self, indices_list, segm_map, shape):
290
+ # flatten segm_map (b, h, w)
291
+ segm_map = F.interpolate(
292
+ segm_map, size=(shape[1], shape[2]), mode='nearest')
293
+ segm_map_flatten = segm_map.view(-1)
294
+
295
+ z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
296
+ self.e_dim).to(segm_map.device)
297
+ for codebook_idx in range(18):
298
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
299
+ min_encoding_indices_selected = indices_list[
300
+ codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
301
+ z_q_selected = self.embedding_list[codebook_idx](
302
+ min_encoding_indices_selected)
303
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
304
+
305
+ z_q = z_q.view(shape)
306
+ # reshape back to match original input shape
307
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
308
+
309
+ return z_q
310
+
311
+
312
+ def sample_patches(inputs, patch_size=3, stride=1):
313
+ """Extract sliding local patches from an input feature tensor.
314
+ The sampled pathes are row-major.
315
+ Args:
316
+ inputs (Tensor): the input feature maps, shape: (n, c, h, w).
317
+ patch_size (int): the spatial size of sampled patches. Default: 3.
318
+ stride (int): the stride of sampling. Default: 1.
319
+ Returns:
320
+ patches (Tensor): extracted patches, shape: (n, c * patch_size *
321
+ patch_size, n_patches).
322
+ """
323
+
324
+ patches = F.unfold(inputs, (patch_size, patch_size), stride=stride)
325
+
326
+ return patches
327
+
328
+
329
+ class VectorQuantizerSpatialTextureAware(nn.Module):
330
+ """
331
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
332
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
333
+ """
334
+
335
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
336
+ # backwards compatibility we use the buggy version by default, but you can
337
+ # specify legacy=False to fix it.
338
+ def __init__(self,
339
+ n_e,
340
+ e_dim,
341
+ beta,
342
+ spatial_size,
343
+ remap=None,
344
+ unknown_index="random",
345
+ sane_index_shape=False,
346
+ legacy=True):
347
+ super().__init__()
348
+ self.n_e = n_e
349
+ self.e_dim = e_dim * spatial_size * spatial_size
350
+ self.beta = beta
351
+ self.legacy = legacy
352
+ self.spatial_size = spatial_size
353
+
354
+ # TODO: decide number of embeddings
355
+ self.embedding_list = nn.ModuleList(
356
+ [nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
357
+ for embedding in self.embedding_list:
358
+ embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
359
+
360
+ self.remap = remap
361
+ if self.remap is not None:
362
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
363
+ self.re_embed = self.used.shape[0]
364
+ self.unknown_index = unknown_index # "random" or "extra" or integer
365
+ if self.unknown_index == "extra":
366
+ self.unknown_index = self.re_embed
367
+ self.re_embed = self.re_embed + 1
368
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
369
+ f"Using {self.unknown_index} for unknown indices.")
370
+ else:
371
+ self.re_embed = n_e
372
+
373
+ self.sane_index_shape = sane_index_shape
374
+
375
+ def forward(self,
376
+ z,
377
+ segm_map,
378
+ temp=None,
379
+ rescale_logits=False,
380
+ return_logits=False):
381
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
382
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
383
+ assert return_logits == False, "Only for interface compatible with Gumbel"
384
+
385
+ segm_map = F.interpolate(
386
+ segm_map,
387
+ size=(z.size(2) // self.spatial_size,
388
+ z.size(3) // self.spatial_size),
389
+ mode='nearest')
390
+
391
+ # reshape z -> (batch, height, width, channel) and flatten
392
+ # z = rearrange(z, 'b c h w -> b h w c').contiguous() ?
393
+ z_patches = sample_patches(
394
+ z, patch_size=self.spatial_size,
395
+ stride=self.spatial_size).permute(0, 2, 1)
396
+ z_patches_flattened = z_patches.reshape(-1, self.e_dim)
397
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
398
+
399
+ # flatten segm_map (b, h, w)
400
+ segm_map_flatten = segm_map.view(-1)
401
+
402
+ z_q = torch.zeros_like(z_patches_flattened)
403
+ min_encoding_indices_list = []
404
+ min_encoding_indices_continual = torch.full(
405
+ segm_map_flatten.size(),
406
+ fill_value=-1,
407
+ dtype=torch.long,
408
+ device=segm_map_flatten.device)
409
+
410
+ for codebook_idx in range(18):
411
+ min_encoding_indices = torch.full(
412
+ segm_map_flatten.size(),
413
+ fill_value=-1,
414
+ dtype=torch.long,
415
+ device=segm_map_flatten.device)
416
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
417
+ z_selected = z_patches_flattened[segm_map_flatten ==
418
+ codebook_idx]
419
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
420
+ d_selected = torch.sum(
421
+ z_selected**2, dim=1, keepdim=True) + torch.sum(
422
+ self.embedding_list[codebook_idx].weight**2,
423
+ dim=1) - 2 * torch.einsum(
424
+ 'bd,dn->bn', z_selected,
425
+ rearrange(self.embedding_list[codebook_idx].weight,
426
+ 'n d -> d n'))
427
+ min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
428
+ z_q_selected = self.embedding_list[codebook_idx](
429
+ min_encoding_indices_selected)
430
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
431
+ min_encoding_indices[
432
+ segm_map_flatten ==
433
+ codebook_idx] = min_encoding_indices_selected
434
+ min_encoding_indices_continual[
435
+ segm_map_flatten ==
436
+ codebook_idx] = min_encoding_indices_selected + self.n_e * codebook_idx
437
+ min_encoding_indices = min_encoding_indices.reshape(
438
+ z_patches.shape[0], segm_map.shape[2], segm_map.shape[3])
439
+ min_encoding_indices_list.append(min_encoding_indices)
440
+
441
+ z_q = F.fold(
442
+ z_q.view(z_patches.shape).permute(0, 2, 1),
443
+ z.size()[2:],
444
+ kernel_size=(self.spatial_size, self.spatial_size),
445
+ stride=self.spatial_size)
446
+
447
+ perplexity = None
448
+
449
+ # compute loss for embedding
450
+ if not self.legacy:
451
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
452
+ torch.mean((z_q - z.detach()) ** 2)
453
+ else:
454
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
455
+ torch.mean((z_q - z.detach()) ** 2)
456
+
457
+ # preserve gradients
458
+ z_q = z + (z_q - z).detach()
459
+
460
+ return z_q, loss, (perplexity, min_encoding_indices_continual,
461
+ min_encoding_indices_list)
462
+
463
+ def get_codebook_entry(self, indices_list, segm_map, shape):
464
+ # flatten segm_map (b, h, w)
465
+ segm_map = F.interpolate(
466
+ segm_map, size=(shape[1], shape[2]), mode='nearest')
467
+ segm_map_flatten = segm_map.view(-1)
468
+
469
+ z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
470
+ self.e_dim).to(segm_map.device)
471
+ for codebook_idx in range(18):
472
+ if torch.sum(segm_map_flatten == codebook_idx) > 0:
473
+ min_encoding_indices_selected = indices_list[
474
+ codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
475
+ z_q_selected = self.embedding_list[codebook_idx](
476
+ min_encoding_indices_selected)
477
+ z_q[segm_map_flatten == codebook_idx] = z_q_selected
478
+
479
+ z_q = F.fold(
480
+ z_q.view(((shape[0], shape[1] * shape[2],
481
+ self.e_dim))).permute(0, 2, 1),
482
+ (shape[1] * self.spatial_size, shape[2] * self.spatial_size),
483
+ kernel_size=(self.spatial_size, self.spatial_size),
484
+ stride=self.spatial_size)
485
+
486
+ return z_q
487
+
488
+
489
+ def get_timestep_embedding(timesteps, embedding_dim):
490
+ """
491
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
492
+ From Fairseq.
493
+ Build sinusoidal embeddings.
494
+ This matches the implementation in tensor2tensor, but differs slightly
495
+ from the description in Section 3.5 of "Attention Is All You Need".
496
+ """
497
+ assert len(timesteps.shape) == 1
498
+
499
+ half_dim = embedding_dim // 2
500
+ emb = math.log(10000) / (half_dim - 1)
501
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
502
+ emb = emb.to(device=timesteps.device)
503
+ emb = timesteps.float()[:, None] * emb[None, :]
504
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
505
+ if embedding_dim % 2 == 1: # zero pad
506
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
507
+ return emb
508
+
509
+
510
+ def nonlinearity(x):
511
+ # swish
512
+ return x * torch.sigmoid(x)
513
+
514
+
515
+ def Normalize(in_channels):
516
+ return torch.nn.GroupNorm(
517
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
518
+
519
+
520
+ class Upsample(nn.Module):
521
+
522
+ def __init__(self, in_channels, with_conv):
523
+ super().__init__()
524
+ self.with_conv = with_conv
525
+ if self.with_conv:
526
+ self.conv = torch.nn.Conv2d(
527
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1)
528
+
529
+ def forward(self, x):
530
+ x = torch.nn.functional.interpolate(
531
+ x, scale_factor=2.0, mode="nearest")
532
+ if self.with_conv:
533
+ x = self.conv(x)
534
+ return x
535
+
536
+
537
+ class Downsample(nn.Module):
538
+
539
+ def __init__(self, in_channels, with_conv):
540
+ super().__init__()
541
+ self.with_conv = with_conv
542
+ if self.with_conv:
543
+ # no asymmetric padding in torch conv, must do it ourselves
544
+ self.conv = torch.nn.Conv2d(
545
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0)
546
+
547
+ def forward(self, x):
548
+ if self.with_conv:
549
+ pad = (0, 1, 0, 1)
550
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
551
+ x = self.conv(x)
552
+ else:
553
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
554
+ return x
555
+
556
+
557
+ class ResnetBlock(nn.Module):
558
+
559
+ def __init__(self,
560
+ *,
561
+ in_channels,
562
+ out_channels=None,
563
+ conv_shortcut=False,
564
+ dropout,
565
+ temb_channels=512):
566
+ super().__init__()
567
+ self.in_channels = in_channels
568
+ out_channels = in_channels if out_channels is None else out_channels
569
+ self.out_channels = out_channels
570
+ self.use_conv_shortcut = conv_shortcut
571
+
572
+ self.norm1 = Normalize(in_channels)
573
+ self.conv1 = torch.nn.Conv2d(
574
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1)
575
+ if temb_channels > 0:
576
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
577
+ self.norm2 = Normalize(out_channels)
578
+ self.dropout = torch.nn.Dropout(dropout)
579
+ self.conv2 = torch.nn.Conv2d(
580
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1)
581
+ if self.in_channels != self.out_channels:
582
+ if self.use_conv_shortcut:
583
+ self.conv_shortcut = torch.nn.Conv2d(
584
+ in_channels,
585
+ out_channels,
586
+ kernel_size=3,
587
+ stride=1,
588
+ padding=1)
589
+ else:
590
+ self.nin_shortcut = torch.nn.Conv2d(
591
+ in_channels,
592
+ out_channels,
593
+ kernel_size=1,
594
+ stride=1,
595
+ padding=0)
596
+
597
+ def forward(self, x, temb):
598
+ h = x
599
+ h = self.norm1(h)
600
+ h = nonlinearity(h)
601
+ h = self.conv1(h)
602
+
603
+ if temb is not None:
604
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
605
+
606
+ h = self.norm2(h)
607
+ h = nonlinearity(h)
608
+ h = self.dropout(h)
609
+ h = self.conv2(h)
610
+
611
+ if self.in_channels != self.out_channels:
612
+ if self.use_conv_shortcut:
613
+ x = self.conv_shortcut(x)
614
+ else:
615
+ x = self.nin_shortcut(x)
616
+
617
+ return x + h
618
+
619
+
620
+ class AttnBlock(nn.Module):
621
+
622
+ def __init__(self, in_channels):
623
+ super().__init__()
624
+ self.in_channels = in_channels
625
+
626
+ self.norm = Normalize(in_channels)
627
+ self.q = torch.nn.Conv2d(
628
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
629
+ self.k = torch.nn.Conv2d(
630
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
631
+ self.v = torch.nn.Conv2d(
632
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
633
+ self.proj_out = torch.nn.Conv2d(
634
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
635
+
636
+ def forward(self, x):
637
+ h_ = x
638
+ h_ = self.norm(h_)
639
+ q = self.q(h_)
640
+ k = self.k(h_)
641
+ v = self.v(h_)
642
+
643
+ # compute attention
644
+ b, c, h, w = q.shape
645
+ q = q.reshape(b, c, h * w)
646
+ q = q.permute(0, 2, 1) # b,hw,c
647
+ k = k.reshape(b, c, h * w) # b,c,hw
648
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
649
+ w_ = w_ * (int(c)**(-0.5))
650
+ w_ = torch.nn.functional.softmax(w_, dim=2)
651
+
652
+ # attend to values
653
+ v = v.reshape(b, c, h * w)
654
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
655
+ h_ = torch.bmm(
656
+ v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
657
+ h_ = h_.reshape(b, c, h, w)
658
+
659
+ h_ = self.proj_out(h_)
660
+
661
+ return x + h_
662
+
663
+
664
+ class Model(nn.Module):
665
+
666
+ def __init__(self,
667
+ *,
668
+ ch,
669
+ out_ch,
670
+ ch_mult=(1, 2, 4, 8),
671
+ num_res_blocks,
672
+ attn_resolutions,
673
+ dropout=0.0,
674
+ resamp_with_conv=True,
675
+ in_channels,
676
+ resolution,
677
+ use_timestep=True):
678
+ super().__init__()
679
+ self.ch = ch
680
+ self.temb_ch = self.ch * 4
681
+ self.num_resolutions = len(ch_mult)
682
+ self.num_res_blocks = num_res_blocks
683
+ self.resolution = resolution
684
+ self.in_channels = in_channels
685
+
686
+ self.use_timestep = use_timestep
687
+ if self.use_timestep:
688
+ # timestep embedding
689
+ self.temb = nn.Module()
690
+ self.temb.dense = nn.ModuleList([
691
+ torch.nn.Linear(self.ch, self.temb_ch),
692
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
693
+ ])
694
+
695
+ # downsampling
696
+ self.conv_in = torch.nn.Conv2d(
697
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1)
698
+
699
+ curr_res = resolution
700
+ in_ch_mult = (1, ) + tuple(ch_mult)
701
+ self.down = nn.ModuleList()
702
+ for i_level in range(self.num_resolutions):
703
+ block = nn.ModuleList()
704
+ attn = nn.ModuleList()
705
+ block_in = ch * in_ch_mult[i_level]
706
+ block_out = ch * ch_mult[i_level]
707
+ for i_block in range(self.num_res_blocks):
708
+ block.append(
709
+ ResnetBlock(
710
+ in_channels=block_in,
711
+ out_channels=block_out,
712
+ temb_channels=self.temb_ch,
713
+ dropout=dropout))
714
+ block_in = block_out
715
+ if curr_res in attn_resolutions:
716
+ attn.append(AttnBlock(block_in))
717
+ down = nn.Module()
718
+ down.block = block
719
+ down.attn = attn
720
+ if i_level != self.num_resolutions - 1:
721
+ down.downsample = Downsample(block_in, resamp_with_conv)
722
+ curr_res = curr_res // 2
723
+ self.down.append(down)
724
+
725
+ # middle
726
+ self.mid = nn.Module()
727
+ self.mid.block_1 = ResnetBlock(
728
+ in_channels=block_in,
729
+ out_channels=block_in,
730
+ temb_channels=self.temb_ch,
731
+ dropout=dropout)
732
+ self.mid.attn_1 = AttnBlock(block_in)
733
+ self.mid.block_2 = ResnetBlock(
734
+ in_channels=block_in,
735
+ out_channels=block_in,
736
+ temb_channels=self.temb_ch,
737
+ dropout=dropout)
738
+
739
+ # upsampling
740
+ self.up = nn.ModuleList()
741
+ for i_level in reversed(range(self.num_resolutions)):
742
+ block = nn.ModuleList()
743
+ attn = nn.ModuleList()
744
+ block_out = ch * ch_mult[i_level]
745
+ skip_in = ch * ch_mult[i_level]
746
+ for i_block in range(self.num_res_blocks + 1):
747
+ if i_block == self.num_res_blocks:
748
+ skip_in = ch * in_ch_mult[i_level]
749
+ block.append(
750
+ ResnetBlock(
751
+ in_channels=block_in + skip_in,
752
+ out_channels=block_out,
753
+ temb_channels=self.temb_ch,
754
+ dropout=dropout))
755
+ block_in = block_out
756
+ if curr_res in attn_resolutions:
757
+ attn.append(AttnBlock(block_in))
758
+ up = nn.Module()
759
+ up.block = block
760
+ up.attn = attn
761
+ if i_level != 0:
762
+ up.upsample = Upsample(block_in, resamp_with_conv)
763
+ curr_res = curr_res * 2
764
+ self.up.insert(0, up) # prepend to get consistent order
765
+
766
+ # end
767
+ self.norm_out = Normalize(block_in)
768
+ self.conv_out = torch.nn.Conv2d(
769
+ block_in, out_ch, kernel_size=3, stride=1, padding=1)
770
+
771
+ def forward(self, x, t=None):
772
+ #assert x.shape[2] == x.shape[3] == self.resolution
773
+
774
+ if self.use_timestep:
775
+ # timestep embedding
776
+ assert t is not None
777
+ temb = get_timestep_embedding(t, self.ch)
778
+ temb = self.temb.dense[0](temb)
779
+ temb = nonlinearity(temb)
780
+ temb = self.temb.dense[1](temb)
781
+ else:
782
+ temb = None
783
+
784
+ # downsampling
785
+ hs = [self.conv_in(x)]
786
+ for i_level in range(self.num_resolutions):
787
+ for i_block in range(self.num_res_blocks):
788
+ h = self.down[i_level].block[i_block](hs[-1], temb)
789
+ if len(self.down[i_level].attn) > 0:
790
+ h = self.down[i_level].attn[i_block](h)
791
+ hs.append(h)
792
+ if i_level != self.num_resolutions - 1:
793
+ hs.append(self.down[i_level].downsample(hs[-1]))
794
+
795
+ # middle
796
+ h = hs[-1]
797
+ h = self.mid.block_1(h, temb)
798
+ h = self.mid.attn_1(h)
799
+ h = self.mid.block_2(h, temb)
800
+
801
+ # upsampling
802
+ for i_level in reversed(range(self.num_resolutions)):
803
+ for i_block in range(self.num_res_blocks + 1):
804
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
805
+ dim=1), temb)
806
+ if len(self.up[i_level].attn) > 0:
807
+ h = self.up[i_level].attn[i_block](h)
808
+ if i_level != 0:
809
+ h = self.up[i_level].upsample(h)
810
+
811
+ # end
812
+ h = self.norm_out(h)
813
+ h = nonlinearity(h)
814
+ h = self.conv_out(h)
815
+ return h
816
+
817
+
818
+ class Encoder(nn.Module):
819
+
820
+ def __init__(self,
821
+ ch,
822
+ num_res_blocks,
823
+ attn_resolutions,
824
+ in_channels,
825
+ resolution,
826
+ z_channels,
827
+ ch_mult=(1, 2, 4, 8),
828
+ dropout=0.0,
829
+ resamp_with_conv=True,
830
+ double_z=True):
831
+ super().__init__()
832
+ self.ch = ch
833
+ self.temb_ch = 0
834
+ self.num_resolutions = len(ch_mult)
835
+ self.num_res_blocks = num_res_blocks
836
+ self.resolution = resolution
837
+ self.in_channels = in_channels
838
+
839
+ # downsampling
840
+ self.conv_in = torch.nn.Conv2d(
841
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1)
842
+
843
+ curr_res = resolution
844
+ in_ch_mult = (1, ) + tuple(ch_mult)
845
+ self.down = nn.ModuleList()
846
+ for i_level in range(self.num_resolutions):
847
+ block = nn.ModuleList()
848
+ attn = nn.ModuleList()
849
+ block_in = ch * in_ch_mult[i_level]
850
+ block_out = ch * ch_mult[i_level]
851
+ for i_block in range(self.num_res_blocks):
852
+ block.append(
853
+ ResnetBlock(
854
+ in_channels=block_in,
855
+ out_channels=block_out,
856
+ temb_channels=self.temb_ch,
857
+ dropout=dropout))
858
+ block_in = block_out
859
+ if curr_res in attn_resolutions:
860
+ attn.append(AttnBlock(block_in))
861
+ down = nn.Module()
862
+ down.block = block
863
+ down.attn = attn
864
+ if i_level != self.num_resolutions - 1:
865
+ down.downsample = Downsample(block_in, resamp_with_conv)
866
+ curr_res = curr_res // 2
867
+ self.down.append(down)
868
+
869
+ # middle
870
+ self.mid = nn.Module()
871
+ self.mid.block_1 = ResnetBlock(
872
+ in_channels=block_in,
873
+ out_channels=block_in,
874
+ temb_channels=self.temb_ch,
875
+ dropout=dropout)
876
+ self.mid.attn_1 = AttnBlock(block_in)
877
+ self.mid.block_2 = ResnetBlock(
878
+ in_channels=block_in,
879
+ out_channels=block_in,
880
+ temb_channels=self.temb_ch,
881
+ dropout=dropout)
882
+
883
+ # end
884
+ self.norm_out = Normalize(block_in)
885
+ self.conv_out = torch.nn.Conv2d(
886
+ block_in,
887
+ 2 * z_channels if double_z else z_channels,
888
+ kernel_size=3,
889
+ stride=1,
890
+ padding=1)
891
+
892
+ def forward(self, x):
893
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
894
+
895
+ # timestep embedding
896
+ temb = None
897
+
898
+ # downsampling
899
+ hs = [self.conv_in(x)]
900
+ for i_level in range(self.num_resolutions):
901
+ for i_block in range(self.num_res_blocks):
902
+ h = self.down[i_level].block[i_block](hs[-1], temb)
903
+ if len(self.down[i_level].attn) > 0:
904
+ h = self.down[i_level].attn[i_block](h)
905
+ hs.append(h)
906
+ if i_level != self.num_resolutions - 1:
907
+ hs.append(self.down[i_level].downsample(hs[-1]))
908
+
909
+ # middle
910
+ h = hs[-1]
911
+ h = self.mid.block_1(h, temb)
912
+ h = self.mid.attn_1(h)
913
+ h = self.mid.block_2(h, temb)
914
+
915
+ # end
916
+ h = self.norm_out(h)
917
+ h = nonlinearity(h)
918
+ h = self.conv_out(h)
919
+ return h
920
+
921
+
922
+ class Decoder(nn.Module):
923
+
924
+ def __init__(self,
925
+ in_channels,
926
+ resolution,
927
+ z_channels,
928
+ ch,
929
+ out_ch,
930
+ num_res_blocks,
931
+ attn_resolutions,
932
+ ch_mult=(1, 2, 4, 8),
933
+ dropout=0.0,
934
+ resamp_with_conv=True,
935
+ give_pre_end=False):
936
+ super().__init__()
937
+ self.ch = ch
938
+ self.temb_ch = 0
939
+ self.num_resolutions = len(ch_mult)
940
+ self.num_res_blocks = num_res_blocks
941
+ self.resolution = resolution
942
+ self.in_channels = in_channels
943
+ self.give_pre_end = give_pre_end
944
+
945
+ # compute in_ch_mult, block_in and curr_res at lowest res
946
+ in_ch_mult = (1, ) + tuple(ch_mult)
947
+ block_in = ch * ch_mult[self.num_resolutions - 1]
948
+ curr_res = resolution // 2**(self.num_resolutions - 1)
949
+ self.z_shape = (1, z_channels, curr_res, curr_res // 2)
950
+ print("Working with z of shape {} = {} dimensions.".format(
951
+ self.z_shape, np.prod(self.z_shape)))
952
+
953
+ # z to block_in
954
+ self.conv_in = torch.nn.Conv2d(
955
+ z_channels, block_in, kernel_size=3, stride=1, padding=1)
956
+
957
+ # middle
958
+ self.mid = nn.Module()
959
+ self.mid.block_1 = ResnetBlock(
960
+ in_channels=block_in,
961
+ out_channels=block_in,
962
+ temb_channels=self.temb_ch,
963
+ dropout=dropout)
964
+ self.mid.attn_1 = AttnBlock(block_in)
965
+ self.mid.block_2 = ResnetBlock(
966
+ in_channels=block_in,
967
+ out_channels=block_in,
968
+ temb_channels=self.temb_ch,
969
+ dropout=dropout)
970
+
971
+ # upsampling
972
+ self.up = nn.ModuleList()
973
+ for i_level in reversed(range(self.num_resolutions)):
974
+ block = nn.ModuleList()
975
+ attn = nn.ModuleList()
976
+ block_out = ch * ch_mult[i_level]
977
+ for i_block in range(self.num_res_blocks + 1):
978
+ block.append(
979
+ ResnetBlock(
980
+ in_channels=block_in,
981
+ out_channels=block_out,
982
+ temb_channels=self.temb_ch,
983
+ dropout=dropout))
984
+ block_in = block_out
985
+ if curr_res in attn_resolutions:
986
+ attn.append(AttnBlock(block_in))
987
+ up = nn.Module()
988
+ up.block = block
989
+ up.attn = attn
990
+ if i_level != 0:
991
+ up.upsample = Upsample(block_in, resamp_with_conv)
992
+ curr_res = curr_res * 2
993
+ self.up.insert(0, up) # prepend to get consistent order
994
+
995
+ # end
996
+ self.norm_out = Normalize(block_in)
997
+ self.conv_out = torch.nn.Conv2d(
998
+ block_in, out_ch, kernel_size=3, stride=1, padding=1)
999
+
1000
+ def forward(self, z, bot_h=None):
1001
+ #assert z.shape[1:] == self.z_shape[1:]
1002
+ self.last_z_shape = z.shape
1003
+
1004
+ # timestep embedding
1005
+ temb = None
1006
+
1007
+ # z to block_in
1008
+ h = self.conv_in(z)
1009
+
1010
+ # middle
1011
+ h = self.mid.block_1(h, temb)
1012
+ h = self.mid.attn_1(h)
1013
+ h = self.mid.block_2(h, temb)
1014
+
1015
+ # upsampling
1016
+ for i_level in reversed(range(self.num_resolutions)):
1017
+ for i_block in range(self.num_res_blocks + 1):
1018
+ h = self.up[i_level].block[i_block](h, temb)
1019
+ if len(self.up[i_level].attn) > 0:
1020
+ h = self.up[i_level].attn[i_block](h)
1021
+ if i_level != 0:
1022
+ h = self.up[i_level].upsample(h)
1023
+ if i_level == 4 and bot_h is not None:
1024
+ h += bot_h
1025
+
1026
+ # end
1027
+ if self.give_pre_end:
1028
+ return h
1029
+
1030
+ h = self.norm_out(h)
1031
+ h = nonlinearity(h)
1032
+ h = self.conv_out(h)
1033
+ return h
1034
+
1035
+ def get_feature_top(self, z):
1036
+ #assert z.shape[1:] == self.z_shape[1:]
1037
+ self.last_z_shape = z.shape
1038
+
1039
+ # timestep embedding
1040
+ temb = None
1041
+
1042
+ # z to block_in
1043
+ h = self.conv_in(z)
1044
+
1045
+ # middle
1046
+ h = self.mid.block_1(h, temb)
1047
+ h = self.mid.attn_1(h)
1048
+ h = self.mid.block_2(h, temb)
1049
+
1050
+ # upsampling
1051
+ for i_level in reversed(range(self.num_resolutions)):
1052
+ for i_block in range(self.num_res_blocks + 1):
1053
+ h = self.up[i_level].block[i_block](h, temb)
1054
+ if len(self.up[i_level].attn) > 0:
1055
+ h = self.up[i_level].attn[i_block](h)
1056
+ if i_level != 0:
1057
+ h = self.up[i_level].upsample(h)
1058
+ if i_level == 4:
1059
+ return h
1060
+
1061
+ def get_feature_middle(self, z, mid_h):
1062
+ #assert z.shape[1:] == self.z_shape[1:]
1063
+ self.last_z_shape = z.shape
1064
+
1065
+ # timestep embedding
1066
+ temb = None
1067
+
1068
+ # z to block_in
1069
+ h = self.conv_in(z)
1070
+
1071
+ # middle
1072
+ h = self.mid.block_1(h, temb)
1073
+ h = self.mid.attn_1(h)
1074
+ h = self.mid.block_2(h, temb)
1075
+
1076
+ # upsampling
1077
+ for i_level in reversed(range(self.num_resolutions)):
1078
+ for i_block in range(self.num_res_blocks + 1):
1079
+ h = self.up[i_level].block[i_block](h, temb)
1080
+ if len(self.up[i_level].attn) > 0:
1081
+ h = self.up[i_level].attn[i_block](h)
1082
+ if i_level != 0:
1083
+ h = self.up[i_level].upsample(h)
1084
+ if i_level == 4:
1085
+ h += mid_h
1086
+ if i_level == 3:
1087
+ return h
1088
+
1089
+
1090
+ class DecoderRes(nn.Module):
1091
+
1092
+ def __init__(self,
1093
+ in_channels,
1094
+ resolution,
1095
+ z_channels,
1096
+ ch,
1097
+ num_res_blocks,
1098
+ ch_mult=(1, 2, 4, 8),
1099
+ dropout=0.0,
1100
+ give_pre_end=False):
1101
+ super().__init__()
1102
+ self.ch = ch
1103
+ self.temb_ch = 0
1104
+ self.num_resolutions = len(ch_mult)
1105
+ self.num_res_blocks = num_res_blocks
1106
+ self.resolution = resolution
1107
+ self.in_channels = in_channels
1108
+ self.give_pre_end = give_pre_end
1109
+
1110
+ # compute in_ch_mult, block_in and curr_res at lowest res
1111
+ in_ch_mult = (1, ) + tuple(ch_mult)
1112
+ block_in = ch * ch_mult[self.num_resolutions - 1]
1113
+ curr_res = resolution // 2**(self.num_resolutions - 1)
1114
+ self.z_shape = (1, z_channels, curr_res, curr_res // 2)
1115
+ print("Working with z of shape {} = {} dimensions.".format(
1116
+ self.z_shape, np.prod(self.z_shape)))
1117
+
1118
+ # z to block_in
1119
+ self.conv_in = torch.nn.Conv2d(
1120
+ z_channels, block_in, kernel_size=3, stride=1, padding=1)
1121
+
1122
+ # middle
1123
+ self.mid = nn.Module()
1124
+ self.mid.block_1 = ResnetBlock(
1125
+ in_channels=block_in,
1126
+ out_channels=block_in,
1127
+ temb_channels=self.temb_ch,
1128
+ dropout=dropout)
1129
+ self.mid.attn_1 = AttnBlock(block_in)
1130
+ self.mid.block_2 = ResnetBlock(
1131
+ in_channels=block_in,
1132
+ out_channels=block_in,
1133
+ temb_channels=self.temb_ch,
1134
+ dropout=dropout)
1135
+
1136
+ def forward(self, z):
1137
+ #assert z.shape[1:] == self.z_shape[1:]
1138
+ self.last_z_shape = z.shape
1139
+
1140
+ # timestep embedding
1141
+ temb = None
1142
+
1143
+ # z to block_in
1144
+ h = self.conv_in(z)
1145
+
1146
+ # middle
1147
+ h = self.mid.block_1(h, temb)
1148
+ h = self.mid.attn_1(h)
1149
+ h = self.mid.block_2(h, temb)
1150
+
1151
+ return h
1152
+
1153
+
1154
+ # patch based discriminator
1155
+ class Discriminator(nn.Module):
1156
+
1157
+ def __init__(self, nc, ndf, n_layers=3):
1158
+ super().__init__()
1159
+
1160
+ layers = [
1161
+ nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
1162
+ nn.LeakyReLU(0.2, True)
1163
+ ]
1164
+ ndf_mult = 1
1165
+ ndf_mult_prev = 1
1166
+ for n in range(1,
1167
+ n_layers): # gradually increase the number of filters
1168
+ ndf_mult_prev = ndf_mult
1169
+ ndf_mult = min(2**n, 8)
1170
+ layers += [
1171
+ nn.Conv2d(
1172
+ ndf * ndf_mult_prev,
1173
+ ndf * ndf_mult,
1174
+ kernel_size=4,
1175
+ stride=2,
1176
+ padding=1,
1177
+ bias=False),
1178
+ nn.BatchNorm2d(ndf * ndf_mult),
1179
+ nn.LeakyReLU(0.2, True)
1180
+ ]
1181
+
1182
+ ndf_mult_prev = ndf_mult
1183
+ ndf_mult = min(2**n_layers, 8)
1184
+
1185
+ layers += [
1186
+ nn.Conv2d(
1187
+ ndf * ndf_mult_prev,
1188
+ ndf * ndf_mult,
1189
+ kernel_size=4,
1190
+ stride=1,
1191
+ padding=1,
1192
+ bias=False),
1193
+ nn.BatchNorm2d(ndf * ndf_mult),
1194
+ nn.LeakyReLU(0.2, True)
1195
+ ]
1196
+
1197
+ layers += [
1198
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
1199
+ ] # output 1 channel prediction map
1200
+ self.main = nn.Sequential(*layers)
1201
+
1202
+ def forward(self, x):
1203
+ return self.main(x)
Text2Human/models/hierarchy_inference_model.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torchvision.utils import save_image
8
+
9
+ from models.archs.fcn_arch import MultiHeadFCNHead
10
+ from models.archs.unet_arch import UNet
11
+ from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
12
+ VectorQuantizerSpatialTextureAware,
13
+ VectorQuantizerTexture)
14
+ from models.losses.accuracy import accuracy
15
+ from models.losses.cross_entropy_loss import CrossEntropyLoss
16
+
17
+ logger = logging.getLogger('base')
18
+
19
+
20
+ class VQGANTextureAwareSpatialHierarchyInferenceModel():
21
+
22
+ def __init__(self, opt):
23
+ self.opt = opt
24
+ self.device = torch.device('cuda')
25
+ self.is_train = opt['is_train']
26
+
27
+ self.top_encoder = Encoder(
28
+ ch=opt['top_ch'],
29
+ num_res_blocks=opt['top_num_res_blocks'],
30
+ attn_resolutions=opt['top_attn_resolutions'],
31
+ ch_mult=opt['top_ch_mult'],
32
+ in_channels=opt['top_in_channels'],
33
+ resolution=opt['top_resolution'],
34
+ z_channels=opt['top_z_channels'],
35
+ double_z=opt['top_double_z'],
36
+ dropout=opt['top_dropout']).to(self.device)
37
+ self.decoder = Decoder(
38
+ in_channels=opt['top_in_channels'],
39
+ resolution=opt['top_resolution'],
40
+ z_channels=opt['top_z_channels'],
41
+ ch=opt['top_ch'],
42
+ out_ch=opt['top_out_ch'],
43
+ num_res_blocks=opt['top_num_res_blocks'],
44
+ attn_resolutions=opt['top_attn_resolutions'],
45
+ ch_mult=opt['top_ch_mult'],
46
+ dropout=opt['top_dropout'],
47
+ resamp_with_conv=True,
48
+ give_pre_end=False).to(self.device)
49
+ self.top_quantize = VectorQuantizerTexture(
50
+ 1024, opt['embed_dim'], beta=0.25).to(self.device)
51
+ self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
52
+ opt['embed_dim'],
53
+ 1).to(self.device)
54
+ self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
55
+ opt["top_z_channels"],
56
+ 1).to(self.device)
57
+ self.load_top_pretrain_models()
58
+
59
+ self.bot_encoder = Encoder(
60
+ ch=opt['bot_ch'],
61
+ num_res_blocks=opt['bot_num_res_blocks'],
62
+ attn_resolutions=opt['bot_attn_resolutions'],
63
+ ch_mult=opt['bot_ch_mult'],
64
+ in_channels=opt['bot_in_channels'],
65
+ resolution=opt['bot_resolution'],
66
+ z_channels=opt['bot_z_channels'],
67
+ double_z=opt['bot_double_z'],
68
+ dropout=opt['bot_dropout']).to(self.device)
69
+ self.bot_decoder_res = DecoderRes(
70
+ in_channels=opt['bot_in_channels'],
71
+ resolution=opt['bot_resolution'],
72
+ z_channels=opt['bot_z_channels'],
73
+ ch=opt['bot_ch'],
74
+ num_res_blocks=opt['bot_num_res_blocks'],
75
+ ch_mult=opt['bot_ch_mult'],
76
+ dropout=opt['bot_dropout'],
77
+ give_pre_end=False).to(self.device)
78
+ self.bot_quantize = VectorQuantizerSpatialTextureAware(
79
+ opt['bot_n_embed'],
80
+ opt['embed_dim'],
81
+ beta=0.25,
82
+ spatial_size=opt['codebook_spatial_size']).to(self.device)
83
+ self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
84
+ opt['embed_dim'],
85
+ 1).to(self.device)
86
+ self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
87
+ opt["bot_z_channels"],
88
+ 1).to(self.device)
89
+
90
+ self.load_bot_pretrain_network()
91
+
92
+ self.guidance_encoder = UNet(
93
+ in_channels=opt['encoder_in_channels']).to(self.device)
94
+ self.index_decoder = MultiHeadFCNHead(
95
+ in_channels=opt['fc_in_channels'],
96
+ in_index=opt['fc_in_index'],
97
+ channels=opt['fc_channels'],
98
+ num_convs=opt['fc_num_convs'],
99
+ concat_input=opt['fc_concat_input'],
100
+ dropout_ratio=opt['fc_dropout_ratio'],
101
+ num_classes=opt['fc_num_classes'],
102
+ align_corners=opt['fc_align_corners'],
103
+ num_head=18).to(self.device)
104
+
105
+ self.init_training_settings()
106
+
107
+ def init_training_settings(self):
108
+ optim_params = []
109
+ for v in self.guidance_encoder.parameters():
110
+ if v.requires_grad:
111
+ optim_params.append(v)
112
+ for v in self.index_decoder.parameters():
113
+ if v.requires_grad:
114
+ optim_params.append(v)
115
+ # set up optimizers
116
+ if self.opt['optimizer'] == 'Adam':
117
+ self.optimizer = torch.optim.Adam(
118
+ optim_params,
119
+ self.opt['lr'],
120
+ weight_decay=self.opt['weight_decay'])
121
+ elif self.opt['optimizer'] == 'SGD':
122
+ self.optimizer = torch.optim.SGD(
123
+ optim_params,
124
+ self.opt['lr'],
125
+ momentum=self.opt['momentum'],
126
+ weight_decay=self.opt['weight_decay'])
127
+ self.log_dict = OrderedDict()
128
+ if self.opt['loss_function'] == 'cross_entropy':
129
+ self.loss_func = CrossEntropyLoss().to(self.device)
130
+
131
+ def load_top_pretrain_models(self):
132
+ # load pretrained vqgan for segmentation mask
133
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
134
+ self.top_encoder.load_state_dict(
135
+ top_vae_checkpoint['encoder'], strict=True)
136
+ self.decoder.load_state_dict(
137
+ top_vae_checkpoint['decoder'], strict=True)
138
+ self.top_quantize.load_state_dict(
139
+ top_vae_checkpoint['quantize'], strict=True)
140
+ self.top_quant_conv.load_state_dict(
141
+ top_vae_checkpoint['quant_conv'], strict=True)
142
+ self.top_post_quant_conv.load_state_dict(
143
+ top_vae_checkpoint['post_quant_conv'], strict=True)
144
+ self.top_encoder.eval()
145
+ self.top_quantize.eval()
146
+ self.top_quant_conv.eval()
147
+ self.top_post_quant_conv.eval()
148
+
149
+ def load_bot_pretrain_network(self):
150
+ checkpoint = torch.load(self.opt['bot_vae_path'])
151
+ self.bot_encoder.load_state_dict(
152
+ checkpoint['bot_encoder'], strict=True)
153
+ self.bot_decoder_res.load_state_dict(
154
+ checkpoint['bot_decoder_res'], strict=True)
155
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
156
+ self.bot_quantize.load_state_dict(
157
+ checkpoint['bot_quantize'], strict=True)
158
+ self.bot_quant_conv.load_state_dict(
159
+ checkpoint['bot_quant_conv'], strict=True)
160
+ self.bot_post_quant_conv.load_state_dict(
161
+ checkpoint['bot_post_quant_conv'], strict=True)
162
+
163
+ self.bot_encoder.eval()
164
+ self.bot_decoder_res.eval()
165
+ self.decoder.eval()
166
+ self.bot_quantize.eval()
167
+ self.bot_quant_conv.eval()
168
+ self.bot_post_quant_conv.eval()
169
+
170
+ def top_encode(self, x, mask):
171
+ h = self.top_encoder(x)
172
+ h = self.top_quant_conv(h)
173
+ quant, _, _ = self.top_quantize(h, mask)
174
+ quant = self.top_post_quant_conv(quant)
175
+
176
+ return quant, quant
177
+
178
+ def feed_data(self, data):
179
+ self.image = data['image'].to(self.device)
180
+ self.texture_mask = data['texture_mask'].float().to(self.device)
181
+ self.get_gt_indices()
182
+
183
+ self.texture_tokens = F.interpolate(
184
+ self.texture_mask, size=(32, 16),
185
+ mode='nearest').view(self.image.size(0), -1).long()
186
+
187
+ def bot_encode(self, x, mask):
188
+ h = self.bot_encoder(x)
189
+ h = self.bot_quant_conv(h)
190
+ _, _, (_, _, indices_list) = self.bot_quantize(h, mask)
191
+
192
+ return indices_list
193
+
194
+ def get_gt_indices(self):
195
+ self.quant_t, self.feature_t = self.top_encode(self.image,
196
+ self.texture_mask)
197
+ self.gt_indices_list = self.bot_encode(self.image, self.texture_mask)
198
+
199
+ def index_to_image(self, index_bottom_list, texture_mask):
200
+ quant_b = self.bot_quantize.get_codebook_entry(
201
+ index_bottom_list, texture_mask,
202
+ (index_bottom_list[0].size(0), index_bottom_list[0].size(1),
203
+ index_bottom_list[0].size(2),
204
+ self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
205
+ quant_b = self.bot_post_quant_conv(quant_b)
206
+ bot_dec_res = self.bot_decoder_res(quant_b)
207
+
208
+ dec = self.decoder(self.quant_t, bot_h=bot_dec_res)
209
+
210
+ return dec
211
+
212
+ def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path):
213
+ rec_img = self.index_to_image(rec_img_index, texture_mask)
214
+ pred_img = self.index_to_image(pred_img_index, texture_mask)
215
+
216
+ base_img = self.decoder(self.quant_t)
217
+ img_cat = torch.cat([
218
+ self.image,
219
+ rec_img,
220
+ base_img,
221
+ pred_img,
222
+ ], dim=3).detach()
223
+ img_cat = ((img_cat + 1) / 2)
224
+ img_cat = img_cat.clamp_(0, 1)
225
+ save_image(img_cat, save_path, nrow=1, padding=4)
226
+
227
+ def optimize_parameters(self):
228
+ self.guidance_encoder.train()
229
+ self.index_decoder.train()
230
+
231
+ self.feature_enc = self.guidance_encoder(self.feature_t)
232
+ self.memory_logits_list = self.index_decoder(self.feature_enc)
233
+
234
+ loss = 0
235
+ for i in range(18):
236
+ loss += self.loss_func(
237
+ self.memory_logits_list[i],
238
+ self.gt_indices_list[i],
239
+ ignore_index=-1)
240
+
241
+ self.optimizer.zero_grad()
242
+ loss.backward()
243
+ self.optimizer.step()
244
+
245
+ self.log_dict['loss_total'] = loss
246
+
247
+ def inference(self, data_loader, save_dir):
248
+ self.guidance_encoder.eval()
249
+ self.index_decoder.eval()
250
+
251
+ acc = 0
252
+ num = 0
253
+
254
+ for _, data in enumerate(data_loader):
255
+ self.feed_data(data)
256
+ img_name = data['img_name']
257
+
258
+ num += self.image.size(0)
259
+
260
+ texture_mask_flatten = self.texture_tokens.view(-1)
261
+ min_encodings_indices_list = [
262
+ torch.full(
263
+ texture_mask_flatten.size(),
264
+ fill_value=-1,
265
+ dtype=torch.long,
266
+ device=texture_mask_flatten.device) for _ in range(18)
267
+ ]
268
+ with torch.no_grad():
269
+ self.feature_enc = self.guidance_encoder(self.feature_t)
270
+ memory_logits_list = self.index_decoder(self.feature_enc)
271
+ # memory_indices_pred = memory_logits.argmax(dim=1)
272
+ batch_acc = 0
273
+ for codebook_idx, memory_logits in enumerate(memory_logits_list):
274
+ region_of_interest = texture_mask_flatten == codebook_idx
275
+ if torch.sum(region_of_interest) > 0:
276
+ memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
277
+ batch_acc += torch.sum(
278
+ memory_indices_pred[region_of_interest] ==
279
+ self.gt_indices_list[codebook_idx].view(
280
+ -1)[region_of_interest])
281
+ memory_indices_pred = memory_indices_pred
282
+ min_encodings_indices_list[codebook_idx][
283
+ region_of_interest] = memory_indices_pred[
284
+ region_of_interest]
285
+ min_encodings_indices_return_list = [
286
+ min_encodings_indices.view(self.gt_indices_list[0].size())
287
+ for min_encodings_indices in min_encodings_indices_list
288
+ ]
289
+ batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel(
290
+ ) * self.image.size(0)
291
+ acc += batch_acc
292
+ self.get_vis(min_encodings_indices_return_list,
293
+ self.gt_indices_list, self.texture_mask,
294
+ f'{save_dir}/{img_name[0]}')
295
+
296
+ self.guidance_encoder.train()
297
+ self.index_decoder.train()
298
+ return (acc / num).item()
299
+
300
+ def load_network(self):
301
+ checkpoint = torch.load(self.opt['pretrained_models'])
302
+ self.guidance_encoder.load_state_dict(
303
+ checkpoint['guidance_encoder'], strict=True)
304
+ self.guidance_encoder.eval()
305
+
306
+ self.index_decoder.load_state_dict(
307
+ checkpoint['index_decoder'], strict=True)
308
+ self.index_decoder.eval()
309
+
310
+ def save_network(self, save_path):
311
+ """Save networks.
312
+
313
+ Args:
314
+ net (nn.Module): Network to be saved.
315
+ net_label (str): Network label.
316
+ current_iter (int): Current iter number.
317
+ """
318
+
319
+ save_dict = {}
320
+ save_dict['guidance_encoder'] = self.guidance_encoder.state_dict()
321
+ save_dict['index_decoder'] = self.index_decoder.state_dict()
322
+
323
+ torch.save(save_dict, save_path)
324
+
325
+ def update_learning_rate(self, epoch):
326
+ """Update learning rate.
327
+
328
+ Args:
329
+ current_iter (int): Current iteration.
330
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
331
+ Default: -1.
332
+ """
333
+ lr = self.optimizer.param_groups[0]['lr']
334
+
335
+ if self.opt['lr_decay'] == 'step':
336
+ lr = self.opt['lr'] * (
337
+ self.opt['gamma']**(epoch // self.opt['step']))
338
+ elif self.opt['lr_decay'] == 'cos':
339
+ lr = self.opt['lr'] * (
340
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
341
+ elif self.opt['lr_decay'] == 'linear':
342
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
343
+ elif self.opt['lr_decay'] == 'linear2exp':
344
+ if epoch < self.opt['turning_point'] + 1:
345
+ # learning rate decay as 95%
346
+ # at the turning point (1 / 95% = 1.0526)
347
+ lr = self.opt['lr'] * (
348
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
349
+ else:
350
+ lr *= self.opt['gamma']
351
+ elif self.opt['lr_decay'] == 'schedule':
352
+ if epoch in self.opt['schedule']:
353
+ lr *= self.opt['gamma']
354
+ else:
355
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
356
+ # set learning rate
357
+ for param_group in self.optimizer.param_groups:
358
+ param_group['lr'] = lr
359
+
360
+ return lr
361
+
362
+ def get_current_log(self):
363
+ return self.log_dict
Text2Human/models/hierarchy_vqgan_model.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ from collections import OrderedDict
4
+
5
+ sys.path.append('..')
6
+ import lpips
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torchvision.utils import save_image
10
+
11
+ from models.archs.vqgan_arch import (Decoder, DecoderRes, Discriminator,
12
+ Encoder,
13
+ VectorQuantizerSpatialTextureAware,
14
+ VectorQuantizerTexture)
15
+ from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
16
+ calculate_adaptive_weight, hinge_d_loss)
17
+
18
+
19
+ class HierarchyVQSpatialTextureAwareModel():
20
+
21
+ def __init__(self, opt):
22
+ self.opt = opt
23
+ self.device = torch.device('cuda')
24
+ self.top_encoder = Encoder(
25
+ ch=opt['top_ch'],
26
+ num_res_blocks=opt['top_num_res_blocks'],
27
+ attn_resolutions=opt['top_attn_resolutions'],
28
+ ch_mult=opt['top_ch_mult'],
29
+ in_channels=opt['top_in_channels'],
30
+ resolution=opt['top_resolution'],
31
+ z_channels=opt['top_z_channels'],
32
+ double_z=opt['top_double_z'],
33
+ dropout=opt['top_dropout']).to(self.device)
34
+ self.decoder = Decoder(
35
+ in_channels=opt['top_in_channels'],
36
+ resolution=opt['top_resolution'],
37
+ z_channels=opt['top_z_channels'],
38
+ ch=opt['top_ch'],
39
+ out_ch=opt['top_out_ch'],
40
+ num_res_blocks=opt['top_num_res_blocks'],
41
+ attn_resolutions=opt['top_attn_resolutions'],
42
+ ch_mult=opt['top_ch_mult'],
43
+ dropout=opt['top_dropout'],
44
+ resamp_with_conv=True,
45
+ give_pre_end=False).to(self.device)
46
+ self.top_quantize = VectorQuantizerTexture(
47
+ 1024, opt['embed_dim'], beta=0.25).to(self.device)
48
+ self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
49
+ opt['embed_dim'],
50
+ 1).to(self.device)
51
+ self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
52
+ opt["top_z_channels"],
53
+ 1).to(self.device)
54
+ self.load_top_pretrain_models()
55
+
56
+ self.bot_encoder = Encoder(
57
+ ch=opt['bot_ch'],
58
+ num_res_blocks=opt['bot_num_res_blocks'],
59
+ attn_resolutions=opt['bot_attn_resolutions'],
60
+ ch_mult=opt['bot_ch_mult'],
61
+ in_channels=opt['bot_in_channels'],
62
+ resolution=opt['bot_resolution'],
63
+ z_channels=opt['bot_z_channels'],
64
+ double_z=opt['bot_double_z'],
65
+ dropout=opt['bot_dropout']).to(self.device)
66
+ self.bot_decoder_res = DecoderRes(
67
+ in_channels=opt['bot_in_channels'],
68
+ resolution=opt['bot_resolution'],
69
+ z_channels=opt['bot_z_channels'],
70
+ ch=opt['bot_ch'],
71
+ num_res_blocks=opt['bot_num_res_blocks'],
72
+ ch_mult=opt['bot_ch_mult'],
73
+ dropout=opt['bot_dropout'],
74
+ give_pre_end=False).to(self.device)
75
+ self.bot_quantize = VectorQuantizerSpatialTextureAware(
76
+ opt['bot_n_embed'],
77
+ opt['embed_dim'],
78
+ beta=0.25,
79
+ spatial_size=opt['codebook_spatial_size']).to(self.device)
80
+ self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
81
+ opt['embed_dim'],
82
+ 1).to(self.device)
83
+ self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
84
+ opt["bot_z_channels"],
85
+ 1).to(self.device)
86
+
87
+ self.disc = Discriminator(
88
+ opt['n_channels'], opt['ndf'],
89
+ n_layers=opt['disc_layers']).to(self.device)
90
+ self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
91
+ self.perceptual_weight = opt['perceptual_weight']
92
+ self.disc_start_step = opt['disc_start_step']
93
+ self.disc_weight_max = opt['disc_weight_max']
94
+ self.diff_aug = opt['diff_aug']
95
+ self.policy = "color,translation"
96
+
97
+ self.load_discriminator_models()
98
+
99
+ self.disc.train()
100
+
101
+ self.fix_decoder = opt['fix_decoder']
102
+
103
+ self.init_training_settings()
104
+
105
+ def load_top_pretrain_models(self):
106
+ # load pretrained vqgan for segmentation mask
107
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
108
+ self.top_encoder.load_state_dict(
109
+ top_vae_checkpoint['encoder'], strict=True)
110
+ self.decoder.load_state_dict(
111
+ top_vae_checkpoint['decoder'], strict=True)
112
+ self.top_quantize.load_state_dict(
113
+ top_vae_checkpoint['quantize'], strict=True)
114
+ self.top_quant_conv.load_state_dict(
115
+ top_vae_checkpoint['quant_conv'], strict=True)
116
+ self.top_post_quant_conv.load_state_dict(
117
+ top_vae_checkpoint['post_quant_conv'], strict=True)
118
+ self.top_encoder.eval()
119
+ self.top_quantize.eval()
120
+ self.top_quant_conv.eval()
121
+ self.top_post_quant_conv.eval()
122
+
123
+ def init_training_settings(self):
124
+ self.log_dict = OrderedDict()
125
+ self.configure_optimizers()
126
+
127
+ def configure_optimizers(self):
128
+ optim_params = []
129
+ for v in self.bot_encoder.parameters():
130
+ if v.requires_grad:
131
+ optim_params.append(v)
132
+ for v in self.bot_decoder_res.parameters():
133
+ if v.requires_grad:
134
+ optim_params.append(v)
135
+ for v in self.bot_quantize.parameters():
136
+ if v.requires_grad:
137
+ optim_params.append(v)
138
+ for v in self.bot_quant_conv.parameters():
139
+ if v.requires_grad:
140
+ optim_params.append(v)
141
+ for v in self.bot_post_quant_conv.parameters():
142
+ if v.requires_grad:
143
+ optim_params.append(v)
144
+ if not self.fix_decoder:
145
+ for name, v in self.decoder.named_parameters():
146
+ if v.requires_grad:
147
+ if 'up.0' in name:
148
+ optim_params.append(v)
149
+ if 'up.1' in name:
150
+ optim_params.append(v)
151
+ if 'up.2' in name:
152
+ optim_params.append(v)
153
+ if 'up.3' in name:
154
+ optim_params.append(v)
155
+
156
+ self.optimizer = torch.optim.Adam(optim_params, lr=self.opt['lr'])
157
+
158
+ self.disc_optimizer = torch.optim.Adam(
159
+ self.disc.parameters(), lr=self.opt['lr'])
160
+
161
+ def load_discriminator_models(self):
162
+ # load pretrained vqgan for segmentation mask
163
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
164
+ self.disc.load_state_dict(
165
+ top_vae_checkpoint['discriminator'], strict=True)
166
+
167
+ def save_network(self, save_path):
168
+ """Save networks.
169
+ """
170
+
171
+ save_dict = {}
172
+ save_dict['bot_encoder'] = self.bot_encoder.state_dict()
173
+ save_dict['bot_decoder_res'] = self.bot_decoder_res.state_dict()
174
+ save_dict['decoder'] = self.decoder.state_dict()
175
+ save_dict['bot_quantize'] = self.bot_quantize.state_dict()
176
+ save_dict['bot_quant_conv'] = self.bot_quant_conv.state_dict()
177
+ save_dict['bot_post_quant_conv'] = self.bot_post_quant_conv.state_dict(
178
+ )
179
+ save_dict['discriminator'] = self.disc.state_dict()
180
+ torch.save(save_dict, save_path)
181
+
182
+ def load_network(self):
183
+ checkpoint = torch.load(self.opt['pretrained_models'])
184
+ self.bot_encoder.load_state_dict(
185
+ checkpoint['bot_encoder'], strict=True)
186
+ self.bot_decoder_res.load_state_dict(
187
+ checkpoint['bot_decoder_res'], strict=True)
188
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
189
+ self.bot_quantize.load_state_dict(
190
+ checkpoint['bot_quantize'], strict=True)
191
+ self.bot_quant_conv.load_state_dict(
192
+ checkpoint['bot_quant_conv'], strict=True)
193
+ self.bot_post_quant_conv.load_state_dict(
194
+ checkpoint['bot_post_quant_conv'], strict=True)
195
+
196
+ def optimize_parameters(self, data, step):
197
+ self.bot_encoder.train()
198
+ self.bot_decoder_res.train()
199
+ if not self.fix_decoder:
200
+ self.decoder.train()
201
+ self.bot_quantize.train()
202
+ self.bot_quant_conv.train()
203
+ self.bot_post_quant_conv.train()
204
+
205
+ loss, d_loss = self.training_step(data, step)
206
+ self.optimizer.zero_grad()
207
+ loss.backward()
208
+ self.optimizer.step()
209
+
210
+ if step > self.disc_start_step:
211
+ self.disc_optimizer.zero_grad()
212
+ d_loss.backward()
213
+ self.disc_optimizer.step()
214
+
215
+ def top_encode(self, x, mask):
216
+ h = self.top_encoder(x)
217
+ h = self.top_quant_conv(h)
218
+ quant, _, _ = self.top_quantize(h, mask)
219
+ quant = self.top_post_quant_conv(quant)
220
+ return quant
221
+
222
+ def bot_encode(self, x, mask):
223
+ h = self.bot_encoder(x)
224
+ h = self.bot_quant_conv(h)
225
+ quant, emb_loss, info = self.bot_quantize(h, mask)
226
+ quant = self.bot_post_quant_conv(quant)
227
+ bot_dec_res = self.bot_decoder_res(quant)
228
+ return bot_dec_res, emb_loss, info
229
+
230
+ def decode(self, quant_top, bot_dec_res):
231
+ dec = self.decoder(quant_top, bot_h=bot_dec_res)
232
+ return dec
233
+
234
+ def forward_step(self, input, mask):
235
+ with torch.no_grad():
236
+ quant_top = self.top_encode(input, mask)
237
+ bot_dec_res, diff, _ = self.bot_encode(input, mask)
238
+ dec = self.decode(quant_top, bot_dec_res)
239
+ return dec, diff
240
+
241
+ def feed_data(self, data):
242
+ x = data['image'].float().to(self.device)
243
+ mask = data['texture_mask'].float().to(self.device)
244
+
245
+ return x, mask
246
+
247
+ def training_step(self, data, step):
248
+ x, mask = self.feed_data(data)
249
+ xrec, codebook_loss = self.forward_step(x, mask)
250
+
251
+ # get recon/perceptual loss
252
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
253
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
254
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
255
+ nll_loss = torch.mean(nll_loss)
256
+
257
+ # augment for input to discriminator
258
+ if self.diff_aug:
259
+ xrec = DiffAugment(xrec, policy=self.policy)
260
+
261
+ # update generator
262
+ logits_fake = self.disc(xrec)
263
+ g_loss = -torch.mean(logits_fake)
264
+ last_layer = self.decoder.conv_out.weight
265
+ d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
266
+ self.disc_weight_max)
267
+ d_weight *= adopt_weight(1, step, self.disc_start_step)
268
+ loss = nll_loss + d_weight * g_loss + codebook_loss
269
+
270
+ self.log_dict["loss"] = loss
271
+ self.log_dict["l1"] = recon_loss.mean().item()
272
+ self.log_dict["perceptual"] = p_loss.mean().item()
273
+ self.log_dict["nll_loss"] = nll_loss.item()
274
+ self.log_dict["g_loss"] = g_loss.item()
275
+ self.log_dict["d_weight"] = d_weight
276
+ self.log_dict["codebook_loss"] = codebook_loss.item()
277
+
278
+ if step > self.disc_start_step:
279
+ if self.diff_aug:
280
+ logits_real = self.disc(
281
+ DiffAugment(x.contiguous().detach(), policy=self.policy))
282
+ else:
283
+ logits_real = self.disc(x.contiguous().detach())
284
+ logits_fake = self.disc(xrec.contiguous().detach(
285
+ )) # detach so that generator isn"t also updated
286
+ d_loss = hinge_d_loss(logits_real, logits_fake)
287
+ self.log_dict["d_loss"] = d_loss
288
+ else:
289
+ d_loss = None
290
+
291
+ return loss, d_loss
292
+
293
+ @torch.no_grad()
294
+ def inference(self, data_loader, save_dir):
295
+ self.bot_encoder.eval()
296
+ self.bot_decoder_res.eval()
297
+ self.decoder.eval()
298
+ self.bot_quantize.eval()
299
+ self.bot_quant_conv.eval()
300
+ self.bot_post_quant_conv.eval()
301
+
302
+ loss_total = 0
303
+ num = 0
304
+
305
+ for _, data in enumerate(data_loader):
306
+ img_name = data['img_name'][0]
307
+ x, mask = self.feed_data(data)
308
+ xrec, _ = self.forward_step(x, mask)
309
+
310
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
311
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
312
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
313
+ nll_loss = torch.mean(nll_loss)
314
+ loss_total += nll_loss
315
+
316
+ num += x.size(0)
317
+
318
+ if x.shape[1] > 3:
319
+ # colorize with random projection
320
+ assert xrec.shape[1] > 3
321
+ # convert logits to indices
322
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
323
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
324
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
325
+ x = self.to_rgb(x)
326
+ xrec = self.to_rgb(xrec)
327
+
328
+ img_cat = torch.cat([x, xrec], dim=3).detach()
329
+ img_cat = ((img_cat + 1) / 2)
330
+ img_cat = img_cat.clamp_(0, 1)
331
+ save_image(
332
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
333
+
334
+ return (loss_total / num).item()
335
+
336
+ def get_current_log(self):
337
+ return self.log_dict
338
+
339
+ def update_learning_rate(self, epoch):
340
+ """Update learning rate.
341
+
342
+ Args:
343
+ current_iter (int): Current iteration.
344
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
345
+ Default: -1.
346
+ """
347
+ lr = self.optimizer.param_groups[0]['lr']
348
+
349
+ if self.opt['lr_decay'] == 'step':
350
+ lr = self.opt['lr'] * (
351
+ self.opt['gamma']**(epoch // self.opt['step']))
352
+ elif self.opt['lr_decay'] == 'cos':
353
+ lr = self.opt['lr'] * (
354
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
355
+ elif self.opt['lr_decay'] == 'linear':
356
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
357
+ elif self.opt['lr_decay'] == 'linear2exp':
358
+ if epoch < self.opt['turning_point'] + 1:
359
+ # learning rate decay as 95%
360
+ # at the turning point (1 / 95% = 1.0526)
361
+ lr = self.opt['lr'] * (
362
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
363
+ else:
364
+ lr *= self.opt['gamma']
365
+ elif self.opt['lr_decay'] == 'schedule':
366
+ if epoch in self.opt['schedule']:
367
+ lr *= self.opt['gamma']
368
+ else:
369
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
370
+ # set learning rate
371
+ for param_group in self.optimizer.param_groups:
372
+ param_group['lr'] = lr
373
+
374
+ return lr
Text2Human/models/losses/__init__.py ADDED
File without changes
Text2Human/models/losses/accuracy.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def accuracy(pred, target, topk=1, thresh=None):
2
+ """Calculate accuracy according to the prediction and target.
3
+
4
+ Args:
5
+ pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
6
+ target (torch.Tensor): The target of each prediction, shape (N, , ...)
7
+ topk (int | tuple[int], optional): If the predictions in ``topk``
8
+ matches the target, the predictions will be regarded as
9
+ correct ones. Defaults to 1.
10
+ thresh (float, optional): If not None, predictions with scores under
11
+ this threshold are considered incorrect. Default to None.
12
+
13
+ Returns:
14
+ float | tuple[float]: If the input ``topk`` is a single integer,
15
+ the function will return a single float as accuracy. If
16
+ ``topk`` is a tuple containing multiple integers, the
17
+ function will return a tuple containing accuracies of
18
+ each ``topk`` number.
19
+ """
20
+ assert isinstance(topk, (int, tuple))
21
+ if isinstance(topk, int):
22
+ topk = (topk, )
23
+ return_single = True
24
+ else:
25
+ return_single = False
26
+
27
+ maxk = max(topk)
28
+ if pred.size(0) == 0:
29
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
30
+ return accu[0] if return_single else accu
31
+ assert pred.ndim == target.ndim + 1
32
+ assert pred.size(0) == target.size(0)
33
+ assert maxk <= pred.size(1), \
34
+ f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
35
+ pred_value, pred_label = pred.topk(maxk, dim=1)
36
+ # transpose to shape (maxk, N, ...)
37
+ pred_label = pred_label.transpose(0, 1)
38
+ correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
39
+ if thresh is not None:
40
+ # Only prediction values larger than thresh are counted as correct
41
+ correct = correct & (pred_value > thresh).t()
42
+ res = []
43
+ for k in topk:
44
+ correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
45
+ res.append(correct_k.mul_(100.0 / target.numel()))
46
+ return res[0] if return_single else res
Text2Human/models/losses/cross_entropy_loss.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def reduce_loss(loss, reduction):
7
+ """Reduce loss as specified.
8
+
9
+ Args:
10
+ loss (Tensor): Elementwise loss tensor.
11
+ reduction (str): Options are "none", "mean" and "sum".
12
+
13
+ Return:
14
+ Tensor: Reduced loss tensor.
15
+ """
16
+ reduction_enum = F._Reduction.get_enum(reduction)
17
+ # none: 0, elementwise_mean:1, sum: 2
18
+ if reduction_enum == 0:
19
+ return loss
20
+ elif reduction_enum == 1:
21
+ return loss.mean()
22
+ elif reduction_enum == 2:
23
+ return loss.sum()
24
+
25
+
26
+ def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
27
+ """Apply element-wise weight and reduce loss.
28
+
29
+ Args:
30
+ loss (Tensor): Element-wise loss.
31
+ weight (Tensor): Element-wise weights.
32
+ reduction (str): Same as built-in losses of PyTorch.
33
+ avg_factor (float): Avarage factor when computing the mean of losses.
34
+
35
+ Returns:
36
+ Tensor: Processed loss values.
37
+ """
38
+ # if weight is specified, apply element-wise weight
39
+ if weight is not None:
40
+ assert weight.dim() == loss.dim()
41
+ if weight.dim() > 1:
42
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
43
+ loss = loss * weight
44
+
45
+ # if avg_factor is not specified, just reduce the loss
46
+ if avg_factor is None:
47
+ loss = reduce_loss(loss, reduction)
48
+ else:
49
+ # if reduction is mean, then average the loss by avg_factor
50
+ if reduction == 'mean':
51
+ loss = loss.sum() / avg_factor
52
+ # if reduction is 'none', then do nothing, otherwise raise an error
53
+ elif reduction != 'none':
54
+ raise ValueError('avg_factor can not be used with reduction="sum"')
55
+ return loss
56
+
57
+
58
+ def cross_entropy(pred,
59
+ label,
60
+ weight=None,
61
+ class_weight=None,
62
+ reduction='mean',
63
+ avg_factor=None,
64
+ ignore_index=-100):
65
+ """The wrapper function for :func:`F.cross_entropy`"""
66
+ # class_weight is a manual rescaling weight given to each class.
67
+ # If given, has to be a Tensor of size C element-wise losses
68
+ loss = F.cross_entropy(
69
+ pred,
70
+ label,
71
+ weight=class_weight,
72
+ reduction='none',
73
+ ignore_index=ignore_index)
74
+
75
+ # apply weights and do the reduction
76
+ if weight is not None:
77
+ weight = weight.float()
78
+ loss = weight_reduce_loss(
79
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
80
+
81
+ return loss
82
+
83
+
84
+ def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
85
+ """Expand onehot labels to match the size of prediction."""
86
+ bin_labels = labels.new_zeros(target_shape)
87
+ valid_mask = (labels >= 0) & (labels != ignore_index)
88
+ inds = torch.nonzero(valid_mask, as_tuple=True)
89
+
90
+ if inds[0].numel() > 0:
91
+ if labels.dim() == 3:
92
+ bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
93
+ else:
94
+ bin_labels[inds[0], labels[valid_mask]] = 1
95
+
96
+ valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
97
+ if label_weights is None:
98
+ bin_label_weights = valid_mask
99
+ else:
100
+ bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
101
+ bin_label_weights *= valid_mask
102
+
103
+ return bin_labels, bin_label_weights
104
+
105
+
106
+ def binary_cross_entropy(pred,
107
+ label,
108
+ weight=None,
109
+ reduction='mean',
110
+ avg_factor=None,
111
+ class_weight=None,
112
+ ignore_index=255):
113
+ """Calculate the binary CrossEntropy loss.
114
+
115
+ Args:
116
+ pred (torch.Tensor): The prediction with shape (N, 1).
117
+ label (torch.Tensor): The learning label of the prediction.
118
+ weight (torch.Tensor, optional): Sample-wise loss weight.
119
+ reduction (str, optional): The method used to reduce the loss.
120
+ Options are "none", "mean" and "sum".
121
+ avg_factor (int, optional): Average factor that is used to average
122
+ the loss. Defaults to None.
123
+ class_weight (list[float], optional): The weight for each class.
124
+ ignore_index (int | None): The label index to be ignored. Default: 255
125
+
126
+ Returns:
127
+ torch.Tensor: The calculated loss
128
+ """
129
+ if pred.dim() != label.dim():
130
+ assert (pred.dim() == 2 and label.dim() == 1) or (
131
+ pred.dim() == 4 and label.dim() == 3), \
132
+ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
133
+ 'H, W], label shape [N, H, W] are supported'
134
+ label, weight = _expand_onehot_labels(label, weight, pred.shape,
135
+ ignore_index)
136
+
137
+ # weighted element-wise losses
138
+ if weight is not None:
139
+ weight = weight.float()
140
+ loss = F.binary_cross_entropy_with_logits(
141
+ pred, label.float(), pos_weight=class_weight, reduction='none')
142
+ # do the reduction for the weighted loss
143
+ loss = weight_reduce_loss(
144
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
145
+
146
+ return loss
147
+
148
+
149
+ def mask_cross_entropy(pred,
150
+ target,
151
+ label,
152
+ reduction='mean',
153
+ avg_factor=None,
154
+ class_weight=None,
155
+ ignore_index=None):
156
+ """Calculate the CrossEntropy loss for masks.
157
+
158
+ Args:
159
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
160
+ of classes.
161
+ target (torch.Tensor): The learning label of the prediction.
162
+ label (torch.Tensor): ``label`` indicates the class label of the mask'
163
+ corresponding object. This will be used to select the mask in the
164
+ of the class which the object belongs to when the mask prediction
165
+ if not class-agnostic.
166
+ reduction (str, optional): The method used to reduce the loss.
167
+ Options are "none", "mean" and "sum".
168
+ avg_factor (int, optional): Average factor that is used to average
169
+ the loss. Defaults to None.
170
+ class_weight (list[float], optional): The weight for each class.
171
+ ignore_index (None): Placeholder, to be consistent with other loss.
172
+ Default: None.
173
+
174
+ Returns:
175
+ torch.Tensor: The calculated loss
176
+ """
177
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
178
+ # TODO: handle these two reserved arguments
179
+ assert reduction == 'mean' and avg_factor is None
180
+ num_rois = pred.size()[0]
181
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
182
+ pred_slice = pred[inds, label].squeeze(1)
183
+ return F.binary_cross_entropy_with_logits(
184
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
185
+
186
+
187
+ class CrossEntropyLoss(nn.Module):
188
+ """CrossEntropyLoss.
189
+
190
+ Args:
191
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
192
+ of softmax. Defaults to False.
193
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
194
+ Defaults to False.
195
+ reduction (str, optional): . Defaults to 'mean'.
196
+ Options are "none", "mean" and "sum".
197
+ class_weight (list[float], optional): Weight of each class.
198
+ Defaults to None.
199
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
200
+ """
201
+
202
+ def __init__(self,
203
+ use_sigmoid=False,
204
+ use_mask=False,
205
+ reduction='mean',
206
+ class_weight=None,
207
+ loss_weight=1.0):
208
+ super(CrossEntropyLoss, self).__init__()
209
+ assert (use_sigmoid is False) or (use_mask is False)
210
+ self.use_sigmoid = use_sigmoid
211
+ self.use_mask = use_mask
212
+ self.reduction = reduction
213
+ self.loss_weight = loss_weight
214
+ self.class_weight = class_weight
215
+
216
+ if self.use_sigmoid:
217
+ self.cls_criterion = binary_cross_entropy
218
+ elif self.use_mask:
219
+ self.cls_criterion = mask_cross_entropy
220
+ else:
221
+ self.cls_criterion = cross_entropy
222
+
223
+ def forward(self,
224
+ cls_score,
225
+ label,
226
+ weight=None,
227
+ avg_factor=None,
228
+ reduction_override=None,
229
+ **kwargs):
230
+ """Forward function."""
231
+ assert reduction_override in (None, 'none', 'mean', 'sum')
232
+ reduction = (
233
+ reduction_override if reduction_override else self.reduction)
234
+ if self.class_weight is not None:
235
+ class_weight = cls_score.new_tensor(self.class_weight)
236
+ else:
237
+ class_weight = None
238
+ loss_cls = self.loss_weight * self.cls_criterion(
239
+ cls_score,
240
+ label,
241
+ weight,
242
+ class_weight=class_weight,
243
+ reduction=reduction,
244
+ avg_factor=avg_factor,
245
+ **kwargs)
246
+ return loss_cls
Text2Human/models/losses/segmentation_loss.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class BCELoss(nn.Module):
6
+
7
+ def forward(self, prediction, target):
8
+ loss = F.binary_cross_entropy_with_logits(prediction, target)
9
+ return loss, {}
10
+
11
+
12
+ class BCELossWithQuant(nn.Module):
13
+
14
+ def __init__(self, codebook_weight=1.):
15
+ super().__init__()
16
+ self.codebook_weight = codebook_weight
17
+
18
+ def forward(self, qloss, target, prediction, split):
19
+ bce_loss = F.binary_cross_entropy_with_logits(prediction, target)
20
+ loss = bce_loss + self.codebook_weight * qloss
21
+ return loss, {
22
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
23
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
24
+ "{}/quant_loss".format(split): qloss.detach().mean()
25
+ }
Text2Human/models/losses/vqgan_loss.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max):
6
+ recon_grads = torch.autograd.grad(
7
+ recon_loss, last_layer, retain_graph=True)[0]
8
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
9
+
10
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
11
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
12
+ return d_weight
13
+
14
+
15
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
16
+ if global_step < threshold:
17
+ weight = value
18
+ return weight
19
+
20
+
21
+ @torch.jit.script
22
+ def hinge_d_loss(logits_real, logits_fake):
23
+ loss_real = torch.mean(F.relu(1. - logits_real))
24
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
25
+ d_loss = 0.5 * (loss_real + loss_fake)
26
+ return d_loss
27
+
28
+
29
+ def DiffAugment(x, policy='', channels_first=True):
30
+ if policy:
31
+ if not channels_first:
32
+ x = x.permute(0, 3, 1, 2)
33
+ for p in policy.split(','):
34
+ for f in AUGMENT_FNS[p]:
35
+ x = f(x)
36
+ if not channels_first:
37
+ x = x.permute(0, 2, 3, 1)
38
+ x = x.contiguous()
39
+ return x
40
+
41
+
42
+ def rand_brightness(x):
43
+ x = x + (
44
+ torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
45
+ return x
46
+
47
+
48
+ def rand_saturation(x):
49
+ x_mean = x.mean(dim=1, keepdim=True)
50
+ x = (x - x_mean) * (torch.rand(
51
+ x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
52
+ return x
53
+
54
+
55
+ def rand_contrast(x):
56
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
57
+ x = (x - x_mean) * (torch.rand(
58
+ x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
59
+ return x
60
+
61
+
62
+ def rand_translation(x, ratio=0.125):
63
+ shift_x, shift_y = int(x.size(2) * ratio +
64
+ 0.5), int(x.size(3) * ratio + 0.5)
65
+ translation_x = torch.randint(
66
+ -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
67
+ translation_y = torch.randint(
68
+ -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
69
+ grid_batch, grid_x, grid_y = torch.meshgrid(
70
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
71
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
72
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
73
+ )
74
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
75
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
76
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
77
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x,
78
+ grid_y].permute(0, 3, 1, 2)
79
+ return x
80
+
81
+
82
+ def rand_cutout(x, ratio=0.5):
83
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
84
+ offset_x = torch.randint(
85
+ 0,
86
+ x.size(2) + (1 - cutout_size[0] % 2),
87
+ size=[x.size(0), 1, 1],
88
+ device=x.device)
89
+ offset_y = torch.randint(
90
+ 0,
91
+ x.size(3) + (1 - cutout_size[1] % 2),
92
+ size=[x.size(0), 1, 1],
93
+ device=x.device)
94
+ grid_batch, grid_x, grid_y = torch.meshgrid(
95
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
96
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
97
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
98
+ )
99
+ grid_x = torch.clamp(
100
+ grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
101
+ grid_y = torch.clamp(
102
+ grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
103
+ mask = torch.ones(
104
+ x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
105
+ mask[grid_batch, grid_x, grid_y] = 0
106
+ x = x * mask.unsqueeze(1)
107
+ return x
108
+
109
+
110
+ AUGMENT_FNS = {
111
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
112
+ 'translation': [rand_translation],
113
+ 'cutout': [rand_cutout],
114
+ }
Text2Human/models/parsing_gen_model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from collections import OrderedDict
4
+
5
+ import mmcv
6
+ import numpy as np
7
+ import torch
8
+ from torchvision.utils import save_image
9
+
10
+ from models.archs.fcn_arch import FCNHead
11
+ from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
12
+ from models.archs.unet_arch import ShapeUNet
13
+ from models.losses.accuracy import accuracy
14
+ from models.losses.cross_entropy_loss import CrossEntropyLoss
15
+
16
+ logger = logging.getLogger('base')
17
+
18
+
19
+ class ParsingGenModel():
20
+ """Paring Generation model.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ self.opt = opt
25
+ self.device = torch.device('cuda')
26
+ self.is_train = opt['is_train']
27
+
28
+ self.attr_embedder = ShapeAttrEmbedding(
29
+ dim=opt['embedder_dim'],
30
+ out_dim=opt['embedder_out_dim'],
31
+ cls_num_list=opt['attr_class_num']).to(self.device)
32
+ self.parsing_encoder = ShapeUNet(
33
+ in_channels=opt['encoder_in_channels']).to(self.device)
34
+ self.parsing_decoder = FCNHead(
35
+ in_channels=opt['fc_in_channels'],
36
+ in_index=opt['fc_in_index'],
37
+ channels=opt['fc_channels'],
38
+ num_convs=opt['fc_num_convs'],
39
+ concat_input=opt['fc_concat_input'],
40
+ dropout_ratio=opt['fc_dropout_ratio'],
41
+ num_classes=opt['fc_num_classes'],
42
+ align_corners=opt['fc_align_corners'],
43
+ ).to(self.device)
44
+
45
+ self.init_training_settings()
46
+
47
+ self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
48
+ [250, 235, 215], [255, 250, 205], [211, 211, 211],
49
+ [70, 130, 180], [127, 255, 212], [0, 100, 0],
50
+ [50, 205, 50], [255, 255, 0], [245, 222, 179],
51
+ [255, 140, 0], [255, 0, 0], [16, 78, 139],
52
+ [144, 238, 144], [50, 205, 174], [50, 155, 250],
53
+ [160, 140, 88], [213, 140, 88], [90, 140, 90],
54
+ [185, 210, 205], [130, 165, 180], [225, 141, 151]]
55
+
56
+ def init_training_settings(self):
57
+ optim_params = []
58
+ for v in self.attr_embedder.parameters():
59
+ if v.requires_grad:
60
+ optim_params.append(v)
61
+ for v in self.parsing_encoder.parameters():
62
+ if v.requires_grad:
63
+ optim_params.append(v)
64
+ for v in self.parsing_decoder.parameters():
65
+ if v.requires_grad:
66
+ optim_params.append(v)
67
+ # set up optimizers
68
+ self.optimizer = torch.optim.Adam(
69
+ optim_params,
70
+ self.opt['lr'],
71
+ weight_decay=self.opt['weight_decay'])
72
+ self.log_dict = OrderedDict()
73
+ self.entropy_loss = CrossEntropyLoss().to(self.device)
74
+
75
+ def feed_data(self, data):
76
+ self.pose = data['densepose'].to(self.device)
77
+ self.attr = data['attr'].to(self.device)
78
+ self.segm = data['segm'].to(self.device)
79
+
80
+ def optimize_parameters(self):
81
+ self.attr_embedder.train()
82
+ self.parsing_encoder.train()
83
+ self.parsing_decoder.train()
84
+
85
+ self.attr_embedding = self.attr_embedder(self.attr)
86
+ self.pose_enc = self.parsing_encoder(self.pose, self.attr_embedding)
87
+ self.seg_logits = self.parsing_decoder(self.pose_enc)
88
+
89
+ loss = self.entropy_loss(self.seg_logits, self.segm)
90
+
91
+ self.optimizer.zero_grad()
92
+ loss.backward()
93
+ self.optimizer.step()
94
+
95
+ self.log_dict['loss_total'] = loss
96
+
97
+ def get_vis(self, save_path):
98
+ img_cat = torch.cat([
99
+ self.pose,
100
+ self.segm,
101
+ ], dim=3).detach()
102
+ img_cat = ((img_cat + 1) / 2)
103
+
104
+ img_cat = img_cat.clamp_(0, 1)
105
+
106
+ save_image(img_cat, save_path, nrow=1, padding=4)
107
+
108
+ def inference(self, data_loader, save_dir):
109
+ self.attr_embedder.eval()
110
+ self.parsing_encoder.eval()
111
+ self.parsing_decoder.eval()
112
+
113
+ acc = 0
114
+ num = 0
115
+
116
+ for _, data in enumerate(data_loader):
117
+ pose = data['densepose'].to(self.device)
118
+ attr = data['attr'].to(self.device)
119
+ segm = data['segm'].to(self.device)
120
+ img_name = data['img_name']
121
+
122
+ num += pose.size(0)
123
+ with torch.no_grad():
124
+ attr_embedding = self.attr_embedder(attr)
125
+ pose_enc = self.parsing_encoder(pose, attr_embedding)
126
+ seg_logits = self.parsing_decoder(pose_enc)
127
+ seg_pred = seg_logits.argmax(dim=1)
128
+ acc += accuracy(seg_logits, segm)
129
+ palette_label = self.palette_result(segm.cpu().numpy())
130
+ palette_pred = self.palette_result(seg_pred.cpu().numpy())
131
+ pose_numpy = ((pose[0] + 1) / 2. * 255.).expand(
132
+ 3,
133
+ pose[0].size(1),
134
+ pose[0].size(2),
135
+ ).cpu().numpy().clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
136
+ concat_result = np.concatenate(
137
+ (pose_numpy, palette_pred, palette_label), axis=1)
138
+ mmcv.imwrite(concat_result, f'{save_dir}/{img_name[0]}')
139
+
140
+ self.attr_embedder.train()
141
+ self.parsing_encoder.train()
142
+ self.parsing_decoder.train()
143
+ return (acc / num).item()
144
+
145
+ def get_current_log(self):
146
+ return self.log_dict
147
+
148
+ def update_learning_rate(self, epoch):
149
+ """Update learning rate.
150
+
151
+ Args:
152
+ current_iter (int): Current iteration.
153
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
154
+ Default: -1.
155
+ """
156
+ lr = self.optimizer.param_groups[0]['lr']
157
+
158
+ if self.opt['lr_decay'] == 'step':
159
+ lr = self.opt['lr'] * (
160
+ self.opt['gamma']**(epoch // self.opt['step']))
161
+ elif self.opt['lr_decay'] == 'cos':
162
+ lr = self.opt['lr'] * (
163
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
164
+ elif self.opt['lr_decay'] == 'linear':
165
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
166
+ elif self.opt['lr_decay'] == 'linear2exp':
167
+ if epoch < self.opt['turning_point'] + 1:
168
+ # learning rate decay as 95%
169
+ # at the turning point (1 / 95% = 1.0526)
170
+ lr = self.opt['lr'] * (
171
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
172
+ else:
173
+ lr *= self.opt['gamma']
174
+ elif self.opt['lr_decay'] == 'schedule':
175
+ if epoch in self.opt['schedule']:
176
+ lr *= self.opt['gamma']
177
+ else:
178
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
179
+ # set learning rate
180
+ for param_group in self.optimizer.param_groups:
181
+ param_group['lr'] = lr
182
+
183
+ return lr
184
+
185
+ def save_network(self, save_path):
186
+ """Save networks.
187
+ """
188
+
189
+ save_dict = {}
190
+ save_dict['embedder'] = self.attr_embedder.state_dict()
191
+ save_dict['encoder'] = self.parsing_encoder.state_dict()
192
+ save_dict['decoder'] = self.parsing_decoder.state_dict()
193
+
194
+ torch.save(save_dict, save_path)
195
+
196
+ def load_network(self):
197
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
198
+
199
+ self.attr_embedder.load_state_dict(checkpoint['embedder'], strict=True)
200
+ self.attr_embedder.eval()
201
+
202
+ self.parsing_encoder.load_state_dict(
203
+ checkpoint['encoder'], strict=True)
204
+ self.parsing_encoder.eval()
205
+
206
+ self.parsing_decoder.load_state_dict(
207
+ checkpoint['decoder'], strict=True)
208
+ self.parsing_decoder.eval()
209
+
210
+ def palette_result(self, result):
211
+ seg = result[0]
212
+ palette = np.array(self.palette)
213
+ assert palette.shape[1] == 3
214
+ assert len(palette.shape) == 2
215
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
216
+ for label, color in enumerate(palette):
217
+ color_seg[seg == label, :] = color
218
+ # convert to BGR
219
+ color_seg = color_seg[..., ::-1]
220
+ return color_seg
Text2Human/models/sample_model.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.distributions as dists
6
+ import torch.nn.functional as F
7
+ from torchvision.utils import save_image
8
+
9
+ from models.archs.fcn_arch import FCNHead, MultiHeadFCNHead
10
+ from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
11
+ from models.archs.transformer_arch import TransformerMultiHead
12
+ from models.archs.unet_arch import ShapeUNet, UNet
13
+ from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
14
+ VectorQuantizer,
15
+ VectorQuantizerSpatialTextureAware,
16
+ VectorQuantizerTexture)
17
+
18
+ logger = logging.getLogger('base')
19
+
20
+
21
+ class BaseSampleModel():
22
+ """Base Model"""
23
+
24
+ def __init__(self, opt):
25
+ self.opt = opt
26
+ self.device = torch.device('cuda')
27
+
28
+ # hierarchical VQVAE
29
+ self.decoder = Decoder(
30
+ in_channels=opt['top_in_channels'],
31
+ resolution=opt['top_resolution'],
32
+ z_channels=opt['top_z_channels'],
33
+ ch=opt['top_ch'],
34
+ out_ch=opt['top_out_ch'],
35
+ num_res_blocks=opt['top_num_res_blocks'],
36
+ attn_resolutions=opt['top_attn_resolutions'],
37
+ ch_mult=opt['top_ch_mult'],
38
+ dropout=opt['top_dropout'],
39
+ resamp_with_conv=True,
40
+ give_pre_end=False).to(self.device)
41
+ self.top_quantize = VectorQuantizerTexture(
42
+ 1024, opt['embed_dim'], beta=0.25).to(self.device)
43
+ self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
44
+ opt["top_z_channels"],
45
+ 1).to(self.device)
46
+ self.load_top_pretrain_models()
47
+
48
+ self.bot_decoder_res = DecoderRes(
49
+ in_channels=opt['bot_in_channels'],
50
+ resolution=opt['bot_resolution'],
51
+ z_channels=opt['bot_z_channels'],
52
+ ch=opt['bot_ch'],
53
+ num_res_blocks=opt['bot_num_res_blocks'],
54
+ ch_mult=opt['bot_ch_mult'],
55
+ dropout=opt['bot_dropout'],
56
+ give_pre_end=False).to(self.device)
57
+ self.bot_quantize = VectorQuantizerSpatialTextureAware(
58
+ opt['bot_n_embed'],
59
+ opt['embed_dim'],
60
+ beta=0.25,
61
+ spatial_size=opt['bot_codebook_spatial_size']).to(self.device)
62
+ self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
63
+ opt["bot_z_channels"],
64
+ 1).to(self.device)
65
+ self.load_bot_pretrain_network()
66
+
67
+ # top -> bot prediction
68
+ self.index_pred_guidance_encoder = UNet(
69
+ in_channels=opt['index_pred_encoder_in_channels']).to(self.device)
70
+ self.index_pred_decoder = MultiHeadFCNHead(
71
+ in_channels=opt['index_pred_fc_in_channels'],
72
+ in_index=opt['index_pred_fc_in_index'],
73
+ channels=opt['index_pred_fc_channels'],
74
+ num_convs=opt['index_pred_fc_num_convs'],
75
+ concat_input=opt['index_pred_fc_concat_input'],
76
+ dropout_ratio=opt['index_pred_fc_dropout_ratio'],
77
+ num_classes=opt['index_pred_fc_num_classes'],
78
+ align_corners=opt['index_pred_fc_align_corners'],
79
+ num_head=18).to(self.device)
80
+ self.load_index_pred_network()
81
+
82
+ # VAE for segmentation mask
83
+ self.segm_encoder = Encoder(
84
+ ch=opt['segm_ch'],
85
+ num_res_blocks=opt['segm_num_res_blocks'],
86
+ attn_resolutions=opt['segm_attn_resolutions'],
87
+ ch_mult=opt['segm_ch_mult'],
88
+ in_channels=opt['segm_in_channels'],
89
+ resolution=opt['segm_resolution'],
90
+ z_channels=opt['segm_z_channels'],
91
+ double_z=opt['segm_double_z'],
92
+ dropout=opt['segm_dropout']).to(self.device)
93
+ self.segm_quantizer = VectorQuantizer(
94
+ opt['segm_n_embed'],
95
+ opt['segm_embed_dim'],
96
+ beta=0.25,
97
+ sane_index_shape=True).to(self.device)
98
+ self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
99
+ opt['segm_embed_dim'],
100
+ 1).to(self.device)
101
+ self.load_pretrained_segm_token()
102
+
103
+ # define sampler
104
+ self.sampler_fn = TransformerMultiHead(
105
+ codebook_size=opt['codebook_size'],
106
+ segm_codebook_size=opt['segm_codebook_size'],
107
+ texture_codebook_size=opt['texture_codebook_size'],
108
+ bert_n_emb=opt['bert_n_emb'],
109
+ bert_n_layers=opt['bert_n_layers'],
110
+ bert_n_head=opt['bert_n_head'],
111
+ block_size=opt['block_size'],
112
+ latent_shape=opt['latent_shape'],
113
+ embd_pdrop=opt['embd_pdrop'],
114
+ resid_pdrop=opt['resid_pdrop'],
115
+ attn_pdrop=opt['attn_pdrop'],
116
+ num_head=opt['num_head']).to(self.device)
117
+ self.load_sampler_pretrained_network()
118
+
119
+ self.shape = tuple(opt['latent_shape'])
120
+
121
+ self.mask_id = opt['codebook_size']
122
+ self.sample_steps = opt['sample_steps']
123
+
124
+ def load_top_pretrain_models(self):
125
+ # load pretrained vqgan
126
+ top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
127
+
128
+ self.decoder.load_state_dict(
129
+ top_vae_checkpoint['decoder'], strict=True)
130
+ self.top_quantize.load_state_dict(
131
+ top_vae_checkpoint['quantize'], strict=True)
132
+ self.top_post_quant_conv.load_state_dict(
133
+ top_vae_checkpoint['post_quant_conv'], strict=True)
134
+
135
+ self.decoder.eval()
136
+ self.top_quantize.eval()
137
+ self.top_post_quant_conv.eval()
138
+
139
+ def load_bot_pretrain_network(self):
140
+ checkpoint = torch.load(self.opt['bot_vae_path'])
141
+ self.bot_decoder_res.load_state_dict(
142
+ checkpoint['bot_decoder_res'], strict=True)
143
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
144
+ self.bot_quantize.load_state_dict(
145
+ checkpoint['bot_quantize'], strict=True)
146
+ self.bot_post_quant_conv.load_state_dict(
147
+ checkpoint['bot_post_quant_conv'], strict=True)
148
+
149
+ self.bot_decoder_res.eval()
150
+ self.decoder.eval()
151
+ self.bot_quantize.eval()
152
+ self.bot_post_quant_conv.eval()
153
+
154
+ def load_pretrained_segm_token(self):
155
+ # load pretrained vqgan for segmentation mask
156
+ segm_token_checkpoint = torch.load(self.opt['segm_token_path'])
157
+ self.segm_encoder.load_state_dict(
158
+ segm_token_checkpoint['encoder'], strict=True)
159
+ self.segm_quantizer.load_state_dict(
160
+ segm_token_checkpoint['quantize'], strict=True)
161
+ self.segm_quant_conv.load_state_dict(
162
+ segm_token_checkpoint['quant_conv'], strict=True)
163
+
164
+ self.segm_encoder.eval()
165
+ self.segm_quantizer.eval()
166
+ self.segm_quant_conv.eval()
167
+
168
+ def load_index_pred_network(self):
169
+ checkpoint = torch.load(self.opt['pretrained_index_network'])
170
+ self.index_pred_guidance_encoder.load_state_dict(
171
+ checkpoint['guidance_encoder'], strict=True)
172
+ self.index_pred_decoder.load_state_dict(
173
+ checkpoint['index_decoder'], strict=True)
174
+
175
+ self.index_pred_guidance_encoder.eval()
176
+ self.index_pred_decoder.eval()
177
+
178
+ def load_sampler_pretrained_network(self):
179
+ checkpoint = torch.load(self.opt['pretrained_sampler'])
180
+ self.sampler_fn.load_state_dict(checkpoint, strict=True)
181
+ self.sampler_fn.eval()
182
+
183
+ def bot_index_prediction(self, feature_top, texture_mask):
184
+ self.index_pred_guidance_encoder.eval()
185
+ self.index_pred_decoder.eval()
186
+
187
+ texture_tokens = F.interpolate(
188
+ texture_mask, (32, 16), mode='nearest').view(self.batch_size,
189
+ -1).long()
190
+
191
+ texture_mask_flatten = texture_tokens.view(-1)
192
+ min_encodings_indices_list = [
193
+ torch.full(
194
+ texture_mask_flatten.size(),
195
+ fill_value=-1,
196
+ dtype=torch.long,
197
+ device=texture_mask_flatten.device) for _ in range(18)
198
+ ]
199
+ with torch.no_grad():
200
+ feature_enc = self.index_pred_guidance_encoder(feature_top)
201
+ memory_logits_list = self.index_pred_decoder(feature_enc)
202
+ for codebook_idx, memory_logits in enumerate(memory_logits_list):
203
+ region_of_interest = texture_mask_flatten == codebook_idx
204
+ if torch.sum(region_of_interest) > 0:
205
+ memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
206
+ memory_indices_pred = memory_indices_pred
207
+ min_encodings_indices_list[codebook_idx][
208
+ region_of_interest] = memory_indices_pred[
209
+ region_of_interest]
210
+ min_encodings_indices_return_list = [
211
+ min_encodings_indices.view((1, 32, 16))
212
+ for min_encodings_indices in min_encodings_indices_list
213
+ ]
214
+
215
+ return min_encodings_indices_return_list
216
+
217
+ def sample_and_refine(self, save_dir=None, img_name=None):
218
+ # sample 32x16 features indices
219
+ sampled_top_indices_list = self.sample_fn(
220
+ temp=1, sample_steps=self.sample_steps)
221
+
222
+ for sample_idx in range(self.batch_size):
223
+ sample_indices = [
224
+ sampled_indices_cur[sample_idx:sample_idx + 1]
225
+ for sampled_indices_cur in sampled_top_indices_list
226
+ ]
227
+ top_quant = self.top_quantize.get_codebook_entry(
228
+ sample_indices, self.texture_mask[sample_idx:sample_idx + 1],
229
+ (sample_indices[0].size(0), self.shape[0], self.shape[1],
230
+ self.opt["top_z_channels"]))
231
+
232
+ top_quant = self.top_post_quant_conv(top_quant)
233
+
234
+ bot_indices_list = self.bot_index_prediction(
235
+ top_quant, self.texture_mask[sample_idx:sample_idx + 1])
236
+
237
+ quant_bot = self.bot_quantize.get_codebook_entry(
238
+ bot_indices_list, self.texture_mask[sample_idx:sample_idx + 1],
239
+ (bot_indices_list[0].size(0), bot_indices_list[0].size(1),
240
+ bot_indices_list[0].size(2),
241
+ self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
242
+ quant_bot = self.bot_post_quant_conv(quant_bot)
243
+ bot_dec_res = self.bot_decoder_res(quant_bot)
244
+
245
+ dec = self.decoder(top_quant, bot_h=bot_dec_res)
246
+
247
+ dec = ((dec + 1) / 2)
248
+ dec = dec.clamp_(0, 1)
249
+ if save_dir is None and img_name is None:
250
+ return dec
251
+ else:
252
+ save_image(
253
+ dec,
254
+ f'{save_dir}/{img_name[sample_idx]}',
255
+ nrow=1,
256
+ padding=4)
257
+
258
+ def sample_fn(self, temp=1.0, sample_steps=None):
259
+ self.sampler_fn.eval()
260
+
261
+ x_t = torch.ones((self.batch_size, np.prod(self.shape)),
262
+ device=self.device).long() * self.mask_id
263
+ unmasked = torch.zeros_like(x_t, device=self.device).bool()
264
+ sample_steps = list(range(1, sample_steps + 1))
265
+
266
+ texture_tokens = F.interpolate(
267
+ self.texture_mask, (32, 16),
268
+ mode='nearest').view(self.batch_size, -1).long()
269
+
270
+ texture_mask_flatten = texture_tokens.view(-1)
271
+
272
+ # min_encodings_indices_list would be used to visualize the image
273
+ min_encodings_indices_list = [
274
+ torch.full(
275
+ texture_mask_flatten.size(),
276
+ fill_value=-1,
277
+ dtype=torch.long,
278
+ device=texture_mask_flatten.device) for _ in range(18)
279
+ ]
280
+
281
+ for t in reversed(sample_steps):
282
+ t = torch.full((self.batch_size, ),
283
+ t,
284
+ device=self.device,
285
+ dtype=torch.long)
286
+
287
+ # where to unmask
288
+ changes = torch.rand(
289
+ x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
290
+ # don't unmask somewhere already unmasked
291
+ changes = torch.bitwise_xor(changes,
292
+ torch.bitwise_and(changes, unmasked))
293
+ # update mask with changes
294
+ unmasked = torch.bitwise_or(unmasked, changes)
295
+
296
+ x_0_logits_list = self.sampler_fn(
297
+ x_t, self.segm_tokens, texture_tokens, t=t)
298
+
299
+ changes_flatten = changes.view(-1)
300
+ ori_shape = x_t.shape # [b, h*w]
301
+ x_t = x_t.view(-1) # [b*h*w]
302
+ for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
303
+ if torch.sum(texture_mask_flatten[changes_flatten] ==
304
+ codebook_idx) > 0:
305
+ # scale by temperature
306
+ x_0_logits = x_0_logits / temp
307
+ x_0_dist = dists.Categorical(logits=x_0_logits)
308
+ x_0_hat = x_0_dist.sample().long()
309
+ x_0_hat = x_0_hat.view(-1)
310
+
311
+ # only replace the changed indices with corresponding codebook_idx
312
+ changes_segm = torch.bitwise_and(
313
+ changes_flatten, texture_mask_flatten == codebook_idx)
314
+
315
+ # x_t would be the input to the transformer, so the index range should be continual one
316
+ x_t[changes_segm] = x_0_hat[
317
+ changes_segm] + 1024 * codebook_idx
318
+ min_encodings_indices_list[codebook_idx][
319
+ changes_segm] = x_0_hat[changes_segm]
320
+
321
+ x_t = x_t.view(ori_shape) # [b, h*w]
322
+
323
+ min_encodings_indices_return_list = [
324
+ min_encodings_indices.view(ori_shape)
325
+ for min_encodings_indices in min_encodings_indices_list
326
+ ]
327
+
328
+ self.sampler_fn.train()
329
+
330
+ return min_encodings_indices_return_list
331
+
332
+ @torch.no_grad()
333
+ def get_quantized_segm(self, segm):
334
+ segm_one_hot = F.one_hot(
335
+ segm.squeeze(1).long(),
336
+ num_classes=self.opt['segm_num_segm_classes']).permute(
337
+ 0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
338
+ encoded_segm_mask = self.segm_encoder(segm_one_hot)
339
+ encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
340
+ _, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
341
+
342
+ return segm_tokens
343
+
344
+
345
+ class SampleFromParsingModel(BaseSampleModel):
346
+ """SampleFromParsing model.
347
+ """
348
+
349
+ def feed_data(self, data):
350
+ self.segm = data['segm'].to(self.device)
351
+ self.texture_mask = data['texture_mask'].to(self.device)
352
+ self.batch_size = self.segm.size(0)
353
+
354
+ self.segm_tokens = self.get_quantized_segm(self.segm)
355
+ self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
356
+
357
+ def inference(self, data_loader, save_dir):
358
+ for _, data in enumerate(data_loader):
359
+ img_name = data['img_name']
360
+ self.feed_data(data)
361
+ with torch.no_grad():
362
+ self.sample_and_refine(save_dir, img_name)
363
+
364
+
365
+ class SampleFromPoseModel(BaseSampleModel):
366
+ """SampleFromPose model.
367
+ """
368
+
369
+ def __init__(self, opt):
370
+ super().__init__(opt)
371
+ # pose-to-parsing
372
+ self.shape_attr_embedder = ShapeAttrEmbedding(
373
+ dim=opt['shape_embedder_dim'],
374
+ out_dim=opt['shape_embedder_out_dim'],
375
+ cls_num_list=opt['shape_attr_class_num']).to(self.device)
376
+ self.shape_parsing_encoder = ShapeUNet(
377
+ in_channels=opt['shape_encoder_in_channels']).to(self.device)
378
+ self.shape_parsing_decoder = FCNHead(
379
+ in_channels=opt['shape_fc_in_channels'],
380
+ in_index=opt['shape_fc_in_index'],
381
+ channels=opt['shape_fc_channels'],
382
+ num_convs=opt['shape_fc_num_convs'],
383
+ concat_input=opt['shape_fc_concat_input'],
384
+ dropout_ratio=opt['shape_fc_dropout_ratio'],
385
+ num_classes=opt['shape_fc_num_classes'],
386
+ align_corners=opt['shape_fc_align_corners'],
387
+ ).to(self.device)
388
+ self.load_shape_generation_models()
389
+
390
+ self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
391
+ [250, 235, 215], [255, 250, 205], [211, 211, 211],
392
+ [70, 130, 180], [127, 255, 212], [0, 100, 0],
393
+ [50, 205, 50], [255, 255, 0], [245, 222, 179],
394
+ [255, 140, 0], [255, 0, 0], [16, 78, 139],
395
+ [144, 238, 144], [50, 205, 174], [50, 155, 250],
396
+ [160, 140, 88], [213, 140, 88], [90, 140, 90],
397
+ [185, 210, 205], [130, 165, 180], [225, 141, 151]]
398
+
399
+ def load_shape_generation_models(self):
400
+ checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
401
+
402
+ self.shape_attr_embedder.load_state_dict(
403
+ checkpoint['embedder'], strict=True)
404
+ self.shape_attr_embedder.eval()
405
+
406
+ self.shape_parsing_encoder.load_state_dict(
407
+ checkpoint['encoder'], strict=True)
408
+ self.shape_parsing_encoder.eval()
409
+
410
+ self.shape_parsing_decoder.load_state_dict(
411
+ checkpoint['decoder'], strict=True)
412
+ self.shape_parsing_decoder.eval()
413
+
414
+ def feed_data(self, data):
415
+ self.pose = data['densepose'].to(self.device)
416
+ self.batch_size = self.pose.size(0)
417
+
418
+ self.shape_attr = data['shape_attr'].to(self.device)
419
+ self.upper_fused_attr = data['upper_fused_attr'].to(self.device)
420
+ self.lower_fused_attr = data['lower_fused_attr'].to(self.device)
421
+ self.outer_fused_attr = data['outer_fused_attr'].to(self.device)
422
+
423
+ def inference(self, data_loader, save_dir):
424
+ for _, data in enumerate(data_loader):
425
+ img_name = data['img_name']
426
+ self.feed_data(data)
427
+ with torch.no_grad():
428
+ self.generate_parsing_map()
429
+ self.generate_quantized_segm()
430
+ self.generate_texture_map()
431
+ self.sample_and_refine(save_dir, img_name)
432
+
433
+ def generate_parsing_map(self):
434
+ with torch.no_grad():
435
+ attr_embedding = self.shape_attr_embedder(self.shape_attr)
436
+ pose_enc = self.shape_parsing_encoder(self.pose, attr_embedding)
437
+ seg_logits = self.shape_parsing_decoder(pose_enc)
438
+ self.segm = seg_logits.argmax(dim=1)
439
+ self.segm = self.segm.unsqueeze(1)
440
+
441
+ def generate_quantized_segm(self):
442
+ self.segm_tokens = self.get_quantized_segm(self.segm)
443
+ self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
444
+
445
+ def generate_texture_map(self):
446
+ upper_cls = [1., 4.]
447
+ lower_cls = [3., 5., 21.]
448
+ outer_cls = [2.]
449
+
450
+ mask_batch = []
451
+ for idx in range(self.batch_size):
452
+ mask = torch.zeros_like(self.segm[idx])
453
+ upper_fused_attr = self.upper_fused_attr[idx]
454
+ lower_fused_attr = self.lower_fused_attr[idx]
455
+ outer_fused_attr = self.outer_fused_attr[idx]
456
+ if upper_fused_attr != 17:
457
+ for cls in upper_cls:
458
+ mask[self.segm[idx] == cls] = upper_fused_attr + 1
459
+
460
+ if lower_fused_attr != 17:
461
+ for cls in lower_cls:
462
+ mask[self.segm[idx] == cls] = lower_fused_attr + 1
463
+
464
+ if outer_fused_attr != 17:
465
+ for cls in outer_cls:
466
+ mask[self.segm[idx] == cls] = outer_fused_attr + 1
467
+
468
+ mask_batch.append(mask)
469
+ self.texture_mask = torch.stack(mask_batch, dim=0).to(torch.float32)
470
+
471
+ def feed_pose_data(self, pose_img):
472
+ # for ui demo
473
+
474
+ self.pose = pose_img.to(self.device)
475
+ self.batch_size = self.pose.size(0)
476
+
477
+ def feed_shape_attributes(self, shape_attr):
478
+ # for ui demo
479
+
480
+ self.shape_attr = shape_attr.to(self.device)
481
+
482
+ def feed_texture_attributes(self, texture_attr):
483
+ # for ui demo
484
+
485
+ self.upper_fused_attr = texture_attr[0].unsqueeze(0).to(self.device)
486
+ self.lower_fused_attr = texture_attr[1].unsqueeze(0).to(self.device)
487
+ self.outer_fused_attr = texture_attr[2].unsqueeze(0).to(self.device)
488
+
489
+ def palette_result(self, result):
490
+
491
+ seg = result[0]
492
+ palette = np.array(self.palette)
493
+ assert palette.shape[1] == 3
494
+ assert len(palette.shape) == 2
495
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
496
+ for label, color in enumerate(palette):
497
+ color_seg[seg == label, :] = color
498
+ # convert to BGR
499
+ # color_seg = color_seg[..., ::-1]
500
+ return color_seg
Text2Human/models/transformer_model.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from collections import OrderedDict
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.distributions as dists
8
+ import torch.nn.functional as F
9
+ from torchvision.utils import save_image
10
+
11
+ from models.archs.transformer_arch import TransformerMultiHead
12
+ from models.archs.vqgan_arch import (Decoder, Encoder, VectorQuantizer,
13
+ VectorQuantizerTexture)
14
+
15
+ logger = logging.getLogger('base')
16
+
17
+
18
+ class TransformerTextureAwareModel():
19
+ """Texture-Aware Diffusion based Transformer model.
20
+ """
21
+
22
+ def __init__(self, opt):
23
+ self.opt = opt
24
+ self.device = torch.device('cuda')
25
+ self.is_train = opt['is_train']
26
+
27
+ # VQVAE for image
28
+ self.img_encoder = Encoder(
29
+ ch=opt['img_ch'],
30
+ num_res_blocks=opt['img_num_res_blocks'],
31
+ attn_resolutions=opt['img_attn_resolutions'],
32
+ ch_mult=opt['img_ch_mult'],
33
+ in_channels=opt['img_in_channels'],
34
+ resolution=opt['img_resolution'],
35
+ z_channels=opt['img_z_channels'],
36
+ double_z=opt['img_double_z'],
37
+ dropout=opt['img_dropout']).to(self.device)
38
+ self.img_decoder = Decoder(
39
+ in_channels=opt['img_in_channels'],
40
+ resolution=opt['img_resolution'],
41
+ z_channels=opt['img_z_channels'],
42
+ ch=opt['img_ch'],
43
+ out_ch=opt['img_out_ch'],
44
+ num_res_blocks=opt['img_num_res_blocks'],
45
+ attn_resolutions=opt['img_attn_resolutions'],
46
+ ch_mult=opt['img_ch_mult'],
47
+ dropout=opt['img_dropout'],
48
+ resamp_with_conv=True,
49
+ give_pre_end=False).to(self.device)
50
+ self.img_quantizer = VectorQuantizerTexture(
51
+ opt['img_n_embed'], opt['img_embed_dim'],
52
+ beta=0.25).to(self.device)
53
+ self.img_quant_conv = torch.nn.Conv2d(opt["img_z_channels"],
54
+ opt['img_embed_dim'],
55
+ 1).to(self.device)
56
+ self.img_post_quant_conv = torch.nn.Conv2d(opt['img_embed_dim'],
57
+ opt["img_z_channels"],
58
+ 1).to(self.device)
59
+ self.load_pretrained_image_vae()
60
+
61
+ # VAE for segmentation mask
62
+ self.segm_encoder = Encoder(
63
+ ch=opt['segm_ch'],
64
+ num_res_blocks=opt['segm_num_res_blocks'],
65
+ attn_resolutions=opt['segm_attn_resolutions'],
66
+ ch_mult=opt['segm_ch_mult'],
67
+ in_channels=opt['segm_in_channels'],
68
+ resolution=opt['segm_resolution'],
69
+ z_channels=opt['segm_z_channels'],
70
+ double_z=opt['segm_double_z'],
71
+ dropout=opt['segm_dropout']).to(self.device)
72
+ self.segm_quantizer = VectorQuantizer(
73
+ opt['segm_n_embed'],
74
+ opt['segm_embed_dim'],
75
+ beta=0.25,
76
+ sane_index_shape=True).to(self.device)
77
+ self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
78
+ opt['segm_embed_dim'],
79
+ 1).to(self.device)
80
+ self.load_pretrained_segm_vae()
81
+
82
+ # define sampler
83
+ self._denoise_fn = TransformerMultiHead(
84
+ codebook_size=opt['codebook_size'],
85
+ segm_codebook_size=opt['segm_codebook_size'],
86
+ texture_codebook_size=opt['texture_codebook_size'],
87
+ bert_n_emb=opt['bert_n_emb'],
88
+ bert_n_layers=opt['bert_n_layers'],
89
+ bert_n_head=opt['bert_n_head'],
90
+ block_size=opt['block_size'],
91
+ latent_shape=opt['latent_shape'],
92
+ embd_pdrop=opt['embd_pdrop'],
93
+ resid_pdrop=opt['resid_pdrop'],
94
+ attn_pdrop=opt['attn_pdrop'],
95
+ num_head=opt['num_head']).to(self.device)
96
+
97
+ self.num_classes = opt['codebook_size']
98
+ self.shape = tuple(opt['latent_shape'])
99
+ self.num_timesteps = 1000
100
+
101
+ self.mask_id = opt['codebook_size']
102
+ self.loss_type = opt['loss_type']
103
+ self.mask_schedule = opt['mask_schedule']
104
+
105
+ self.sample_steps = opt['sample_steps']
106
+
107
+ self.init_training_settings()
108
+
109
+ def load_pretrained_image_vae(self):
110
+ # load pretrained vqgan for segmentation mask
111
+ img_ae_checkpoint = torch.load(self.opt['img_ae_path'])
112
+ self.img_encoder.load_state_dict(
113
+ img_ae_checkpoint['encoder'], strict=True)
114
+ self.img_decoder.load_state_dict(
115
+ img_ae_checkpoint['decoder'], strict=True)
116
+ self.img_quantizer.load_state_dict(
117
+ img_ae_checkpoint['quantize'], strict=True)
118
+ self.img_quant_conv.load_state_dict(
119
+ img_ae_checkpoint['quant_conv'], strict=True)
120
+ self.img_post_quant_conv.load_state_dict(
121
+ img_ae_checkpoint['post_quant_conv'], strict=True)
122
+ self.img_encoder.eval()
123
+ self.img_decoder.eval()
124
+ self.img_quantizer.eval()
125
+ self.img_quant_conv.eval()
126
+ self.img_post_quant_conv.eval()
127
+
128
+ def load_pretrained_segm_vae(self):
129
+ # load pretrained vqgan for segmentation mask
130
+ segm_ae_checkpoint = torch.load(self.opt['segm_ae_path'])
131
+ self.segm_encoder.load_state_dict(
132
+ segm_ae_checkpoint['encoder'], strict=True)
133
+ self.segm_quantizer.load_state_dict(
134
+ segm_ae_checkpoint['quantize'], strict=True)
135
+ self.segm_quant_conv.load_state_dict(
136
+ segm_ae_checkpoint['quant_conv'], strict=True)
137
+ self.segm_encoder.eval()
138
+ self.segm_quantizer.eval()
139
+ self.segm_quant_conv.eval()
140
+
141
+ def init_training_settings(self):
142
+ optim_params = []
143
+ for v in self._denoise_fn.parameters():
144
+ if v.requires_grad:
145
+ optim_params.append(v)
146
+ # set up optimizer
147
+ self.optimizer = torch.optim.Adam(
148
+ optim_params,
149
+ self.opt['lr'],
150
+ weight_decay=self.opt['weight_decay'])
151
+ self.log_dict = OrderedDict()
152
+
153
+ @torch.no_grad()
154
+ def get_quantized_img(self, image, texture_mask):
155
+ encoded_img = self.img_encoder(image)
156
+ encoded_img = self.img_quant_conv(encoded_img)
157
+
158
+ # img_tokens_input is the continual index for the input of transformer
159
+ # img_tokens_gt_list is the index for 18 texture-aware codebooks respectively
160
+ _, _, [_, img_tokens_input, img_tokens_gt_list
161
+ ] = self.img_quantizer(encoded_img, texture_mask)
162
+
163
+ # reshape the tokens
164
+ b = image.size(0)
165
+ img_tokens_input = img_tokens_input.view(b, -1)
166
+ img_tokens_gt_return_list = [
167
+ img_tokens_gt.view(b, -1) for img_tokens_gt in img_tokens_gt_list
168
+ ]
169
+
170
+ return img_tokens_input, img_tokens_gt_return_list
171
+
172
+ @torch.no_grad()
173
+ def decode(self, quant):
174
+ quant = self.img_post_quant_conv(quant)
175
+ dec = self.img_decoder(quant)
176
+ return dec
177
+
178
+ @torch.no_grad()
179
+ def decode_image_indices(self, indices_list, texture_mask):
180
+ quant = self.img_quantizer.get_codebook_entry(
181
+ indices_list, texture_mask,
182
+ (indices_list[0].size(0), self.shape[0], self.shape[1],
183
+ self.opt["img_z_channels"]))
184
+ dec = self.decode(quant)
185
+
186
+ return dec
187
+
188
+ def sample_time(self, b, device, method='uniform'):
189
+ if method == 'importance':
190
+ if not (self.Lt_count > 10).all():
191
+ return self.sample_time(b, device, method='uniform')
192
+
193
+ Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
194
+ Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
195
+ pt_all = Lt_sqrt / Lt_sqrt.sum()
196
+
197
+ t = torch.multinomial(pt_all, num_samples=b, replacement=True)
198
+
199
+ pt = pt_all.gather(dim=0, index=t)
200
+
201
+ return t, pt
202
+
203
+ elif method == 'uniform':
204
+ t = torch.randint(
205
+ 1, self.num_timesteps + 1, (b, ), device=device).long()
206
+ pt = torch.ones_like(t).float() / self.num_timesteps
207
+ return t, pt
208
+
209
+ else:
210
+ raise ValueError
211
+
212
+ def q_sample(self, x_0, x_0_gt_list, t):
213
+ # samples q(x_t | x_0)
214
+ # randomly set token to mask with probability t/T
215
+ # x_t, x_0_ignore = x_0.clone(), x_0.clone()
216
+ x_t = x_0.clone()
217
+
218
+ mask = torch.rand_like(x_t.float()) < (
219
+ t.float().unsqueeze(-1) / self.num_timesteps)
220
+ x_t[mask] = self.mask_id
221
+ # x_0_ignore[torch.bitwise_not(mask)] = -1
222
+
223
+ # for every gt token list, we also need to do the mask
224
+ x_0_gt_ignore_list = []
225
+ for x_0_gt in x_0_gt_list:
226
+ x_0_gt_ignore = x_0_gt.clone()
227
+ x_0_gt_ignore[torch.bitwise_not(mask)] = -1
228
+ x_0_gt_ignore_list.append(x_0_gt_ignore)
229
+
230
+ return x_t, x_0_gt_ignore_list, mask
231
+
232
+ def _train_loss(self, x_0, x_0_gt_list):
233
+ b, device = x_0.size(0), x_0.device
234
+
235
+ # choose what time steps to compute loss at
236
+ t, pt = self.sample_time(b, device, 'uniform')
237
+
238
+ # make x noisy and denoise
239
+ if self.mask_schedule == 'random':
240
+ x_t, x_0_gt_ignore_list, mask = self.q_sample(
241
+ x_0=x_0, x_0_gt_list=x_0_gt_list, t=t)
242
+ else:
243
+ raise NotImplementedError
244
+
245
+ # sample p(x_0 | x_t)
246
+ x_0_hat_logits_list = self._denoise_fn(
247
+ x_t, self.segm_tokens, self.texture_tokens, t=t)
248
+
249
+ # Always compute ELBO for comparison purposes
250
+ cross_entropy_loss = 0
251
+ for x_0_hat_logits, x_0_gt_ignore in zip(x_0_hat_logits_list,
252
+ x_0_gt_ignore_list):
253
+ cross_entropy_loss += F.cross_entropy(
254
+ x_0_hat_logits.permute(0, 2, 1),
255
+ x_0_gt_ignore,
256
+ ignore_index=-1,
257
+ reduction='none').sum(1)
258
+ vb_loss = cross_entropy_loss / t
259
+ vb_loss = vb_loss / pt
260
+ vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel())
261
+ if self.loss_type == 'elbo':
262
+ loss = vb_loss
263
+ elif self.loss_type == 'mlm':
264
+ denom = mask.float().sum(1)
265
+ denom[denom == 0] = 1 # prevent divide by 0 errors.
266
+ loss = cross_entropy_loss / denom
267
+ elif self.loss_type == 'reweighted_elbo':
268
+ weight = (1 - (t / self.num_timesteps))
269
+ loss = weight * cross_entropy_loss
270
+ loss = loss / (math.log(2) * x_0.shape[1:].numel())
271
+ else:
272
+ raise ValueError
273
+
274
+ return loss.mean(), vb_loss.mean()
275
+
276
+ def feed_data(self, data):
277
+ self.image = data['image'].to(self.device)
278
+ self.segm = data['segm'].to(self.device)
279
+ self.texture_mask = data['texture_mask'].to(self.device)
280
+ self.input_indices, self.gt_indices_list = self.get_quantized_img(
281
+ self.image, self.texture_mask)
282
+
283
+ self.texture_tokens = F.interpolate(
284
+ self.texture_mask, size=self.shape,
285
+ mode='nearest').view(self.image.size(0), -1).long()
286
+
287
+ self.segm_tokens = self.get_quantized_segm(self.segm)
288
+ self.segm_tokens = self.segm_tokens.view(self.image.size(0), -1)
289
+
290
+ def optimize_parameters(self):
291
+ self._denoise_fn.train()
292
+
293
+ loss, vb_loss = self._train_loss(self.input_indices,
294
+ self.gt_indices_list)
295
+
296
+ self.optimizer.zero_grad()
297
+ loss.backward()
298
+ self.optimizer.step()
299
+
300
+ self.log_dict['loss'] = loss
301
+ self.log_dict['vb_loss'] = vb_loss
302
+
303
+ self._denoise_fn.eval()
304
+
305
+ @torch.no_grad()
306
+ def get_quantized_segm(self, segm):
307
+ segm_one_hot = F.one_hot(
308
+ segm.squeeze(1).long(),
309
+ num_classes=self.opt['segm_num_segm_classes']).permute(
310
+ 0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
311
+ encoded_segm_mask = self.segm_encoder(segm_one_hot)
312
+ encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
313
+ _, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
314
+
315
+ return segm_tokens
316
+
317
+ def sample_fn(self, temp=1.0, sample_steps=None):
318
+ self._denoise_fn.eval()
319
+
320
+ b, device = self.image.size(0), 'cuda'
321
+ x_t = torch.ones(
322
+ (b, np.prod(self.shape)), device=device).long() * self.mask_id
323
+ unmasked = torch.zeros_like(x_t, device=device).bool()
324
+ sample_steps = list(range(1, sample_steps + 1))
325
+
326
+ texture_mask_flatten = self.texture_tokens.view(-1)
327
+
328
+ # min_encodings_indices_list would be used to visualize the image
329
+ min_encodings_indices_list = [
330
+ torch.full(
331
+ texture_mask_flatten.size(),
332
+ fill_value=-1,
333
+ dtype=torch.long,
334
+ device=texture_mask_flatten.device) for _ in range(18)
335
+ ]
336
+
337
+ for t in reversed(sample_steps):
338
+ print(f'Sample timestep {t:4d}', end='\r')
339
+ t = torch.full((b, ), t, device=device, dtype=torch.long)
340
+
341
+ # where to unmask
342
+ changes = torch.rand(
343
+ x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
344
+ # don't unmask somewhere already unmasked
345
+ changes = torch.bitwise_xor(changes,
346
+ torch.bitwise_and(changes, unmasked))
347
+ # update mask with changes
348
+ unmasked = torch.bitwise_or(unmasked, changes)
349
+
350
+ x_0_logits_list = self._denoise_fn(
351
+ x_t, self.segm_tokens, self.texture_tokens, t=t)
352
+
353
+ changes_flatten = changes.view(-1)
354
+ ori_shape = x_t.shape # [b, h*w]
355
+ x_t = x_t.view(-1) # [b*h*w]
356
+ for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
357
+ if torch.sum(texture_mask_flatten[changes_flatten] ==
358
+ codebook_idx) > 0:
359
+ # scale by temperature
360
+ x_0_logits = x_0_logits / temp
361
+ x_0_dist = dists.Categorical(logits=x_0_logits)
362
+ x_0_hat = x_0_dist.sample().long()
363
+ x_0_hat = x_0_hat.view(-1)
364
+
365
+ # only replace the changed indices with corresponding codebook_idx
366
+ changes_segm = torch.bitwise_and(
367
+ changes_flatten, texture_mask_flatten == codebook_idx)
368
+
369
+ # x_t would be the input to the transformer, so the index range should be continual one
370
+ x_t[changes_segm] = x_0_hat[
371
+ changes_segm] + 1024 * codebook_idx
372
+ min_encodings_indices_list[codebook_idx][
373
+ changes_segm] = x_0_hat[changes_segm]
374
+
375
+ x_t = x_t.view(ori_shape) # [b, h*w]
376
+
377
+ min_encodings_indices_return_list = [
378
+ min_encodings_indices.view(ori_shape)
379
+ for min_encodings_indices in min_encodings_indices_list
380
+ ]
381
+
382
+ self._denoise_fn.train()
383
+
384
+ return min_encodings_indices_return_list
385
+
386
+ def get_vis(self, image, gt_indices, predicted_indices, texture_mask,
387
+ save_path):
388
+ # original image
389
+ ori_img = self.decode_image_indices(gt_indices, texture_mask)
390
+ # pred image
391
+ pred_img = self.decode_image_indices(predicted_indices, texture_mask)
392
+ img_cat = torch.cat([
393
+ image,
394
+ ori_img,
395
+ pred_img,
396
+ ], dim=3).detach()
397
+ img_cat = ((img_cat + 1) / 2)
398
+ img_cat = img_cat.clamp_(0, 1)
399
+ save_image(img_cat, save_path, nrow=1, padding=4)
400
+
401
+ def inference(self, data_loader, save_dir):
402
+ self._denoise_fn.eval()
403
+
404
+ for _, data in enumerate(data_loader):
405
+ img_name = data['img_name']
406
+ self.feed_data(data)
407
+ b = self.image.size(0)
408
+ with torch.no_grad():
409
+ sampled_indices_list = self.sample_fn(
410
+ temp=1, sample_steps=self.sample_steps)
411
+ for idx in range(b):
412
+ self.get_vis(self.image[idx:idx + 1], [
413
+ gt_indices[idx:idx + 1]
414
+ for gt_indices in self.gt_indices_list
415
+ ], [
416
+ sampled_indices[idx:idx + 1]
417
+ for sampled_indices in sampled_indices_list
418
+ ], self.texture_mask[idx:idx + 1],
419
+ f'{save_dir}/{img_name[idx]}')
420
+
421
+ self._denoise_fn.train()
422
+
423
+ def get_current_log(self):
424
+ return self.log_dict
425
+
426
+ def update_learning_rate(self, epoch, iters=None):
427
+ """Update learning rate.
428
+
429
+ Args:
430
+ current_iter (int): Current iteration.
431
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
432
+ Default: -1.
433
+ """
434
+ lr = self.optimizer.param_groups[0]['lr']
435
+
436
+ if self.opt['lr_decay'] == 'step':
437
+ lr = self.opt['lr'] * (
438
+ self.opt['gamma']**(epoch // self.opt['step']))
439
+ elif self.opt['lr_decay'] == 'cos':
440
+ lr = self.opt['lr'] * (
441
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
442
+ elif self.opt['lr_decay'] == 'linear':
443
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
444
+ elif self.opt['lr_decay'] == 'linear2exp':
445
+ if epoch < self.opt['turning_point'] + 1:
446
+ # learning rate decay as 95%
447
+ # at the turning point (1 / 95% = 1.0526)
448
+ lr = self.opt['lr'] * (
449
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
450
+ else:
451
+ lr *= self.opt['gamma']
452
+ elif self.opt['lr_decay'] == 'schedule':
453
+ if epoch in self.opt['schedule']:
454
+ lr *= self.opt['gamma']
455
+ elif self.opt['lr_decay'] == 'warm_up':
456
+ if iters <= self.opt['warmup_iters']:
457
+ lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters']
458
+ else:
459
+ lr = self.opt['lr']
460
+ else:
461
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
462
+ # set learning rate
463
+ for param_group in self.optimizer.param_groups:
464
+ param_group['lr'] = lr
465
+
466
+ return lr
467
+
468
+ def save_network(self, net, save_path):
469
+ """Save networks.
470
+
471
+ Args:
472
+ net (nn.Module): Network to be saved.
473
+ net_label (str): Network label.
474
+ current_iter (int): Current iter number.
475
+ """
476
+ state_dict = net.state_dict()
477
+ torch.save(state_dict, save_path)
478
+
479
+ def load_network(self):
480
+ checkpoint = torch.load(self.opt['pretrained_sampler'])
481
+ self._denoise_fn.load_state_dict(checkpoint, strict=True)
482
+ self._denoise_fn.eval()
Text2Human/models/vqgan_model.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ from collections import OrderedDict
4
+
5
+ sys.path.append('..')
6
+ import lpips
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torchvision.utils import save_image
10
+
11
+ from models.archs.vqgan_arch import (Decoder, Discriminator, Encoder,
12
+ VectorQuantizer, VectorQuantizerTexture)
13
+ from models.losses.segmentation_loss import BCELossWithQuant
14
+ from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
15
+ calculate_adaptive_weight, hinge_d_loss)
16
+
17
+
18
+ class VQModel():
19
+
20
+ def __init__(self, opt):
21
+ super().__init__()
22
+ self.opt = opt
23
+ self.device = torch.device('cuda')
24
+ self.encoder = Encoder(
25
+ ch=opt['ch'],
26
+ num_res_blocks=opt['num_res_blocks'],
27
+ attn_resolutions=opt['attn_resolutions'],
28
+ ch_mult=opt['ch_mult'],
29
+ in_channels=opt['in_channels'],
30
+ resolution=opt['resolution'],
31
+ z_channels=opt['z_channels'],
32
+ double_z=opt['double_z'],
33
+ dropout=opt['dropout']).to(self.device)
34
+ self.decoder = Decoder(
35
+ in_channels=opt['in_channels'],
36
+ resolution=opt['resolution'],
37
+ z_channels=opt['z_channels'],
38
+ ch=opt['ch'],
39
+ out_ch=opt['out_ch'],
40
+ num_res_blocks=opt['num_res_blocks'],
41
+ attn_resolutions=opt['attn_resolutions'],
42
+ ch_mult=opt['ch_mult'],
43
+ dropout=opt['dropout'],
44
+ resamp_with_conv=True,
45
+ give_pre_end=False).to(self.device)
46
+ self.quantize = VectorQuantizer(
47
+ opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
48
+ self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
49
+ 1).to(self.device)
50
+ self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
51
+ opt["z_channels"],
52
+ 1).to(self.device)
53
+
54
+ def init_training_settings(self):
55
+ self.loss = BCELossWithQuant()
56
+ self.log_dict = OrderedDict()
57
+ self.configure_optimizers()
58
+
59
+ def save_network(self, save_path):
60
+ """Save networks.
61
+
62
+ Args:
63
+ net (nn.Module): Network to be saved.
64
+ net_label (str): Network label.
65
+ current_iter (int): Current iter number.
66
+ """
67
+
68
+ save_dict = {}
69
+ save_dict['encoder'] = self.encoder.state_dict()
70
+ save_dict['decoder'] = self.decoder.state_dict()
71
+ save_dict['quantize'] = self.quantize.state_dict()
72
+ save_dict['quant_conv'] = self.quant_conv.state_dict()
73
+ save_dict['post_quant_conv'] = self.post_quant_conv.state_dict()
74
+ save_dict['discriminator'] = self.disc.state_dict()
75
+ torch.save(save_dict, save_path)
76
+
77
+ def load_network(self):
78
+ checkpoint = torch.load(self.opt['pretrained_models'])
79
+ self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
80
+ self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
81
+ self.quantize.load_state_dict(checkpoint['quantize'], strict=True)
82
+ self.quant_conv.load_state_dict(checkpoint['quant_conv'], strict=True)
83
+ self.post_quant_conv.load_state_dict(
84
+ checkpoint['post_quant_conv'], strict=True)
85
+
86
+ def optimize_parameters(self, data, current_iter):
87
+ self.encoder.train()
88
+ self.decoder.train()
89
+ self.quantize.train()
90
+ self.quant_conv.train()
91
+ self.post_quant_conv.train()
92
+
93
+ loss = self.training_step(data)
94
+ self.optimizer.zero_grad()
95
+ loss.backward()
96
+ self.optimizer.step()
97
+
98
+ def encode(self, x):
99
+ h = self.encoder(x)
100
+ h = self.quant_conv(h)
101
+ quant, emb_loss, info = self.quantize(h)
102
+ return quant, emb_loss, info
103
+
104
+ def decode(self, quant):
105
+ quant = self.post_quant_conv(quant)
106
+ dec = self.decoder(quant)
107
+ return dec
108
+
109
+ def decode_code(self, code_b):
110
+ quant_b = self.quantize.embed_code(code_b)
111
+ dec = self.decode(quant_b)
112
+ return dec
113
+
114
+ def forward_step(self, input):
115
+ quant, diff, _ = self.encode(input)
116
+ dec = self.decode(quant)
117
+ return dec, diff
118
+
119
+ def feed_data(self, data):
120
+ x = data['segm']
121
+ x = F.one_hot(x, num_classes=self.opt['num_segm_classes'])
122
+
123
+ if len(x.shape) == 3:
124
+ x = x[..., None]
125
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
126
+ return x.float().to(self.device)
127
+
128
+ def get_current_log(self):
129
+ return self.log_dict
130
+
131
+ def update_learning_rate(self, epoch):
132
+ """Update learning rate.
133
+
134
+ Args:
135
+ current_iter (int): Current iteration.
136
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
137
+ Default: -1.
138
+ """
139
+ lr = self.optimizer.param_groups[0]['lr']
140
+
141
+ if self.opt['lr_decay'] == 'step':
142
+ lr = self.opt['lr'] * (
143
+ self.opt['gamma']**(epoch // self.opt['step']))
144
+ elif self.opt['lr_decay'] == 'cos':
145
+ lr = self.opt['lr'] * (
146
+ 1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
147
+ elif self.opt['lr_decay'] == 'linear':
148
+ lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
149
+ elif self.opt['lr_decay'] == 'linear2exp':
150
+ if epoch < self.opt['turning_point'] + 1:
151
+ # learning rate decay as 95%
152
+ # at the turning point (1 / 95% = 1.0526)
153
+ lr = self.opt['lr'] * (
154
+ 1 - epoch / int(self.opt['turning_point'] * 1.0526))
155
+ else:
156
+ lr *= self.opt['gamma']
157
+ elif self.opt['lr_decay'] == 'schedule':
158
+ if epoch in self.opt['schedule']:
159
+ lr *= self.opt['gamma']
160
+ else:
161
+ raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
162
+ # set learning rate
163
+ for param_group in self.optimizer.param_groups:
164
+ param_group['lr'] = lr
165
+
166
+ return lr
167
+
168
+
169
+ class VQSegmentationModel(VQModel):
170
+
171
+ def __init__(self, opt):
172
+ super().__init__(opt)
173
+ self.colorize = torch.randn(3, opt['num_segm_classes'], 1,
174
+ 1).to(self.device)
175
+
176
+ self.init_training_settings()
177
+
178
+ def configure_optimizers(self):
179
+ self.optimizer = torch.optim.Adam(
180
+ list(self.encoder.parameters()) + list(self.decoder.parameters()) +
181
+ list(self.quantize.parameters()) +
182
+ list(self.quant_conv.parameters()) +
183
+ list(self.post_quant_conv.parameters()),
184
+ lr=self.opt['lr'],
185
+ betas=(0.5, 0.9))
186
+
187
+ def training_step(self, data):
188
+ x = self.feed_data(data)
189
+ xrec, qloss = self.forward_step(x)
190
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
191
+ self.log_dict.update(log_dict_ae)
192
+ return aeloss
193
+
194
+ def to_rgb(self, x):
195
+ x = F.conv2d(x, weight=self.colorize)
196
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
197
+ return x
198
+
199
+ @torch.no_grad()
200
+ def inference(self, data_loader, save_dir):
201
+ self.encoder.eval()
202
+ self.decoder.eval()
203
+ self.quantize.eval()
204
+ self.quant_conv.eval()
205
+ self.post_quant_conv.eval()
206
+
207
+ loss_total = 0
208
+ loss_bce = 0
209
+ loss_quant = 0
210
+ num = 0
211
+
212
+ for _, data in enumerate(data_loader):
213
+ img_name = data['img_name'][0]
214
+ x = self.feed_data(data)
215
+ xrec, qloss = self.forward_step(x)
216
+ _, log_dict_ae = self.loss(qloss, x, xrec, split="val")
217
+
218
+ loss_total += log_dict_ae['val/total_loss']
219
+ loss_bce += log_dict_ae['val/bce_loss']
220
+ loss_quant += log_dict_ae['val/quant_loss']
221
+
222
+ num += x.size(0)
223
+
224
+ if x.shape[1] > 3:
225
+ # colorize with random projection
226
+ assert xrec.shape[1] > 3
227
+ # convert logits to indices
228
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
229
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
230
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
231
+ x = self.to_rgb(x)
232
+ xrec = self.to_rgb(xrec)
233
+
234
+ img_cat = torch.cat([x, xrec], dim=3).detach()
235
+ img_cat = ((img_cat + 1) / 2)
236
+ img_cat = img_cat.clamp_(0, 1)
237
+ save_image(
238
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
239
+
240
+ return (loss_total / num).item(), (loss_bce /
241
+ num).item(), (loss_quant /
242
+ num).item()
243
+
244
+
245
+ class VQImageModel(VQModel):
246
+
247
+ def __init__(self, opt):
248
+ super().__init__(opt)
249
+ self.disc = Discriminator(
250
+ opt['n_channels'], opt['ndf'],
251
+ n_layers=opt['disc_layers']).to(self.device)
252
+ self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
253
+ self.perceptual_weight = opt['perceptual_weight']
254
+ self.disc_start_step = opt['disc_start_step']
255
+ self.disc_weight_max = opt['disc_weight_max']
256
+ self.diff_aug = opt['diff_aug']
257
+ self.policy = "color,translation"
258
+
259
+ self.disc.train()
260
+
261
+ self.init_training_settings()
262
+
263
+ def feed_data(self, data):
264
+ x = data['image']
265
+
266
+ return x.float().to(self.device)
267
+
268
+ def init_training_settings(self):
269
+ self.log_dict = OrderedDict()
270
+ self.configure_optimizers()
271
+
272
+ def configure_optimizers(self):
273
+ self.optimizer = torch.optim.Adam(
274
+ list(self.encoder.parameters()) + list(self.decoder.parameters()) +
275
+ list(self.quantize.parameters()) +
276
+ list(self.quant_conv.parameters()) +
277
+ list(self.post_quant_conv.parameters()),
278
+ lr=self.opt['lr'])
279
+
280
+ self.disc_optimizer = torch.optim.Adam(
281
+ self.disc.parameters(), lr=self.opt['lr'])
282
+
283
+ def training_step(self, data, step):
284
+ x = self.feed_data(data)
285
+ xrec, codebook_loss = self.forward_step(x)
286
+
287
+ # get recon/perceptual loss
288
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
289
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
290
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
291
+ nll_loss = torch.mean(nll_loss)
292
+
293
+ # augment for input to discriminator
294
+ if self.diff_aug:
295
+ xrec = DiffAugment(xrec, policy=self.policy)
296
+
297
+ # update generator
298
+ logits_fake = self.disc(xrec)
299
+ g_loss = -torch.mean(logits_fake)
300
+ last_layer = self.decoder.conv_out.weight
301
+ d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
302
+ self.disc_weight_max)
303
+ d_weight *= adopt_weight(1, step, self.disc_start_step)
304
+ loss = nll_loss + d_weight * g_loss + codebook_loss
305
+
306
+ self.log_dict["loss"] = loss
307
+ self.log_dict["l1"] = recon_loss.mean().item()
308
+ self.log_dict["perceptual"] = p_loss.mean().item()
309
+ self.log_dict["nll_loss"] = nll_loss.item()
310
+ self.log_dict["g_loss"] = g_loss.item()
311
+ self.log_dict["d_weight"] = d_weight
312
+ self.log_dict["codebook_loss"] = codebook_loss.item()
313
+
314
+ if step > self.disc_start_step:
315
+ if self.diff_aug:
316
+ logits_real = self.disc(
317
+ DiffAugment(x.contiguous().detach(), policy=self.policy))
318
+ else:
319
+ logits_real = self.disc(x.contiguous().detach())
320
+ logits_fake = self.disc(xrec.contiguous().detach(
321
+ )) # detach so that generator isn"t also updated
322
+ d_loss = hinge_d_loss(logits_real, logits_fake)
323
+ self.log_dict["d_loss"] = d_loss
324
+ else:
325
+ d_loss = None
326
+
327
+ return loss, d_loss
328
+
329
+ def optimize_parameters(self, data, step):
330
+ self.encoder.train()
331
+ self.decoder.train()
332
+ self.quantize.train()
333
+ self.quant_conv.train()
334
+ self.post_quant_conv.train()
335
+
336
+ loss, d_loss = self.training_step(data, step)
337
+ self.optimizer.zero_grad()
338
+ loss.backward()
339
+ self.optimizer.step()
340
+
341
+ if step > self.disc_start_step:
342
+ self.disc_optimizer.zero_grad()
343
+ d_loss.backward()
344
+ self.disc_optimizer.step()
345
+
346
+ @torch.no_grad()
347
+ def inference(self, data_loader, save_dir):
348
+ self.encoder.eval()
349
+ self.decoder.eval()
350
+ self.quantize.eval()
351
+ self.quant_conv.eval()
352
+ self.post_quant_conv.eval()
353
+
354
+ loss_total = 0
355
+ num = 0
356
+
357
+ for _, data in enumerate(data_loader):
358
+ img_name = data['img_name'][0]
359
+ x = self.feed_data(data)
360
+ xrec, _ = self.forward_step(x)
361
+
362
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
363
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
364
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
365
+ nll_loss = torch.mean(nll_loss)
366
+ loss_total += nll_loss
367
+
368
+ num += x.size(0)
369
+
370
+ if x.shape[1] > 3:
371
+ # colorize with random projection
372
+ assert xrec.shape[1] > 3
373
+ # convert logits to indices
374
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
375
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
376
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
377
+ x = self.to_rgb(x)
378
+ xrec = self.to_rgb(xrec)
379
+
380
+ img_cat = torch.cat([x, xrec], dim=3).detach()
381
+ img_cat = ((img_cat + 1) / 2)
382
+ img_cat = img_cat.clamp_(0, 1)
383
+ save_image(
384
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
385
+
386
+ return (loss_total / num).item()
387
+
388
+
389
+ class VQImageSegmTextureModel(VQImageModel):
390
+
391
+ def __init__(self, opt):
392
+ self.opt = opt
393
+ self.device = torch.device('cuda')
394
+ self.encoder = Encoder(
395
+ ch=opt['ch'],
396
+ num_res_blocks=opt['num_res_blocks'],
397
+ attn_resolutions=opt['attn_resolutions'],
398
+ ch_mult=opt['ch_mult'],
399
+ in_channels=opt['in_channels'],
400
+ resolution=opt['resolution'],
401
+ z_channels=opt['z_channels'],
402
+ double_z=opt['double_z'],
403
+ dropout=opt['dropout']).to(self.device)
404
+ self.decoder = Decoder(
405
+ in_channels=opt['in_channels'],
406
+ resolution=opt['resolution'],
407
+ z_channels=opt['z_channels'],
408
+ ch=opt['ch'],
409
+ out_ch=opt['out_ch'],
410
+ num_res_blocks=opt['num_res_blocks'],
411
+ attn_resolutions=opt['attn_resolutions'],
412
+ ch_mult=opt['ch_mult'],
413
+ dropout=opt['dropout'],
414
+ resamp_with_conv=True,
415
+ give_pre_end=False).to(self.device)
416
+ self.quantize = VectorQuantizerTexture(
417
+ opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
418
+ self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
419
+ 1).to(self.device)
420
+ self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
421
+ opt["z_channels"],
422
+ 1).to(self.device)
423
+
424
+ self.disc = Discriminator(
425
+ opt['n_channels'], opt['ndf'],
426
+ n_layers=opt['disc_layers']).to(self.device)
427
+ self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
428
+ self.perceptual_weight = opt['perceptual_weight']
429
+ self.disc_start_step = opt['disc_start_step']
430
+ self.disc_weight_max = opt['disc_weight_max']
431
+ self.diff_aug = opt['diff_aug']
432
+ self.policy = "color,translation"
433
+
434
+ self.disc.train()
435
+
436
+ self.init_training_settings()
437
+
438
+ def feed_data(self, data):
439
+ x = data['image'].float().to(self.device)
440
+ mask = data['texture_mask'].float().to(self.device)
441
+
442
+ return x, mask
443
+
444
+ def training_step(self, data, step):
445
+ x, mask = self.feed_data(data)
446
+ xrec, codebook_loss = self.forward_step(x, mask)
447
+
448
+ # get recon/perceptual loss
449
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
450
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
451
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
452
+ nll_loss = torch.mean(nll_loss)
453
+
454
+ # augment for input to discriminator
455
+ if self.diff_aug:
456
+ xrec = DiffAugment(xrec, policy=self.policy)
457
+
458
+ # update generator
459
+ logits_fake = self.disc(xrec)
460
+ g_loss = -torch.mean(logits_fake)
461
+ last_layer = self.decoder.conv_out.weight
462
+ d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
463
+ self.disc_weight_max)
464
+ d_weight *= adopt_weight(1, step, self.disc_start_step)
465
+ loss = nll_loss + d_weight * g_loss + codebook_loss
466
+
467
+ self.log_dict["loss"] = loss
468
+ self.log_dict["l1"] = recon_loss.mean().item()
469
+ self.log_dict["perceptual"] = p_loss.mean().item()
470
+ self.log_dict["nll_loss"] = nll_loss.item()
471
+ self.log_dict["g_loss"] = g_loss.item()
472
+ self.log_dict["d_weight"] = d_weight
473
+ self.log_dict["codebook_loss"] = codebook_loss.item()
474
+
475
+ if step > self.disc_start_step:
476
+ if self.diff_aug:
477
+ logits_real = self.disc(
478
+ DiffAugment(x.contiguous().detach(), policy=self.policy))
479
+ else:
480
+ logits_real = self.disc(x.contiguous().detach())
481
+ logits_fake = self.disc(xrec.contiguous().detach(
482
+ )) # detach so that generator isn"t also updated
483
+ d_loss = hinge_d_loss(logits_real, logits_fake)
484
+ self.log_dict["d_loss"] = d_loss
485
+ else:
486
+ d_loss = None
487
+
488
+ return loss, d_loss
489
+
490
+ @torch.no_grad()
491
+ def inference(self, data_loader, save_dir):
492
+ self.encoder.eval()
493
+ self.decoder.eval()
494
+ self.quantize.eval()
495
+ self.quant_conv.eval()
496
+ self.post_quant_conv.eval()
497
+
498
+ loss_total = 0
499
+ num = 0
500
+
501
+ for _, data in enumerate(data_loader):
502
+ img_name = data['img_name'][0]
503
+ x, mask = self.feed_data(data)
504
+ xrec, _ = self.forward_step(x, mask)
505
+
506
+ recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
507
+ p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
508
+ nll_loss = recon_loss + self.perceptual_weight * p_loss
509
+ nll_loss = torch.mean(nll_loss)
510
+ loss_total += nll_loss
511
+
512
+ num += x.size(0)
513
+
514
+ if x.shape[1] > 3:
515
+ # colorize with random projection
516
+ assert xrec.shape[1] > 3
517
+ # convert logits to indices
518
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
519
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
520
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
521
+ x = self.to_rgb(x)
522
+ xrec = self.to_rgb(xrec)
523
+
524
+ img_cat = torch.cat([x, xrec], dim=3).detach()
525
+ img_cat = ((img_cat + 1) / 2)
526
+ img_cat = img_cat.clamp_(0, 1)
527
+ save_image(
528
+ img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
529
+
530
+ return (loss_total / num).item()
531
+
532
+ def encode(self, x, mask):
533
+ h = self.encoder(x)
534
+ h = self.quant_conv(h)
535
+ quant, emb_loss, info = self.quantize(h, mask)
536
+ return quant, emb_loss, info
537
+
538
+ def decode(self, quant):
539
+ quant = self.post_quant_conv(quant)
540
+ dec = self.decoder(quant)
541
+ return dec
542
+
543
+ def decode_code(self, code_b):
544
+ quant_b = self.quantize.embed_code(code_b)
545
+ dec = self.decode(quant_b)
546
+ return dec
547
+
548
+ def forward_step(self, input, mask):
549
+ quant, diff, _ = self.encode(input, mask)
550
+ dec = self.decode(quant)
551
+ return dec, diff
Text2Human/sample_from_parsing.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os.path as osp
4
+ import random
5
+
6
+ import torch
7
+
8
+ from data.segm_attr_dataset import DeepFashionAttrSegmDataset
9
+ from models import create_model
10
+ from utils.logger import get_root_logger
11
+ from utils.options import dict2str, dict_to_nonedict, parse
12
+ from utils.util import make_exp_dirs, set_random_seed
13
+
14
+
15
+ def main():
16
+ # options
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
19
+ args = parser.parse_args()
20
+ opt = parse(args.opt, is_train=False)
21
+
22
+ # mkdir and loggers
23
+ make_exp_dirs(opt)
24
+ log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log")
25
+ logger = get_root_logger(
26
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
27
+ logger.info(dict2str(opt))
28
+
29
+ # convert to NoneDict, which returns None for missing keys
30
+ opt = dict_to_nonedict(opt)
31
+
32
+ # random seed
33
+ seed = opt['manual_seed']
34
+ if seed is None:
35
+ seed = random.randint(1, 10000)
36
+ logger.info(f'Random seed: {seed}')
37
+ set_random_seed(seed)
38
+
39
+ test_dataset = DeepFashionAttrSegmDataset(
40
+ img_dir=opt['test_img_dir'],
41
+ segm_dir=opt['segm_dir'],
42
+ pose_dir=opt['pose_dir'],
43
+ ann_dir=opt['test_ann_file'])
44
+ test_loader = torch.utils.data.DataLoader(
45
+ dataset=test_dataset, batch_size=4, shuffle=False)
46
+ logger.info(f'Number of test set: {len(test_dataset)}.')
47
+
48
+ model = create_model(opt)
49
+ _ = model.inference(test_loader, opt['path']['results_root'])
50
+
51
+
52
+ if __name__ == '__main__':
53
+ main()
Text2Human/sample_from_pose.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os.path as osp
4
+ import random
5
+
6
+ import torch
7
+
8
+ from data.pose_attr_dataset import DeepFashionAttrPoseDataset
9
+ from models import create_model
10
+ from utils.logger import get_root_logger
11
+ from utils.options import dict2str, dict_to_nonedict, parse
12
+ from utils.util import make_exp_dirs, set_random_seed
13
+
14
+
15
+ def main():
16
+ # options
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
19
+ args = parser.parse_args()
20
+ opt = parse(args.opt, is_train=False)
21
+
22
+ # mkdir and loggers
23
+ make_exp_dirs(opt)
24
+ log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log")
25
+ logger = get_root_logger(
26
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
27
+ logger.info(dict2str(opt))
28
+
29
+ # convert to NoneDict, which returns None for missing keys
30
+ opt = dict_to_nonedict(opt)
31
+
32
+ # random seed
33
+ seed = opt['manual_seed']
34
+ if seed is None:
35
+ seed = random.randint(1, 10000)
36
+ logger.info(f'Random seed: {seed}')
37
+ set_random_seed(seed)
38
+
39
+ test_dataset = DeepFashionAttrPoseDataset(
40
+ pose_dir=opt['pose_dir'],
41
+ texture_ann_dir=opt['texture_ann_file'],
42
+ shape_ann_path=opt['shape_ann_path'])
43
+ test_loader = torch.utils.data.DataLoader(
44
+ dataset=test_dataset, batch_size=4, shuffle=False)
45
+ logger.info(f'Number of test set: {len(test_dataset)}.')
46
+
47
+ model = create_model(opt)
48
+ _ = model.inference(test_loader, opt['path']['results_root'])
49
+
50
+
51
+ if __name__ == '__main__':
52
+ main()
Text2Human/train_index_prediction.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import os.path as osp
5
+ import random
6
+ import time
7
+
8
+ import torch
9
+
10
+ from data.segm_attr_dataset import DeepFashionAttrSegmDataset
11
+ from models import create_model
12
+ from utils.logger import MessageLogger, get_root_logger, init_tb_logger
13
+ from utils.options import dict2str, dict_to_nonedict, parse
14
+ from utils.util import make_exp_dirs
15
+
16
+
17
+ def main():
18
+ # options
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
21
+ args = parser.parse_args()
22
+ opt = parse(args.opt, is_train=True)
23
+
24
+ # mkdir and loggers
25
+ make_exp_dirs(opt)
26
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
27
+ logger = get_root_logger(
28
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
29
+ logger.info(dict2str(opt))
30
+ # initialize tensorboard logger
31
+ tb_logger = None
32
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
33
+ tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
34
+
35
+ # convert to NoneDict, which returns None for missing keys
36
+ opt = dict_to_nonedict(opt)
37
+
38
+ # set up data loader
39
+ train_dataset = DeepFashionAttrSegmDataset(
40
+ img_dir=opt['train_img_dir'],
41
+ segm_dir=opt['segm_dir'],
42
+ pose_dir=opt['pose_dir'],
43
+ ann_dir=opt['train_ann_file'],
44
+ xflip=True)
45
+ train_loader = torch.utils.data.DataLoader(
46
+ dataset=train_dataset,
47
+ batch_size=opt['batch_size'],
48
+ shuffle=True,
49
+ num_workers=opt['num_workers'],
50
+ drop_last=True)
51
+ logger.info(f'Number of train set: {len(train_dataset)}.')
52
+ opt['max_iters'] = opt['num_epochs'] * len(
53
+ train_dataset) // opt['batch_size']
54
+
55
+ val_dataset = DeepFashionAttrSegmDataset(
56
+ img_dir=opt['train_img_dir'],
57
+ segm_dir=opt['segm_dir'],
58
+ pose_dir=opt['pose_dir'],
59
+ ann_dir=opt['val_ann_file'])
60
+ val_loader = torch.utils.data.DataLoader(
61
+ dataset=val_dataset, batch_size=1, shuffle=False)
62
+ logger.info(f'Number of val set: {len(val_dataset)}.')
63
+
64
+ test_dataset = DeepFashionAttrSegmDataset(
65
+ img_dir=opt['test_img_dir'],
66
+ segm_dir=opt['segm_dir'],
67
+ pose_dir=opt['pose_dir'],
68
+ ann_dir=opt['test_ann_file'])
69
+ test_loader = torch.utils.data.DataLoader(
70
+ dataset=test_dataset, batch_size=1, shuffle=False)
71
+ logger.info(f'Number of test set: {len(test_dataset)}.')
72
+
73
+ current_iter = 0
74
+ best_epoch = None
75
+ best_acc = 0
76
+
77
+ model = create_model(opt)
78
+
79
+ data_time, iter_time = 0, 0
80
+ current_iter = 0
81
+
82
+ # create message logger (formatted outputs)
83
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
84
+
85
+ for epoch in range(opt['num_epochs']):
86
+ lr = model.update_learning_rate(epoch)
87
+
88
+ for _, batch_data in enumerate(train_loader):
89
+ data_time = time.time() - data_time
90
+
91
+ current_iter += 1
92
+
93
+ model.feed_data(batch_data)
94
+ model.optimize_parameters()
95
+
96
+ iter_time = time.time() - iter_time
97
+ if current_iter % opt['print_freq'] == 0:
98
+ log_vars = {'epoch': epoch, 'iter': current_iter}
99
+ log_vars.update({'lrs': [lr]})
100
+ log_vars.update({'time': iter_time, 'data_time': data_time})
101
+ log_vars.update(model.get_current_log())
102
+ msg_logger(log_vars)
103
+
104
+ data_time = time.time()
105
+ iter_time = time.time()
106
+
107
+ if epoch % opt['val_freq'] == 0:
108
+ save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
109
+ os.makedirs(save_dir, exist_ok=opt['debug'])
110
+ val_acc = model.inference(val_loader, save_dir)
111
+
112
+ save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
113
+ os.makedirs(save_dir, exist_ok=opt['debug'])
114
+ test_acc = model.inference(test_loader, save_dir)
115
+
116
+ logger.info(
117
+ f'Epoch: {epoch}, val_acc: {val_acc: .4f}, test_acc: {test_acc: .4f}.'
118
+ )
119
+
120
+ if test_acc > best_acc:
121
+ best_epoch = epoch
122
+ best_acc = test_acc
123
+
124
+ logger.info(f'Best epoch: {best_epoch}, '
125
+ f'Best test acc: {best_acc: .4f}.')
126
+
127
+ # save model
128
+ model.save_network(
129
+ f'{opt["path"]["models"]}/models_epoch{epoch}.pth')
130
+
131
+
132
+ if __name__ == '__main__':
133
+ main()
Text2Human/train_parsing_gen.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import os.path as osp
5
+ import random
6
+ import time
7
+
8
+ import torch
9
+
10
+ from data.parsing_generation_segm_attr_dataset import \
11
+ ParsingGenerationDeepFashionAttrSegmDataset
12
+ from models import create_model
13
+ from utils.logger import MessageLogger, get_root_logger, init_tb_logger
14
+ from utils.options import dict2str, dict_to_nonedict, parse
15
+ from utils.util import make_exp_dirs
16
+
17
+
18
+ def main():
19
+ # options
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
22
+ args = parser.parse_args()
23
+ opt = parse(args.opt, is_train=True)
24
+
25
+ # mkdir and loggers
26
+ make_exp_dirs(opt)
27
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
28
+ logger = get_root_logger(
29
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
30
+ logger.info(dict2str(opt))
31
+ # initialize tensorboard logger
32
+ tb_logger = None
33
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
34
+ tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
35
+
36
+ # convert to NoneDict, which returns None for missing keys
37
+ opt = dict_to_nonedict(opt)
38
+
39
+ # set up data loader
40
+ train_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
41
+ segm_dir=opt['segm_dir'],
42
+ pose_dir=opt['pose_dir'],
43
+ ann_file=opt['train_ann_file'])
44
+ train_loader = torch.utils.data.DataLoader(
45
+ dataset=train_dataset,
46
+ batch_size=opt['batch_size'],
47
+ shuffle=True,
48
+ num_workers=opt['num_workers'],
49
+ drop_last=True)
50
+ logger.info(f'Number of train set: {len(train_dataset)}.')
51
+ opt['max_iters'] = opt['num_epochs'] * len(
52
+ train_dataset) // opt['batch_size']
53
+
54
+ val_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
55
+ segm_dir=opt['segm_dir'],
56
+ pose_dir=opt['pose_dir'],
57
+ ann_file=opt['val_ann_file'])
58
+ val_loader = torch.utils.data.DataLoader(
59
+ dataset=val_dataset,
60
+ batch_size=1,
61
+ shuffle=False,
62
+ num_workers=opt['num_workers'])
63
+ logger.info(f'Number of val set: {len(val_dataset)}.')
64
+
65
+ test_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
66
+ segm_dir=opt['segm_dir'],
67
+ pose_dir=opt['pose_dir'],
68
+ ann_file=opt['test_ann_file'])
69
+ test_loader = torch.utils.data.DataLoader(
70
+ dataset=test_dataset,
71
+ batch_size=1,
72
+ shuffle=False,
73
+ num_workers=opt['num_workers'])
74
+ logger.info(f'Number of test set: {len(test_dataset)}.')
75
+
76
+ current_iter = 0
77
+ best_epoch = None
78
+ best_acc = 0
79
+
80
+ model = create_model(opt)
81
+
82
+ data_time, iter_time = 0, 0
83
+ current_iter = 0
84
+
85
+ # create message logger (formatted outputs)
86
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
87
+
88
+ for epoch in range(opt['num_epochs']):
89
+ lr = model.update_learning_rate(epoch)
90
+
91
+ for _, batch_data in enumerate(train_loader):
92
+ data_time = time.time() - data_time
93
+
94
+ current_iter += 1
95
+
96
+ model.feed_data(batch_data)
97
+ model.optimize_parameters()
98
+
99
+ iter_time = time.time() - iter_time
100
+ if current_iter % opt['print_freq'] == 0:
101
+ log_vars = {'epoch': epoch, 'iter': current_iter}
102
+ log_vars.update({'lrs': [lr]})
103
+ log_vars.update({'time': iter_time, 'data_time': data_time})
104
+ log_vars.update(model.get_current_log())
105
+ msg_logger(log_vars)
106
+
107
+ data_time = time.time()
108
+ iter_time = time.time()
109
+
110
+ if epoch % opt['val_freq'] == 0:
111
+ save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}'
112
+ os.makedirs(save_dir, exist_ok=opt['debug'])
113
+ val_acc = model.inference(val_loader, save_dir)
114
+
115
+ save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}'
116
+ os.makedirs(save_dir, exist_ok=opt['debug'])
117
+ test_acc = model.inference(test_loader, save_dir)
118
+
119
+ logger.info(f'Epoch: {epoch}, '
120
+ f'val_acc: {val_acc: .4f}, '
121
+ f'test_acc: {test_acc: .4f}.')
122
+
123
+ if test_acc > best_acc:
124
+ best_epoch = epoch
125
+ best_acc = test_acc
126
+
127
+ logger.info(f'Best epoch: {best_epoch}, '
128
+ f'Best test acc: {best_acc: .4f}.')
129
+
130
+ # save model
131
+ model.save_network(
132
+ f'{opt["path"]["models"]}/parsing_generation_epoch{epoch}.pth')
133
+
134
+
135
+ if __name__ == '__main__':
136
+ main()
Text2Human/train_parsing_token.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import os.path as osp
5
+ import random
6
+ import time
7
+
8
+ import torch
9
+
10
+ from data.mask_dataset import MaskDataset
11
+ from models import create_model
12
+ from utils.logger import MessageLogger, get_root_logger, init_tb_logger
13
+ from utils.options import dict2str, dict_to_nonedict, parse
14
+ from utils.util import make_exp_dirs
15
+
16
+
17
+ def main():
18
+ # options
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
21
+ args = parser.parse_args()
22
+ opt = parse(args.opt, is_train=True)
23
+
24
+ # mkdir and loggers
25
+ make_exp_dirs(opt)
26
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
27
+ logger = get_root_logger(
28
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
29
+ logger.info(dict2str(opt))
30
+ # initialize tensorboard logger
31
+ tb_logger = None
32
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
33
+ tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
34
+
35
+ # convert to NoneDict, which returns None for missing keys
36
+ opt = dict_to_nonedict(opt)
37
+
38
+ # set up data loader
39
+ train_dataset = MaskDataset(
40
+ segm_dir=opt['segm_dir'], ann_dir=opt['train_ann_file'], xflip=True)
41
+ train_loader = torch.utils.data.DataLoader(
42
+ dataset=train_dataset,
43
+ batch_size=opt['batch_size'],
44
+ shuffle=True,
45
+ num_workers=opt['num_workers'],
46
+ persistent_workers=True,
47
+ drop_last=True)
48
+ logger.info(f'Number of train set: {len(train_dataset)}.')
49
+ opt['max_iters'] = opt['num_epochs'] * len(
50
+ train_dataset) // opt['batch_size']
51
+
52
+ val_dataset = MaskDataset(
53
+ segm_dir=opt['segm_dir'], ann_dir=opt['val_ann_file'])
54
+ val_loader = torch.utils.data.DataLoader(
55
+ dataset=val_dataset, batch_size=1, shuffle=False)
56
+ logger.info(f'Number of val set: {len(val_dataset)}.')
57
+
58
+ test_dataset = MaskDataset(
59
+ segm_dir=opt['segm_dir'], ann_dir=opt['test_ann_file'])
60
+ test_loader = torch.utils.data.DataLoader(
61
+ dataset=test_dataset, batch_size=1, shuffle=False)
62
+ logger.info(f'Number of test set: {len(test_dataset)}.')
63
+
64
+ current_iter = 0
65
+ best_epoch = None
66
+ best_loss = 100000
67
+
68
+ model = create_model(opt)
69
+
70
+ data_time, iter_time = 0, 0
71
+ current_iter = 0
72
+
73
+ # create message logger (formatted outputs)
74
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
75
+
76
+ for epoch in range(opt['num_epochs']):
77
+ lr = model.update_learning_rate(epoch)
78
+
79
+ for _, batch_data in enumerate(train_loader):
80
+ data_time = time.time() - data_time
81
+
82
+ current_iter += 1
83
+
84
+ model.optimize_parameters(batch_data, current_iter)
85
+
86
+ iter_time = time.time() - iter_time
87
+ if current_iter % opt['print_freq'] == 0:
88
+ log_vars = {'epoch': epoch, 'iter': current_iter}
89
+ log_vars.update({'lrs': [lr]})
90
+ log_vars.update({'time': iter_time, 'data_time': data_time})
91
+ log_vars.update(model.get_current_log())
92
+ msg_logger(log_vars)
93
+
94
+ data_time = time.time()
95
+ iter_time = time.time()
96
+
97
+ if epoch % opt['val_freq'] == 0:
98
+ save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
99
+ os.makedirs(save_dir, exist_ok=opt['debug'])
100
+ val_loss_total, _, _ = model.inference(val_loader, save_dir)
101
+
102
+ save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
103
+ os.makedirs(save_dir, exist_ok=opt['debug'])
104
+ test_loss_total, _, _ = model.inference(test_loader, save_dir)
105
+
106
+ logger.info(f'Epoch: {epoch}, '
107
+ f'val_loss_total: {val_loss_total}, '
108
+ f'test_loss_total: {test_loss_total}.')
109
+
110
+ if test_loss_total < best_loss:
111
+ best_epoch = epoch
112
+ best_loss = test_loss_total
113
+
114
+ logger.info(f'Best epoch: {best_epoch}, '
115
+ f'Best test loss: {best_loss: .4f}.')
116
+
117
+ # save model
118
+ model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')
119
+
120
+
121
+ if __name__ == '__main__':
122
+ main()
Text2Human/train_sampler.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import os.path as osp
5
+ import random
6
+ import time
7
+
8
+ import torch
9
+
10
+ from data.segm_attr_dataset import DeepFashionAttrSegmDataset
11
+ from models import create_model
12
+ from utils.logger import MessageLogger, get_root_logger, init_tb_logger
13
+ from utils.options import dict2str, dict_to_nonedict, parse
14
+ from utils.util import make_exp_dirs
15
+
16
+
17
+ def main():
18
+ # options
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
21
+ args = parser.parse_args()
22
+ opt = parse(args.opt, is_train=True)
23
+
24
+ # mkdir and loggers
25
+ make_exp_dirs(opt)
26
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
27
+ logger = get_root_logger(
28
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
29
+ logger.info(dict2str(opt))
30
+ # initialize tensorboard logger
31
+ tb_logger = None
32
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
33
+ tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
34
+
35
+ # convert to NoneDict, which returns None for missing keys
36
+ opt = dict_to_nonedict(opt)
37
+
38
+ # set up data loader
39
+ train_dataset = DeepFashionAttrSegmDataset(
40
+ img_dir=opt['train_img_dir'],
41
+ segm_dir=opt['segm_dir'],
42
+ pose_dir=opt['pose_dir'],
43
+ ann_dir=opt['train_ann_file'],
44
+ xflip=True)
45
+ train_loader = torch.utils.data.DataLoader(
46
+ dataset=train_dataset,
47
+ batch_size=opt['batch_size'],
48
+ shuffle=True,
49
+ num_workers=opt['num_workers'],
50
+ persistent_workers=True,
51
+ drop_last=True)
52
+ logger.info(f'Number of train set: {len(train_dataset)}.')
53
+ opt['max_iters'] = opt['num_epochs'] * len(
54
+ train_dataset) // opt['batch_size']
55
+
56
+ val_dataset = DeepFashionAttrSegmDataset(
57
+ img_dir=opt['train_img_dir'],
58
+ segm_dir=opt['segm_dir'],
59
+ pose_dir=opt['pose_dir'],
60
+ ann_dir=opt['val_ann_file'])
61
+ val_loader = torch.utils.data.DataLoader(
62
+ dataset=val_dataset, batch_size=opt['batch_size'], shuffle=False)
63
+ logger.info(f'Number of val set: {len(val_dataset)}.')
64
+
65
+ test_dataset = DeepFashionAttrSegmDataset(
66
+ img_dir=opt['test_img_dir'],
67
+ segm_dir=opt['segm_dir'],
68
+ pose_dir=opt['pose_dir'],
69
+ ann_dir=opt['test_ann_file'])
70
+ test_loader = torch.utils.data.DataLoader(
71
+ dataset=test_dataset, batch_size=opt['batch_size'], shuffle=False)
72
+ logger.info(f'Number of test set: {len(test_dataset)}.')
73
+
74
+ current_iter = 0
75
+
76
+ model = create_model(opt)
77
+
78
+ data_time, iter_time = 0, 0
79
+ current_iter = 0
80
+
81
+ # create message logger (formatted outputs)
82
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
83
+
84
+ for epoch in range(opt['num_epochs']):
85
+ lr = model.update_learning_rate(epoch, current_iter)
86
+
87
+ for _, batch_data in enumerate(train_loader):
88
+ data_time = time.time() - data_time
89
+
90
+ current_iter += 1
91
+
92
+ model.feed_data(batch_data)
93
+ model.optimize_parameters()
94
+
95
+ iter_time = time.time() - iter_time
96
+ if current_iter % opt['print_freq'] == 0:
97
+ log_vars = {'epoch': epoch, 'iter': current_iter}
98
+ log_vars.update({'lrs': [lr]})
99
+ log_vars.update({'time': iter_time, 'data_time': data_time})
100
+ log_vars.update(model.get_current_log())
101
+ msg_logger(log_vars)
102
+
103
+ data_time = time.time()
104
+ iter_time = time.time()
105
+
106
+ if epoch % opt['val_freq'] == 0 and epoch != 0:
107
+ save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
108
+ os.makedirs(save_dir, exist_ok=opt['debug'])
109
+ model.inference(val_loader, save_dir)
110
+
111
+ save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
112
+ os.makedirs(save_dir, exist_ok=opt['debug'])
113
+ model.inference(test_loader, save_dir)
114
+
115
+ # save model
116
+ model.save_network(
117
+ model._denoise_fn,
118
+ f'{opt["path"]["models"]}/sampler_epoch{epoch}.pth')
119
+
120
+
121
+ if __name__ == '__main__':
122
+ main()
Text2Human/train_vqvae.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import os.path as osp
5
+ import random
6
+ import time
7
+
8
+ import torch
9
+
10
+ from data.segm_attr_dataset import DeepFashionAttrSegmDataset
11
+ from models import create_model
12
+ from utils.logger import MessageLogger, get_root_logger, init_tb_logger
13
+ from utils.options import dict2str, dict_to_nonedict, parse
14
+ from utils.util import make_exp_dirs
15
+
16
+
17
+ def main():
18
+ # options
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument('-opt', type=str, help='Path to option YAML file.')
21
+ args = parser.parse_args()
22
+ opt = parse(args.opt, is_train=True)
23
+
24
+ # mkdir and loggers
25
+ make_exp_dirs(opt)
26
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
27
+ logger = get_root_logger(
28
+ logger_name='base', log_level=logging.INFO, log_file=log_file)
29
+ logger.info(dict2str(opt))
30
+ # initialize tensorboard logger
31
+ tb_logger = None
32
+ if opt['use_tb_logger'] and 'debug' not in opt['name']:
33
+ tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
34
+
35
+ # convert to NoneDict, which returns None for missing keys
36
+ opt = dict_to_nonedict(opt)
37
+
38
+ # set up data loader
39
+ train_dataset = DeepFashionAttrSegmDataset(
40
+ img_dir=opt['train_img_dir'],
41
+ segm_dir=opt['segm_dir'],
42
+ pose_dir=opt['pose_dir'],
43
+ ann_dir=opt['train_ann_file'],
44
+ xflip=True)
45
+ train_loader = torch.utils.data.DataLoader(
46
+ dataset=train_dataset,
47
+ batch_size=opt['batch_size'],
48
+ shuffle=True,
49
+ num_workers=opt['num_workers'],
50
+ persistent_workers=True,
51
+ drop_last=True)
52
+ logger.info(f'Number of train set: {len(train_dataset)}.')
53
+ opt['max_iters'] = opt['num_epochs'] * len(
54
+ train_dataset) // opt['batch_size']
55
+
56
+ val_dataset = DeepFashionAttrSegmDataset(
57
+ img_dir=opt['train_img_dir'],
58
+ segm_dir=opt['segm_dir'],
59
+ pose_dir=opt['pose_dir'],
60
+ ann_dir=opt['val_ann_file'])
61
+ val_loader = torch.utils.data.DataLoader(
62
+ dataset=val_dataset, batch_size=1, shuffle=False)
63
+ logger.info(f'Number of val set: {len(val_dataset)}.')
64
+
65
+ test_dataset = DeepFashionAttrSegmDataset(
66
+ img_dir=opt['test_img_dir'],
67
+ segm_dir=opt['segm_dir'],
68
+ pose_dir=opt['pose_dir'],
69
+ ann_dir=opt['test_ann_file'])
70
+ test_loader = torch.utils.data.DataLoader(
71
+ dataset=test_dataset, batch_size=1, shuffle=False)
72
+ logger.info(f'Number of test set: {len(test_dataset)}.')
73
+
74
+ current_iter = 0
75
+ best_epoch = None
76
+ best_loss = 100000
77
+
78
+ model = create_model(opt)
79
+
80
+ data_time, iter_time = 0, 0
81
+ current_iter = 0
82
+
83
+ # create message logger (formatted outputs)
84
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
85
+
86
+ for epoch in range(opt['num_epochs']):
87
+ lr = model.update_learning_rate(epoch)
88
+
89
+ for _, batch_data in enumerate(train_loader):
90
+ data_time = time.time() - data_time
91
+
92
+ current_iter += 1
93
+
94
+ model.optimize_parameters(batch_data, current_iter)
95
+
96
+ iter_time = time.time() - iter_time
97
+ if current_iter % opt['print_freq'] == 0:
98
+ log_vars = {'epoch': epoch, 'iter': current_iter}
99
+ log_vars.update({'lrs': [lr]})
100
+ log_vars.update({'time': iter_time, 'data_time': data_time})
101
+ log_vars.update(model.get_current_log())
102
+ msg_logger(log_vars)
103
+
104
+ data_time = time.time()
105
+ iter_time = time.time()
106
+
107
+ if epoch % opt['val_freq'] == 0:
108
+ save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
109
+ os.makedirs(save_dir, exist_ok=opt['debug'])
110
+ val_loss_total = model.inference(val_loader, save_dir)
111
+
112
+ save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
113
+ os.makedirs(save_dir, exist_ok=opt['debug'])
114
+ test_loss_total = model.inference(test_loader, save_dir)
115
+
116
+ logger.info(f'Epoch: {epoch}, '
117
+ f'val_loss_total: {val_loss_total}, '
118
+ f'test_loss_total: {test_loss_total}.')
119
+
120
+ if test_loss_total < best_loss:
121
+ best_epoch = epoch
122
+ best_loss = test_loss_total
123
+
124
+ logger.info(f'Best epoch: {best_epoch}, '
125
+ f'Best test loss: {best_loss: .4f}.')
126
+
127
+ # save model
128
+ model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')
129
+
130
+
131
+ if __name__ == '__main__':
132
+ main()
Text2Human/ui/__init__.py ADDED
File without changes
Text2Human/ui/mouse_event.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+ from PyQt5.QtCore import *
5
+ from PyQt5.QtGui import *
6
+ from PyQt5.QtWidgets import *
7
+
8
+ color_list = [
9
+ QColor(0, 0, 0),
10
+ QColor(255, 250, 250),
11
+ QColor(220, 220, 220),
12
+ QColor(250, 235, 215),
13
+ QColor(255, 250, 205),
14
+ QColor(211, 211, 211),
15
+ QColor(70, 130, 180),
16
+ QColor(127, 255, 212),
17
+ QColor(0, 100, 0),
18
+ QColor(50, 205, 50),
19
+ QColor(255, 255, 0),
20
+ QColor(245, 222, 179),
21
+ QColor(255, 140, 0),
22
+ QColor(255, 0, 0),
23
+ QColor(16, 78, 139),
24
+ QColor(144, 238, 144),
25
+ QColor(50, 205, 174),
26
+ QColor(50, 155, 250),
27
+ QColor(160, 140, 88),
28
+ QColor(213, 140, 88),
29
+ QColor(90, 140, 90),
30
+ QColor(185, 210, 205),
31
+ QColor(130, 165, 180),
32
+ QColor(225, 141, 151)
33
+ ]
34
+
35
+
36
+ class GraphicsScene(QGraphicsScene):
37
+
38
+ def __init__(self, mode, size, parent=None):
39
+ QGraphicsScene.__init__(self, parent)
40
+ self.mode = mode
41
+ self.size = size
42
+ self.mouse_clicked = False
43
+ self.prev_pt = None
44
+
45
+ # self.masked_image = None
46
+
47
+ # save the points
48
+ self.mask_points = []
49
+ for i in range(len(color_list)):
50
+ self.mask_points.append([])
51
+
52
+ # save the size of points
53
+ self.size_points = []
54
+ for i in range(len(color_list)):
55
+ self.size_points.append([])
56
+
57
+ # save the history of edit
58
+ self.history = []
59
+
60
+ def reset(self):
61
+ # save the points
62
+ self.mask_points = []
63
+ for i in range(len(color_list)):
64
+ self.mask_points.append([])
65
+ # save the size of points
66
+ self.size_points = []
67
+ for i in range(len(color_list)):
68
+ self.size_points.append([])
69
+ # save the history of edit
70
+ self.history = []
71
+
72
+ self.mode = 0
73
+ self.prev_pt = None
74
+
75
+ def mousePressEvent(self, event):
76
+ self.mouse_clicked = True
77
+
78
+ def mouseReleaseEvent(self, event):
79
+ self.prev_pt = None
80
+ self.mouse_clicked = False
81
+
82
+ def mouseMoveEvent(self, event): # drawing
83
+ if self.mouse_clicked:
84
+ if self.prev_pt:
85
+ self.drawMask(self.prev_pt, event.scenePos(),
86
+ color_list[self.mode], self.size)
87
+ pts = {}
88
+ pts['prev'] = (int(self.prev_pt.x()), int(self.prev_pt.y()))
89
+ pts['curr'] = (int(event.scenePos().x()),
90
+ int(event.scenePos().y()))
91
+
92
+ self.size_points[self.mode].append(self.size)
93
+ self.mask_points[self.mode].append(pts)
94
+ self.history.append(self.mode)
95
+ self.prev_pt = event.scenePos()
96
+ else:
97
+ self.prev_pt = event.scenePos()
98
+
99
+ def drawMask(self, prev_pt, curr_pt, color, size):
100
+ lineItem = QGraphicsLineItem(QLineF(prev_pt, curr_pt))
101
+ lineItem.setPen(QPen(color, size, Qt.SolidLine)) # rect
102
+ self.addItem(lineItem)
103
+
104
+ def erase_prev_pt(self):
105
+ self.prev_pt = None
106
+
107
+ def reset_items(self):
108
+ for i in range(len(self.items())):
109
+ item = self.items()[0]
110
+ self.removeItem(item)
111
+
112
+ def undo(self):
113
+ if len(self.items()) > 1:
114
+ if len(self.items()) >= 9:
115
+ for i in range(8):
116
+ item = self.items()[0]
117
+ self.removeItem(item)
118
+ if self.history[-1] == self.mode:
119
+ self.mask_points[self.mode].pop()
120
+ self.size_points[self.mode].pop()
121
+ self.history.pop()
122
+ else:
123
+ for i in range(len(self.items()) - 1):
124
+ item = self.items()[0]
125
+ self.removeItem(item)
126
+ if self.history[-1] == self.mode:
127
+ self.mask_points[self.mode].pop()
128
+ self.size_points[self.mode].pop()
129
+ self.history.pop()
Text2Human/ui/ui.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PyQt5 import QtCore, QtGui, QtWidgets
2
+ from PyQt5.QtCore import *
3
+ from PyQt5.QtGui import *
4
+ from PyQt5.QtWidgets import *
5
+
6
+
7
+ class Ui_Form(object):
8
+
9
+ def setupUi(self, Form):
10
+ Form.setObjectName("Form")
11
+ Form.resize(1250, 670)
12
+
13
+ self.pushButton_2 = QtWidgets.QPushButton(Form)
14
+ self.pushButton_2.setGeometry(QtCore.QRect(20, 60, 97, 27))
15
+ self.pushButton_2.setObjectName("pushButton_2")
16
+
17
+ self.pushButton_6 = QtWidgets.QPushButton(Form)
18
+ self.pushButton_6.setGeometry(QtCore.QRect(20, 100, 97, 27))
19
+ self.pushButton_6.setObjectName("pushButton_6")
20
+
21
+ # Generate Parsing
22
+ self.pushButton_0 = QtWidgets.QPushButton(Form)
23
+ self.pushButton_0.setGeometry(QtCore.QRect(126, 60, 150, 27))
24
+ self.pushButton_0.setObjectName("pushButton_0")
25
+
26
+ # Generate Human
27
+ self.pushButton_1 = QtWidgets.QPushButton(Form)
28
+ self.pushButton_1.setGeometry(QtCore.QRect(126, 100, 150, 27))
29
+ self.pushButton_1.setObjectName("pushButton_1")
30
+
31
+ # shape text box
32
+ self.label_heading_1 = QtWidgets.QLabel(Form)
33
+ self.label_heading_1.setText('Describe the shape.')
34
+ self.label_heading_1.setObjectName("label_heading_1")
35
+ self.label_heading_1.setGeometry(QtCore.QRect(320, 20, 200, 20))
36
+
37
+ self.message_box_1 = QtWidgets.QLineEdit(Form)
38
+ self.message_box_1.setGeometry(QtCore.QRect(320, 50, 256, 80))
39
+ self.message_box_1.setObjectName("message_box_1")
40
+ self.message_box_1.setAlignment(Qt.AlignTop)
41
+
42
+ # texture text box
43
+ self.label_heading_2 = QtWidgets.QLabel(Form)
44
+ self.label_heading_2.setText('Describe the textures.')
45
+ self.label_heading_2.setObjectName("label_heading_2")
46
+ self.label_heading_2.setGeometry(QtCore.QRect(620, 20, 200, 20))
47
+
48
+ self.message_box_2 = QtWidgets.QLineEdit(Form)
49
+ self.message_box_2.setGeometry(QtCore.QRect(620, 50, 256, 80))
50
+ self.message_box_2.setObjectName("message_box_2")
51
+ self.message_box_2.setAlignment(Qt.AlignTop)
52
+
53
+ # title icon
54
+ self.title_icon = QtWidgets.QLabel(Form)
55
+ self.title_icon.setGeometry(QtCore.QRect(30, 10, 200, 50))
56
+ self.title_icon.setPixmap(
57
+ QtGui.QPixmap('./ui/icons/icon_title.png').scaledToWidth(200))
58
+
59
+ # palette icon
60
+ self.palette_icon = QtWidgets.QLabel(Form)
61
+ self.palette_icon.setGeometry(QtCore.QRect(950, 10, 256, 128))
62
+ self.palette_icon.setPixmap(
63
+ QtGui.QPixmap('./ui/icons/icon_palette.png').scaledToWidth(256))
64
+
65
+ # top
66
+ self.pushButton_8 = QtWidgets.QPushButton(' top', Form)
67
+ self.pushButton_8.setGeometry(QtCore.QRect(940, 120, 120, 27))
68
+ self.pushButton_8.setObjectName("pushButton_8")
69
+ self.pushButton_8.setStyleSheet(
70
+ "text-align: left; padding-left: 10px;")
71
+ self.pushButton_8.setIcon(QIcon('./ui/color_blocks/class_top.png'))
72
+ # skin
73
+ self.pushButton_9 = QtWidgets.QPushButton(' skin', Form)
74
+ self.pushButton_9.setGeometry(QtCore.QRect(940, 165, 120, 27))
75
+ self.pushButton_9.setObjectName("pushButton_9")
76
+ self.pushButton_9.setStyleSheet(
77
+ "text-align: left; padding-left: 10px;")
78
+ self.pushButton_9.setIcon(QIcon('./ui/color_blocks/class_skin.png'))
79
+ # outer
80
+ self.pushButton_10 = QtWidgets.QPushButton(' outer', Form)
81
+ self.pushButton_10.setGeometry(QtCore.QRect(940, 210, 120, 27))
82
+ self.pushButton_10.setObjectName("pushButton_10")
83
+ self.pushButton_10.setStyleSheet(
84
+ "text-align: left; padding-left: 10px;")
85
+ self.pushButton_10.setIcon(QIcon('./ui/color_blocks/class_outer.png'))
86
+ # face
87
+ self.pushButton_11 = QtWidgets.QPushButton(' face', Form)
88
+ self.pushButton_11.setGeometry(QtCore.QRect(940, 255, 120, 27))
89
+ self.pushButton_11.setObjectName("pushButton_11")
90
+ self.pushButton_11.setStyleSheet(
91
+ "text-align: left; padding-left: 10px;")
92
+ self.pushButton_11.setIcon(QIcon('./ui/color_blocks/class_face.png'))
93
+ # skirt
94
+ self.pushButton_12 = QtWidgets.QPushButton(' skirt', Form)
95
+ self.pushButton_12.setGeometry(QtCore.QRect(940, 300, 120, 27))
96
+ self.pushButton_12.setObjectName("pushButton_12")
97
+ self.pushButton_12.setStyleSheet(
98
+ "text-align: left; padding-left: 10px;")
99
+ self.pushButton_12.setIcon(QIcon('./ui/color_blocks/class_skirt.png'))
100
+ # hair
101
+ self.pushButton_13 = QtWidgets.QPushButton(' hair', Form)
102
+ self.pushButton_13.setGeometry(QtCore.QRect(940, 345, 120, 27))
103
+ self.pushButton_13.setObjectName("pushButton_13")
104
+ self.pushButton_13.setStyleSheet(
105
+ "text-align: left; padding-left: 10px;")
106
+ self.pushButton_13.setIcon(QIcon('./ui/color_blocks/class_hair.png'))
107
+ # dress
108
+ self.pushButton_14 = QtWidgets.QPushButton(' dress', Form)
109
+ self.pushButton_14.setGeometry(QtCore.QRect(940, 390, 120, 27))
110
+ self.pushButton_14.setObjectName("pushButton_14")
111
+ self.pushButton_14.setStyleSheet(
112
+ "text-align: left; padding-left: 10px;")
113
+ self.pushButton_14.setIcon(QIcon('./ui/color_blocks/class_dress.png'))
114
+ # headwear
115
+ self.pushButton_15 = QtWidgets.QPushButton(' headwear', Form)
116
+ self.pushButton_15.setGeometry(QtCore.QRect(940, 435, 120, 27))
117
+ self.pushButton_15.setObjectName("pushButton_15")
118
+ self.pushButton_15.setStyleSheet(
119
+ "text-align: left; padding-left: 10px;")
120
+ self.pushButton_15.setIcon(
121
+ QIcon('./ui/color_blocks/class_headwear.png'))
122
+ # pants
123
+ self.pushButton_16 = QtWidgets.QPushButton(' pants', Form)
124
+ self.pushButton_16.setGeometry(QtCore.QRect(940, 480, 120, 27))
125
+ self.pushButton_16.setObjectName("pushButton_16")
126
+ self.pushButton_16.setStyleSheet(
127
+ "text-align: left; padding-left: 10px;")
128
+ self.pushButton_16.setIcon(QIcon('./ui/color_blocks/class_pants.png'))
129
+ # eyeglasses
130
+ self.pushButton_17 = QtWidgets.QPushButton(' eyeglass', Form)
131
+ self.pushButton_17.setGeometry(QtCore.QRect(940, 525, 120, 27))
132
+ self.pushButton_17.setObjectName("pushButton_17")
133
+ self.pushButton_17.setStyleSheet(
134
+ "text-align: left; padding-left: 10px;")
135
+ self.pushButton_17.setIcon(
136
+ QIcon('./ui/color_blocks/class_eyeglass.png'))
137
+ # rompers
138
+ self.pushButton_18 = QtWidgets.QPushButton(' rompers', Form)
139
+ self.pushButton_18.setGeometry(QtCore.QRect(940, 570, 120, 27))
140
+ self.pushButton_18.setObjectName("pushButton_18")
141
+ self.pushButton_18.setStyleSheet(
142
+ "text-align: left; padding-left: 10px;")
143
+ self.pushButton_18.setIcon(
144
+ QIcon('./ui/color_blocks/class_rompers.png'))
145
+ # footwear
146
+ self.pushButton_19 = QtWidgets.QPushButton(' footwear', Form)
147
+ self.pushButton_19.setGeometry(QtCore.QRect(940, 615, 120, 27))
148
+ self.pushButton_19.setObjectName("pushButton_19")
149
+ self.pushButton_19.setStyleSheet(
150
+ "text-align: left; padding-left: 10px;")
151
+ self.pushButton_19.setIcon(
152
+ QIcon('./ui/color_blocks/class_footwear.png'))
153
+
154
+ # leggings
155
+ self.pushButton_20 = QtWidgets.QPushButton(' leggings', Form)
156
+ self.pushButton_20.setGeometry(QtCore.QRect(1100, 120, 120, 27))
157
+ self.pushButton_20.setObjectName("pushButton_10")
158
+ self.pushButton_20.setStyleSheet(
159
+ "text-align: left; padding-left: 10px;")
160
+ self.pushButton_20.setIcon(
161
+ QIcon('./ui/color_blocks/class_leggings.png'))
162
+
163
+ # ring
164
+ self.pushButton_21 = QtWidgets.QPushButton(' ring', Form)
165
+ self.pushButton_21.setGeometry(QtCore.QRect(1100, 165, 120, 27))
166
+ self.pushButton_21.setObjectName("pushButton_2`0`")
167
+ self.pushButton_21.setStyleSheet(
168
+ "text-align: left; padding-left: 10px;")
169
+ self.pushButton_21.setIcon(QIcon('./ui/color_blocks/class_ring.png'))
170
+
171
+ # belt
172
+ self.pushButton_22 = QtWidgets.QPushButton(' belt', Form)
173
+ self.pushButton_22.setGeometry(QtCore.QRect(1100, 210, 120, 27))
174
+ self.pushButton_22.setObjectName("pushButton_2`0`")
175
+ self.pushButton_22.setStyleSheet(
176
+ "text-align: left; padding-left: 10px;")
177
+ self.pushButton_22.setIcon(QIcon('./ui/color_blocks/class_belt.png'))
178
+
179
+ # neckwear
180
+ self.pushButton_23 = QtWidgets.QPushButton(' neckwear', Form)
181
+ self.pushButton_23.setGeometry(QtCore.QRect(1100, 255, 120, 27))
182
+ self.pushButton_23.setObjectName("pushButton_2`0`")
183
+ self.pushButton_23.setStyleSheet(
184
+ "text-align: left; padding-left: 10px;")
185
+ self.pushButton_23.setIcon(
186
+ QIcon('./ui/color_blocks/class_neckwear.png'))
187
+
188
+ # wrist
189
+ self.pushButton_24 = QtWidgets.QPushButton(' wrist', Form)
190
+ self.pushButton_24.setGeometry(QtCore.QRect(1100, 300, 120, 27))
191
+ self.pushButton_24.setObjectName("pushButton_2`0`")
192
+ self.pushButton_24.setStyleSheet(
193
+ "text-align: left; padding-left: 10px;")
194
+ self.pushButton_24.setIcon(QIcon('./ui/color_blocks/class_wrist.png'))
195
+
196
+ # socks
197
+ self.pushButton_25 = QtWidgets.QPushButton(' socks', Form)
198
+ self.pushButton_25.setGeometry(QtCore.QRect(1100, 345, 120, 27))
199
+ self.pushButton_25.setObjectName("pushButton_2`0`")
200
+ self.pushButton_25.setStyleSheet(
201
+ "text-align: left; padding-left: 10px;")
202
+ self.pushButton_25.setIcon(QIcon('./ui/color_blocks/class_socks.png'))
203
+
204
+ # tie
205
+ self.pushButton_26 = QtWidgets.QPushButton(' tie', Form)
206
+ self.pushButton_26.setGeometry(QtCore.QRect(1100, 390, 120, 27))
207
+ self.pushButton_26.setObjectName("pushButton_2`0`")
208
+ self.pushButton_26.setStyleSheet(
209
+ "text-align: left; padding-left: 10px;")
210
+ self.pushButton_26.setIcon(QIcon('./ui/color_blocks/class_tie.png'))
211
+
212
+ # earstuds
213
+ self.pushButton_27 = QtWidgets.QPushButton(' necklace', Form)
214
+ self.pushButton_27.setGeometry(QtCore.QRect(1100, 435, 120, 27))
215
+ self.pushButton_27.setObjectName("pushButton_2`0`")
216
+ self.pushButton_27.setStyleSheet(
217
+ "text-align: left; padding-left: 10px;")
218
+ self.pushButton_27.setIcon(
219
+ QIcon('./ui/color_blocks/class_necklace.png'))
220
+
221
+ # necklace
222
+ self.pushButton_28 = QtWidgets.QPushButton(' earstuds', Form)
223
+ self.pushButton_28.setGeometry(QtCore.QRect(1100, 480, 120, 27))
224
+ self.pushButton_28.setObjectName("pushButton_2`0`")
225
+ self.pushButton_28.setStyleSheet(
226
+ "text-align: left; padding-left: 10px;")
227
+ self.pushButton_28.setIcon(
228
+ QIcon('./ui/color_blocks/class_earstuds.png'))
229
+
230
+ # bag
231
+ self.pushButton_29 = QtWidgets.QPushButton(' bag', Form)
232
+ self.pushButton_29.setGeometry(QtCore.QRect(1100, 525, 120, 27))
233
+ self.pushButton_29.setObjectName("pushButton_2`0`")
234
+ self.pushButton_29.setStyleSheet(
235
+ "text-align: left; padding-left: 10px;")
236
+ self.pushButton_29.setIcon(QIcon('./ui/color_blocks/class_bag.png'))
237
+
238
+ # glove
239
+ self.pushButton_30 = QtWidgets.QPushButton(' glove', Form)
240
+ self.pushButton_30.setGeometry(QtCore.QRect(1100, 570, 120, 27))
241
+ self.pushButton_30.setObjectName("pushButton_2`0`")
242
+ self.pushButton_30.setStyleSheet(
243
+ "text-align: left; padding-left: 10px;")
244
+ self.pushButton_30.setIcon(QIcon('./ui/color_blocks/class_glove.png'))
245
+
246
+ # background
247
+ self.pushButton_31 = QtWidgets.QPushButton(' background', Form)
248
+ self.pushButton_31.setGeometry(QtCore.QRect(1100, 615, 120, 27))
249
+ self.pushButton_31.setObjectName("pushButton_2`0`")
250
+ self.pushButton_31.setStyleSheet(
251
+ "text-align: left; padding-left: 10px;")
252
+ self.pushButton_31.setIcon(QIcon('./ui/color_blocks/class_bg.png'))
253
+
254
+ self.graphicsView = QtWidgets.QGraphicsView(Form)
255
+ self.graphicsView.setGeometry(QtCore.QRect(20, 140, 256, 512))
256
+ self.graphicsView.setObjectName("graphicsView")
257
+ self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
258
+ self.graphicsView_2.setGeometry(QtCore.QRect(320, 140, 256, 512))
259
+ self.graphicsView_2.setObjectName("graphicsView_2")
260
+ self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
261
+ self.graphicsView_3.setGeometry(QtCore.QRect(620, 140, 256, 512))
262
+ self.graphicsView_3.setObjectName("graphicsView_3")
263
+
264
+ self.retranslateUi(Form)
265
+ self.pushButton_2.clicked.connect(Form.open_densepose)
266
+ self.pushButton_6.clicked.connect(Form.save_img)
267
+ self.pushButton_8.clicked.connect(Form.top_mode)
268
+ self.pushButton_9.clicked.connect(Form.skin_mode)
269
+ self.pushButton_10.clicked.connect(Form.outer_mode)
270
+ self.pushButton_11.clicked.connect(Form.face_mode)
271
+ self.pushButton_12.clicked.connect(Form.skirt_mode)
272
+ self.pushButton_13.clicked.connect(Form.hair_mode)
273
+ self.pushButton_14.clicked.connect(Form.dress_mode)
274
+ self.pushButton_15.clicked.connect(Form.headwear_mode)
275
+ self.pushButton_16.clicked.connect(Form.pants_mode)
276
+ self.pushButton_17.clicked.connect(Form.eyeglass_mode)
277
+ self.pushButton_18.clicked.connect(Form.rompers_mode)
278
+ self.pushButton_19.clicked.connect(Form.footwear_mode)
279
+ self.pushButton_20.clicked.connect(Form.leggings_mode)
280
+ self.pushButton_21.clicked.connect(Form.ring_mode)
281
+ self.pushButton_22.clicked.connect(Form.belt_mode)
282
+ self.pushButton_23.clicked.connect(Form.neckwear_mode)
283
+ self.pushButton_24.clicked.connect(Form.wrist_mode)
284
+ self.pushButton_25.clicked.connect(Form.socks_mode)
285
+ self.pushButton_26.clicked.connect(Form.tie_mode)
286
+ self.pushButton_27.clicked.connect(Form.earstuds_mode)
287
+ self.pushButton_28.clicked.connect(Form.necklace_mode)
288
+ self.pushButton_29.clicked.connect(Form.bag_mode)
289
+ self.pushButton_30.clicked.connect(Form.glove_mode)
290
+ self.pushButton_31.clicked.connect(Form.background_mode)
291
+ self.pushButton_0.clicked.connect(Form.generate_parsing)
292
+ self.pushButton_1.clicked.connect(Form.generate_human)
293
+
294
+ QtCore.QMetaObject.connectSlotsByName(Form)
295
+
296
+ def retranslateUi(self, Form):
297
+ _translate = QtCore.QCoreApplication.translate
298
+ Form.setWindowTitle(_translate("Form", "Text2Human"))
299
+ self.pushButton_2.setText(_translate("Form", "Load Pose"))
300
+ self.pushButton_6.setText(_translate("Form", "Save Image"))
301
+
302
+ self.pushButton_0.setText(_translate("Form", "Generate Parsing"))
303
+ self.pushButton_1.setText(_translate("Form", "Generate Human"))
304
+
305
+
306
+ if __name__ == "__main__":
307
+ import sys
308
+ app = QtWidgets.QApplication(sys.argv)
309
+ Form = QtWidgets.QWidget()
310
+ ui = Ui_Form()
311
+ ui.setupUi(Form)
312
+ Form.show()
313
+ sys.exit(app.exec_())
Text2Human/ui_demo.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from PyQt5.QtCore import *
8
+ from PyQt5.QtGui import *
9
+ from PyQt5.QtWidgets import *
10
+
11
+ from models.sample_model import SampleFromPoseModel
12
+ from ui.mouse_event import GraphicsScene
13
+ from ui.ui import Ui_Form
14
+ from utils.language_utils import (generate_shape_attributes,
15
+ generate_texture_attributes)
16
+ from utils.options import dict_to_nonedict, parse
17
+
18
+ color_list = [(0, 0, 0), (255, 250, 250), (220, 220, 220), (250, 235, 215),
19
+ (255, 250, 205), (211, 211, 211), (70, 130, 180),
20
+ (127, 255, 212), (0, 100, 0), (50, 205, 50), (255, 255, 0),
21
+ (245, 222, 179), (255, 140, 0), (255, 0, 0), (16, 78, 139),
22
+ (144, 238, 144), (50, 205, 174), (50, 155, 250), (160, 140, 88),
23
+ (213, 140, 88), (90, 140, 90), (185, 210, 205), (130, 165, 180),
24
+ (225, 141, 151)]
25
+
26
+
27
+ class Ex(QWidget, Ui_Form):
28
+
29
+ def __init__(self, opt):
30
+ super(Ex, self).__init__()
31
+ self.setupUi(self)
32
+ self.show()
33
+
34
+ self.output_img = None
35
+
36
+ self.mat_img = None
37
+
38
+ self.mode = 0
39
+ self.size = 6
40
+ self.mask = None
41
+ self.mask_m = None
42
+ self.img = None
43
+
44
+ # about UI
45
+ self.mouse_clicked = False
46
+ self.scene = QGraphicsScene()
47
+ self.graphicsView.setScene(self.scene)
48
+ self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
49
+ self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
50
+ self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
51
+
52
+ self.ref_scene = GraphicsScene(self.mode, self.size)
53
+ self.graphicsView_2.setScene(self.ref_scene)
54
+ self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft)
55
+ self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
56
+ self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
57
+
58
+ self.result_scene = QGraphicsScene()
59
+ self.graphicsView_3.setScene(self.result_scene)
60
+ self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft)
61
+ self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
62
+ self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
63
+
64
+ self.dlg = QColorDialog(self.graphicsView)
65
+ self.color = None
66
+
67
+ self.sample_model = SampleFromPoseModel(opt)
68
+
69
+ def open_densepose(self):
70
+ fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
71
+ QDir.currentPath())
72
+ if fileName:
73
+ image = QPixmap(fileName)
74
+ mat_img = Image.open(fileName)
75
+ self.pose_img = mat_img.copy()
76
+ if image.isNull():
77
+ QMessageBox.information(self, "Image Viewer",
78
+ "Cannot load %s." % fileName)
79
+ return
80
+ image = image.scaled(self.graphicsView.size(),
81
+ Qt.IgnoreAspectRatio)
82
+
83
+ if len(self.scene.items()) > 0:
84
+ self.scene.removeItem(self.scene.items()[-1])
85
+ self.scene.addPixmap(image)
86
+
87
+ self.ref_scene.clear()
88
+ self.result_scene.clear()
89
+
90
+ # load pose to model
91
+ self.pose_img = np.array(
92
+ self.pose_img.resize(
93
+ size=(256, 512),
94
+ resample=Image.LANCZOS))[:, :, 2:].transpose(
95
+ 2, 0, 1).astype(np.float32)
96
+ self.pose_img = self.pose_img / 12. - 1
97
+
98
+ self.pose_img = torch.from_numpy(self.pose_img).unsqueeze(1)
99
+
100
+ self.sample_model.feed_pose_data(self.pose_img)
101
+
102
+ def generate_parsing(self):
103
+ self.ref_scene.reset_items()
104
+ self.ref_scene.reset()
105
+
106
+ shape_texts = self.message_box_1.text()
107
+
108
+ shape_attributes = generate_shape_attributes(shape_texts)
109
+ shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
110
+ self.sample_model.feed_shape_attributes(shape_attributes)
111
+
112
+ self.sample_model.generate_parsing_map()
113
+ self.sample_model.generate_quantized_segm()
114
+
115
+ self.colored_segm = self.sample_model.palette_result(
116
+ self.sample_model.segm[0].cpu())
117
+
118
+ self.mask_m = cv2.cvtColor(
119
+ cv2.cvtColor(self.colored_segm, cv2.COLOR_RGB2BGR),
120
+ cv2.COLOR_BGR2RGB)
121
+
122
+ qim = QImage(self.colored_segm.data.tobytes(),
123
+ self.colored_segm.shape[1], self.colored_segm.shape[0],
124
+ QImage.Format_RGB888)
125
+
126
+ image = QPixmap.fromImage(qim)
127
+
128
+ image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
129
+
130
+ if len(self.ref_scene.items()) > 0:
131
+ self.ref_scene.removeItem(self.ref_scene.items()[-1])
132
+ self.ref_scene.addPixmap(image)
133
+
134
+ self.result_scene.clear()
135
+
136
+ def generate_human(self):
137
+ for i in range(24):
138
+ self.mask_m = self.make_mask(self.mask_m,
139
+ self.ref_scene.mask_points[i],
140
+ self.ref_scene.size_points[i],
141
+ color_list[i])
142
+
143
+ seg_map = np.full(self.mask_m.shape[:-1], -1)
144
+
145
+ # convert rgb to num
146
+ for index, color in enumerate(color_list):
147
+ seg_map[np.sum(self.mask_m == color, axis=2) == 3] = index
148
+ assert (seg_map != -1).all()
149
+
150
+ self.sample_model.segm = torch.from_numpy(seg_map).unsqueeze(
151
+ 0).unsqueeze(0).to(self.sample_model.device)
152
+ self.sample_model.generate_quantized_segm()
153
+
154
+ texture_texts = self.message_box_2.text()
155
+ texture_attributes = generate_texture_attributes(texture_texts)
156
+
157
+ texture_attributes = torch.LongTensor(texture_attributes)
158
+
159
+ self.sample_model.feed_texture_attributes(texture_attributes)
160
+
161
+ self.sample_model.generate_texture_map()
162
+ result = self.sample_model.sample_and_refine()
163
+ result = result.permute(0, 2, 3, 1)
164
+ result = result.detach().cpu().numpy()
165
+ result = result * 255
166
+
167
+ result = np.asarray(result[0, :, :, :], dtype=np.uint8)
168
+
169
+ self.output_img = result
170
+
171
+ qim = QImage(result.data.tobytes(), result.shape[1], result.shape[0],
172
+ QImage.Format_RGB888)
173
+ image = QPixmap.fromImage(qim)
174
+
175
+ image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
176
+
177
+ if len(self.result_scene.items()) > 0:
178
+ self.result_scene.removeItem(self.result_scene.items()[-1])
179
+ self.result_scene.addPixmap(image)
180
+
181
+ def top_mode(self):
182
+ self.ref_scene.mode = 1
183
+
184
+ def skin_mode(self):
185
+ self.ref_scene.mode = 15
186
+
187
+ def outer_mode(self):
188
+ self.ref_scene.mode = 2
189
+
190
+ def face_mode(self):
191
+ self.ref_scene.mode = 14
192
+
193
+ def skirt_mode(self):
194
+ self.ref_scene.mode = 3
195
+
196
+ def hair_mode(self):
197
+ self.ref_scene.mode = 13
198
+
199
+ def dress_mode(self):
200
+ self.ref_scene.mode = 4
201
+
202
+ def headwear_mode(self):
203
+ self.ref_scene.mode = 7
204
+
205
+ def pants_mode(self):
206
+ self.ref_scene.mode = 5
207
+
208
+ def eyeglass_mode(self):
209
+ self.ref_scene.mode = 8
210
+
211
+ def rompers_mode(self):
212
+ self.ref_scene.mode = 21
213
+
214
+ def footwear_mode(self):
215
+ self.ref_scene.mode = 11
216
+
217
+ def leggings_mode(self):
218
+ self.ref_scene.mode = 6
219
+
220
+ def ring_mode(self):
221
+ self.ref_scene.mode = 16
222
+
223
+ def belt_mode(self):
224
+ self.ref_scene.mode = 10
225
+
226
+ def neckwear_mode(self):
227
+ self.ref_scene.mode = 9
228
+
229
+ def wrist_mode(self):
230
+ self.ref_scene.mode = 17
231
+
232
+ def socks_mode(self):
233
+ self.ref_scene.mode = 18
234
+
235
+ def tie_mode(self):
236
+ self.ref_scene.mode = 23
237
+
238
+ def earstuds_mode(self):
239
+ self.ref_scene.mode = 22
240
+
241
+ def necklace_mode(self):
242
+ self.ref_scene.mode = 20
243
+
244
+ def bag_mode(self):
245
+ self.ref_scene.mode = 12
246
+
247
+ def glove_mode(self):
248
+ self.ref_scene.mode = 19
249
+
250
+ def background_mode(self):
251
+ self.ref_scene.mode = 0
252
+
253
+ def make_mask(self, mask, pts, sizes, color):
254
+ if len(pts) > 0:
255
+ for idx, pt in enumerate(pts):
256
+ cv2.line(mask, pt['prev'], pt['curr'], color, sizes[idx])
257
+ return mask
258
+
259
+ def save_img(self):
260
+ if type(self.output_img):
261
+ fileName, _ = QFileDialog.getSaveFileName(self, "Save File",
262
+ QDir.currentPath())
263
+ cv2.imwrite(fileName + '.png', self.output_img[:, :, ::-1])
264
+
265
+ def undo(self):
266
+ self.scene.undo()
267
+
268
+ def clear(self):
269
+
270
+ self.ref_scene.reset_items()
271
+ self.ref_scene.reset()
272
+
273
+ self.ref_scene.clear()
274
+
275
+ self.result_scene.clear()
276
+
277
+
278
+ if __name__ == '__main__':
279
+
280
+ app = QApplication(sys.argv)
281
+ opt = './configs/sample_from_pose.yml'
282
+ opt = parse(opt, is_train=False)
283
+ opt = dict_to_nonedict(opt)
284
+ ex = Ex(opt)
285
+ sys.exit(app.exec_())
Text2Human/ui_util/__init__.py ADDED
File without changes
Text2Human/ui_util/config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+
5
+ import yaml
6
+
7
+ logger = logging.getLogger()
8
+
9
+ class Config(object):
10
+ def __init__(self, filename=None):
11
+ assert os.path.exists(filename), "ERROR: Config File doesn't exist."
12
+ try:
13
+ with open(filename, 'r') as f:
14
+ self._cfg_dict = yaml.load(f)
15
+ # parent of IOError, OSError *and* WindowsError where available
16
+ except EnvironmentError:
17
+ logger.error('Please check the file with name of "%s"', filename)
18
+ logger.info(' APP CONFIG '.center(80, '-'))
19
+ logger.info(''.center(80, '-'))
20
+
21
+ def __getattr__(self, name):
22
+ value = self._cfg_dict[name]
23
+ if isinstance(value, dict):
24
+ value = DictAsMember(value)
25
+ return value
Text2Human/utils/__init__.py ADDED
File without changes