Spaces:
Runtime error
Runtime error
yitianlian
commited on
Commit
·
24be7a2
1
Parent(s):
c96d8cf
update demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Text2Human/.gitignore +9 -0
- Text2Human/LICENSE +21 -0
- Text2Human/README.md +255 -0
- Text2Human/configs/index_pred_net.yml +84 -0
- Text2Human/configs/parsing_gen.yml +40 -0
- Text2Human/configs/parsing_token.yml +47 -0
- Text2Human/configs/sample_from_parsing.yml +93 -0
- Text2Human/configs/sample_from_pose.yml +107 -0
- Text2Human/configs/sampler.yml +83 -0
- Text2Human/configs/vqvae_bottom.yml +72 -0
- Text2Human/configs/vqvae_top.yml +53 -0
- Text2Human/data/__init__.py +0 -0
- Text2Human/data/mask_dataset.py +59 -0
- Text2Human/data/parsing_generation_segm_attr_dataset.py +80 -0
- Text2Human/data/pose_attr_dataset.py +109 -0
- Text2Human/data/segm_attr_dataset.py +167 -0
- Text2Human/environment/text2human_env.yaml +114 -0
- Text2Human/models/__init__.py +42 -0
- Text2Human/models/archs/__init__.py +0 -0
- Text2Human/models/archs/fcn_arch.py +418 -0
- Text2Human/models/archs/shape_attr_embedding_arch.py +35 -0
- Text2Human/models/archs/transformer_arch.py +273 -0
- Text2Human/models/archs/unet_arch.py +693 -0
- Text2Human/models/archs/vqgan_arch.py +1203 -0
- Text2Human/models/hierarchy_inference_model.py +363 -0
- Text2Human/models/hierarchy_vqgan_model.py +374 -0
- Text2Human/models/losses/__init__.py +0 -0
- Text2Human/models/losses/accuracy.py +46 -0
- Text2Human/models/losses/cross_entropy_loss.py +246 -0
- Text2Human/models/losses/segmentation_loss.py +25 -0
- Text2Human/models/losses/vqgan_loss.py +114 -0
- Text2Human/models/parsing_gen_model.py +220 -0
- Text2Human/models/sample_model.py +500 -0
- Text2Human/models/transformer_model.py +482 -0
- Text2Human/models/vqgan_model.py +551 -0
- Text2Human/sample_from_parsing.py +53 -0
- Text2Human/sample_from_pose.py +52 -0
- Text2Human/train_index_prediction.py +133 -0
- Text2Human/train_parsing_gen.py +136 -0
- Text2Human/train_parsing_token.py +122 -0
- Text2Human/train_sampler.py +122 -0
- Text2Human/train_vqvae.py +132 -0
- Text2Human/ui/__init__.py +0 -0
- Text2Human/ui/mouse_event.py +129 -0
- Text2Human/ui/ui.py +313 -0
- Text2Human/ui_demo.py +285 -0
- Text2Human/ui_util/__init__.py +0 -0
- Text2Human/ui_util/config.py +25 -0
- Text2Human/utils/__init__.py +0 -0
- Text2Human/utils/language_utils.py +315 -0
Text2Human/.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.cache/
|
3 |
+
datasets/*
|
4 |
+
experiments/*
|
5 |
+
tb_logger/*
|
6 |
+
results/*
|
7 |
+
*.png
|
8 |
+
*.txt
|
9 |
+
*.pth
|
Text2Human/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Yuming Jiang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Text2Human/README.md
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Text2Human - Official PyTorch Implementation
|
2 |
+
|
3 |
+
<!-- <img src="./doc_images/overview.jpg" width="96%" height="96%"> -->
|
4 |
+
|
5 |
+
This repository provides the official PyTorch implementation for the following paper:
|
6 |
+
|
7 |
+
**Text2Human: Text-Driven Controllable Human Image Generation**</br>
|
8 |
+
[Yuming Jiang](https://yumingj.github.io/), [Shuai Yang](https://williamyang1991.github.io/), [Haonan Qiu](http://haonanqiu.com/), [Wayne Wu](https://dblp.org/pid/50/8731.html), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) and [Ziwei Liu](https://liuziwei7.github.io/)</br>
|
9 |
+
In ACM Transactions on Graphics (Proceedings of SIGGRAPH), 2022.
|
10 |
+
|
11 |
+
From [MMLab@NTU](https://www.mmlab-ntu.com/index.html) affliated with S-Lab, Nanyang Technological University and SenseTime Research.
|
12 |
+
|
13 |
+
<table>
|
14 |
+
<tr>
|
15 |
+
<td><img src="assets/1.png" width="100%"/></td>
|
16 |
+
<td><img src="assets/2.png" width="100%"/></td>
|
17 |
+
<td><img src="assets/3.png" width="100%"/></td>
|
18 |
+
<td><img src="assets/4.png" width="100%"/></td>
|
19 |
+
</tr>
|
20 |
+
<tr>
|
21 |
+
<td align='center' width='24%'>The lady wears a short-sleeve T-shirt with pure color pattern, and a short and denim skirt.</td>
|
22 |
+
<td align='center' width='24%'>The man wears a long and floral shirt, and long pants with the pure color pattern.</td>
|
23 |
+
<td align='center' width='24%'>A lady is wearing a sleeveless pure-color shirt and long jeans</td>
|
24 |
+
<td align='center' width='24%'>The man wears a short-sleeve T-shirt with the pure color pattern and a short pants with the pure color pattern.</td>
|
25 |
+
<tr>
|
26 |
+
</table>
|
27 |
+
|
28 |
+
[**[Project Page]**](https://yumingj.github.io/projects/Text2Human.html) | [**[Paper]**](https://arxiv.org/pdf/2205.15996.pdf) | [**[Dataset]**](https://github.com/yumingj/DeepFashion-MultiModal) | [**[Demo Video]**](https://youtu.be/yKh4VORA_E0)
|
29 |
+
|
30 |
+
|
31 |
+
## Updates
|
32 |
+
|
33 |
+
- [05/2022] Paper and demo video are released.
|
34 |
+
- [05/2022] Code is released.
|
35 |
+
- [05/2022] This website is created.
|
36 |
+
|
37 |
+
## Installation
|
38 |
+
**Clone this repo:**
|
39 |
+
```bash
|
40 |
+
git clone https://github.com/yumingj/Text2Human.git
|
41 |
+
cd Text2Human
|
42 |
+
```
|
43 |
+
**Dependencies:**
|
44 |
+
|
45 |
+
All dependencies for defining the environment are provided in `environment/text2human_env.yaml`.
|
46 |
+
We recommend using [Anaconda](https://docs.anaconda.com/anaconda/install/) to manage the python environment:
|
47 |
+
```bash
|
48 |
+
conda env create -f ./environment/text2human_env.yaml
|
49 |
+
conda activate text2human
|
50 |
+
conda install -c huggingface tokenizers=0.9.4
|
51 |
+
conda install -c huggingface transformers=4.0.0
|
52 |
+
conda install -c conda-forge sentence-transformers=2.0.0
|
53 |
+
```
|
54 |
+
|
55 |
+
If it doesn't work, you may need to install the following packages on your own:
|
56 |
+
- Python 3.6
|
57 |
+
- PyTorch 1.7.1
|
58 |
+
- CUDA 10.1
|
59 |
+
- [sentence-transformers](https://huggingface.co/sentence-transformers) 2.0.0
|
60 |
+
- [tokenizers](https://pypi.org/project/tokenizers/) 0.9.4
|
61 |
+
- [transformers](https://huggingface.co/docs/transformers/installation) 4.0.0
|
62 |
+
|
63 |
+
## (1) Dataset Preparation
|
64 |
+
|
65 |
+
In this work, we contribute a large-scale high-quality dataset with rich multi-modal annotations named [DeepFashion-MultiModal](https://github.com/yumingj/DeepFashion-MultiModal) Dataset.
|
66 |
+
Here we pre-processed the raw annotations of the original dataset for the task of text-driven controllable human image generation. The pre-processing pipeline consists of:
|
67 |
+
- align the human body in the center of the images according to the human pose
|
68 |
+
- fuse the clothing color and clothing fabric annotations into one texture annotation
|
69 |
+
- do some annotation cleaning and image filtering
|
70 |
+
- split the whole dataset into the training set and testing set
|
71 |
+
|
72 |
+
You can download our processed dataset from this [Google Drive](https://drive.google.com/file/d/1KIoFfRZNQVn6RV_wTxG2wZmY8f2T_84B/view?usp=sharing). If you want to access the raw annotations, please refer to the [DeepFashion-MultiModal](https://github.com/yumingj/DeepFashion-MultiModal) Dataset.
|
73 |
+
|
74 |
+
After downloading the dataset, unzip the file and put them under the dataset folder with the following structure:
|
75 |
+
```
|
76 |
+
./datasets
|
77 |
+
├── train_images
|
78 |
+
├── xxx.png
|
79 |
+
...
|
80 |
+
├── xxx.png
|
81 |
+
└── xxx.png
|
82 |
+
├── test_images
|
83 |
+
% the same structure as in train_images
|
84 |
+
├── densepose
|
85 |
+
% the same structure as in train_images
|
86 |
+
├── segm
|
87 |
+
% the same structure as in train_images
|
88 |
+
├── shape_ann
|
89 |
+
├── test_ann_file.txt
|
90 |
+
├── train_ann_file.txt
|
91 |
+
└── val_ann_file.txt
|
92 |
+
└── texture_ann
|
93 |
+
├── test
|
94 |
+
├── lower_fused.txt
|
95 |
+
├── outer_fused.txt
|
96 |
+
└── upper_fused.txt
|
97 |
+
├── train
|
98 |
+
% the same files as in test
|
99 |
+
└── val
|
100 |
+
% the same files as in test
|
101 |
+
```
|
102 |
+
|
103 |
+
## (2) Sampling
|
104 |
+
|
105 |
+
### Inference Notebook
|
106 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a></br>
|
107 |
+
Coming soon.
|
108 |
+
|
109 |
+
|
110 |
+
### Pretrained Models
|
111 |
+
|
112 |
+
Pretrained models can be downloaded from this [Google Drive](https://drive.google.com/file/d/1VyI8_AbPwAUaZJPaPba8zxsFIWumlDen/view?usp=sharing). Unzip the file and put them under the dataset folder with the following structure:
|
113 |
+
```
|
114 |
+
pretrained_models
|
115 |
+
├── index_pred_net.pth
|
116 |
+
├── parsing_gen.pth
|
117 |
+
├── parsing_token.pth
|
118 |
+
├── sampler.pth
|
119 |
+
├── vqvae_bottom.pth
|
120 |
+
└── vqvae_top.pth
|
121 |
+
```
|
122 |
+
|
123 |
+
### Generation from Paring Maps
|
124 |
+
You can generate images from given parsing maps and pre-defined texture annotations:
|
125 |
+
```python
|
126 |
+
python sample_from_parsing.py -opt ./configs/sample_from_parsing.yml
|
127 |
+
```
|
128 |
+
The results are saved in the folder `./results/sampling_from_parsing`.
|
129 |
+
|
130 |
+
### Generation from Poses
|
131 |
+
You can generate images from given human poses and pre-defined clothing shape and texture annotations:
|
132 |
+
```python
|
133 |
+
python sample_from_pose.py -opt ./configs/sample_from_pose.yml
|
134 |
+
```
|
135 |
+
|
136 |
+
**Remarks**: The above two scripts generate images without language interactions. If you want to generate images using texts, you can use the notebook or our user interface.
|
137 |
+
|
138 |
+
### User Interface
|
139 |
+
|
140 |
+
```python
|
141 |
+
python ui_demo.py
|
142 |
+
```
|
143 |
+
<img src="./assets/ui.png" width="100%">
|
144 |
+
|
145 |
+
The descriptions for shapes should follow the following format:
|
146 |
+
```
|
147 |
+
<gender>, <sleeve length>, <length of lower clothing>, <outer clothing type>, <other accessories1>, ...
|
148 |
+
|
149 |
+
Note: The outer clothing type and accessories can be omitted.
|
150 |
+
|
151 |
+
Examples:
|
152 |
+
man, sleeveless T-shirt, long pants
|
153 |
+
woman, short-sleeve T-shirt, short jeans
|
154 |
+
```
|
155 |
+
|
156 |
+
The descriptions for textures should follow the following format:
|
157 |
+
```
|
158 |
+
<upper clothing texture>, <lower clothing texture>, <outer clothing texture>
|
159 |
+
|
160 |
+
Note: Currently, we only support 5 types of textures, i.e., pure color, stripe/spline, plaid/lattice,
|
161 |
+
floral, denim. Your inputs should be restricted to these textures.
|
162 |
+
```
|
163 |
+
|
164 |
+
## (3) Training Text2Human
|
165 |
+
|
166 |
+
### Stage I: Pose to Parsing
|
167 |
+
Train the parsing generation network. If you want to skip the training of this network, you can download our pretrained model from [here](https://drive.google.com/file/d/1MNyFLGqIQcOMg_HhgwCmKqdwfQSjeg_6/view?usp=sharing).
|
168 |
+
```python
|
169 |
+
python train_parsing_gen.py -opt ./configs/parsing_gen.yml
|
170 |
+
```
|
171 |
+
|
172 |
+
### Stage II: Parsing to Human
|
173 |
+
|
174 |
+
**Step 1: Train the top level of the hierarchical VQVAE.**
|
175 |
+
We provide our pretrained model [here](https://drive.google.com/file/d/1TwypUg85gPFJtMwBLUjVS66FKR3oaTz8/view?usp=sharing). This model is trained by:
|
176 |
+
```python
|
177 |
+
python train_vqvae.py -opt ./configs/vqvae_top.yml
|
178 |
+
```
|
179 |
+
|
180 |
+
**Step 2: Train the bottom level of the hierarchical VQVAE.**
|
181 |
+
We provide our pretrained model [here](https://drive.google.com/file/d/15hzbY-RG-ILgzUqqGC0qMzlS4OayPdRH/view?usp=sharing). This model is trained by:
|
182 |
+
```python
|
183 |
+
python train_vqvae.py -opt ./configs/vqvae_bottom.yml
|
184 |
+
```
|
185 |
+
|
186 |
+
**Stage 3 & 4: Train the sampler with mixture-of-experts.** To train the sampler, we first need to train a model to tokenize the parsing maps. You can access our pretrained parsing maps [here](https://drive.google.com/file/d/1GLHoOeCP6sMao1-R63ahJMJF7-J00uir/view?usp=sharing).
|
187 |
+
```python
|
188 |
+
python train_parsing_token.py -opt ./configs/parsing_token.yml
|
189 |
+
```
|
190 |
+
|
191 |
+
With the parsing tokenization model, the sampler is trained by:
|
192 |
+
```python
|
193 |
+
python train_sampler.py -opt ./configs/sampler.yml
|
194 |
+
```
|
195 |
+
Our pretrained sampler is provided [here](https://drive.google.com/file/d/1OQO_kG2fK7eKiG1VJH1OL782X71UQAmS/view?usp=sharing).
|
196 |
+
|
197 |
+
**Stage 5: Train the index prediction network.**
|
198 |
+
We provide our pretrained index prediction network [here](https://drive.google.com/file/d/1rqhkQD-JGd7YBeIfDvMV-vjfbNHpIhYm/view?usp=sharing). It is trained by:
|
199 |
+
```python
|
200 |
+
python train_index_prediction.py -opt ./configs/index_pred_net.yml
|
201 |
+
```
|
202 |
+
|
203 |
+
|
204 |
+
**Remarks**: In the config files, we use the path to our models as the required pretrained models. If you want to train the models from scratch, please replace the path to your own one. We set the numbers of the training epochs as large numbers and you can choose the best epoch for each model. For your reference, our pretrained parsing generation network is trained for 50 epochs, top-level VQVAE is trained for 135 epochs, bottom-level VQVAE is trained for 70 epochs, parsing tokenization network is trained for 20 epochs, sampler is trained for 95 epochs, and the index prediction network is trained for 70 epochs.
|
205 |
+
|
206 |
+
## (4) Results
|
207 |
+
|
208 |
+
Please visit our [Project Page](https://yumingj.github.io/projects/Text2Human.html#results) to view more results.</br>
|
209 |
+
You can select the attribtues to customize the desired human images.
|
210 |
+
[<img src="./assets/results.png" width="90%">
|
211 |
+
](https://yumingj.github.io/projects/Text2Human.html#results)
|
212 |
+
|
213 |
+
## DeepFashion-MultiModal Dataset
|
214 |
+
|
215 |
+
<img src="./assets/dataset_logo.png" width="90%">
|
216 |
+
|
217 |
+
In this work, we also propose **DeepFashion-MultiModal**, a large-scale high-quality human dataset with rich multi-modal annotations. It has the following properties:
|
218 |
+
1. It contains 44,096 high-resolution human images, including 12,701 full body human images.
|
219 |
+
2. For each full body images, we **manually annotate** the human parsing labels of 24 classes.
|
220 |
+
3. For each full body images, we **manually annotate** the keypoints.
|
221 |
+
4. We extract DensePose for each human image.
|
222 |
+
5. Each image is **manually annotated** with attributes for both clothes shapes and textures.
|
223 |
+
6. We provide a textual description for each image.
|
224 |
+
|
225 |
+
<img src="./assets/dataset_overview.png" width="100%">
|
226 |
+
|
227 |
+
Please refer to [this repo](https://github.com/yumingj/DeepFashion-MultiModal) for more details about our proposed dataset.
|
228 |
+
|
229 |
+
## TODO List
|
230 |
+
|
231 |
+
- [ ] Release 1024x512 version of Text2Human.
|
232 |
+
- [ ] Train the Text2Human using [SHHQ dataset](https://stylegan-human.github.io/).
|
233 |
+
|
234 |
+
## Citation
|
235 |
+
|
236 |
+
If you find this work useful for your research, please consider citing our paper:
|
237 |
+
|
238 |
+
```bibtex
|
239 |
+
@article{jiang2022text2human,
|
240 |
+
title={Text2Human: Text-Driven Controllable Human Image Generation},
|
241 |
+
author={Jiang, Yuming and Yang, Shuai and Qiu, Haonan and Wu, Wayne and Loy, Chen Change and Liu, Ziwei},
|
242 |
+
journal={ACM Transactions on Graphics (TOG)},
|
243 |
+
volume={41},
|
244 |
+
number={4},
|
245 |
+
articleno={162},
|
246 |
+
pages={1--11},
|
247 |
+
year={2022},
|
248 |
+
publisher={ACM New York, NY, USA},
|
249 |
+
doi={10.1145/3528223.3530104},
|
250 |
+
}
|
251 |
+
```
|
252 |
+
|
253 |
+
## Acknowledgments
|
254 |
+
|
255 |
+
Part of the code is borrowed from [unleashing-transformers](https://github.com/samb-t/unleashing-transformers), [taming-transformers](https://github.com/CompVis/taming-transformers) and [mmsegmentation](https://github.com/open-mmlab/mmsegmentation).
|
Text2Human/configs/index_pred_net.yml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: index_prediction_network
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
train_img_dir: ./datasets/train_images
|
10 |
+
test_img_dir: ./datasets/test_images
|
11 |
+
segm_dir: ./datasets/segm
|
12 |
+
pose_dir: ./datasets/densepose
|
13 |
+
train_ann_file: ./datasets/texture_ann/train
|
14 |
+
val_ann_file: ./datasets/texture_ann/val
|
15 |
+
test_ann_file: ./datasets/texture_ann/test
|
16 |
+
downsample_factor: 2
|
17 |
+
|
18 |
+
model_type: VQGANTextureAwareSpatialHierarchyInferenceModel
|
19 |
+
# network configs
|
20 |
+
embed_dim: 256
|
21 |
+
n_embed: 1024
|
22 |
+
codebook_spatial_size: 2
|
23 |
+
|
24 |
+
# bottom level vqvae
|
25 |
+
bot_n_embed: 512
|
26 |
+
bot_double_z: false
|
27 |
+
bot_z_channels: 256
|
28 |
+
bot_resolution: 512
|
29 |
+
bot_in_channels: 3
|
30 |
+
bot_out_ch: 3
|
31 |
+
bot_ch: 128
|
32 |
+
bot_ch_mult: [1, 1, 2, 4]
|
33 |
+
bot_num_res_blocks: 2
|
34 |
+
bot_attn_resolutions: [64]
|
35 |
+
bot_dropout: 0.0
|
36 |
+
bot_vae_path: ./pretrained_models/vqvae_bottom.pth
|
37 |
+
|
38 |
+
# top level vqgan
|
39 |
+
top_double_z: false
|
40 |
+
top_z_channels: 256
|
41 |
+
top_resolution: 512
|
42 |
+
top_in_channels: 3
|
43 |
+
top_out_ch: 3
|
44 |
+
top_ch: 128
|
45 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
46 |
+
top_num_res_blocks: 2
|
47 |
+
top_attn_resolutions: [32]
|
48 |
+
top_dropout: 0.0
|
49 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
50 |
+
|
51 |
+
# unet configs
|
52 |
+
encoder_in_channels: 256
|
53 |
+
fc_in_channels: 64
|
54 |
+
fc_in_index: 4
|
55 |
+
fc_channels: 64
|
56 |
+
fc_num_convs: 1
|
57 |
+
fc_concat_input: False
|
58 |
+
fc_dropout_ratio: 0.1
|
59 |
+
fc_num_classes: 512
|
60 |
+
fc_align_corners: False
|
61 |
+
|
62 |
+
disc_layers: 3
|
63 |
+
disc_weight_max: 1
|
64 |
+
disc_start_step: 30001
|
65 |
+
n_channels: 3
|
66 |
+
ndf: 64
|
67 |
+
nf: 128
|
68 |
+
perceptual_weight: 1.0
|
69 |
+
|
70 |
+
num_segm_classes: 24
|
71 |
+
|
72 |
+
# training configs
|
73 |
+
val_freq: 5
|
74 |
+
print_freq: 100
|
75 |
+
weight_decay: 0
|
76 |
+
manual_seed: 2021
|
77 |
+
num_epochs: 100
|
78 |
+
lr: !!float 1.0e-04
|
79 |
+
lr_decay: step
|
80 |
+
gamma: 1.0
|
81 |
+
step: 50
|
82 |
+
optimizer: Adam
|
83 |
+
loss_function: cross_entropy
|
84 |
+
|
Text2Human/configs/parsing_gen.yml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: parsing_generation
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 8
|
8 |
+
num_workers: 4
|
9 |
+
segm_dir: ./datasets/segm
|
10 |
+
pose_dir: ./datasets/densepose
|
11 |
+
train_ann_file: ./datasets/shape_ann/train_ann_file.txt
|
12 |
+
val_ann_file: ./datasets/shape_ann/val_ann_file.txt
|
13 |
+
test_ann_file: ./datasets/shape_ann/test_ann_file.txt
|
14 |
+
downsample_factor: 2
|
15 |
+
|
16 |
+
model_type: ParsingGenModel
|
17 |
+
# network configs
|
18 |
+
embedder_dim: 8
|
19 |
+
embedder_out_dim: 128
|
20 |
+
attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
|
21 |
+
encoder_in_channels: 1
|
22 |
+
fc_in_channels: 64
|
23 |
+
fc_in_index: 4
|
24 |
+
fc_channels: 64
|
25 |
+
fc_num_convs: 1
|
26 |
+
fc_concat_input: False
|
27 |
+
fc_dropout_ratio: 0.1
|
28 |
+
fc_num_classes: 24
|
29 |
+
fc_align_corners: False
|
30 |
+
|
31 |
+
# training configs
|
32 |
+
val_freq: 5
|
33 |
+
print_freq: 100
|
34 |
+
weight_decay: 0
|
35 |
+
manual_seed: 2021
|
36 |
+
num_epochs: 100
|
37 |
+
lr: !!float 1e-4
|
38 |
+
lr_decay: step
|
39 |
+
gamma: 0.1
|
40 |
+
step: 50
|
Text2Human/configs/parsing_token.yml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: parsing_tokenization
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
train_img_dir: ./datasets/train_images
|
10 |
+
test_img_dir: ./datasets/test_images
|
11 |
+
segm_dir: ./datasets/segm
|
12 |
+
pose_dir: ./datasets/densepose
|
13 |
+
train_ann_file: ./datasets/texture_ann/train
|
14 |
+
val_ann_file: ./datasets/texture_ann/val
|
15 |
+
test_ann_file: ./datasets/texture_ann/test
|
16 |
+
downsample_factor: 2
|
17 |
+
|
18 |
+
model_type: VQSegmentationModel
|
19 |
+
# network configs
|
20 |
+
embed_dim: 32
|
21 |
+
n_embed: 1024
|
22 |
+
image_key: "segmentation"
|
23 |
+
n_labels: 24
|
24 |
+
double_z: false
|
25 |
+
z_channels: 32
|
26 |
+
resolution: 512
|
27 |
+
in_channels: 24
|
28 |
+
out_ch: 24
|
29 |
+
ch: 64
|
30 |
+
ch_mult: [1, 1, 2, 2, 4]
|
31 |
+
num_res_blocks: 1
|
32 |
+
attn_resolutions: [16]
|
33 |
+
dropout: 0.0
|
34 |
+
|
35 |
+
num_segm_classes: 24
|
36 |
+
|
37 |
+
|
38 |
+
# training configs
|
39 |
+
val_freq: 5
|
40 |
+
print_freq: 100
|
41 |
+
weight_decay: 0
|
42 |
+
manual_seed: 2021
|
43 |
+
num_epochs: 100
|
44 |
+
lr: !!float 4.5e-05
|
45 |
+
lr_decay: step
|
46 |
+
gamma: 0.1
|
47 |
+
step: 50
|
Text2Human/configs/sample_from_parsing.yml
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: sample_from_parsing
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
test_img_dir: ./datasets/test_images
|
10 |
+
segm_dir: ./datasets/segm
|
11 |
+
pose_dir: ./datasets/densepose
|
12 |
+
test_ann_file: ./datasets/texture_ann/test
|
13 |
+
downsample_factor: 2
|
14 |
+
|
15 |
+
model_type: SampleFromParsingModel
|
16 |
+
# network configs
|
17 |
+
embed_dim: 256
|
18 |
+
n_embed: 1024
|
19 |
+
codebook_spatial_size: 2
|
20 |
+
|
21 |
+
# bottom level vqvae
|
22 |
+
bot_n_embed: 512
|
23 |
+
bot_codebook_spatial_size: 2
|
24 |
+
bot_double_z: false
|
25 |
+
bot_z_channels: 256
|
26 |
+
bot_resolution: 512
|
27 |
+
bot_in_channels: 3
|
28 |
+
bot_out_ch: 3
|
29 |
+
bot_ch: 128
|
30 |
+
bot_ch_mult: [1, 1, 2, 4]
|
31 |
+
bot_num_res_blocks: 2
|
32 |
+
bot_attn_resolutions: [64]
|
33 |
+
bot_dropout: 0.0
|
34 |
+
bot_vae_path: ./pretrained_models/vqvae_bottom.pth
|
35 |
+
|
36 |
+
# top level vqgan
|
37 |
+
top_double_z: false
|
38 |
+
top_z_channels: 256
|
39 |
+
top_resolution: 512
|
40 |
+
top_in_channels: 3
|
41 |
+
top_out_ch: 3
|
42 |
+
top_ch: 128
|
43 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
44 |
+
top_num_res_blocks: 2
|
45 |
+
top_attn_resolutions: [32]
|
46 |
+
top_dropout: 0.0
|
47 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
48 |
+
|
49 |
+
# unet configs
|
50 |
+
index_pred_encoder_in_channels: 256
|
51 |
+
index_pred_fc_in_channels: 64
|
52 |
+
index_pred_fc_in_index: 4
|
53 |
+
index_pred_fc_channels: 64
|
54 |
+
index_pred_fc_num_convs: 1
|
55 |
+
index_pred_fc_concat_input: False
|
56 |
+
index_pred_fc_dropout_ratio: 0.1
|
57 |
+
index_pred_fc_num_classes: 512
|
58 |
+
index_pred_fc_align_corners: False
|
59 |
+
pretrained_index_network: ./pretrained_models/index_pred_net.pth
|
60 |
+
|
61 |
+
# segmentation tokenization
|
62 |
+
segm_double_z: false
|
63 |
+
segm_z_channels: 32
|
64 |
+
segm_resolution: 512
|
65 |
+
segm_in_channels: 24
|
66 |
+
segm_out_ch: 24
|
67 |
+
segm_ch: 64
|
68 |
+
segm_ch_mult: [1, 1, 2, 2, 4]
|
69 |
+
segm_num_res_blocks: 1
|
70 |
+
segm_attn_resolutions: [16]
|
71 |
+
segm_dropout: 0.0
|
72 |
+
segm_num_segm_classes: 24
|
73 |
+
segm_n_embed: 1024
|
74 |
+
segm_embed_dim: 32
|
75 |
+
segm_token_path: ./pretrained_models/parsing_token.pth
|
76 |
+
|
77 |
+
# sampler configs
|
78 |
+
codebook_size: 18432
|
79 |
+
segm_codebook_size: 1024
|
80 |
+
texture_codebook_size: 18
|
81 |
+
bert_n_emb: 512
|
82 |
+
bert_n_layers: 24
|
83 |
+
bert_n_head: 8
|
84 |
+
block_size: 512 # 32 x 16
|
85 |
+
latent_shape: [32, 16]
|
86 |
+
embd_pdrop: 0.0
|
87 |
+
resid_pdrop: 0.0
|
88 |
+
attn_pdrop: 0.0
|
89 |
+
num_head: 18
|
90 |
+
pretrained_sampler: ./pretrained_models/sampler.pth
|
91 |
+
|
92 |
+
manual_seed: 2021
|
93 |
+
sample_steps: 256
|
Text2Human/configs/sample_from_pose.yml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: sample_from_pose
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
pose_dir: ./datasets/densepose
|
10 |
+
texture_ann_file: ./datasets/texture_ann/test
|
11 |
+
shape_ann_path: ./datasets/shape_ann/test_ann_file.txt
|
12 |
+
downsample_factor: 2
|
13 |
+
|
14 |
+
model_type: SampleFromPoseModel
|
15 |
+
# network configs
|
16 |
+
embed_dim: 256
|
17 |
+
n_embed: 1024
|
18 |
+
codebook_spatial_size: 2
|
19 |
+
|
20 |
+
# bottom level vqgan
|
21 |
+
bot_n_embed: 512
|
22 |
+
bot_codebook_spatial_size: 2
|
23 |
+
bot_double_z: false
|
24 |
+
bot_z_channels: 256
|
25 |
+
bot_resolution: 512
|
26 |
+
bot_in_channels: 3
|
27 |
+
bot_out_ch: 3
|
28 |
+
bot_ch: 128
|
29 |
+
bot_ch_mult: [1, 1, 2, 4]
|
30 |
+
bot_num_res_blocks: 2
|
31 |
+
bot_attn_resolutions: [64]
|
32 |
+
bot_dropout: 0.0
|
33 |
+
bot_vae_path: ./pretrained_models/vqvae_bottom.pth
|
34 |
+
|
35 |
+
# top level vqgan
|
36 |
+
top_double_z: false
|
37 |
+
top_z_channels: 256
|
38 |
+
top_resolution: 512
|
39 |
+
top_in_channels: 3
|
40 |
+
top_out_ch: 3
|
41 |
+
top_ch: 128
|
42 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
43 |
+
top_num_res_blocks: 2
|
44 |
+
top_attn_resolutions: [32]
|
45 |
+
top_dropout: 0.0
|
46 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
47 |
+
|
48 |
+
# unet configs
|
49 |
+
index_pred_encoder_in_channels: 256
|
50 |
+
index_pred_fc_in_channels: 64
|
51 |
+
index_pred_fc_in_index: 4
|
52 |
+
index_pred_fc_channels: 64
|
53 |
+
index_pred_fc_num_convs: 1
|
54 |
+
index_pred_fc_concat_input: False
|
55 |
+
index_pred_fc_dropout_ratio: 0.1
|
56 |
+
index_pred_fc_num_classes: 512
|
57 |
+
index_pred_fc_align_corners: False
|
58 |
+
pretrained_index_network: ./pretrained_models/index_pred_net.pth
|
59 |
+
|
60 |
+
# segmentation tokenization
|
61 |
+
segm_double_z: false
|
62 |
+
segm_z_channels: 32
|
63 |
+
segm_resolution: 512
|
64 |
+
segm_in_channels: 24
|
65 |
+
segm_out_ch: 24
|
66 |
+
segm_ch: 64
|
67 |
+
segm_ch_mult: [1, 1, 2, 2, 4]
|
68 |
+
segm_num_res_blocks: 1
|
69 |
+
segm_attn_resolutions: [16]
|
70 |
+
segm_dropout: 0.0
|
71 |
+
segm_num_segm_classes: 24
|
72 |
+
segm_n_embed: 1024
|
73 |
+
segm_embed_dim: 32
|
74 |
+
segm_token_path: ./pretrained_models/parsing_token.pth
|
75 |
+
|
76 |
+
# sampler configs
|
77 |
+
codebook_size: 18432
|
78 |
+
segm_codebook_size: 1024
|
79 |
+
texture_codebook_size: 18
|
80 |
+
bert_n_emb: 512
|
81 |
+
bert_n_layers: 24
|
82 |
+
bert_n_head: 8
|
83 |
+
block_size: 512 # 32 x 16
|
84 |
+
latent_shape: [32, 16]
|
85 |
+
embd_pdrop: 0.0
|
86 |
+
resid_pdrop: 0.0
|
87 |
+
attn_pdrop: 0.0
|
88 |
+
num_head: 18
|
89 |
+
pretrained_sampler: ./pretrained_models/sampler.pth
|
90 |
+
|
91 |
+
# shape network configs
|
92 |
+
shape_embedder_dim: 8
|
93 |
+
shape_embedder_out_dim: 128
|
94 |
+
shape_attr_class_num: [2, 4, 6, 5, 4, 3, 5, 5, 3, 2, 2, 2, 2, 2, 2]
|
95 |
+
shape_encoder_in_channels: 1
|
96 |
+
shape_fc_in_channels: 64
|
97 |
+
shape_fc_in_index: 4
|
98 |
+
shape_fc_channels: 64
|
99 |
+
shape_fc_num_convs: 1
|
100 |
+
shape_fc_concat_input: False
|
101 |
+
shape_fc_dropout_ratio: 0.1
|
102 |
+
shape_fc_num_classes: 24
|
103 |
+
shape_fc_align_corners: False
|
104 |
+
pretrained_parsing_gen: ./pretrained_models/parsing_gen.pth
|
105 |
+
|
106 |
+
manual_seed: 2021
|
107 |
+
sample_steps: 256
|
Text2Human/configs/sampler.yml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: sampler
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 1
|
9 |
+
train_img_dir: ./datasets/train_images
|
10 |
+
test_img_dir: ./datasets/test_images
|
11 |
+
segm_dir: ./datasets/segm
|
12 |
+
pose_dir: ./datasets/densepose
|
13 |
+
train_ann_file: ./datasets/texture_ann/train
|
14 |
+
val_ann_file: ./datasets/texture_ann/val
|
15 |
+
test_ann_file: ./datasets/texture_ann/test
|
16 |
+
downsample_factor: 2
|
17 |
+
|
18 |
+
# pretrained models
|
19 |
+
img_ae_path: ./pretrained_models/vqvae_top.pth
|
20 |
+
segm_ae_path: ./pretrained_models/parsing_token.pth
|
21 |
+
|
22 |
+
model_type: TransformerTextureAwareModel
|
23 |
+
# network configs
|
24 |
+
|
25 |
+
# image autoencoder
|
26 |
+
img_embed_dim: 256
|
27 |
+
img_n_embed: 1024
|
28 |
+
img_double_z: false
|
29 |
+
img_z_channels: 256
|
30 |
+
img_resolution: 512
|
31 |
+
img_in_channels: 3
|
32 |
+
img_out_ch: 3
|
33 |
+
img_ch: 128
|
34 |
+
img_ch_mult: [1, 1, 2, 2, 4]
|
35 |
+
img_num_res_blocks: 2
|
36 |
+
img_attn_resolutions: [32]
|
37 |
+
img_dropout: 0.0
|
38 |
+
|
39 |
+
# segmentation tokenization
|
40 |
+
segm_double_z: false
|
41 |
+
segm_z_channels: 32
|
42 |
+
segm_resolution: 512
|
43 |
+
segm_in_channels: 24
|
44 |
+
segm_out_ch: 24
|
45 |
+
segm_ch: 64
|
46 |
+
segm_ch_mult: [1, 1, 2, 2, 4]
|
47 |
+
segm_num_res_blocks: 1
|
48 |
+
segm_attn_resolutions: [16]
|
49 |
+
segm_dropout: 0.0
|
50 |
+
segm_num_segm_classes: 24
|
51 |
+
segm_n_embed: 1024
|
52 |
+
segm_embed_dim: 32
|
53 |
+
|
54 |
+
# sampler configs
|
55 |
+
codebook_size: 18432
|
56 |
+
segm_codebook_size: 1024
|
57 |
+
texture_codebook_size: 18
|
58 |
+
bert_n_emb: 512
|
59 |
+
bert_n_layers: 24
|
60 |
+
bert_n_head: 8
|
61 |
+
block_size: 512 # 32 x 16
|
62 |
+
latent_shape: [32, 16]
|
63 |
+
embd_pdrop: 0.0
|
64 |
+
resid_pdrop: 0.0
|
65 |
+
attn_pdrop: 0.0
|
66 |
+
num_head: 18
|
67 |
+
|
68 |
+
# loss configs
|
69 |
+
loss_type: reweighted_elbo
|
70 |
+
mask_schedule: random
|
71 |
+
|
72 |
+
sample_steps: 256
|
73 |
+
|
74 |
+
# training configs
|
75 |
+
val_freq: 5
|
76 |
+
print_freq: 100
|
77 |
+
weight_decay: 0
|
78 |
+
manual_seed: 2021
|
79 |
+
num_epochs: 100
|
80 |
+
lr: !!float 1e-4
|
81 |
+
lr_decay: step
|
82 |
+
gamma: 1.0
|
83 |
+
step: 50
|
Text2Human/configs/vqvae_bottom.yml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: vqvae_bottom
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
train_img_dir: ./datasets/train_images
|
10 |
+
test_img_dir: ./datasets/test_images
|
11 |
+
segm_dir: ./datasets/segm
|
12 |
+
pose_dir: ./datasets/densepose
|
13 |
+
train_ann_file: ./datasets/texture_ann/train
|
14 |
+
val_ann_file: ./datasets/texture_ann/val
|
15 |
+
test_ann_file: ./datasets/texture_ann/test
|
16 |
+
downsample_factor: 2
|
17 |
+
|
18 |
+
model_type: HierarchyVQSpatialTextureAwareModel
|
19 |
+
# network configs
|
20 |
+
embed_dim: 256
|
21 |
+
n_embed: 1024
|
22 |
+
codebook_spatial_size: 2
|
23 |
+
|
24 |
+
# bottom level vqvae
|
25 |
+
bot_n_embed: 512
|
26 |
+
bot_double_z: false
|
27 |
+
bot_z_channels: 256
|
28 |
+
bot_resolution: 512
|
29 |
+
bot_in_channels: 3
|
30 |
+
bot_out_ch: 3
|
31 |
+
bot_ch: 128
|
32 |
+
bot_ch_mult: [1, 1, 2, 4]
|
33 |
+
bot_num_res_blocks: 2
|
34 |
+
bot_attn_resolutions: [64]
|
35 |
+
bot_dropout: 0.0
|
36 |
+
|
37 |
+
# top level vqgan
|
38 |
+
top_double_z: false
|
39 |
+
top_z_channels: 256
|
40 |
+
top_resolution: 512
|
41 |
+
top_in_channels: 3
|
42 |
+
top_out_ch: 3
|
43 |
+
top_ch: 128
|
44 |
+
top_ch_mult: [1, 1, 2, 2, 4]
|
45 |
+
top_num_res_blocks: 2
|
46 |
+
top_attn_resolutions: [32]
|
47 |
+
top_dropout: 0.0
|
48 |
+
top_vae_path: ./pretrained_models/vqvae_top.pth
|
49 |
+
|
50 |
+
fix_decoder: false
|
51 |
+
|
52 |
+
disc_layers: 3
|
53 |
+
disc_weight_max: 1
|
54 |
+
disc_start_step: 1
|
55 |
+
n_channels: 3
|
56 |
+
ndf: 64
|
57 |
+
nf: 128
|
58 |
+
perceptual_weight: 1.0
|
59 |
+
|
60 |
+
num_segm_classes: 24
|
61 |
+
|
62 |
+
# training configs
|
63 |
+
val_freq: 5
|
64 |
+
print_freq: 100
|
65 |
+
weight_decay: 0
|
66 |
+
manual_seed: 2021
|
67 |
+
num_epochs: 1000
|
68 |
+
lr: !!float 1.0e-04
|
69 |
+
lr_decay: step
|
70 |
+
gamma: 1.0
|
71 |
+
step: 50
|
72 |
+
|
Text2Human/configs/vqvae_top.yml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: vqvae_top
|
2 |
+
use_tb_logger: true
|
3 |
+
set_CUDA_VISIBLE_DEVICES: ~
|
4 |
+
gpu_ids: [3]
|
5 |
+
|
6 |
+
# dataset configs
|
7 |
+
batch_size: 4
|
8 |
+
num_workers: 4
|
9 |
+
train_img_dir: ./datasets/train_images
|
10 |
+
test_img_dir: ./datasets/test_images
|
11 |
+
segm_dir: ./datasets/segm
|
12 |
+
pose_dir: ./datasets/densepose
|
13 |
+
train_ann_file: ./datasets/texture_ann/train
|
14 |
+
val_ann_file: ./datasets/texture_ann/val
|
15 |
+
test_ann_file: ./datasets/texture_ann/test
|
16 |
+
downsample_factor: 2
|
17 |
+
|
18 |
+
model_type: VQImageSegmTextureModel
|
19 |
+
# network configs
|
20 |
+
embed_dim: 256
|
21 |
+
n_embed: 1024
|
22 |
+
double_z: false
|
23 |
+
z_channels: 256
|
24 |
+
resolution: 512
|
25 |
+
in_channels: 3
|
26 |
+
out_ch: 3
|
27 |
+
ch: 128
|
28 |
+
ch_mult: [1, 1, 2, 2, 4]
|
29 |
+
num_res_blocks: 2
|
30 |
+
attn_resolutions: [32]
|
31 |
+
dropout: 0.0
|
32 |
+
|
33 |
+
disc_layers: 3
|
34 |
+
disc_weight_max: 0
|
35 |
+
disc_start_step: 3000000000000000000000000001
|
36 |
+
n_channels: 3
|
37 |
+
ndf: 64
|
38 |
+
nf: 128
|
39 |
+
perceptual_weight: 1.0
|
40 |
+
|
41 |
+
num_segm_classes: 24
|
42 |
+
|
43 |
+
|
44 |
+
# training configs
|
45 |
+
val_freq: 5
|
46 |
+
print_freq: 100
|
47 |
+
weight_decay: 0
|
48 |
+
manual_seed: 2021
|
49 |
+
num_epochs: 1000
|
50 |
+
lr: !!float 1.0e-04
|
51 |
+
lr_decay: step
|
52 |
+
gamma: 1.0
|
53 |
+
step: 50
|
Text2Human/data/__init__.py
ADDED
File without changes
|
Text2Human/data/mask_dataset.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
class MaskDataset(data.Dataset):
|
12 |
+
|
13 |
+
def __init__(self, segm_dir, ann_dir, downsample_factor=2, xflip=False):
|
14 |
+
|
15 |
+
self._segm_path = segm_dir
|
16 |
+
self._image_fnames = []
|
17 |
+
|
18 |
+
self.downsample_factor = downsample_factor
|
19 |
+
self.xflip = xflip
|
20 |
+
|
21 |
+
# load attributes
|
22 |
+
assert os.path.exists(f'{ann_dir}/upper_fused.txt')
|
23 |
+
for idx, row in enumerate(
|
24 |
+
open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')):
|
25 |
+
annotations = row.split()
|
26 |
+
self._image_fnames.append(annotations[0])
|
27 |
+
|
28 |
+
def _open_file(self, path_prefix, fname):
|
29 |
+
return open(os.path.join(path_prefix, fname), 'rb')
|
30 |
+
|
31 |
+
def _load_segm(self, raw_idx):
|
32 |
+
fname = self._image_fnames[raw_idx]
|
33 |
+
fname = f'{fname[:-4]}_segm.png'
|
34 |
+
with self._open_file(self._segm_path, fname) as f:
|
35 |
+
segm = Image.open(f)
|
36 |
+
if self.downsample_factor != 1:
|
37 |
+
width, height = segm.size
|
38 |
+
width = width // self.downsample_factor
|
39 |
+
height = height // self.downsample_factor
|
40 |
+
segm = segm.resize(
|
41 |
+
size=(width, height), resample=Image.NEAREST)
|
42 |
+
segm = np.array(segm)
|
43 |
+
# segm = segm[:, :, np.newaxis].transpose(2, 0, 1)
|
44 |
+
return segm.astype(np.float32)
|
45 |
+
|
46 |
+
def __getitem__(self, index):
|
47 |
+
segm = self._load_segm(index)
|
48 |
+
|
49 |
+
if self.xflip and random.random() > 0.5:
|
50 |
+
segm = segm[:, ::-1].copy()
|
51 |
+
|
52 |
+
segm = torch.from_numpy(segm).long()
|
53 |
+
|
54 |
+
return_dict = {'segm': segm, 'img_name': self._image_fnames[index]}
|
55 |
+
|
56 |
+
return return_dict
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self._image_fnames)
|
Text2Human/data/parsing_generation_segm_attr_dataset.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.utils.data as data
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
class ParsingGenerationDeepFashionAttrSegmDataset(data.Dataset):
|
11 |
+
|
12 |
+
def __init__(self, segm_dir, pose_dir, ann_file, downsample_factor=2):
|
13 |
+
self._densepose_path = pose_dir
|
14 |
+
self._segm_path = segm_dir
|
15 |
+
self._image_fnames = []
|
16 |
+
self.attrs = []
|
17 |
+
|
18 |
+
self.downsample_factor = downsample_factor
|
19 |
+
|
20 |
+
# training, ground-truth available
|
21 |
+
assert os.path.exists(ann_file)
|
22 |
+
for row in open(os.path.join(ann_file), 'r'):
|
23 |
+
annotations = row.split()
|
24 |
+
self._image_fnames.append(annotations[0])
|
25 |
+
self.attrs.append([int(i) for i in annotations[1:]])
|
26 |
+
|
27 |
+
def _open_file(self, path_prefix, fname):
|
28 |
+
return open(os.path.join(path_prefix, fname), 'rb')
|
29 |
+
|
30 |
+
def _load_densepose(self, raw_idx):
|
31 |
+
fname = self._image_fnames[raw_idx]
|
32 |
+
fname = f'{fname[:-4]}_densepose.png'
|
33 |
+
with self._open_file(self._densepose_path, fname) as f:
|
34 |
+
densepose = Image.open(f)
|
35 |
+
if self.downsample_factor != 1:
|
36 |
+
width, height = densepose.size
|
37 |
+
width = width // self.downsample_factor
|
38 |
+
height = height // self.downsample_factor
|
39 |
+
densepose = densepose.resize(
|
40 |
+
size=(width, height), resample=Image.NEAREST)
|
41 |
+
# channel-wise IUV order, [3, H, W]
|
42 |
+
densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
|
43 |
+
return densepose.astype(np.float32)
|
44 |
+
|
45 |
+
def _load_segm(self, raw_idx):
|
46 |
+
fname = self._image_fnames[raw_idx]
|
47 |
+
fname = f'{fname[:-4]}_segm.png'
|
48 |
+
with self._open_file(self._segm_path, fname) as f:
|
49 |
+
segm = Image.open(f)
|
50 |
+
if self.downsample_factor != 1:
|
51 |
+
width, height = segm.size
|
52 |
+
width = width // self.downsample_factor
|
53 |
+
height = height // self.downsample_factor
|
54 |
+
segm = segm.resize(
|
55 |
+
size=(width, height), resample=Image.NEAREST)
|
56 |
+
segm = np.array(segm)
|
57 |
+
return segm.astype(np.float32)
|
58 |
+
|
59 |
+
def __getitem__(self, index):
|
60 |
+
pose = self._load_densepose(index)
|
61 |
+
segm = self._load_segm(index)
|
62 |
+
attr = self.attrs[index]
|
63 |
+
|
64 |
+
pose = torch.from_numpy(pose)
|
65 |
+
segm = torch.LongTensor(segm)
|
66 |
+
attr = torch.LongTensor(attr)
|
67 |
+
|
68 |
+
pose = pose / 12. - 1
|
69 |
+
|
70 |
+
return_dict = {
|
71 |
+
'densepose': pose,
|
72 |
+
'segm': segm,
|
73 |
+
'attr': attr,
|
74 |
+
'img_name': self._image_fnames[index]
|
75 |
+
}
|
76 |
+
|
77 |
+
return return_dict
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self._image_fnames)
|
Text2Human/data/pose_attr_dataset.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
class DeepFashionAttrPoseDataset(data.Dataset):
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
pose_dir,
|
15 |
+
texture_ann_dir,
|
16 |
+
shape_ann_path,
|
17 |
+
downsample_factor=2,
|
18 |
+
xflip=False):
|
19 |
+
self._densepose_path = pose_dir
|
20 |
+
self._image_fnames_target = []
|
21 |
+
self._image_fnames = []
|
22 |
+
self.upper_fused_attrs = []
|
23 |
+
self.lower_fused_attrs = []
|
24 |
+
self.outer_fused_attrs = []
|
25 |
+
self.shape_attrs = []
|
26 |
+
|
27 |
+
self.downsample_factor = downsample_factor
|
28 |
+
self.xflip = xflip
|
29 |
+
|
30 |
+
# load attributes
|
31 |
+
assert os.path.exists(f'{texture_ann_dir}/upper_fused.txt')
|
32 |
+
for idx, row in enumerate(
|
33 |
+
open(os.path.join(f'{texture_ann_dir}/upper_fused.txt'), 'r')):
|
34 |
+
annotations = row.split()
|
35 |
+
self._image_fnames_target.append(annotations[0])
|
36 |
+
self._image_fnames.append(f'{annotations[0].split(".")[0]}.png')
|
37 |
+
self.upper_fused_attrs.append(int(annotations[1]))
|
38 |
+
|
39 |
+
assert len(self._image_fnames_target) == len(self.upper_fused_attrs)
|
40 |
+
|
41 |
+
assert os.path.exists(f'{texture_ann_dir}/lower_fused.txt')
|
42 |
+
for idx, row in enumerate(
|
43 |
+
open(os.path.join(f'{texture_ann_dir}/lower_fused.txt'), 'r')):
|
44 |
+
annotations = row.split()
|
45 |
+
assert self._image_fnames_target[idx] == annotations[0]
|
46 |
+
self.lower_fused_attrs.append(int(annotations[1]))
|
47 |
+
|
48 |
+
assert len(self._image_fnames_target) == len(self.lower_fused_attrs)
|
49 |
+
|
50 |
+
assert os.path.exists(f'{texture_ann_dir}/outer_fused.txt')
|
51 |
+
for idx, row in enumerate(
|
52 |
+
open(os.path.join(f'{texture_ann_dir}/outer_fused.txt'), 'r')):
|
53 |
+
annotations = row.split()
|
54 |
+
assert self._image_fnames_target[idx] == annotations[0]
|
55 |
+
self.outer_fused_attrs.append(int(annotations[1]))
|
56 |
+
|
57 |
+
assert len(self._image_fnames_target) == len(self.outer_fused_attrs)
|
58 |
+
|
59 |
+
assert os.path.exists(shape_ann_path)
|
60 |
+
for idx, row in enumerate(open(os.path.join(shape_ann_path), 'r')):
|
61 |
+
annotations = row.split()
|
62 |
+
assert self._image_fnames_target[idx] == annotations[0]
|
63 |
+
self.shape_attrs.append([int(i) for i in annotations[1:]])
|
64 |
+
|
65 |
+
def _open_file(self, path_prefix, fname):
|
66 |
+
return open(os.path.join(path_prefix, fname), 'rb')
|
67 |
+
|
68 |
+
def _load_densepose(self, raw_idx):
|
69 |
+
fname = self._image_fnames[raw_idx]
|
70 |
+
fname = f'{fname[:-4]}_densepose.png'
|
71 |
+
with self._open_file(self._densepose_path, fname) as f:
|
72 |
+
densepose = Image.open(f)
|
73 |
+
if self.downsample_factor != 1:
|
74 |
+
width, height = densepose.size
|
75 |
+
width = width // self.downsample_factor
|
76 |
+
height = height // self.downsample_factor
|
77 |
+
densepose = densepose.resize(
|
78 |
+
size=(width, height), resample=Image.NEAREST)
|
79 |
+
# channel-wise IUV order, [3, H, W]
|
80 |
+
densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
|
81 |
+
return densepose.astype(np.float32)
|
82 |
+
|
83 |
+
def __getitem__(self, index):
|
84 |
+
pose = self._load_densepose(index)
|
85 |
+
shape_attr = self.shape_attrs[index]
|
86 |
+
shape_attr = torch.LongTensor(shape_attr)
|
87 |
+
|
88 |
+
if self.xflip and random.random() > 0.5:
|
89 |
+
pose = pose[:, :, ::-1].copy()
|
90 |
+
|
91 |
+
upper_fused_attr = self.upper_fused_attrs[index]
|
92 |
+
lower_fused_attr = self.lower_fused_attrs[index]
|
93 |
+
outer_fused_attr = self.outer_fused_attrs[index]
|
94 |
+
|
95 |
+
pose = pose / 12. - 1
|
96 |
+
|
97 |
+
return_dict = {
|
98 |
+
'densepose': pose,
|
99 |
+
'img_name': self._image_fnames_target[index],
|
100 |
+
'shape_attr': shape_attr,
|
101 |
+
'upper_fused_attr': upper_fused_attr,
|
102 |
+
'lower_fused_attr': lower_fused_attr,
|
103 |
+
'outer_fused_attr': outer_fused_attr,
|
104 |
+
}
|
105 |
+
|
106 |
+
return return_dict
|
107 |
+
|
108 |
+
def __len__(self):
|
109 |
+
return len(self._image_fnames)
|
Text2Human/data/segm_attr_dataset.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
class DeepFashionAttrSegmDataset(data.Dataset):
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
img_dir,
|
15 |
+
segm_dir,
|
16 |
+
pose_dir,
|
17 |
+
ann_dir,
|
18 |
+
downsample_factor=2,
|
19 |
+
xflip=False):
|
20 |
+
self._img_path = img_dir
|
21 |
+
self._densepose_path = pose_dir
|
22 |
+
self._segm_path = segm_dir
|
23 |
+
self._image_fnames = []
|
24 |
+
self.upper_fused_attrs = []
|
25 |
+
self.lower_fused_attrs = []
|
26 |
+
self.outer_fused_attrs = []
|
27 |
+
|
28 |
+
self.downsample_factor = downsample_factor
|
29 |
+
self.xflip = xflip
|
30 |
+
|
31 |
+
# load attributes
|
32 |
+
assert os.path.exists(f'{ann_dir}/upper_fused.txt')
|
33 |
+
for idx, row in enumerate(
|
34 |
+
open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')):
|
35 |
+
annotations = row.split()
|
36 |
+
self._image_fnames.append(annotations[0])
|
37 |
+
# assert self._image_fnames[idx] == annotations[0]
|
38 |
+
self.upper_fused_attrs.append(int(annotations[1]))
|
39 |
+
|
40 |
+
assert len(self._image_fnames) == len(self.upper_fused_attrs)
|
41 |
+
|
42 |
+
assert os.path.exists(f'{ann_dir}/lower_fused.txt')
|
43 |
+
for idx, row in enumerate(
|
44 |
+
open(os.path.join(f'{ann_dir}/lower_fused.txt'), 'r')):
|
45 |
+
annotations = row.split()
|
46 |
+
assert self._image_fnames[idx] == annotations[0]
|
47 |
+
self.lower_fused_attrs.append(int(annotations[1]))
|
48 |
+
|
49 |
+
assert len(self._image_fnames) == len(self.lower_fused_attrs)
|
50 |
+
|
51 |
+
assert os.path.exists(f'{ann_dir}/outer_fused.txt')
|
52 |
+
for idx, row in enumerate(
|
53 |
+
open(os.path.join(f'{ann_dir}/outer_fused.txt'), 'r')):
|
54 |
+
annotations = row.split()
|
55 |
+
assert self._image_fnames[idx] == annotations[0]
|
56 |
+
self.outer_fused_attrs.append(int(annotations[1]))
|
57 |
+
|
58 |
+
assert len(self._image_fnames) == len(self.outer_fused_attrs)
|
59 |
+
|
60 |
+
# remove the overlapping item between upper cls and lower cls
|
61 |
+
# cls 21 can appear with upper clothes
|
62 |
+
# cls 4 can appear with lower clothes
|
63 |
+
self.upper_cls = [1., 4.]
|
64 |
+
self.lower_cls = [3., 5., 21.]
|
65 |
+
self.outer_cls = [2.]
|
66 |
+
self.other_cls = [
|
67 |
+
11., 18., 7., 8., 9., 10., 12., 16., 17., 19., 20., 22., 23., 15.,
|
68 |
+
14., 13., 0., 6.
|
69 |
+
]
|
70 |
+
|
71 |
+
def _open_file(self, path_prefix, fname):
|
72 |
+
return open(os.path.join(path_prefix, fname), 'rb')
|
73 |
+
|
74 |
+
def _load_raw_image(self, raw_idx):
|
75 |
+
fname = self._image_fnames[raw_idx]
|
76 |
+
with self._open_file(self._img_path, fname) as f:
|
77 |
+
image = Image.open(f)
|
78 |
+
if self.downsample_factor != 1:
|
79 |
+
width, height = image.size
|
80 |
+
width = width // self.downsample_factor
|
81 |
+
height = height // self.downsample_factor
|
82 |
+
image = image.resize(
|
83 |
+
size=(width, height), resample=Image.LANCZOS)
|
84 |
+
image = np.array(image)
|
85 |
+
if image.ndim == 2:
|
86 |
+
image = image[:, :, np.newaxis] # HW => HWC
|
87 |
+
image = image.transpose(2, 0, 1) # HWC => CHW
|
88 |
+
return image
|
89 |
+
|
90 |
+
def _load_densepose(self, raw_idx):
|
91 |
+
fname = self._image_fnames[raw_idx]
|
92 |
+
fname = f'{fname[:-4]}_densepose.png'
|
93 |
+
with self._open_file(self._densepose_path, fname) as f:
|
94 |
+
densepose = Image.open(f)
|
95 |
+
if self.downsample_factor != 1:
|
96 |
+
width, height = densepose.size
|
97 |
+
width = width // self.downsample_factor
|
98 |
+
height = height // self.downsample_factor
|
99 |
+
densepose = densepose.resize(
|
100 |
+
size=(width, height), resample=Image.NEAREST)
|
101 |
+
# channel-wise IUV order, [3, H, W]
|
102 |
+
densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
|
103 |
+
return densepose.astype(np.float32)
|
104 |
+
|
105 |
+
def _load_segm(self, raw_idx):
|
106 |
+
fname = self._image_fnames[raw_idx]
|
107 |
+
fname = f'{fname[:-4]}_segm.png'
|
108 |
+
with self._open_file(self._segm_path, fname) as f:
|
109 |
+
segm = Image.open(f)
|
110 |
+
if self.downsample_factor != 1:
|
111 |
+
width, height = segm.size
|
112 |
+
width = width // self.downsample_factor
|
113 |
+
height = height // self.downsample_factor
|
114 |
+
segm = segm.resize(
|
115 |
+
size=(width, height), resample=Image.NEAREST)
|
116 |
+
segm = np.array(segm)
|
117 |
+
segm = segm[:, :, np.newaxis].transpose(2, 0, 1)
|
118 |
+
return segm.astype(np.float32)
|
119 |
+
|
120 |
+
def __getitem__(self, index):
|
121 |
+
image = self._load_raw_image(index)
|
122 |
+
pose = self._load_densepose(index)
|
123 |
+
segm = self._load_segm(index)
|
124 |
+
|
125 |
+
if self.xflip and random.random() > 0.5:
|
126 |
+
assert image.ndim == 3 # CHW
|
127 |
+
image = image[:, :, ::-1].copy()
|
128 |
+
pose = pose[:, :, ::-1].copy()
|
129 |
+
segm = segm[:, :, ::-1].copy()
|
130 |
+
|
131 |
+
image = torch.from_numpy(image)
|
132 |
+
segm = torch.from_numpy(segm)
|
133 |
+
|
134 |
+
upper_fused_attr = self.upper_fused_attrs[index]
|
135 |
+
lower_fused_attr = self.lower_fused_attrs[index]
|
136 |
+
outer_fused_attr = self.outer_fused_attrs[index]
|
137 |
+
|
138 |
+
# mask 0: denotes the common codebook,
|
139 |
+
# mask (attr + 1): denotes the texture-specific codebook
|
140 |
+
mask = torch.zeros_like(segm)
|
141 |
+
if upper_fused_attr != 17:
|
142 |
+
for cls in self.upper_cls:
|
143 |
+
mask[segm == cls] = upper_fused_attr + 1
|
144 |
+
|
145 |
+
if lower_fused_attr != 17:
|
146 |
+
for cls in self.lower_cls:
|
147 |
+
mask[segm == cls] = lower_fused_attr + 1
|
148 |
+
|
149 |
+
if outer_fused_attr != 17:
|
150 |
+
for cls in self.outer_cls:
|
151 |
+
mask[segm == cls] = outer_fused_attr + 1
|
152 |
+
|
153 |
+
pose = pose / 12. - 1
|
154 |
+
image = image / 127.5 - 1
|
155 |
+
|
156 |
+
return_dict = {
|
157 |
+
'image': image,
|
158 |
+
'densepose': pose,
|
159 |
+
'segm': segm,
|
160 |
+
'texture_mask': mask,
|
161 |
+
'img_name': self._image_fnames[index]
|
162 |
+
}
|
163 |
+
|
164 |
+
return return_dict
|
165 |
+
|
166 |
+
def __len__(self):
|
167 |
+
return len(self._image_fnames)
|
Text2Human/environment/text2human_env.yaml
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: text2human
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- anaconda
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- _libgcc_mutex=0.1=main
|
8 |
+
- astroid=2.5=py36h06a4308_1
|
9 |
+
- blas=1.0=mkl
|
10 |
+
- brotlipy=0.7.0=py36h7b6447c_1000
|
11 |
+
- ca-certificates=2021.10.26=h06a4308_2
|
12 |
+
- certifi=2021.5.30=py36h06a4308_0
|
13 |
+
- cffi=1.14.3=py36he30daa8_0
|
14 |
+
- chardet=3.0.4=py36_1003
|
15 |
+
- click=8.0.3=pyhd3eb1b0_0
|
16 |
+
- cryptography=3.1.1=py36h1ba5d50_0
|
17 |
+
- cudatoolkit=10.1.243=h6bb024c_0
|
18 |
+
- dataclasses=0.8=pyh4f3eec9_6
|
19 |
+
- dbus=1.13.18=hb2f20db_0
|
20 |
+
- expat=2.2.10=he6710b0_2
|
21 |
+
- filelock=3.4.0=pyhd3eb1b0_0
|
22 |
+
- fontconfig=2.13.0=h9420a91_0
|
23 |
+
- freetype=2.10.4=h5ab3b9f_0
|
24 |
+
- glib=2.56.2=hd408876_0
|
25 |
+
- gst-plugins-base=1.14.0=hbbd80ab_1
|
26 |
+
- gstreamer=1.14.0=hb453b48_1
|
27 |
+
- icu=58.2=he6710b0_3
|
28 |
+
- idna=2.10=py_0
|
29 |
+
- importlib-metadata=4.8.1=py36h06a4308_0
|
30 |
+
- importlib_metadata=4.8.1=hd3eb1b0_0
|
31 |
+
- intel-openmp=2020.2=254
|
32 |
+
- isort=5.7.0=pyhd3eb1b0_0
|
33 |
+
- joblib=1.0.1=pyhd3eb1b0_0
|
34 |
+
- jpeg=9b=habf39ab_1
|
35 |
+
- lazy-object-proxy=1.5.2=py36h27cfd23_0
|
36 |
+
- lcms2=2.11=h396b838_0
|
37 |
+
- ld_impl_linux-64=2.33.1=h53a641e_7
|
38 |
+
- libffi=3.3=he6710b0_2
|
39 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
40 |
+
- libpng=1.6.37=hbc83047_0
|
41 |
+
- libprotobuf=3.17.2=h4ff587b_1
|
42 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
43 |
+
- libtiff=4.2.0=h3942068_0
|
44 |
+
- libuuid=1.0.3=h1bed415_2
|
45 |
+
- libuv=1.40.0=h7b6447c_0
|
46 |
+
- libwebp-base=1.2.0=h27cfd23_0
|
47 |
+
- libxcb=1.14=h7b6447c_0
|
48 |
+
- libxml2=2.9.10=hb55368b_3
|
49 |
+
- lz4-c=1.9.3=h2531618_0
|
50 |
+
- mccabe=0.6.1=py36_1
|
51 |
+
- mkl=2020.2=256
|
52 |
+
- mkl-service=2.3.0=py36he8ac12f_0
|
53 |
+
- mkl_fft=1.3.0=py36h54f3939_0
|
54 |
+
- mkl_random=1.1.1=py36h0573a6f_0
|
55 |
+
- ncurses=6.2=he6710b0_1
|
56 |
+
- ninja=1.10.2=h5e70eb0_2
|
57 |
+
- numpy=1.19.2=py36h54aff64_0
|
58 |
+
- numpy-base=1.19.2=py36hfa32c7d_0
|
59 |
+
- olefile=0.46=py36_0
|
60 |
+
- openssl=1.1.1m=h7f8727e_0
|
61 |
+
- packaging=21.3=pyhd3eb1b0_0
|
62 |
+
- pcre=8.44=he6710b0_0
|
63 |
+
- pillow=8.1.2=py36he98fc37_0
|
64 |
+
- pip=21.0.1=py36h06a4308_0
|
65 |
+
- protobuf=3.17.2=py36h295c915_0
|
66 |
+
- pycparser=2.20=py_2
|
67 |
+
- pylint=2.7.2=py36h06a4308_1
|
68 |
+
- pyopenssl=19.1.0=py_1
|
69 |
+
- pyqt=5.9.2=py36h05f1152_2
|
70 |
+
- pysocks=1.7.1=py36_0
|
71 |
+
- python=3.6.13=hdb3f193_0
|
72 |
+
- pytorch=1.7.1=py3.6_cuda10.1.243_cudnn7.6.3_0
|
73 |
+
- qt=5.9.7=h5867ecd_1
|
74 |
+
- readline=8.1=h27cfd23_0
|
75 |
+
- regex=2021.8.3=py36h7f8727e_0
|
76 |
+
- requests=2.24.0=py_0
|
77 |
+
- setuptools=52.0.0=py36h06a4308_0
|
78 |
+
- sip=4.19.8=py36hf484d3e_0
|
79 |
+
- six=1.15.0=py36h06a4308_0
|
80 |
+
- sqlite=3.35.2=hdfb4753_0
|
81 |
+
- tk=8.6.10=hbc83047_0
|
82 |
+
- toml=0.10.2=pyhd3eb1b0_0
|
83 |
+
- torchvision=0.8.2=py36_cu101
|
84 |
+
- tqdm=4.62.3=pyhd3eb1b0_1
|
85 |
+
- typed-ast=1.4.2=py36h27cfd23_1
|
86 |
+
- typing-extensions=3.10.0.2=hd3eb1b0_0
|
87 |
+
- typing_extensions=3.10.0.2=pyh06a4308_0
|
88 |
+
- urllib3=1.25.11=py_0
|
89 |
+
- wheel=0.36.2=pyhd3eb1b0_0
|
90 |
+
- wrapt=1.12.1=py36h7b6447c_1
|
91 |
+
- xz=5.2.5=h7b6447c_0
|
92 |
+
- yaml=0.2.5=h7b6447c_0
|
93 |
+
- zipp=3.6.0=pyhd3eb1b0_0
|
94 |
+
- zlib=1.2.11=h7b6447c_3
|
95 |
+
- zstd=1.4.5=h9ceee32_0
|
96 |
+
- pip:
|
97 |
+
- addict==2.4.0
|
98 |
+
- cycler==0.11.0
|
99 |
+
- einops==0.4.0
|
100 |
+
- kiwisolver==1.3.1
|
101 |
+
- matplotlib==3.3.4
|
102 |
+
- mmcv-full==1.2.1
|
103 |
+
- mmsegmentation==0.9.0
|
104 |
+
- nltk==3.6.7
|
105 |
+
- opencv-python==4.5.5.62
|
106 |
+
- pyparsing==3.0.7
|
107 |
+
- python-dateutil==2.8.2
|
108 |
+
- pyyaml==6.0
|
109 |
+
- scikit-learn==0.24.2
|
110 |
+
- scipy==1.5.4
|
111 |
+
- sentencepiece==0.1.96
|
112 |
+
- terminaltables==3.1.10
|
113 |
+
- threadpoolctl==3.0.0
|
114 |
+
- yapf==0.32.0
|
Text2Human/models/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import importlib
|
3 |
+
import logging
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
# automatically scan and import model modules
|
7 |
+
# scan all the files under the 'models' folder and collect files ending with
|
8 |
+
# '_model.py'
|
9 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
10 |
+
model_filenames = [
|
11 |
+
osp.splitext(osp.basename(v))[0]
|
12 |
+
for v in glob.glob(f'{model_folder}/*_model.py')
|
13 |
+
]
|
14 |
+
# import all the model modules
|
15 |
+
_model_modules = [
|
16 |
+
importlib.import_module(f'models.{file_name}')
|
17 |
+
for file_name in model_filenames
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def create_model(opt):
|
22 |
+
"""Create model.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
opt (dict): Configuration. It constains:
|
26 |
+
model_type (str): Model type.
|
27 |
+
"""
|
28 |
+
model_type = opt['model_type']
|
29 |
+
|
30 |
+
# dynamically instantiation
|
31 |
+
for module in _model_modules:
|
32 |
+
model_cls = getattr(module, model_type, None)
|
33 |
+
if model_cls is not None:
|
34 |
+
break
|
35 |
+
if model_cls is None:
|
36 |
+
raise ValueError(f'Model {model_type} is not found.')
|
37 |
+
|
38 |
+
model = model_cls(opt)
|
39 |
+
|
40 |
+
logger = logging.getLogger('base')
|
41 |
+
logger.info(f'Model [{model.__class__.__name__}] is created.')
|
42 |
+
return model
|
Text2Human/models/archs/__init__.py
ADDED
File without changes
|
Text2Human/models/archs/fcn_arch.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from mmcv.cnn import ConvModule, normal_init
|
4 |
+
from mmseg.ops import resize
|
5 |
+
|
6 |
+
|
7 |
+
class BaseDecodeHead(nn.Module):
|
8 |
+
"""Base class for BaseDecodeHead.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
in_channels (int|Sequence[int]): Input channels.
|
12 |
+
channels (int): Channels after modules, before conv_seg.
|
13 |
+
num_classes (int): Number of classes.
|
14 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
15 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
16 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
17 |
+
act_cfg (dict): Config of activation layers.
|
18 |
+
Default: dict(type='ReLU')
|
19 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
20 |
+
input_transform (str|None): Transformation type of input features.
|
21 |
+
Options: 'resize_concat', 'multiple_select', None.
|
22 |
+
'resize_concat': Multiple feature maps will be resize to the
|
23 |
+
same size as first one and than concat together.
|
24 |
+
Usually used in FCN head of HRNet.
|
25 |
+
'multiple_select': Multiple feature maps will be bundle into
|
26 |
+
a list and passed into decode head.
|
27 |
+
None: Only one select feature map is allowed.
|
28 |
+
Default: None.
|
29 |
+
loss_decode (dict): Config of decode loss.
|
30 |
+
Default: dict(type='CrossEntropyLoss').
|
31 |
+
ignore_index (int | None): The label index to be ignored. When using
|
32 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
33 |
+
sampler (dict|None): The config of segmentation map sampler.
|
34 |
+
Default: None.
|
35 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
36 |
+
Default: False.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
in_channels,
|
41 |
+
channels,
|
42 |
+
*,
|
43 |
+
num_classes,
|
44 |
+
dropout_ratio=0.1,
|
45 |
+
conv_cfg=None,
|
46 |
+
norm_cfg=dict(type='BN'),
|
47 |
+
act_cfg=dict(type='ReLU'),
|
48 |
+
in_index=-1,
|
49 |
+
input_transform=None,
|
50 |
+
ignore_index=255,
|
51 |
+
align_corners=False):
|
52 |
+
super(BaseDecodeHead, self).__init__()
|
53 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
54 |
+
self.channels = channels
|
55 |
+
self.num_classes = num_classes
|
56 |
+
self.dropout_ratio = dropout_ratio
|
57 |
+
self.conv_cfg = conv_cfg
|
58 |
+
self.norm_cfg = norm_cfg
|
59 |
+
self.act_cfg = act_cfg
|
60 |
+
self.in_index = in_index
|
61 |
+
|
62 |
+
self.ignore_index = ignore_index
|
63 |
+
self.align_corners = align_corners
|
64 |
+
|
65 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
66 |
+
if dropout_ratio > 0:
|
67 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
68 |
+
else:
|
69 |
+
self.dropout = None
|
70 |
+
|
71 |
+
def extra_repr(self):
|
72 |
+
"""Extra repr."""
|
73 |
+
s = f'input_transform={self.input_transform}, ' \
|
74 |
+
f'ignore_index={self.ignore_index}, ' \
|
75 |
+
f'align_corners={self.align_corners}'
|
76 |
+
return s
|
77 |
+
|
78 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
79 |
+
"""Check and initialize input transforms.
|
80 |
+
|
81 |
+
The in_channels, in_index and input_transform must match.
|
82 |
+
Specifically, when input_transform is None, only single feature map
|
83 |
+
will be selected. So in_channels and in_index must be of type int.
|
84 |
+
When input_transform
|
85 |
+
|
86 |
+
Args:
|
87 |
+
in_channels (int|Sequence[int]): Input channels.
|
88 |
+
in_index (int|Sequence[int]): Input feature index.
|
89 |
+
input_transform (str|None): Transformation type of input features.
|
90 |
+
Options: 'resize_concat', 'multiple_select', None.
|
91 |
+
'resize_concat': Multiple feature maps will be resize to the
|
92 |
+
same size as first one and than concat together.
|
93 |
+
Usually used in FCN head of HRNet.
|
94 |
+
'multiple_select': Multiple feature maps will be bundle into
|
95 |
+
a list and passed into decode head.
|
96 |
+
None: Only one select feature map is allowed.
|
97 |
+
"""
|
98 |
+
|
99 |
+
if input_transform is not None:
|
100 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
101 |
+
self.input_transform = input_transform
|
102 |
+
self.in_index = in_index
|
103 |
+
if input_transform is not None:
|
104 |
+
assert isinstance(in_channels, (list, tuple))
|
105 |
+
assert isinstance(in_index, (list, tuple))
|
106 |
+
assert len(in_channels) == len(in_index)
|
107 |
+
if input_transform == 'resize_concat':
|
108 |
+
self.in_channels = sum(in_channels)
|
109 |
+
else:
|
110 |
+
self.in_channels = in_channels
|
111 |
+
else:
|
112 |
+
assert isinstance(in_channels, int)
|
113 |
+
assert isinstance(in_index, int)
|
114 |
+
self.in_channels = in_channels
|
115 |
+
|
116 |
+
def init_weights(self):
|
117 |
+
"""Initialize weights of classification layer."""
|
118 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
119 |
+
|
120 |
+
def _transform_inputs(self, inputs):
|
121 |
+
"""Transform inputs for decoder.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
inputs (list[Tensor]): List of multi-level img features.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
Tensor: The transformed inputs
|
128 |
+
"""
|
129 |
+
|
130 |
+
if self.input_transform == 'resize_concat':
|
131 |
+
inputs = [inputs[i] for i in self.in_index]
|
132 |
+
upsampled_inputs = [
|
133 |
+
resize(
|
134 |
+
input=x,
|
135 |
+
size=inputs[0].shape[2:],
|
136 |
+
mode='bilinear',
|
137 |
+
align_corners=self.align_corners) for x in inputs
|
138 |
+
]
|
139 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
140 |
+
elif self.input_transform == 'multiple_select':
|
141 |
+
inputs = [inputs[i] for i in self.in_index]
|
142 |
+
else:
|
143 |
+
inputs = inputs[self.in_index]
|
144 |
+
|
145 |
+
return inputs
|
146 |
+
|
147 |
+
def forward(self, inputs):
|
148 |
+
"""Placeholder of forward function."""
|
149 |
+
pass
|
150 |
+
|
151 |
+
def cls_seg(self, feat):
|
152 |
+
"""Classify each pixel."""
|
153 |
+
if self.dropout is not None:
|
154 |
+
feat = self.dropout(feat)
|
155 |
+
output = self.conv_seg(feat)
|
156 |
+
return output
|
157 |
+
|
158 |
+
|
159 |
+
class FCNHead(BaseDecodeHead):
|
160 |
+
"""Fully Convolution Networks for Semantic Segmentation.
|
161 |
+
|
162 |
+
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
num_convs (int): Number of convs in the head. Default: 2.
|
166 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
167 |
+
concat_input (bool): Whether concat the input and output of convs
|
168 |
+
before classification layer.
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self,
|
172 |
+
num_convs=2,
|
173 |
+
kernel_size=3,
|
174 |
+
concat_input=True,
|
175 |
+
**kwargs):
|
176 |
+
assert num_convs >= 0
|
177 |
+
self.num_convs = num_convs
|
178 |
+
self.concat_input = concat_input
|
179 |
+
self.kernel_size = kernel_size
|
180 |
+
super(FCNHead, self).__init__(**kwargs)
|
181 |
+
if num_convs == 0:
|
182 |
+
assert self.in_channels == self.channels
|
183 |
+
|
184 |
+
convs = []
|
185 |
+
convs.append(
|
186 |
+
ConvModule(
|
187 |
+
self.in_channels,
|
188 |
+
self.channels,
|
189 |
+
kernel_size=kernel_size,
|
190 |
+
padding=kernel_size // 2,
|
191 |
+
conv_cfg=self.conv_cfg,
|
192 |
+
norm_cfg=self.norm_cfg,
|
193 |
+
act_cfg=self.act_cfg))
|
194 |
+
for i in range(num_convs - 1):
|
195 |
+
convs.append(
|
196 |
+
ConvModule(
|
197 |
+
self.channels,
|
198 |
+
self.channels,
|
199 |
+
kernel_size=kernel_size,
|
200 |
+
padding=kernel_size // 2,
|
201 |
+
conv_cfg=self.conv_cfg,
|
202 |
+
norm_cfg=self.norm_cfg,
|
203 |
+
act_cfg=self.act_cfg))
|
204 |
+
if num_convs == 0:
|
205 |
+
self.convs = nn.Identity()
|
206 |
+
else:
|
207 |
+
self.convs = nn.Sequential(*convs)
|
208 |
+
if self.concat_input:
|
209 |
+
self.conv_cat = ConvModule(
|
210 |
+
self.in_channels + self.channels,
|
211 |
+
self.channels,
|
212 |
+
kernel_size=kernel_size,
|
213 |
+
padding=kernel_size // 2,
|
214 |
+
conv_cfg=self.conv_cfg,
|
215 |
+
norm_cfg=self.norm_cfg,
|
216 |
+
act_cfg=self.act_cfg)
|
217 |
+
|
218 |
+
def forward(self, inputs):
|
219 |
+
"""Forward function."""
|
220 |
+
x = self._transform_inputs(inputs)
|
221 |
+
output = self.convs(x)
|
222 |
+
if self.concat_input:
|
223 |
+
output = self.conv_cat(torch.cat([x, output], dim=1))
|
224 |
+
output = self.cls_seg(output)
|
225 |
+
return output
|
226 |
+
|
227 |
+
|
228 |
+
class MultiHeadFCNHead(nn.Module):
|
229 |
+
"""Fully Convolution Networks for Semantic Segmentation.
|
230 |
+
|
231 |
+
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
num_convs (int): Number of convs in the head. Default: 2.
|
235 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
236 |
+
concat_input (bool): Whether concat the input and output of convs
|
237 |
+
before classification layer.
|
238 |
+
"""
|
239 |
+
|
240 |
+
def __init__(self,
|
241 |
+
in_channels,
|
242 |
+
channels,
|
243 |
+
*,
|
244 |
+
num_classes,
|
245 |
+
dropout_ratio=0.1,
|
246 |
+
conv_cfg=None,
|
247 |
+
norm_cfg=dict(type='BN'),
|
248 |
+
act_cfg=dict(type='ReLU'),
|
249 |
+
in_index=-1,
|
250 |
+
input_transform=None,
|
251 |
+
ignore_index=255,
|
252 |
+
align_corners=False,
|
253 |
+
num_convs=2,
|
254 |
+
kernel_size=3,
|
255 |
+
concat_input=True,
|
256 |
+
num_head=18,
|
257 |
+
**kwargs):
|
258 |
+
super(MultiHeadFCNHead, self).__init__()
|
259 |
+
assert num_convs >= 0
|
260 |
+
self.num_convs = num_convs
|
261 |
+
self.concat_input = concat_input
|
262 |
+
self.kernel_size = kernel_size
|
263 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
264 |
+
self.channels = channels
|
265 |
+
self.num_classes = num_classes
|
266 |
+
self.dropout_ratio = dropout_ratio
|
267 |
+
self.conv_cfg = conv_cfg
|
268 |
+
self.norm_cfg = norm_cfg
|
269 |
+
self.act_cfg = act_cfg
|
270 |
+
self.in_index = in_index
|
271 |
+
self.num_head = num_head
|
272 |
+
|
273 |
+
self.ignore_index = ignore_index
|
274 |
+
self.align_corners = align_corners
|
275 |
+
|
276 |
+
if dropout_ratio > 0:
|
277 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
278 |
+
|
279 |
+
conv_seg_head_list = []
|
280 |
+
for _ in range(self.num_head):
|
281 |
+
conv_seg_head_list.append(
|
282 |
+
nn.Conv2d(channels, num_classes, kernel_size=1))
|
283 |
+
|
284 |
+
self.conv_seg_head_list = nn.ModuleList(conv_seg_head_list)
|
285 |
+
|
286 |
+
self.init_weights()
|
287 |
+
|
288 |
+
if num_convs == 0:
|
289 |
+
assert self.in_channels == self.channels
|
290 |
+
|
291 |
+
convs_list = []
|
292 |
+
conv_cat_list = []
|
293 |
+
|
294 |
+
for _ in range(self.num_head):
|
295 |
+
convs = []
|
296 |
+
convs.append(
|
297 |
+
ConvModule(
|
298 |
+
self.in_channels,
|
299 |
+
self.channels,
|
300 |
+
kernel_size=kernel_size,
|
301 |
+
padding=kernel_size // 2,
|
302 |
+
conv_cfg=self.conv_cfg,
|
303 |
+
norm_cfg=self.norm_cfg,
|
304 |
+
act_cfg=self.act_cfg))
|
305 |
+
for _ in range(num_convs - 1):
|
306 |
+
convs.append(
|
307 |
+
ConvModule(
|
308 |
+
self.channels,
|
309 |
+
self.channels,
|
310 |
+
kernel_size=kernel_size,
|
311 |
+
padding=kernel_size // 2,
|
312 |
+
conv_cfg=self.conv_cfg,
|
313 |
+
norm_cfg=self.norm_cfg,
|
314 |
+
act_cfg=self.act_cfg))
|
315 |
+
if num_convs == 0:
|
316 |
+
convs_list.append(nn.Identity())
|
317 |
+
else:
|
318 |
+
convs_list.append(nn.Sequential(*convs))
|
319 |
+
if self.concat_input:
|
320 |
+
conv_cat_list.append(
|
321 |
+
ConvModule(
|
322 |
+
self.in_channels + self.channels,
|
323 |
+
self.channels,
|
324 |
+
kernel_size=kernel_size,
|
325 |
+
padding=kernel_size // 2,
|
326 |
+
conv_cfg=self.conv_cfg,
|
327 |
+
norm_cfg=self.norm_cfg,
|
328 |
+
act_cfg=self.act_cfg))
|
329 |
+
|
330 |
+
self.convs_list = nn.ModuleList(convs_list)
|
331 |
+
self.conv_cat_list = nn.ModuleList(conv_cat_list)
|
332 |
+
|
333 |
+
def forward(self, inputs):
|
334 |
+
"""Forward function."""
|
335 |
+
x = self._transform_inputs(inputs)
|
336 |
+
|
337 |
+
output_list = []
|
338 |
+
for head_idx in range(self.num_head):
|
339 |
+
output = self.convs_list[head_idx](x)
|
340 |
+
if self.concat_input:
|
341 |
+
output = self.conv_cat_list[head_idx](
|
342 |
+
torch.cat([x, output], dim=1))
|
343 |
+
if self.dropout is not None:
|
344 |
+
output = self.dropout(output)
|
345 |
+
output = self.conv_seg_head_list[head_idx](output)
|
346 |
+
output_list.append(output)
|
347 |
+
|
348 |
+
return output_list
|
349 |
+
|
350 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
351 |
+
"""Check and initialize input transforms.
|
352 |
+
|
353 |
+
The in_channels, in_index and input_transform must match.
|
354 |
+
Specifically, when input_transform is None, only single feature map
|
355 |
+
will be selected. So in_channels and in_index must be of type int.
|
356 |
+
When input_transform
|
357 |
+
|
358 |
+
Args:
|
359 |
+
in_channels (int|Sequence[int]): Input channels.
|
360 |
+
in_index (int|Sequence[int]): Input feature index.
|
361 |
+
input_transform (str|None): Transformation type of input features.
|
362 |
+
Options: 'resize_concat', 'multiple_select', None.
|
363 |
+
'resize_concat': Multiple feature maps will be resize to the
|
364 |
+
same size as first one and than concat together.
|
365 |
+
Usually used in FCN head of HRNet.
|
366 |
+
'multiple_select': Multiple feature maps will be bundle into
|
367 |
+
a list and passed into decode head.
|
368 |
+
None: Only one select feature map is allowed.
|
369 |
+
"""
|
370 |
+
|
371 |
+
if input_transform is not None:
|
372 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
373 |
+
self.input_transform = input_transform
|
374 |
+
self.in_index = in_index
|
375 |
+
if input_transform is not None:
|
376 |
+
assert isinstance(in_channels, (list, tuple))
|
377 |
+
assert isinstance(in_index, (list, tuple))
|
378 |
+
assert len(in_channels) == len(in_index)
|
379 |
+
if input_transform == 'resize_concat':
|
380 |
+
self.in_channels = sum(in_channels)
|
381 |
+
else:
|
382 |
+
self.in_channels = in_channels
|
383 |
+
else:
|
384 |
+
assert isinstance(in_channels, int)
|
385 |
+
assert isinstance(in_index, int)
|
386 |
+
self.in_channels = in_channels
|
387 |
+
|
388 |
+
def init_weights(self):
|
389 |
+
"""Initialize weights of classification layer."""
|
390 |
+
for conv_seg_head in self.conv_seg_head_list:
|
391 |
+
normal_init(conv_seg_head, mean=0, std=0.01)
|
392 |
+
|
393 |
+
def _transform_inputs(self, inputs):
|
394 |
+
"""Transform inputs for decoder.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
inputs (list[Tensor]): List of multi-level img features.
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
Tensor: The transformed inputs
|
401 |
+
"""
|
402 |
+
|
403 |
+
if self.input_transform == 'resize_concat':
|
404 |
+
inputs = [inputs[i] for i in self.in_index]
|
405 |
+
upsampled_inputs = [
|
406 |
+
resize(
|
407 |
+
input=x,
|
408 |
+
size=inputs[0].shape[2:],
|
409 |
+
mode='bilinear',
|
410 |
+
align_corners=self.align_corners) for x in inputs
|
411 |
+
]
|
412 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
413 |
+
elif self.input_transform == 'multiple_select':
|
414 |
+
inputs = [inputs[i] for i in self.in_index]
|
415 |
+
else:
|
416 |
+
inputs = inputs[self.in_index]
|
417 |
+
|
418 |
+
return inputs
|
Text2Human/models/archs/shape_attr_embedding_arch.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class ShapeAttrEmbedding(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, dim, out_dim, cls_num_list):
|
9 |
+
super(ShapeAttrEmbedding, self).__init__()
|
10 |
+
|
11 |
+
for idx, cls_num in enumerate(cls_num_list):
|
12 |
+
setattr(
|
13 |
+
self, f'attr_{idx}',
|
14 |
+
nn.Sequential(
|
15 |
+
nn.Linear(cls_num, dim), nn.LeakyReLU(),
|
16 |
+
nn.Linear(dim, dim)))
|
17 |
+
self.cls_num_list = cls_num_list
|
18 |
+
self.attr_num = len(cls_num_list)
|
19 |
+
self.fusion = nn.Sequential(
|
20 |
+
nn.Linear(dim * self.attr_num, out_dim), nn.LeakyReLU(),
|
21 |
+
nn.Linear(out_dim, out_dim))
|
22 |
+
|
23 |
+
def forward(self, attr):
|
24 |
+
attr_embedding_list = []
|
25 |
+
for idx in range(self.attr_num):
|
26 |
+
attr_embed_fc = getattr(self, f'attr_{idx}')
|
27 |
+
attr_embedding_list.append(
|
28 |
+
attr_embed_fc(
|
29 |
+
F.one_hot(
|
30 |
+
attr[:, idx],
|
31 |
+
num_classes=self.cls_num_list[idx]).to(torch.float32)))
|
32 |
+
attr_embedding = torch.cat(attr_embedding_list, dim=1)
|
33 |
+
attr_embedding = self.fusion(attr_embedding)
|
34 |
+
|
35 |
+
return attr_embedding
|
Text2Human/models/archs/transformer_arch.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class CausalSelfAttention(nn.Module):
|
10 |
+
"""
|
11 |
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
12 |
+
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
13 |
+
explicit implementation here to show that there is nothing too scary here.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, bert_n_emb, bert_n_head, attn_pdrop, resid_pdrop,
|
17 |
+
latent_shape, sampler):
|
18 |
+
super().__init__()
|
19 |
+
assert bert_n_emb % bert_n_head == 0
|
20 |
+
# key, query, value projections for all heads
|
21 |
+
self.key = nn.Linear(bert_n_emb, bert_n_emb)
|
22 |
+
self.query = nn.Linear(bert_n_emb, bert_n_emb)
|
23 |
+
self.value = nn.Linear(bert_n_emb, bert_n_emb)
|
24 |
+
# regularization
|
25 |
+
self.attn_drop = nn.Dropout(attn_pdrop)
|
26 |
+
self.resid_drop = nn.Dropout(resid_pdrop)
|
27 |
+
# output projection
|
28 |
+
self.proj = nn.Linear(bert_n_emb, bert_n_emb)
|
29 |
+
self.n_head = bert_n_head
|
30 |
+
self.causal = True if sampler == 'autoregressive' else False
|
31 |
+
if self.causal:
|
32 |
+
block_size = np.prod(latent_shape)
|
33 |
+
mask = torch.tril(torch.ones(block_size, block_size))
|
34 |
+
self.register_buffer("mask", mask.view(1, 1, block_size,
|
35 |
+
block_size))
|
36 |
+
|
37 |
+
def forward(self, x, layer_past=None):
|
38 |
+
B, T, C = x.size()
|
39 |
+
|
40 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
41 |
+
k = self.key(x).view(B, T, self.n_head,
|
42 |
+
C // self.n_head).transpose(1,
|
43 |
+
2) # (B, nh, T, hs)
|
44 |
+
q = self.query(x).view(B, T, self.n_head,
|
45 |
+
C // self.n_head).transpose(1,
|
46 |
+
2) # (B, nh, T, hs)
|
47 |
+
v = self.value(x).view(B, T, self.n_head,
|
48 |
+
C // self.n_head).transpose(1,
|
49 |
+
2) # (B, nh, T, hs)
|
50 |
+
|
51 |
+
present = torch.stack((k, v))
|
52 |
+
if self.causal and layer_past is not None:
|
53 |
+
past_key, past_value = layer_past
|
54 |
+
k = torch.cat((past_key, k), dim=-2)
|
55 |
+
v = torch.cat((past_value, v), dim=-2)
|
56 |
+
|
57 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
58 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
59 |
+
|
60 |
+
if self.causal and layer_past is None:
|
61 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
|
62 |
+
|
63 |
+
att = F.softmax(att, dim=-1)
|
64 |
+
att = self.attn_drop(att)
|
65 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
66 |
+
# re-assemble all head outputs side by side
|
67 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
68 |
+
|
69 |
+
# output projection
|
70 |
+
y = self.resid_drop(self.proj(y))
|
71 |
+
return y, present
|
72 |
+
|
73 |
+
|
74 |
+
class Block(nn.Module):
|
75 |
+
""" an unassuming Transformer block """
|
76 |
+
|
77 |
+
def __init__(self, bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
|
78 |
+
latent_shape, sampler):
|
79 |
+
super().__init__()
|
80 |
+
self.ln1 = nn.LayerNorm(bert_n_emb)
|
81 |
+
self.ln2 = nn.LayerNorm(bert_n_emb)
|
82 |
+
self.attn = CausalSelfAttention(bert_n_emb, bert_n_head, attn_pdrop,
|
83 |
+
resid_pdrop, latent_shape, sampler)
|
84 |
+
self.mlp = nn.Sequential(
|
85 |
+
nn.Linear(bert_n_emb, 4 * bert_n_emb),
|
86 |
+
nn.GELU(), # nice
|
87 |
+
nn.Linear(4 * bert_n_emb, bert_n_emb),
|
88 |
+
nn.Dropout(resid_pdrop),
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x, layer_past=None, return_present=False):
|
92 |
+
|
93 |
+
attn, present = self.attn(self.ln1(x), layer_past)
|
94 |
+
x = x + attn
|
95 |
+
x = x + self.mlp(self.ln2(x))
|
96 |
+
|
97 |
+
if layer_past is not None or return_present:
|
98 |
+
return x, present
|
99 |
+
return x
|
100 |
+
|
101 |
+
|
102 |
+
class Transformer(nn.Module):
|
103 |
+
""" the full GPT language model, with a context size of block_size """
|
104 |
+
|
105 |
+
def __init__(self,
|
106 |
+
codebook_size,
|
107 |
+
segm_codebook_size,
|
108 |
+
bert_n_emb,
|
109 |
+
bert_n_layers,
|
110 |
+
bert_n_head,
|
111 |
+
block_size,
|
112 |
+
latent_shape,
|
113 |
+
embd_pdrop,
|
114 |
+
resid_pdrop,
|
115 |
+
attn_pdrop,
|
116 |
+
sampler='absorbing'):
|
117 |
+
super().__init__()
|
118 |
+
|
119 |
+
self.vocab_size = codebook_size + 1
|
120 |
+
self.n_embd = bert_n_emb
|
121 |
+
self.block_size = block_size
|
122 |
+
self.n_layers = bert_n_layers
|
123 |
+
self.codebook_size = codebook_size
|
124 |
+
self.segm_codebook_size = segm_codebook_size
|
125 |
+
self.causal = sampler == 'autoregressive'
|
126 |
+
if self.causal:
|
127 |
+
self.vocab_size = codebook_size
|
128 |
+
|
129 |
+
self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
|
130 |
+
self.pos_emb = nn.Parameter(
|
131 |
+
torch.zeros(1, self.block_size, self.n_embd))
|
132 |
+
self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
|
133 |
+
self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
|
134 |
+
self.drop = nn.Dropout(embd_pdrop)
|
135 |
+
|
136 |
+
# transformer
|
137 |
+
self.blocks = nn.Sequential(*[
|
138 |
+
Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
|
139 |
+
latent_shape, sampler) for _ in range(self.n_layers)
|
140 |
+
])
|
141 |
+
# decoder head
|
142 |
+
self.ln_f = nn.LayerNorm(self.n_embd)
|
143 |
+
self.head = nn.Linear(self.n_embd, self.codebook_size, bias=False)
|
144 |
+
|
145 |
+
def get_block_size(self):
|
146 |
+
return self.block_size
|
147 |
+
|
148 |
+
def _init_weights(self, module):
|
149 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
150 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
151 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
152 |
+
module.bias.data.zero_()
|
153 |
+
elif isinstance(module, nn.LayerNorm):
|
154 |
+
module.bias.data.zero_()
|
155 |
+
module.weight.data.fill_(1.0)
|
156 |
+
|
157 |
+
def forward(self, idx, segm_tokens, t=None):
|
158 |
+
# each index maps to a (learnable) vector
|
159 |
+
token_embeddings = self.tok_emb(idx)
|
160 |
+
|
161 |
+
segm_embeddings = self.segm_emb(segm_tokens)
|
162 |
+
|
163 |
+
if self.causal:
|
164 |
+
token_embeddings = torch.cat((self.start_tok.repeat(
|
165 |
+
token_embeddings.size(0), 1, 1), token_embeddings),
|
166 |
+
dim=1)
|
167 |
+
|
168 |
+
t = token_embeddings.shape[1]
|
169 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
170 |
+
# each position maps to a (learnable) vector
|
171 |
+
|
172 |
+
position_embeddings = self.pos_emb[:, :t, :]
|
173 |
+
|
174 |
+
x = token_embeddings + position_embeddings + segm_embeddings
|
175 |
+
x = self.drop(x)
|
176 |
+
for block in self.blocks:
|
177 |
+
x = block(x)
|
178 |
+
x = self.ln_f(x)
|
179 |
+
logits = self.head(x)
|
180 |
+
|
181 |
+
return logits
|
182 |
+
|
183 |
+
|
184 |
+
class TransformerMultiHead(nn.Module):
|
185 |
+
""" the full GPT language model, with a context size of block_size """
|
186 |
+
|
187 |
+
def __init__(self,
|
188 |
+
codebook_size,
|
189 |
+
segm_codebook_size,
|
190 |
+
texture_codebook_size,
|
191 |
+
bert_n_emb,
|
192 |
+
bert_n_layers,
|
193 |
+
bert_n_head,
|
194 |
+
block_size,
|
195 |
+
latent_shape,
|
196 |
+
embd_pdrop,
|
197 |
+
resid_pdrop,
|
198 |
+
attn_pdrop,
|
199 |
+
num_head,
|
200 |
+
sampler='absorbing'):
|
201 |
+
super().__init__()
|
202 |
+
|
203 |
+
self.vocab_size = codebook_size + 1
|
204 |
+
self.n_embd = bert_n_emb
|
205 |
+
self.block_size = block_size
|
206 |
+
self.n_layers = bert_n_layers
|
207 |
+
self.codebook_size = codebook_size
|
208 |
+
self.segm_codebook_size = segm_codebook_size
|
209 |
+
self.texture_codebook_size = texture_codebook_size
|
210 |
+
self.causal = sampler == 'autoregressive'
|
211 |
+
if self.causal:
|
212 |
+
self.vocab_size = codebook_size
|
213 |
+
|
214 |
+
self.tok_emb = nn.Embedding(self.vocab_size, self.n_embd)
|
215 |
+
self.pos_emb = nn.Parameter(
|
216 |
+
torch.zeros(1, self.block_size, self.n_embd))
|
217 |
+
self.segm_emb = nn.Embedding(self.segm_codebook_size, self.n_embd)
|
218 |
+
self.texture_emb = nn.Embedding(self.texture_codebook_size,
|
219 |
+
self.n_embd)
|
220 |
+
self.start_tok = nn.Parameter(torch.zeros(1, 1, self.n_embd))
|
221 |
+
self.drop = nn.Dropout(embd_pdrop)
|
222 |
+
|
223 |
+
# transformer
|
224 |
+
self.blocks = nn.Sequential(*[
|
225 |
+
Block(bert_n_emb, resid_pdrop, bert_n_head, attn_pdrop,
|
226 |
+
latent_shape, sampler) for _ in range(self.n_layers)
|
227 |
+
])
|
228 |
+
# decoder head
|
229 |
+
self.num_head = num_head
|
230 |
+
self.head_class_num = codebook_size // self.num_head
|
231 |
+
self.ln_f = nn.LayerNorm(self.n_embd)
|
232 |
+
self.head_list = nn.ModuleList([
|
233 |
+
nn.Linear(self.n_embd, self.head_class_num, bias=False)
|
234 |
+
for _ in range(self.num_head)
|
235 |
+
])
|
236 |
+
|
237 |
+
def get_block_size(self):
|
238 |
+
return self.block_size
|
239 |
+
|
240 |
+
def _init_weights(self, module):
|
241 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
242 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
243 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
244 |
+
module.bias.data.zero_()
|
245 |
+
elif isinstance(module, nn.LayerNorm):
|
246 |
+
module.bias.data.zero_()
|
247 |
+
module.weight.data.fill_(1.0)
|
248 |
+
|
249 |
+
def forward(self, idx, segm_tokens, texture_tokens, t=None):
|
250 |
+
# each index maps to a (learnable) vector
|
251 |
+
token_embeddings = self.tok_emb(idx)
|
252 |
+
segm_embeddings = self.segm_emb(segm_tokens)
|
253 |
+
texture_embeddings = self.texture_emb(texture_tokens)
|
254 |
+
|
255 |
+
if self.causal:
|
256 |
+
token_embeddings = torch.cat((self.start_tok.repeat(
|
257 |
+
token_embeddings.size(0), 1, 1), token_embeddings),
|
258 |
+
dim=1)
|
259 |
+
|
260 |
+
t = token_embeddings.shape[1]
|
261 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
262 |
+
# each position maps to a (learnable) vector
|
263 |
+
|
264 |
+
position_embeddings = self.pos_emb[:, :t, :]
|
265 |
+
|
266 |
+
x = token_embeddings + position_embeddings + segm_embeddings + texture_embeddings
|
267 |
+
x = self.drop(x)
|
268 |
+
for block in self.blocks:
|
269 |
+
x = block(x)
|
270 |
+
x = self.ln_f(x)
|
271 |
+
logits_list = [self.head_list[i](x) for i in range(self.num_head)]
|
272 |
+
|
273 |
+
return logits_list
|
Text2Human/models/archs/unet_arch.py
ADDED
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.utils.checkpoint as cp
|
4 |
+
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
5 |
+
build_norm_layer, build_upsample_layer, constant_init,
|
6 |
+
kaiming_init)
|
7 |
+
from mmcv.runner import load_checkpoint
|
8 |
+
from mmcv.utils.parrots_wrapper import _BatchNorm
|
9 |
+
from mmseg.utils import get_root_logger
|
10 |
+
|
11 |
+
|
12 |
+
class UpConvBlock(nn.Module):
|
13 |
+
"""Upsample convolution block in decoder for UNet.
|
14 |
+
|
15 |
+
This upsample convolution block consists of one upsample module
|
16 |
+
followed by one convolution block. The upsample module expands the
|
17 |
+
high-level low-resolution feature map and the convolution block fuses
|
18 |
+
the upsampled high-level low-resolution feature map and the low-level
|
19 |
+
high-resolution feature map from encoder.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
conv_block (nn.Sequential): Sequential of convolutional layers.
|
23 |
+
in_channels (int): Number of input channels of the high-level
|
24 |
+
skip_channels (int): Number of input channels of the low-level
|
25 |
+
high-resolution feature map from encoder.
|
26 |
+
out_channels (int): Number of output channels.
|
27 |
+
num_convs (int): Number of convolutional layers in the conv_block.
|
28 |
+
Default: 2.
|
29 |
+
stride (int): Stride of convolutional layer in conv_block. Default: 1.
|
30 |
+
dilation (int): Dilation rate of convolutional layer in conv_block.
|
31 |
+
Default: 1.
|
32 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
33 |
+
memory while slowing down the training speed. Default: False.
|
34 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
35 |
+
Default: None.
|
36 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
37 |
+
Default: dict(type='BN').
|
38 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
39 |
+
Default: dict(type='ReLU').
|
40 |
+
upsample_cfg (dict): The upsample config of the upsample module in
|
41 |
+
decoder. Default: dict(type='InterpConv'). If the size of
|
42 |
+
high-level feature map is the same as that of skip feature map
|
43 |
+
(low-level feature map from encoder), it does not need upsample the
|
44 |
+
high-level feature map and the upsample_cfg is None.
|
45 |
+
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
46 |
+
Default: None.
|
47 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self,
|
51 |
+
conv_block,
|
52 |
+
in_channels,
|
53 |
+
skip_channels,
|
54 |
+
out_channels,
|
55 |
+
num_convs=2,
|
56 |
+
stride=1,
|
57 |
+
dilation=1,
|
58 |
+
with_cp=False,
|
59 |
+
conv_cfg=None,
|
60 |
+
norm_cfg=dict(type='BN'),
|
61 |
+
act_cfg=dict(type='ReLU'),
|
62 |
+
upsample_cfg=dict(type='InterpConv'),
|
63 |
+
dcn=None,
|
64 |
+
plugins=None):
|
65 |
+
super(UpConvBlock, self).__init__()
|
66 |
+
assert dcn is None, 'Not implemented yet.'
|
67 |
+
assert plugins is None, 'Not implemented yet.'
|
68 |
+
|
69 |
+
self.conv_block = conv_block(
|
70 |
+
in_channels=2 * skip_channels,
|
71 |
+
out_channels=out_channels,
|
72 |
+
num_convs=num_convs,
|
73 |
+
stride=stride,
|
74 |
+
dilation=dilation,
|
75 |
+
with_cp=with_cp,
|
76 |
+
conv_cfg=conv_cfg,
|
77 |
+
norm_cfg=norm_cfg,
|
78 |
+
act_cfg=act_cfg,
|
79 |
+
dcn=None,
|
80 |
+
plugins=None)
|
81 |
+
if upsample_cfg is not None:
|
82 |
+
self.upsample = build_upsample_layer(
|
83 |
+
cfg=upsample_cfg,
|
84 |
+
in_channels=in_channels,
|
85 |
+
out_channels=skip_channels,
|
86 |
+
with_cp=with_cp,
|
87 |
+
norm_cfg=norm_cfg,
|
88 |
+
act_cfg=act_cfg)
|
89 |
+
else:
|
90 |
+
self.upsample = ConvModule(
|
91 |
+
in_channels,
|
92 |
+
skip_channels,
|
93 |
+
kernel_size=1,
|
94 |
+
stride=1,
|
95 |
+
padding=0,
|
96 |
+
conv_cfg=conv_cfg,
|
97 |
+
norm_cfg=norm_cfg,
|
98 |
+
act_cfg=act_cfg)
|
99 |
+
|
100 |
+
def forward(self, skip, x):
|
101 |
+
"""Forward function."""
|
102 |
+
|
103 |
+
x = self.upsample(x)
|
104 |
+
out = torch.cat([skip, x], dim=1)
|
105 |
+
out = self.conv_block(out)
|
106 |
+
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
class BasicConvBlock(nn.Module):
|
111 |
+
"""Basic convolutional block for UNet.
|
112 |
+
|
113 |
+
This module consists of several plain convolutional layers.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
in_channels (int): Number of input channels.
|
117 |
+
out_channels (int): Number of output channels.
|
118 |
+
num_convs (int): Number of convolutional layers. Default: 2.
|
119 |
+
stride (int): Whether use stride convolution to downsample
|
120 |
+
the input feature map. If stride=2, it only uses stride convolution
|
121 |
+
in the first convolutional layer to downsample the input feature
|
122 |
+
map. Options are 1 or 2. Default: 1.
|
123 |
+
dilation (int): Whether use dilated convolution to expand the
|
124 |
+
receptive field. Set dilation rate of each convolutional layer and
|
125 |
+
the dilation rate of the first convolutional layer is always 1.
|
126 |
+
Default: 1.
|
127 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
128 |
+
memory while slowing down the training speed. Default: False.
|
129 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
130 |
+
Default: None.
|
131 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
132 |
+
Default: dict(type='BN').
|
133 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
134 |
+
Default: dict(type='ReLU').
|
135 |
+
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
136 |
+
Default: None.
|
137 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self,
|
141 |
+
in_channels,
|
142 |
+
out_channels,
|
143 |
+
num_convs=2,
|
144 |
+
stride=1,
|
145 |
+
dilation=1,
|
146 |
+
with_cp=False,
|
147 |
+
conv_cfg=None,
|
148 |
+
norm_cfg=dict(type='BN'),
|
149 |
+
act_cfg=dict(type='ReLU'),
|
150 |
+
dcn=None,
|
151 |
+
plugins=None):
|
152 |
+
super(BasicConvBlock, self).__init__()
|
153 |
+
assert dcn is None, 'Not implemented yet.'
|
154 |
+
assert plugins is None, 'Not implemented yet.'
|
155 |
+
|
156 |
+
self.with_cp = with_cp
|
157 |
+
convs = []
|
158 |
+
for i in range(num_convs):
|
159 |
+
convs.append(
|
160 |
+
ConvModule(
|
161 |
+
in_channels=in_channels if i == 0 else out_channels,
|
162 |
+
out_channels=out_channels,
|
163 |
+
kernel_size=3,
|
164 |
+
stride=stride if i == 0 else 1,
|
165 |
+
dilation=1 if i == 0 else dilation,
|
166 |
+
padding=1 if i == 0 else dilation,
|
167 |
+
conv_cfg=conv_cfg,
|
168 |
+
norm_cfg=norm_cfg,
|
169 |
+
act_cfg=act_cfg))
|
170 |
+
|
171 |
+
self.convs = nn.Sequential(*convs)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
"""Forward function."""
|
175 |
+
|
176 |
+
if self.with_cp and x.requires_grad:
|
177 |
+
out = cp.checkpoint(self.convs, x)
|
178 |
+
else:
|
179 |
+
out = self.convs(x)
|
180 |
+
return out
|
181 |
+
|
182 |
+
|
183 |
+
class DeconvModule(nn.Module):
|
184 |
+
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
185 |
+
|
186 |
+
This module uses deconvolution to upsample feature map in the decoder
|
187 |
+
of UNet.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
in_channels (int): Number of input channels.
|
191 |
+
out_channels (int): Number of output channels.
|
192 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
193 |
+
memory while slowing down the training speed. Default: False.
|
194 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
195 |
+
Default: dict(type='BN').
|
196 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
197 |
+
Default: dict(type='ReLU').
|
198 |
+
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self,
|
202 |
+
in_channels,
|
203 |
+
out_channels,
|
204 |
+
with_cp=False,
|
205 |
+
norm_cfg=dict(type='BN'),
|
206 |
+
act_cfg=dict(type='ReLU'),
|
207 |
+
*,
|
208 |
+
kernel_size=4,
|
209 |
+
scale_factor=2):
|
210 |
+
super(DeconvModule, self).__init__()
|
211 |
+
|
212 |
+
assert (kernel_size - scale_factor >= 0) and\
|
213 |
+
(kernel_size - scale_factor) % 2 == 0,\
|
214 |
+
f'kernel_size should be greater than or equal to scale_factor '\
|
215 |
+
f'and (kernel_size - scale_factor) should be even numbers, '\
|
216 |
+
f'while the kernel size is {kernel_size} and scale_factor is '\
|
217 |
+
f'{scale_factor}.'
|
218 |
+
|
219 |
+
stride = scale_factor
|
220 |
+
padding = (kernel_size - scale_factor) // 2
|
221 |
+
self.with_cp = with_cp
|
222 |
+
deconv = nn.ConvTranspose2d(
|
223 |
+
in_channels,
|
224 |
+
out_channels,
|
225 |
+
kernel_size=kernel_size,
|
226 |
+
stride=stride,
|
227 |
+
padding=padding)
|
228 |
+
|
229 |
+
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
|
230 |
+
activate = build_activation_layer(act_cfg)
|
231 |
+
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
"""Forward function."""
|
235 |
+
|
236 |
+
if self.with_cp and x.requires_grad:
|
237 |
+
out = cp.checkpoint(self.deconv_upsamping, x)
|
238 |
+
else:
|
239 |
+
out = self.deconv_upsamping(x)
|
240 |
+
return out
|
241 |
+
|
242 |
+
|
243 |
+
@UPSAMPLE_LAYERS.register_module()
|
244 |
+
class InterpConv(nn.Module):
|
245 |
+
"""Interpolation upsample module in decoder for UNet.
|
246 |
+
|
247 |
+
This module uses interpolation to upsample feature map in the decoder
|
248 |
+
of UNet. It consists of one interpolation upsample layer and one
|
249 |
+
convolutional layer. It can be one interpolation upsample layer followed
|
250 |
+
by one convolutional layer (conv_first=False) or one convolutional layer
|
251 |
+
followed by one interpolation upsample layer (conv_first=True).
|
252 |
+
|
253 |
+
Args:
|
254 |
+
in_channels (int): Number of input channels.
|
255 |
+
out_channels (int): Number of output channels.
|
256 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
257 |
+
memory while slowing down the training speed. Default: False.
|
258 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
259 |
+
Default: dict(type='BN').
|
260 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
261 |
+
Default: dict(type='ReLU').
|
262 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
263 |
+
Default: None.
|
264 |
+
conv_first (bool): Whether convolutional layer or interpolation
|
265 |
+
upsample layer first. Default: False. It means interpolation
|
266 |
+
upsample layer followed by one convolutional layer.
|
267 |
+
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
268 |
+
stride (int): Stride of the convolutional layer. Default: 1.
|
269 |
+
padding (int): Padding of the convolutional layer. Default: 1.
|
270 |
+
upsampe_cfg (dict): Interpolation config of the upsample layer.
|
271 |
+
Default: dict(
|
272 |
+
scale_factor=2, mode='bilinear', align_corners=False).
|
273 |
+
"""
|
274 |
+
|
275 |
+
def __init__(self,
|
276 |
+
in_channels,
|
277 |
+
out_channels,
|
278 |
+
with_cp=False,
|
279 |
+
norm_cfg=dict(type='BN'),
|
280 |
+
act_cfg=dict(type='ReLU'),
|
281 |
+
*,
|
282 |
+
conv_cfg=None,
|
283 |
+
conv_first=False,
|
284 |
+
kernel_size=1,
|
285 |
+
stride=1,
|
286 |
+
padding=0,
|
287 |
+
upsampe_cfg=dict(
|
288 |
+
scale_factor=2, mode='bilinear', align_corners=False)):
|
289 |
+
super(InterpConv, self).__init__()
|
290 |
+
|
291 |
+
self.with_cp = with_cp
|
292 |
+
conv = ConvModule(
|
293 |
+
in_channels,
|
294 |
+
out_channels,
|
295 |
+
kernel_size=kernel_size,
|
296 |
+
stride=stride,
|
297 |
+
padding=padding,
|
298 |
+
conv_cfg=conv_cfg,
|
299 |
+
norm_cfg=norm_cfg,
|
300 |
+
act_cfg=act_cfg)
|
301 |
+
upsample = nn.Upsample(**upsampe_cfg)
|
302 |
+
if conv_first:
|
303 |
+
self.interp_upsample = nn.Sequential(conv, upsample)
|
304 |
+
else:
|
305 |
+
self.interp_upsample = nn.Sequential(upsample, conv)
|
306 |
+
|
307 |
+
def forward(self, x):
|
308 |
+
"""Forward function."""
|
309 |
+
|
310 |
+
if self.with_cp and x.requires_grad:
|
311 |
+
out = cp.checkpoint(self.interp_upsample, x)
|
312 |
+
else:
|
313 |
+
out = self.interp_upsample(x)
|
314 |
+
return out
|
315 |
+
|
316 |
+
|
317 |
+
class UNet(nn.Module):
|
318 |
+
"""UNet backbone.
|
319 |
+
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
320 |
+
https://arxiv.org/pdf/1505.04597.pdf
|
321 |
+
|
322 |
+
Args:
|
323 |
+
in_channels (int): Number of input image channels. Default" 3.
|
324 |
+
base_channels (int): Number of base channels of each stage.
|
325 |
+
The output channels of the first stage. Default: 64.
|
326 |
+
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
327 |
+
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
328 |
+
len(strides) is equal to num_stages. Normally the stride of the
|
329 |
+
first stage in encoder is 1. If strides[i]=2, it uses stride
|
330 |
+
convolution to downsample in the correspondence encoder stage.
|
331 |
+
Default: (1, 1, 1, 1, 1).
|
332 |
+
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
333 |
+
convolution block of the correspondence encoder stage.
|
334 |
+
Default: (2, 2, 2, 2, 2).
|
335 |
+
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
336 |
+
convolution block of the correspondence decoder stage.
|
337 |
+
Default: (2, 2, 2, 2).
|
338 |
+
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
339 |
+
feature map after the first stage of encoder
|
340 |
+
(stages: [1, num_stages)). If the correspondence encoder stage use
|
341 |
+
stride convolution (strides[i]=2), it will never use MaxPool to
|
342 |
+
downsample, even downsamples[i-1]=True.
|
343 |
+
Default: (True, True, True, True).
|
344 |
+
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
345 |
+
Default: (1, 1, 1, 1, 1).
|
346 |
+
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
347 |
+
Default: (1, 1, 1, 1).
|
348 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
349 |
+
memory while slowing down the training speed. Default: False.
|
350 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
351 |
+
Default: None.
|
352 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
353 |
+
Default: dict(type='BN').
|
354 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
355 |
+
Default: dict(type='ReLU').
|
356 |
+
upsample_cfg (dict): The upsample config of the upsample module in
|
357 |
+
decoder. Default: dict(type='InterpConv').
|
358 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
359 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
360 |
+
and its variants only. Default: False.
|
361 |
+
dcn (bool): Use deformable convolution in convolutional layer or not.
|
362 |
+
Default: None.
|
363 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
364 |
+
|
365 |
+
Notice:
|
366 |
+
The input image size should be devisible by the whole downsample rate
|
367 |
+
of the encoder. More detail of the whole downsample rate can be found
|
368 |
+
in UNet._check_input_devisible.
|
369 |
+
|
370 |
+
"""
|
371 |
+
|
372 |
+
def __init__(self,
|
373 |
+
in_channels=3,
|
374 |
+
base_channels=64,
|
375 |
+
num_stages=5,
|
376 |
+
strides=(1, 1, 1, 1, 1),
|
377 |
+
enc_num_convs=(2, 2, 2, 2, 2),
|
378 |
+
dec_num_convs=(2, 2, 2, 2),
|
379 |
+
downsamples=(True, True, True, True),
|
380 |
+
enc_dilations=(1, 1, 1, 1, 1),
|
381 |
+
dec_dilations=(1, 1, 1, 1),
|
382 |
+
with_cp=False,
|
383 |
+
conv_cfg=None,
|
384 |
+
norm_cfg=dict(type='BN'),
|
385 |
+
act_cfg=dict(type='ReLU'),
|
386 |
+
upsample_cfg=dict(type='InterpConv'),
|
387 |
+
norm_eval=False,
|
388 |
+
dcn=None,
|
389 |
+
plugins=None):
|
390 |
+
super(UNet, self).__init__()
|
391 |
+
assert dcn is None, 'Not implemented yet.'
|
392 |
+
assert plugins is None, 'Not implemented yet.'
|
393 |
+
assert len(strides) == num_stages, \
|
394 |
+
'The length of strides should be equal to num_stages, '\
|
395 |
+
f'while the strides is {strides}, the length of '\
|
396 |
+
f'strides is {len(strides)}, and the num_stages is '\
|
397 |
+
f'{num_stages}.'
|
398 |
+
assert len(enc_num_convs) == num_stages, \
|
399 |
+
'The length of enc_num_convs should be equal to num_stages, '\
|
400 |
+
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
401 |
+
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
402 |
+
f'{num_stages}.'
|
403 |
+
assert len(dec_num_convs) == (num_stages-1), \
|
404 |
+
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
405 |
+
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
406 |
+
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
407 |
+
f'{num_stages}.'
|
408 |
+
assert len(downsamples) == (num_stages-1), \
|
409 |
+
'The length of downsamples should be equal to (num_stages-1), '\
|
410 |
+
f'while the downsamples is {downsamples}, the length of '\
|
411 |
+
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
412 |
+
f'{num_stages}.'
|
413 |
+
assert len(enc_dilations) == num_stages, \
|
414 |
+
'The length of enc_dilations should be equal to num_stages, '\
|
415 |
+
f'while the enc_dilations is {enc_dilations}, the length of '\
|
416 |
+
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
417 |
+
f'{num_stages}.'
|
418 |
+
assert len(dec_dilations) == (num_stages-1), \
|
419 |
+
'The length of dec_dilations should be equal to (num_stages-1), '\
|
420 |
+
f'while the dec_dilations is {dec_dilations}, the length of '\
|
421 |
+
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
422 |
+
f'{num_stages}.'
|
423 |
+
self.num_stages = num_stages
|
424 |
+
self.strides = strides
|
425 |
+
self.downsamples = downsamples
|
426 |
+
self.norm_eval = norm_eval
|
427 |
+
|
428 |
+
self.encoder = nn.ModuleList()
|
429 |
+
self.decoder = nn.ModuleList()
|
430 |
+
|
431 |
+
for i in range(num_stages):
|
432 |
+
enc_conv_block = []
|
433 |
+
if i != 0:
|
434 |
+
if strides[i] == 1 and downsamples[i - 1]:
|
435 |
+
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
436 |
+
upsample = (strides[i] != 1 or downsamples[i - 1])
|
437 |
+
self.decoder.append(
|
438 |
+
UpConvBlock(
|
439 |
+
conv_block=BasicConvBlock,
|
440 |
+
in_channels=base_channels * 2**i,
|
441 |
+
skip_channels=base_channels * 2**(i - 1),
|
442 |
+
out_channels=base_channels * 2**(i - 1),
|
443 |
+
num_convs=dec_num_convs[i - 1],
|
444 |
+
stride=1,
|
445 |
+
dilation=dec_dilations[i - 1],
|
446 |
+
with_cp=with_cp,
|
447 |
+
conv_cfg=conv_cfg,
|
448 |
+
norm_cfg=norm_cfg,
|
449 |
+
act_cfg=act_cfg,
|
450 |
+
upsample_cfg=upsample_cfg if upsample else None,
|
451 |
+
dcn=None,
|
452 |
+
plugins=None))
|
453 |
+
|
454 |
+
enc_conv_block.append(
|
455 |
+
BasicConvBlock(
|
456 |
+
in_channels=in_channels,
|
457 |
+
out_channels=base_channels * 2**i,
|
458 |
+
num_convs=enc_num_convs[i],
|
459 |
+
stride=strides[i],
|
460 |
+
dilation=enc_dilations[i],
|
461 |
+
with_cp=with_cp,
|
462 |
+
conv_cfg=conv_cfg,
|
463 |
+
norm_cfg=norm_cfg,
|
464 |
+
act_cfg=act_cfg,
|
465 |
+
dcn=None,
|
466 |
+
plugins=None))
|
467 |
+
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
468 |
+
in_channels = base_channels * 2**i
|
469 |
+
|
470 |
+
def forward(self, x):
|
471 |
+
enc_outs = []
|
472 |
+
|
473 |
+
for enc in self.encoder:
|
474 |
+
x = enc(x)
|
475 |
+
enc_outs.append(x)
|
476 |
+
dec_outs = [x]
|
477 |
+
for i in reversed(range(len(self.decoder))):
|
478 |
+
x = self.decoder[i](enc_outs[i], x)
|
479 |
+
dec_outs.append(x)
|
480 |
+
|
481 |
+
return dec_outs
|
482 |
+
|
483 |
+
def init_weights(self, pretrained=None):
|
484 |
+
"""Initialize the weights in backbone.
|
485 |
+
|
486 |
+
Args:
|
487 |
+
pretrained (str, optional): Path to pre-trained weights.
|
488 |
+
Defaults to None.
|
489 |
+
"""
|
490 |
+
if isinstance(pretrained, str):
|
491 |
+
logger = get_root_logger()
|
492 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
493 |
+
elif pretrained is None:
|
494 |
+
for m in self.modules():
|
495 |
+
if isinstance(m, nn.Conv2d):
|
496 |
+
kaiming_init(m)
|
497 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
498 |
+
constant_init(m, 1)
|
499 |
+
else:
|
500 |
+
raise TypeError('pretrained must be a str or None')
|
501 |
+
|
502 |
+
|
503 |
+
class ShapeUNet(nn.Module):
|
504 |
+
"""ShapeUNet backbone with small modifications.
|
505 |
+
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
506 |
+
https://arxiv.org/pdf/1505.04597.pdf
|
507 |
+
|
508 |
+
Args:
|
509 |
+
in_channels (int): Number of input image channels. Default" 3.
|
510 |
+
base_channels (int): Number of base channels of each stage.
|
511 |
+
The output channels of the first stage. Default: 64.
|
512 |
+
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
513 |
+
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
514 |
+
len(strides) is equal to num_stages. Normally the stride of the
|
515 |
+
first stage in encoder is 1. If strides[i]=2, it uses stride
|
516 |
+
convolution to downsample in the correspondance encoder stage.
|
517 |
+
Default: (1, 1, 1, 1, 1).
|
518 |
+
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
519 |
+
convolution block of the correspondance encoder stage.
|
520 |
+
Default: (2, 2, 2, 2, 2).
|
521 |
+
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
522 |
+
convolution block of the correspondance decoder stage.
|
523 |
+
Default: (2, 2, 2, 2).
|
524 |
+
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
525 |
+
feature map after the first stage of encoder
|
526 |
+
(stages: [1, num_stages)). If the correspondance encoder stage use
|
527 |
+
stride convolution (strides[i]=2), it will never use MaxPool to
|
528 |
+
downsample, even downsamples[i-1]=True.
|
529 |
+
Default: (True, True, True, True).
|
530 |
+
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
531 |
+
Default: (1, 1, 1, 1, 1).
|
532 |
+
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
533 |
+
Default: (1, 1, 1, 1).
|
534 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
535 |
+
memory while slowing down the training speed. Default: False.
|
536 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
537 |
+
Default: None.
|
538 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
539 |
+
Default: dict(type='BN').
|
540 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
541 |
+
Default: dict(type='ReLU').
|
542 |
+
upsample_cfg (dict): The upsample config of the upsample module in
|
543 |
+
decoder. Default: dict(type='InterpConv').
|
544 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
545 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
546 |
+
and its variants only. Default: False.
|
547 |
+
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
548 |
+
Default: None.
|
549 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
550 |
+
|
551 |
+
Notice:
|
552 |
+
The input image size should be devisible by the whole downsample rate
|
553 |
+
of the encoder. More detail of the whole downsample rate can be found
|
554 |
+
in UNet._check_input_devisible.
|
555 |
+
|
556 |
+
"""
|
557 |
+
|
558 |
+
def __init__(self,
|
559 |
+
in_channels=3,
|
560 |
+
base_channels=64,
|
561 |
+
num_stages=5,
|
562 |
+
attr_embedding=128,
|
563 |
+
strides=(1, 1, 1, 1, 1),
|
564 |
+
enc_num_convs=(2, 2, 2, 2, 2),
|
565 |
+
dec_num_convs=(2, 2, 2, 2),
|
566 |
+
downsamples=(True, True, True, True),
|
567 |
+
enc_dilations=(1, 1, 1, 1, 1),
|
568 |
+
dec_dilations=(1, 1, 1, 1),
|
569 |
+
with_cp=False,
|
570 |
+
conv_cfg=None,
|
571 |
+
norm_cfg=dict(type='BN'),
|
572 |
+
act_cfg=dict(type='ReLU'),
|
573 |
+
upsample_cfg=dict(type='InterpConv'),
|
574 |
+
norm_eval=False,
|
575 |
+
dcn=None,
|
576 |
+
plugins=None):
|
577 |
+
super(ShapeUNet, self).__init__()
|
578 |
+
assert dcn is None, 'Not implemented yet.'
|
579 |
+
assert plugins is None, 'Not implemented yet.'
|
580 |
+
assert len(strides) == num_stages, \
|
581 |
+
'The length of strides should be equal to num_stages, '\
|
582 |
+
f'while the strides is {strides}, the length of '\
|
583 |
+
f'strides is {len(strides)}, and the num_stages is '\
|
584 |
+
f'{num_stages}.'
|
585 |
+
assert len(enc_num_convs) == num_stages, \
|
586 |
+
'The length of enc_num_convs should be equal to num_stages, '\
|
587 |
+
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
588 |
+
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
589 |
+
f'{num_stages}.'
|
590 |
+
assert len(dec_num_convs) == (num_stages-1), \
|
591 |
+
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
592 |
+
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
593 |
+
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
594 |
+
f'{num_stages}.'
|
595 |
+
assert len(downsamples) == (num_stages-1), \
|
596 |
+
'The length of downsamples should be equal to (num_stages-1), '\
|
597 |
+
f'while the downsamples is {downsamples}, the length of '\
|
598 |
+
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
599 |
+
f'{num_stages}.'
|
600 |
+
assert len(enc_dilations) == num_stages, \
|
601 |
+
'The length of enc_dilations should be equal to num_stages, '\
|
602 |
+
f'while the enc_dilations is {enc_dilations}, the length of '\
|
603 |
+
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
604 |
+
f'{num_stages}.'
|
605 |
+
assert len(dec_dilations) == (num_stages-1), \
|
606 |
+
'The length of dec_dilations should be equal to (num_stages-1), '\
|
607 |
+
f'while the dec_dilations is {dec_dilations}, the length of '\
|
608 |
+
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
609 |
+
f'{num_stages}.'
|
610 |
+
self.num_stages = num_stages
|
611 |
+
self.strides = strides
|
612 |
+
self.downsamples = downsamples
|
613 |
+
self.norm_eval = norm_eval
|
614 |
+
|
615 |
+
self.encoder = nn.ModuleList()
|
616 |
+
self.decoder = nn.ModuleList()
|
617 |
+
|
618 |
+
for i in range(num_stages):
|
619 |
+
enc_conv_block = []
|
620 |
+
if i != 0:
|
621 |
+
if strides[i] == 1 and downsamples[i - 1]:
|
622 |
+
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
623 |
+
upsample = (strides[i] != 1 or downsamples[i - 1])
|
624 |
+
self.decoder.append(
|
625 |
+
UpConvBlock(
|
626 |
+
conv_block=BasicConvBlock,
|
627 |
+
in_channels=base_channels * 2**i,
|
628 |
+
skip_channels=base_channels * 2**(i - 1),
|
629 |
+
out_channels=base_channels * 2**(i - 1),
|
630 |
+
num_convs=dec_num_convs[i - 1],
|
631 |
+
stride=1,
|
632 |
+
dilation=dec_dilations[i - 1],
|
633 |
+
with_cp=with_cp,
|
634 |
+
conv_cfg=conv_cfg,
|
635 |
+
norm_cfg=norm_cfg,
|
636 |
+
act_cfg=act_cfg,
|
637 |
+
upsample_cfg=upsample_cfg if upsample else None,
|
638 |
+
dcn=None,
|
639 |
+
plugins=None))
|
640 |
+
|
641 |
+
enc_conv_block.append(
|
642 |
+
BasicConvBlock(
|
643 |
+
in_channels=in_channels + attr_embedding,
|
644 |
+
out_channels=base_channels * 2**i,
|
645 |
+
num_convs=enc_num_convs[i],
|
646 |
+
stride=strides[i],
|
647 |
+
dilation=enc_dilations[i],
|
648 |
+
with_cp=with_cp,
|
649 |
+
conv_cfg=conv_cfg,
|
650 |
+
norm_cfg=norm_cfg,
|
651 |
+
act_cfg=act_cfg,
|
652 |
+
dcn=None,
|
653 |
+
plugins=None))
|
654 |
+
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
655 |
+
in_channels = base_channels * 2**i
|
656 |
+
|
657 |
+
def forward(self, x, attr_embedding):
|
658 |
+
enc_outs = []
|
659 |
+
Be, Ce = attr_embedding.size()
|
660 |
+
for enc in self.encoder:
|
661 |
+
_, _, H, W = x.size()
|
662 |
+
x = enc(
|
663 |
+
torch.cat([
|
664 |
+
x,
|
665 |
+
attr_embedding.view(Be, Ce, 1, 1).expand((Be, Ce, H, W))
|
666 |
+
],
|
667 |
+
dim=1))
|
668 |
+
enc_outs.append(x)
|
669 |
+
dec_outs = [x]
|
670 |
+
for i in reversed(range(len(self.decoder))):
|
671 |
+
x = self.decoder[i](enc_outs[i], x)
|
672 |
+
dec_outs.append(x)
|
673 |
+
|
674 |
+
return dec_outs
|
675 |
+
|
676 |
+
def init_weights(self, pretrained=None):
|
677 |
+
"""Initialize the weights in backbone.
|
678 |
+
|
679 |
+
Args:
|
680 |
+
pretrained (str, optional): Path to pre-trained weights.
|
681 |
+
Defaults to None.
|
682 |
+
"""
|
683 |
+
if isinstance(pretrained, str):
|
684 |
+
logger = get_root_logger()
|
685 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
686 |
+
elif pretrained is None:
|
687 |
+
for m in self.modules():
|
688 |
+
if isinstance(m, nn.Conv2d):
|
689 |
+
kaiming_init(m)
|
690 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
691 |
+
constant_init(m, 1)
|
692 |
+
else:
|
693 |
+
raise TypeError('pretrained must be a str or None')
|
Text2Human/models/archs/vqgan_arch.py
ADDED
@@ -0,0 +1,1203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_diffusion + derived encoder decoder
|
2 |
+
import math
|
3 |
+
from urllib.request import proxy_bypass
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
|
12 |
+
class VectorQuantizer(nn.Module):
|
13 |
+
"""
|
14 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
15 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
16 |
+
"""
|
17 |
+
|
18 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
19 |
+
# backwards compatibility we use the buggy version by default, but you can
|
20 |
+
# specify legacy=False to fix it.
|
21 |
+
def __init__(self,
|
22 |
+
n_e,
|
23 |
+
e_dim,
|
24 |
+
beta,
|
25 |
+
remap=None,
|
26 |
+
unknown_index="random",
|
27 |
+
sane_index_shape=False,
|
28 |
+
legacy=True):
|
29 |
+
super().__init__()
|
30 |
+
self.n_e = n_e
|
31 |
+
self.e_dim = e_dim
|
32 |
+
self.beta = beta
|
33 |
+
self.legacy = legacy
|
34 |
+
|
35 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
36 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
37 |
+
|
38 |
+
self.remap = remap
|
39 |
+
if self.remap is not None:
|
40 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
41 |
+
self.re_embed = self.used.shape[0]
|
42 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
43 |
+
if self.unknown_index == "extra":
|
44 |
+
self.unknown_index = self.re_embed
|
45 |
+
self.re_embed = self.re_embed + 1
|
46 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
47 |
+
f"Using {self.unknown_index} for unknown indices.")
|
48 |
+
else:
|
49 |
+
self.re_embed = n_e
|
50 |
+
|
51 |
+
self.sane_index_shape = sane_index_shape
|
52 |
+
|
53 |
+
def remap_to_used(self, inds):
|
54 |
+
ishape = inds.shape
|
55 |
+
assert len(ishape) > 1
|
56 |
+
inds = inds.reshape(ishape[0], -1)
|
57 |
+
used = self.used.to(inds)
|
58 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
59 |
+
new = match.argmax(-1)
|
60 |
+
unknown = match.sum(2) < 1
|
61 |
+
if self.unknown_index == "random":
|
62 |
+
new[unknown] = torch.randint(
|
63 |
+
0, self.re_embed,
|
64 |
+
size=new[unknown].shape).to(device=new.device)
|
65 |
+
else:
|
66 |
+
new[unknown] = self.unknown_index
|
67 |
+
return new.reshape(ishape)
|
68 |
+
|
69 |
+
def unmap_to_all(self, inds):
|
70 |
+
ishape = inds.shape
|
71 |
+
assert len(ishape) > 1
|
72 |
+
inds = inds.reshape(ishape[0], -1)
|
73 |
+
used = self.used.to(inds)
|
74 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
75 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
76 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
77 |
+
return back.reshape(ishape)
|
78 |
+
|
79 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
80 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
81 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
82 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
83 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
84 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
85 |
+
z_flattened = z.view(-1, self.e_dim)
|
86 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
87 |
+
|
88 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
89 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
90 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
91 |
+
|
92 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
93 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
94 |
+
perplexity = None
|
95 |
+
min_encodings = None
|
96 |
+
|
97 |
+
# compute loss for embedding
|
98 |
+
if not self.legacy:
|
99 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
100 |
+
torch.mean((z_q - z.detach()) ** 2)
|
101 |
+
else:
|
102 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
103 |
+
torch.mean((z_q - z.detach()) ** 2)
|
104 |
+
|
105 |
+
# preserve gradients
|
106 |
+
z_q = z + (z_q - z).detach()
|
107 |
+
|
108 |
+
# reshape back to match original input shape
|
109 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
110 |
+
|
111 |
+
if self.remap is not None:
|
112 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
113 |
+
z.shape[0], -1) # add batch axis
|
114 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
115 |
+
min_encoding_indices = min_encoding_indices.reshape(-1,
|
116 |
+
1) # flatten
|
117 |
+
|
118 |
+
if self.sane_index_shape:
|
119 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
120 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
121 |
+
|
122 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
123 |
+
|
124 |
+
def get_codebook_entry(self, indices, shape):
|
125 |
+
# shape specifying (batch, height, width, channel)
|
126 |
+
if self.remap is not None:
|
127 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
128 |
+
indices = self.unmap_to_all(indices)
|
129 |
+
indices = indices.reshape(-1) # flatten again
|
130 |
+
|
131 |
+
# get quantized latent vectors
|
132 |
+
z_q = self.embedding(indices)
|
133 |
+
|
134 |
+
if shape is not None:
|
135 |
+
z_q = z_q.view(shape)
|
136 |
+
# reshape back to match original input shape
|
137 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
138 |
+
|
139 |
+
return z_q
|
140 |
+
|
141 |
+
|
142 |
+
class VectorQuantizerTexture(nn.Module):
|
143 |
+
"""
|
144 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
145 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
146 |
+
"""
|
147 |
+
|
148 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
149 |
+
# backwards compatibility we use the buggy version by default, but you can
|
150 |
+
# specify legacy=False to fix it.
|
151 |
+
def __init__(self,
|
152 |
+
n_e,
|
153 |
+
e_dim,
|
154 |
+
beta,
|
155 |
+
remap=None,
|
156 |
+
unknown_index="random",
|
157 |
+
sane_index_shape=False,
|
158 |
+
legacy=True):
|
159 |
+
super().__init__()
|
160 |
+
self.n_e = n_e
|
161 |
+
self.e_dim = e_dim
|
162 |
+
self.beta = beta
|
163 |
+
self.legacy = legacy
|
164 |
+
|
165 |
+
# TODO: decide number of embeddings
|
166 |
+
self.embedding_list = nn.ModuleList(
|
167 |
+
[nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
|
168 |
+
for embedding in self.embedding_list:
|
169 |
+
embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
170 |
+
|
171 |
+
self.remap = remap
|
172 |
+
if self.remap is not None:
|
173 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
174 |
+
self.re_embed = self.used.shape[0]
|
175 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
176 |
+
if self.unknown_index == "extra":
|
177 |
+
self.unknown_index = self.re_embed
|
178 |
+
self.re_embed = self.re_embed + 1
|
179 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
180 |
+
f"Using {self.unknown_index} for unknown indices.")
|
181 |
+
else:
|
182 |
+
self.re_embed = n_e
|
183 |
+
|
184 |
+
self.sane_index_shape = sane_index_shape
|
185 |
+
|
186 |
+
def remap_to_used(self, inds):
|
187 |
+
ishape = inds.shape
|
188 |
+
assert len(ishape) > 1
|
189 |
+
inds = inds.reshape(ishape[0], -1)
|
190 |
+
used = self.used.to(inds)
|
191 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
192 |
+
new = match.argmax(-1)
|
193 |
+
unknown = match.sum(2) < 1
|
194 |
+
if self.unknown_index == "random":
|
195 |
+
new[unknown] = torch.randint(
|
196 |
+
0, self.re_embed,
|
197 |
+
size=new[unknown].shape).to(device=new.device)
|
198 |
+
else:
|
199 |
+
new[unknown] = self.unknown_index
|
200 |
+
return new.reshape(ishape)
|
201 |
+
|
202 |
+
def unmap_to_all(self, inds):
|
203 |
+
ishape = inds.shape
|
204 |
+
assert len(ishape) > 1
|
205 |
+
inds = inds.reshape(ishape[0], -1)
|
206 |
+
used = self.used.to(inds)
|
207 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
208 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
209 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
210 |
+
return back.reshape(ishape)
|
211 |
+
|
212 |
+
def forward(self,
|
213 |
+
z,
|
214 |
+
segm_map,
|
215 |
+
temp=None,
|
216 |
+
rescale_logits=False,
|
217 |
+
return_logits=False):
|
218 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
219 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
220 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
221 |
+
|
222 |
+
segm_map = F.interpolate(segm_map, size=z.size()[2:], mode='nearest')
|
223 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
224 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
225 |
+
z_flattened = z.view(-1, self.e_dim)
|
226 |
+
|
227 |
+
# flatten segm_map (b, h, w)
|
228 |
+
segm_map_flatten = segm_map.view(-1)
|
229 |
+
|
230 |
+
z_q = torch.zeros_like(z_flattened)
|
231 |
+
min_encoding_indices_list = []
|
232 |
+
min_encoding_indices_continual = torch.full(
|
233 |
+
segm_map_flatten.size(),
|
234 |
+
fill_value=-1,
|
235 |
+
dtype=torch.long,
|
236 |
+
device=segm_map_flatten.device)
|
237 |
+
for codebook_idx in range(18):
|
238 |
+
min_encoding_indices = torch.full(
|
239 |
+
segm_map_flatten.size(),
|
240 |
+
fill_value=-1,
|
241 |
+
dtype=torch.long,
|
242 |
+
device=segm_map_flatten.device)
|
243 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
244 |
+
z_selected = z_flattened[segm_map_flatten == codebook_idx]
|
245 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
246 |
+
d_selected = torch.sum(
|
247 |
+
z_selected**2, dim=1, keepdim=True) + torch.sum(
|
248 |
+
self.embedding_list[codebook_idx].weight**2,
|
249 |
+
dim=1) - 2 * torch.einsum(
|
250 |
+
'bd,dn->bn', z_selected,
|
251 |
+
rearrange(self.embedding_list[codebook_idx].weight,
|
252 |
+
'n d -> d n'))
|
253 |
+
min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
|
254 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
255 |
+
min_encoding_indices_selected)
|
256 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
257 |
+
min_encoding_indices[
|
258 |
+
segm_map_flatten ==
|
259 |
+
codebook_idx] = min_encoding_indices_selected
|
260 |
+
min_encoding_indices_continual[
|
261 |
+
segm_map_flatten ==
|
262 |
+
codebook_idx] = min_encoding_indices_selected + 1024 * codebook_idx
|
263 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
264 |
+
z.shape[0], z.shape[1], z.shape[2])
|
265 |
+
min_encoding_indices_list.append(min_encoding_indices)
|
266 |
+
|
267 |
+
min_encoding_indices_continual = min_encoding_indices_continual.reshape(
|
268 |
+
z.shape[0], z.shape[1], z.shape[2])
|
269 |
+
z_q = z_q.view(z.shape)
|
270 |
+
perplexity = None
|
271 |
+
|
272 |
+
# compute loss for embedding
|
273 |
+
if not self.legacy:
|
274 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
275 |
+
torch.mean((z_q - z.detach()) ** 2)
|
276 |
+
else:
|
277 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
278 |
+
torch.mean((z_q - z.detach()) ** 2)
|
279 |
+
|
280 |
+
# preserve gradients
|
281 |
+
z_q = z + (z_q - z).detach()
|
282 |
+
|
283 |
+
# reshape back to match original input shape
|
284 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
285 |
+
|
286 |
+
return z_q, loss, (perplexity, min_encoding_indices_continual,
|
287 |
+
min_encoding_indices_list)
|
288 |
+
|
289 |
+
def get_codebook_entry(self, indices_list, segm_map, shape):
|
290 |
+
# flatten segm_map (b, h, w)
|
291 |
+
segm_map = F.interpolate(
|
292 |
+
segm_map, size=(shape[1], shape[2]), mode='nearest')
|
293 |
+
segm_map_flatten = segm_map.view(-1)
|
294 |
+
|
295 |
+
z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
|
296 |
+
self.e_dim).to(segm_map.device)
|
297 |
+
for codebook_idx in range(18):
|
298 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
299 |
+
min_encoding_indices_selected = indices_list[
|
300 |
+
codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
|
301 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
302 |
+
min_encoding_indices_selected)
|
303 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
304 |
+
|
305 |
+
z_q = z_q.view(shape)
|
306 |
+
# reshape back to match original input shape
|
307 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
308 |
+
|
309 |
+
return z_q
|
310 |
+
|
311 |
+
|
312 |
+
def sample_patches(inputs, patch_size=3, stride=1):
|
313 |
+
"""Extract sliding local patches from an input feature tensor.
|
314 |
+
The sampled pathes are row-major.
|
315 |
+
Args:
|
316 |
+
inputs (Tensor): the input feature maps, shape: (n, c, h, w).
|
317 |
+
patch_size (int): the spatial size of sampled patches. Default: 3.
|
318 |
+
stride (int): the stride of sampling. Default: 1.
|
319 |
+
Returns:
|
320 |
+
patches (Tensor): extracted patches, shape: (n, c * patch_size *
|
321 |
+
patch_size, n_patches).
|
322 |
+
"""
|
323 |
+
|
324 |
+
patches = F.unfold(inputs, (patch_size, patch_size), stride=stride)
|
325 |
+
|
326 |
+
return patches
|
327 |
+
|
328 |
+
|
329 |
+
class VectorQuantizerSpatialTextureAware(nn.Module):
|
330 |
+
"""
|
331 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
332 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
333 |
+
"""
|
334 |
+
|
335 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
336 |
+
# backwards compatibility we use the buggy version by default, but you can
|
337 |
+
# specify legacy=False to fix it.
|
338 |
+
def __init__(self,
|
339 |
+
n_e,
|
340 |
+
e_dim,
|
341 |
+
beta,
|
342 |
+
spatial_size,
|
343 |
+
remap=None,
|
344 |
+
unknown_index="random",
|
345 |
+
sane_index_shape=False,
|
346 |
+
legacy=True):
|
347 |
+
super().__init__()
|
348 |
+
self.n_e = n_e
|
349 |
+
self.e_dim = e_dim * spatial_size * spatial_size
|
350 |
+
self.beta = beta
|
351 |
+
self.legacy = legacy
|
352 |
+
self.spatial_size = spatial_size
|
353 |
+
|
354 |
+
# TODO: decide number of embeddings
|
355 |
+
self.embedding_list = nn.ModuleList(
|
356 |
+
[nn.Embedding(self.n_e, self.e_dim) for i in range(18)])
|
357 |
+
for embedding in self.embedding_list:
|
358 |
+
embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
359 |
+
|
360 |
+
self.remap = remap
|
361 |
+
if self.remap is not None:
|
362 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
363 |
+
self.re_embed = self.used.shape[0]
|
364 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
365 |
+
if self.unknown_index == "extra":
|
366 |
+
self.unknown_index = self.re_embed
|
367 |
+
self.re_embed = self.re_embed + 1
|
368 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
369 |
+
f"Using {self.unknown_index} for unknown indices.")
|
370 |
+
else:
|
371 |
+
self.re_embed = n_e
|
372 |
+
|
373 |
+
self.sane_index_shape = sane_index_shape
|
374 |
+
|
375 |
+
def forward(self,
|
376 |
+
z,
|
377 |
+
segm_map,
|
378 |
+
temp=None,
|
379 |
+
rescale_logits=False,
|
380 |
+
return_logits=False):
|
381 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
382 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
383 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
384 |
+
|
385 |
+
segm_map = F.interpolate(
|
386 |
+
segm_map,
|
387 |
+
size=(z.size(2) // self.spatial_size,
|
388 |
+
z.size(3) // self.spatial_size),
|
389 |
+
mode='nearest')
|
390 |
+
|
391 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
392 |
+
# z = rearrange(z, 'b c h w -> b h w c').contiguous() ?
|
393 |
+
z_patches = sample_patches(
|
394 |
+
z, patch_size=self.spatial_size,
|
395 |
+
stride=self.spatial_size).permute(0, 2, 1)
|
396 |
+
z_patches_flattened = z_patches.reshape(-1, self.e_dim)
|
397 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
398 |
+
|
399 |
+
# flatten segm_map (b, h, w)
|
400 |
+
segm_map_flatten = segm_map.view(-1)
|
401 |
+
|
402 |
+
z_q = torch.zeros_like(z_patches_flattened)
|
403 |
+
min_encoding_indices_list = []
|
404 |
+
min_encoding_indices_continual = torch.full(
|
405 |
+
segm_map_flatten.size(),
|
406 |
+
fill_value=-1,
|
407 |
+
dtype=torch.long,
|
408 |
+
device=segm_map_flatten.device)
|
409 |
+
|
410 |
+
for codebook_idx in range(18):
|
411 |
+
min_encoding_indices = torch.full(
|
412 |
+
segm_map_flatten.size(),
|
413 |
+
fill_value=-1,
|
414 |
+
dtype=torch.long,
|
415 |
+
device=segm_map_flatten.device)
|
416 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
417 |
+
z_selected = z_patches_flattened[segm_map_flatten ==
|
418 |
+
codebook_idx]
|
419 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
420 |
+
d_selected = torch.sum(
|
421 |
+
z_selected**2, dim=1, keepdim=True) + torch.sum(
|
422 |
+
self.embedding_list[codebook_idx].weight**2,
|
423 |
+
dim=1) - 2 * torch.einsum(
|
424 |
+
'bd,dn->bn', z_selected,
|
425 |
+
rearrange(self.embedding_list[codebook_idx].weight,
|
426 |
+
'n d -> d n'))
|
427 |
+
min_encoding_indices_selected = torch.argmin(d_selected, dim=1)
|
428 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
429 |
+
min_encoding_indices_selected)
|
430 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
431 |
+
min_encoding_indices[
|
432 |
+
segm_map_flatten ==
|
433 |
+
codebook_idx] = min_encoding_indices_selected
|
434 |
+
min_encoding_indices_continual[
|
435 |
+
segm_map_flatten ==
|
436 |
+
codebook_idx] = min_encoding_indices_selected + self.n_e * codebook_idx
|
437 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
438 |
+
z_patches.shape[0], segm_map.shape[2], segm_map.shape[3])
|
439 |
+
min_encoding_indices_list.append(min_encoding_indices)
|
440 |
+
|
441 |
+
z_q = F.fold(
|
442 |
+
z_q.view(z_patches.shape).permute(0, 2, 1),
|
443 |
+
z.size()[2:],
|
444 |
+
kernel_size=(self.spatial_size, self.spatial_size),
|
445 |
+
stride=self.spatial_size)
|
446 |
+
|
447 |
+
perplexity = None
|
448 |
+
|
449 |
+
# compute loss for embedding
|
450 |
+
if not self.legacy:
|
451 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
452 |
+
torch.mean((z_q - z.detach()) ** 2)
|
453 |
+
else:
|
454 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
455 |
+
torch.mean((z_q - z.detach()) ** 2)
|
456 |
+
|
457 |
+
# preserve gradients
|
458 |
+
z_q = z + (z_q - z).detach()
|
459 |
+
|
460 |
+
return z_q, loss, (perplexity, min_encoding_indices_continual,
|
461 |
+
min_encoding_indices_list)
|
462 |
+
|
463 |
+
def get_codebook_entry(self, indices_list, segm_map, shape):
|
464 |
+
# flatten segm_map (b, h, w)
|
465 |
+
segm_map = F.interpolate(
|
466 |
+
segm_map, size=(shape[1], shape[2]), mode='nearest')
|
467 |
+
segm_map_flatten = segm_map.view(-1)
|
468 |
+
|
469 |
+
z_q = torch.zeros((shape[0] * shape[1] * shape[2]),
|
470 |
+
self.e_dim).to(segm_map.device)
|
471 |
+
for codebook_idx in range(18):
|
472 |
+
if torch.sum(segm_map_flatten == codebook_idx) > 0:
|
473 |
+
min_encoding_indices_selected = indices_list[
|
474 |
+
codebook_idx].view(-1)[segm_map_flatten == codebook_idx]
|
475 |
+
z_q_selected = self.embedding_list[codebook_idx](
|
476 |
+
min_encoding_indices_selected)
|
477 |
+
z_q[segm_map_flatten == codebook_idx] = z_q_selected
|
478 |
+
|
479 |
+
z_q = F.fold(
|
480 |
+
z_q.view(((shape[0], shape[1] * shape[2],
|
481 |
+
self.e_dim))).permute(0, 2, 1),
|
482 |
+
(shape[1] * self.spatial_size, shape[2] * self.spatial_size),
|
483 |
+
kernel_size=(self.spatial_size, self.spatial_size),
|
484 |
+
stride=self.spatial_size)
|
485 |
+
|
486 |
+
return z_q
|
487 |
+
|
488 |
+
|
489 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
490 |
+
"""
|
491 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
492 |
+
From Fairseq.
|
493 |
+
Build sinusoidal embeddings.
|
494 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
495 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
496 |
+
"""
|
497 |
+
assert len(timesteps.shape) == 1
|
498 |
+
|
499 |
+
half_dim = embedding_dim // 2
|
500 |
+
emb = math.log(10000) / (half_dim - 1)
|
501 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
502 |
+
emb = emb.to(device=timesteps.device)
|
503 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
504 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
505 |
+
if embedding_dim % 2 == 1: # zero pad
|
506 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
507 |
+
return emb
|
508 |
+
|
509 |
+
|
510 |
+
def nonlinearity(x):
|
511 |
+
# swish
|
512 |
+
return x * torch.sigmoid(x)
|
513 |
+
|
514 |
+
|
515 |
+
def Normalize(in_channels):
|
516 |
+
return torch.nn.GroupNorm(
|
517 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
518 |
+
|
519 |
+
|
520 |
+
class Upsample(nn.Module):
|
521 |
+
|
522 |
+
def __init__(self, in_channels, with_conv):
|
523 |
+
super().__init__()
|
524 |
+
self.with_conv = with_conv
|
525 |
+
if self.with_conv:
|
526 |
+
self.conv = torch.nn.Conv2d(
|
527 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
528 |
+
|
529 |
+
def forward(self, x):
|
530 |
+
x = torch.nn.functional.interpolate(
|
531 |
+
x, scale_factor=2.0, mode="nearest")
|
532 |
+
if self.with_conv:
|
533 |
+
x = self.conv(x)
|
534 |
+
return x
|
535 |
+
|
536 |
+
|
537 |
+
class Downsample(nn.Module):
|
538 |
+
|
539 |
+
def __init__(self, in_channels, with_conv):
|
540 |
+
super().__init__()
|
541 |
+
self.with_conv = with_conv
|
542 |
+
if self.with_conv:
|
543 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
544 |
+
self.conv = torch.nn.Conv2d(
|
545 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
546 |
+
|
547 |
+
def forward(self, x):
|
548 |
+
if self.with_conv:
|
549 |
+
pad = (0, 1, 0, 1)
|
550 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
551 |
+
x = self.conv(x)
|
552 |
+
else:
|
553 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
554 |
+
return x
|
555 |
+
|
556 |
+
|
557 |
+
class ResnetBlock(nn.Module):
|
558 |
+
|
559 |
+
def __init__(self,
|
560 |
+
*,
|
561 |
+
in_channels,
|
562 |
+
out_channels=None,
|
563 |
+
conv_shortcut=False,
|
564 |
+
dropout,
|
565 |
+
temb_channels=512):
|
566 |
+
super().__init__()
|
567 |
+
self.in_channels = in_channels
|
568 |
+
out_channels = in_channels if out_channels is None else out_channels
|
569 |
+
self.out_channels = out_channels
|
570 |
+
self.use_conv_shortcut = conv_shortcut
|
571 |
+
|
572 |
+
self.norm1 = Normalize(in_channels)
|
573 |
+
self.conv1 = torch.nn.Conv2d(
|
574 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
575 |
+
if temb_channels > 0:
|
576 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
577 |
+
self.norm2 = Normalize(out_channels)
|
578 |
+
self.dropout = torch.nn.Dropout(dropout)
|
579 |
+
self.conv2 = torch.nn.Conv2d(
|
580 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
581 |
+
if self.in_channels != self.out_channels:
|
582 |
+
if self.use_conv_shortcut:
|
583 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
584 |
+
in_channels,
|
585 |
+
out_channels,
|
586 |
+
kernel_size=3,
|
587 |
+
stride=1,
|
588 |
+
padding=1)
|
589 |
+
else:
|
590 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
591 |
+
in_channels,
|
592 |
+
out_channels,
|
593 |
+
kernel_size=1,
|
594 |
+
stride=1,
|
595 |
+
padding=0)
|
596 |
+
|
597 |
+
def forward(self, x, temb):
|
598 |
+
h = x
|
599 |
+
h = self.norm1(h)
|
600 |
+
h = nonlinearity(h)
|
601 |
+
h = self.conv1(h)
|
602 |
+
|
603 |
+
if temb is not None:
|
604 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
605 |
+
|
606 |
+
h = self.norm2(h)
|
607 |
+
h = nonlinearity(h)
|
608 |
+
h = self.dropout(h)
|
609 |
+
h = self.conv2(h)
|
610 |
+
|
611 |
+
if self.in_channels != self.out_channels:
|
612 |
+
if self.use_conv_shortcut:
|
613 |
+
x = self.conv_shortcut(x)
|
614 |
+
else:
|
615 |
+
x = self.nin_shortcut(x)
|
616 |
+
|
617 |
+
return x + h
|
618 |
+
|
619 |
+
|
620 |
+
class AttnBlock(nn.Module):
|
621 |
+
|
622 |
+
def __init__(self, in_channels):
|
623 |
+
super().__init__()
|
624 |
+
self.in_channels = in_channels
|
625 |
+
|
626 |
+
self.norm = Normalize(in_channels)
|
627 |
+
self.q = torch.nn.Conv2d(
|
628 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
629 |
+
self.k = torch.nn.Conv2d(
|
630 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
631 |
+
self.v = torch.nn.Conv2d(
|
632 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
633 |
+
self.proj_out = torch.nn.Conv2d(
|
634 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
635 |
+
|
636 |
+
def forward(self, x):
|
637 |
+
h_ = x
|
638 |
+
h_ = self.norm(h_)
|
639 |
+
q = self.q(h_)
|
640 |
+
k = self.k(h_)
|
641 |
+
v = self.v(h_)
|
642 |
+
|
643 |
+
# compute attention
|
644 |
+
b, c, h, w = q.shape
|
645 |
+
q = q.reshape(b, c, h * w)
|
646 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
647 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
648 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
649 |
+
w_ = w_ * (int(c)**(-0.5))
|
650 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
651 |
+
|
652 |
+
# attend to values
|
653 |
+
v = v.reshape(b, c, h * w)
|
654 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
655 |
+
h_ = torch.bmm(
|
656 |
+
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
657 |
+
h_ = h_.reshape(b, c, h, w)
|
658 |
+
|
659 |
+
h_ = self.proj_out(h_)
|
660 |
+
|
661 |
+
return x + h_
|
662 |
+
|
663 |
+
|
664 |
+
class Model(nn.Module):
|
665 |
+
|
666 |
+
def __init__(self,
|
667 |
+
*,
|
668 |
+
ch,
|
669 |
+
out_ch,
|
670 |
+
ch_mult=(1, 2, 4, 8),
|
671 |
+
num_res_blocks,
|
672 |
+
attn_resolutions,
|
673 |
+
dropout=0.0,
|
674 |
+
resamp_with_conv=True,
|
675 |
+
in_channels,
|
676 |
+
resolution,
|
677 |
+
use_timestep=True):
|
678 |
+
super().__init__()
|
679 |
+
self.ch = ch
|
680 |
+
self.temb_ch = self.ch * 4
|
681 |
+
self.num_resolutions = len(ch_mult)
|
682 |
+
self.num_res_blocks = num_res_blocks
|
683 |
+
self.resolution = resolution
|
684 |
+
self.in_channels = in_channels
|
685 |
+
|
686 |
+
self.use_timestep = use_timestep
|
687 |
+
if self.use_timestep:
|
688 |
+
# timestep embedding
|
689 |
+
self.temb = nn.Module()
|
690 |
+
self.temb.dense = nn.ModuleList([
|
691 |
+
torch.nn.Linear(self.ch, self.temb_ch),
|
692 |
+
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
693 |
+
])
|
694 |
+
|
695 |
+
# downsampling
|
696 |
+
self.conv_in = torch.nn.Conv2d(
|
697 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
698 |
+
|
699 |
+
curr_res = resolution
|
700 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
701 |
+
self.down = nn.ModuleList()
|
702 |
+
for i_level in range(self.num_resolutions):
|
703 |
+
block = nn.ModuleList()
|
704 |
+
attn = nn.ModuleList()
|
705 |
+
block_in = ch * in_ch_mult[i_level]
|
706 |
+
block_out = ch * ch_mult[i_level]
|
707 |
+
for i_block in range(self.num_res_blocks):
|
708 |
+
block.append(
|
709 |
+
ResnetBlock(
|
710 |
+
in_channels=block_in,
|
711 |
+
out_channels=block_out,
|
712 |
+
temb_channels=self.temb_ch,
|
713 |
+
dropout=dropout))
|
714 |
+
block_in = block_out
|
715 |
+
if curr_res in attn_resolutions:
|
716 |
+
attn.append(AttnBlock(block_in))
|
717 |
+
down = nn.Module()
|
718 |
+
down.block = block
|
719 |
+
down.attn = attn
|
720 |
+
if i_level != self.num_resolutions - 1:
|
721 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
722 |
+
curr_res = curr_res // 2
|
723 |
+
self.down.append(down)
|
724 |
+
|
725 |
+
# middle
|
726 |
+
self.mid = nn.Module()
|
727 |
+
self.mid.block_1 = ResnetBlock(
|
728 |
+
in_channels=block_in,
|
729 |
+
out_channels=block_in,
|
730 |
+
temb_channels=self.temb_ch,
|
731 |
+
dropout=dropout)
|
732 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
733 |
+
self.mid.block_2 = ResnetBlock(
|
734 |
+
in_channels=block_in,
|
735 |
+
out_channels=block_in,
|
736 |
+
temb_channels=self.temb_ch,
|
737 |
+
dropout=dropout)
|
738 |
+
|
739 |
+
# upsampling
|
740 |
+
self.up = nn.ModuleList()
|
741 |
+
for i_level in reversed(range(self.num_resolutions)):
|
742 |
+
block = nn.ModuleList()
|
743 |
+
attn = nn.ModuleList()
|
744 |
+
block_out = ch * ch_mult[i_level]
|
745 |
+
skip_in = ch * ch_mult[i_level]
|
746 |
+
for i_block in range(self.num_res_blocks + 1):
|
747 |
+
if i_block == self.num_res_blocks:
|
748 |
+
skip_in = ch * in_ch_mult[i_level]
|
749 |
+
block.append(
|
750 |
+
ResnetBlock(
|
751 |
+
in_channels=block_in + skip_in,
|
752 |
+
out_channels=block_out,
|
753 |
+
temb_channels=self.temb_ch,
|
754 |
+
dropout=dropout))
|
755 |
+
block_in = block_out
|
756 |
+
if curr_res in attn_resolutions:
|
757 |
+
attn.append(AttnBlock(block_in))
|
758 |
+
up = nn.Module()
|
759 |
+
up.block = block
|
760 |
+
up.attn = attn
|
761 |
+
if i_level != 0:
|
762 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
763 |
+
curr_res = curr_res * 2
|
764 |
+
self.up.insert(0, up) # prepend to get consistent order
|
765 |
+
|
766 |
+
# end
|
767 |
+
self.norm_out = Normalize(block_in)
|
768 |
+
self.conv_out = torch.nn.Conv2d(
|
769 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
770 |
+
|
771 |
+
def forward(self, x, t=None):
|
772 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
773 |
+
|
774 |
+
if self.use_timestep:
|
775 |
+
# timestep embedding
|
776 |
+
assert t is not None
|
777 |
+
temb = get_timestep_embedding(t, self.ch)
|
778 |
+
temb = self.temb.dense[0](temb)
|
779 |
+
temb = nonlinearity(temb)
|
780 |
+
temb = self.temb.dense[1](temb)
|
781 |
+
else:
|
782 |
+
temb = None
|
783 |
+
|
784 |
+
# downsampling
|
785 |
+
hs = [self.conv_in(x)]
|
786 |
+
for i_level in range(self.num_resolutions):
|
787 |
+
for i_block in range(self.num_res_blocks):
|
788 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
789 |
+
if len(self.down[i_level].attn) > 0:
|
790 |
+
h = self.down[i_level].attn[i_block](h)
|
791 |
+
hs.append(h)
|
792 |
+
if i_level != self.num_resolutions - 1:
|
793 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
794 |
+
|
795 |
+
# middle
|
796 |
+
h = hs[-1]
|
797 |
+
h = self.mid.block_1(h, temb)
|
798 |
+
h = self.mid.attn_1(h)
|
799 |
+
h = self.mid.block_2(h, temb)
|
800 |
+
|
801 |
+
# upsampling
|
802 |
+
for i_level in reversed(range(self.num_resolutions)):
|
803 |
+
for i_block in range(self.num_res_blocks + 1):
|
804 |
+
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
|
805 |
+
dim=1), temb)
|
806 |
+
if len(self.up[i_level].attn) > 0:
|
807 |
+
h = self.up[i_level].attn[i_block](h)
|
808 |
+
if i_level != 0:
|
809 |
+
h = self.up[i_level].upsample(h)
|
810 |
+
|
811 |
+
# end
|
812 |
+
h = self.norm_out(h)
|
813 |
+
h = nonlinearity(h)
|
814 |
+
h = self.conv_out(h)
|
815 |
+
return h
|
816 |
+
|
817 |
+
|
818 |
+
class Encoder(nn.Module):
|
819 |
+
|
820 |
+
def __init__(self,
|
821 |
+
ch,
|
822 |
+
num_res_blocks,
|
823 |
+
attn_resolutions,
|
824 |
+
in_channels,
|
825 |
+
resolution,
|
826 |
+
z_channels,
|
827 |
+
ch_mult=(1, 2, 4, 8),
|
828 |
+
dropout=0.0,
|
829 |
+
resamp_with_conv=True,
|
830 |
+
double_z=True):
|
831 |
+
super().__init__()
|
832 |
+
self.ch = ch
|
833 |
+
self.temb_ch = 0
|
834 |
+
self.num_resolutions = len(ch_mult)
|
835 |
+
self.num_res_blocks = num_res_blocks
|
836 |
+
self.resolution = resolution
|
837 |
+
self.in_channels = in_channels
|
838 |
+
|
839 |
+
# downsampling
|
840 |
+
self.conv_in = torch.nn.Conv2d(
|
841 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
842 |
+
|
843 |
+
curr_res = resolution
|
844 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
845 |
+
self.down = nn.ModuleList()
|
846 |
+
for i_level in range(self.num_resolutions):
|
847 |
+
block = nn.ModuleList()
|
848 |
+
attn = nn.ModuleList()
|
849 |
+
block_in = ch * in_ch_mult[i_level]
|
850 |
+
block_out = ch * ch_mult[i_level]
|
851 |
+
for i_block in range(self.num_res_blocks):
|
852 |
+
block.append(
|
853 |
+
ResnetBlock(
|
854 |
+
in_channels=block_in,
|
855 |
+
out_channels=block_out,
|
856 |
+
temb_channels=self.temb_ch,
|
857 |
+
dropout=dropout))
|
858 |
+
block_in = block_out
|
859 |
+
if curr_res in attn_resolutions:
|
860 |
+
attn.append(AttnBlock(block_in))
|
861 |
+
down = nn.Module()
|
862 |
+
down.block = block
|
863 |
+
down.attn = attn
|
864 |
+
if i_level != self.num_resolutions - 1:
|
865 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
866 |
+
curr_res = curr_res // 2
|
867 |
+
self.down.append(down)
|
868 |
+
|
869 |
+
# middle
|
870 |
+
self.mid = nn.Module()
|
871 |
+
self.mid.block_1 = ResnetBlock(
|
872 |
+
in_channels=block_in,
|
873 |
+
out_channels=block_in,
|
874 |
+
temb_channels=self.temb_ch,
|
875 |
+
dropout=dropout)
|
876 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
877 |
+
self.mid.block_2 = ResnetBlock(
|
878 |
+
in_channels=block_in,
|
879 |
+
out_channels=block_in,
|
880 |
+
temb_channels=self.temb_ch,
|
881 |
+
dropout=dropout)
|
882 |
+
|
883 |
+
# end
|
884 |
+
self.norm_out = Normalize(block_in)
|
885 |
+
self.conv_out = torch.nn.Conv2d(
|
886 |
+
block_in,
|
887 |
+
2 * z_channels if double_z else z_channels,
|
888 |
+
kernel_size=3,
|
889 |
+
stride=1,
|
890 |
+
padding=1)
|
891 |
+
|
892 |
+
def forward(self, x):
|
893 |
+
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
894 |
+
|
895 |
+
# timestep embedding
|
896 |
+
temb = None
|
897 |
+
|
898 |
+
# downsampling
|
899 |
+
hs = [self.conv_in(x)]
|
900 |
+
for i_level in range(self.num_resolutions):
|
901 |
+
for i_block in range(self.num_res_blocks):
|
902 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
903 |
+
if len(self.down[i_level].attn) > 0:
|
904 |
+
h = self.down[i_level].attn[i_block](h)
|
905 |
+
hs.append(h)
|
906 |
+
if i_level != self.num_resolutions - 1:
|
907 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
908 |
+
|
909 |
+
# middle
|
910 |
+
h = hs[-1]
|
911 |
+
h = self.mid.block_1(h, temb)
|
912 |
+
h = self.mid.attn_1(h)
|
913 |
+
h = self.mid.block_2(h, temb)
|
914 |
+
|
915 |
+
# end
|
916 |
+
h = self.norm_out(h)
|
917 |
+
h = nonlinearity(h)
|
918 |
+
h = self.conv_out(h)
|
919 |
+
return h
|
920 |
+
|
921 |
+
|
922 |
+
class Decoder(nn.Module):
|
923 |
+
|
924 |
+
def __init__(self,
|
925 |
+
in_channels,
|
926 |
+
resolution,
|
927 |
+
z_channels,
|
928 |
+
ch,
|
929 |
+
out_ch,
|
930 |
+
num_res_blocks,
|
931 |
+
attn_resolutions,
|
932 |
+
ch_mult=(1, 2, 4, 8),
|
933 |
+
dropout=0.0,
|
934 |
+
resamp_with_conv=True,
|
935 |
+
give_pre_end=False):
|
936 |
+
super().__init__()
|
937 |
+
self.ch = ch
|
938 |
+
self.temb_ch = 0
|
939 |
+
self.num_resolutions = len(ch_mult)
|
940 |
+
self.num_res_blocks = num_res_blocks
|
941 |
+
self.resolution = resolution
|
942 |
+
self.in_channels = in_channels
|
943 |
+
self.give_pre_end = give_pre_end
|
944 |
+
|
945 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
946 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
947 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
948 |
+
curr_res = resolution // 2**(self.num_resolutions - 1)
|
949 |
+
self.z_shape = (1, z_channels, curr_res, curr_res // 2)
|
950 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
951 |
+
self.z_shape, np.prod(self.z_shape)))
|
952 |
+
|
953 |
+
# z to block_in
|
954 |
+
self.conv_in = torch.nn.Conv2d(
|
955 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
956 |
+
|
957 |
+
# middle
|
958 |
+
self.mid = nn.Module()
|
959 |
+
self.mid.block_1 = ResnetBlock(
|
960 |
+
in_channels=block_in,
|
961 |
+
out_channels=block_in,
|
962 |
+
temb_channels=self.temb_ch,
|
963 |
+
dropout=dropout)
|
964 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
965 |
+
self.mid.block_2 = ResnetBlock(
|
966 |
+
in_channels=block_in,
|
967 |
+
out_channels=block_in,
|
968 |
+
temb_channels=self.temb_ch,
|
969 |
+
dropout=dropout)
|
970 |
+
|
971 |
+
# upsampling
|
972 |
+
self.up = nn.ModuleList()
|
973 |
+
for i_level in reversed(range(self.num_resolutions)):
|
974 |
+
block = nn.ModuleList()
|
975 |
+
attn = nn.ModuleList()
|
976 |
+
block_out = ch * ch_mult[i_level]
|
977 |
+
for i_block in range(self.num_res_blocks + 1):
|
978 |
+
block.append(
|
979 |
+
ResnetBlock(
|
980 |
+
in_channels=block_in,
|
981 |
+
out_channels=block_out,
|
982 |
+
temb_channels=self.temb_ch,
|
983 |
+
dropout=dropout))
|
984 |
+
block_in = block_out
|
985 |
+
if curr_res in attn_resolutions:
|
986 |
+
attn.append(AttnBlock(block_in))
|
987 |
+
up = nn.Module()
|
988 |
+
up.block = block
|
989 |
+
up.attn = attn
|
990 |
+
if i_level != 0:
|
991 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
992 |
+
curr_res = curr_res * 2
|
993 |
+
self.up.insert(0, up) # prepend to get consistent order
|
994 |
+
|
995 |
+
# end
|
996 |
+
self.norm_out = Normalize(block_in)
|
997 |
+
self.conv_out = torch.nn.Conv2d(
|
998 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
999 |
+
|
1000 |
+
def forward(self, z, bot_h=None):
|
1001 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
1002 |
+
self.last_z_shape = z.shape
|
1003 |
+
|
1004 |
+
# timestep embedding
|
1005 |
+
temb = None
|
1006 |
+
|
1007 |
+
# z to block_in
|
1008 |
+
h = self.conv_in(z)
|
1009 |
+
|
1010 |
+
# middle
|
1011 |
+
h = self.mid.block_1(h, temb)
|
1012 |
+
h = self.mid.attn_1(h)
|
1013 |
+
h = self.mid.block_2(h, temb)
|
1014 |
+
|
1015 |
+
# upsampling
|
1016 |
+
for i_level in reversed(range(self.num_resolutions)):
|
1017 |
+
for i_block in range(self.num_res_blocks + 1):
|
1018 |
+
h = self.up[i_level].block[i_block](h, temb)
|
1019 |
+
if len(self.up[i_level].attn) > 0:
|
1020 |
+
h = self.up[i_level].attn[i_block](h)
|
1021 |
+
if i_level != 0:
|
1022 |
+
h = self.up[i_level].upsample(h)
|
1023 |
+
if i_level == 4 and bot_h is not None:
|
1024 |
+
h += bot_h
|
1025 |
+
|
1026 |
+
# end
|
1027 |
+
if self.give_pre_end:
|
1028 |
+
return h
|
1029 |
+
|
1030 |
+
h = self.norm_out(h)
|
1031 |
+
h = nonlinearity(h)
|
1032 |
+
h = self.conv_out(h)
|
1033 |
+
return h
|
1034 |
+
|
1035 |
+
def get_feature_top(self, z):
|
1036 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
1037 |
+
self.last_z_shape = z.shape
|
1038 |
+
|
1039 |
+
# timestep embedding
|
1040 |
+
temb = None
|
1041 |
+
|
1042 |
+
# z to block_in
|
1043 |
+
h = self.conv_in(z)
|
1044 |
+
|
1045 |
+
# middle
|
1046 |
+
h = self.mid.block_1(h, temb)
|
1047 |
+
h = self.mid.attn_1(h)
|
1048 |
+
h = self.mid.block_2(h, temb)
|
1049 |
+
|
1050 |
+
# upsampling
|
1051 |
+
for i_level in reversed(range(self.num_resolutions)):
|
1052 |
+
for i_block in range(self.num_res_blocks + 1):
|
1053 |
+
h = self.up[i_level].block[i_block](h, temb)
|
1054 |
+
if len(self.up[i_level].attn) > 0:
|
1055 |
+
h = self.up[i_level].attn[i_block](h)
|
1056 |
+
if i_level != 0:
|
1057 |
+
h = self.up[i_level].upsample(h)
|
1058 |
+
if i_level == 4:
|
1059 |
+
return h
|
1060 |
+
|
1061 |
+
def get_feature_middle(self, z, mid_h):
|
1062 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
1063 |
+
self.last_z_shape = z.shape
|
1064 |
+
|
1065 |
+
# timestep embedding
|
1066 |
+
temb = None
|
1067 |
+
|
1068 |
+
# z to block_in
|
1069 |
+
h = self.conv_in(z)
|
1070 |
+
|
1071 |
+
# middle
|
1072 |
+
h = self.mid.block_1(h, temb)
|
1073 |
+
h = self.mid.attn_1(h)
|
1074 |
+
h = self.mid.block_2(h, temb)
|
1075 |
+
|
1076 |
+
# upsampling
|
1077 |
+
for i_level in reversed(range(self.num_resolutions)):
|
1078 |
+
for i_block in range(self.num_res_blocks + 1):
|
1079 |
+
h = self.up[i_level].block[i_block](h, temb)
|
1080 |
+
if len(self.up[i_level].attn) > 0:
|
1081 |
+
h = self.up[i_level].attn[i_block](h)
|
1082 |
+
if i_level != 0:
|
1083 |
+
h = self.up[i_level].upsample(h)
|
1084 |
+
if i_level == 4:
|
1085 |
+
h += mid_h
|
1086 |
+
if i_level == 3:
|
1087 |
+
return h
|
1088 |
+
|
1089 |
+
|
1090 |
+
class DecoderRes(nn.Module):
|
1091 |
+
|
1092 |
+
def __init__(self,
|
1093 |
+
in_channels,
|
1094 |
+
resolution,
|
1095 |
+
z_channels,
|
1096 |
+
ch,
|
1097 |
+
num_res_blocks,
|
1098 |
+
ch_mult=(1, 2, 4, 8),
|
1099 |
+
dropout=0.0,
|
1100 |
+
give_pre_end=False):
|
1101 |
+
super().__init__()
|
1102 |
+
self.ch = ch
|
1103 |
+
self.temb_ch = 0
|
1104 |
+
self.num_resolutions = len(ch_mult)
|
1105 |
+
self.num_res_blocks = num_res_blocks
|
1106 |
+
self.resolution = resolution
|
1107 |
+
self.in_channels = in_channels
|
1108 |
+
self.give_pre_end = give_pre_end
|
1109 |
+
|
1110 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
1111 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
1112 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
1113 |
+
curr_res = resolution // 2**(self.num_resolutions - 1)
|
1114 |
+
self.z_shape = (1, z_channels, curr_res, curr_res // 2)
|
1115 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
1116 |
+
self.z_shape, np.prod(self.z_shape)))
|
1117 |
+
|
1118 |
+
# z to block_in
|
1119 |
+
self.conv_in = torch.nn.Conv2d(
|
1120 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
1121 |
+
|
1122 |
+
# middle
|
1123 |
+
self.mid = nn.Module()
|
1124 |
+
self.mid.block_1 = ResnetBlock(
|
1125 |
+
in_channels=block_in,
|
1126 |
+
out_channels=block_in,
|
1127 |
+
temb_channels=self.temb_ch,
|
1128 |
+
dropout=dropout)
|
1129 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
1130 |
+
self.mid.block_2 = ResnetBlock(
|
1131 |
+
in_channels=block_in,
|
1132 |
+
out_channels=block_in,
|
1133 |
+
temb_channels=self.temb_ch,
|
1134 |
+
dropout=dropout)
|
1135 |
+
|
1136 |
+
def forward(self, z):
|
1137 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
1138 |
+
self.last_z_shape = z.shape
|
1139 |
+
|
1140 |
+
# timestep embedding
|
1141 |
+
temb = None
|
1142 |
+
|
1143 |
+
# z to block_in
|
1144 |
+
h = self.conv_in(z)
|
1145 |
+
|
1146 |
+
# middle
|
1147 |
+
h = self.mid.block_1(h, temb)
|
1148 |
+
h = self.mid.attn_1(h)
|
1149 |
+
h = self.mid.block_2(h, temb)
|
1150 |
+
|
1151 |
+
return h
|
1152 |
+
|
1153 |
+
|
1154 |
+
# patch based discriminator
|
1155 |
+
class Discriminator(nn.Module):
|
1156 |
+
|
1157 |
+
def __init__(self, nc, ndf, n_layers=3):
|
1158 |
+
super().__init__()
|
1159 |
+
|
1160 |
+
layers = [
|
1161 |
+
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
|
1162 |
+
nn.LeakyReLU(0.2, True)
|
1163 |
+
]
|
1164 |
+
ndf_mult = 1
|
1165 |
+
ndf_mult_prev = 1
|
1166 |
+
for n in range(1,
|
1167 |
+
n_layers): # gradually increase the number of filters
|
1168 |
+
ndf_mult_prev = ndf_mult
|
1169 |
+
ndf_mult = min(2**n, 8)
|
1170 |
+
layers += [
|
1171 |
+
nn.Conv2d(
|
1172 |
+
ndf * ndf_mult_prev,
|
1173 |
+
ndf * ndf_mult,
|
1174 |
+
kernel_size=4,
|
1175 |
+
stride=2,
|
1176 |
+
padding=1,
|
1177 |
+
bias=False),
|
1178 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
1179 |
+
nn.LeakyReLU(0.2, True)
|
1180 |
+
]
|
1181 |
+
|
1182 |
+
ndf_mult_prev = ndf_mult
|
1183 |
+
ndf_mult = min(2**n_layers, 8)
|
1184 |
+
|
1185 |
+
layers += [
|
1186 |
+
nn.Conv2d(
|
1187 |
+
ndf * ndf_mult_prev,
|
1188 |
+
ndf * ndf_mult,
|
1189 |
+
kernel_size=4,
|
1190 |
+
stride=1,
|
1191 |
+
padding=1,
|
1192 |
+
bias=False),
|
1193 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
1194 |
+
nn.LeakyReLU(0.2, True)
|
1195 |
+
]
|
1196 |
+
|
1197 |
+
layers += [
|
1198 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)
|
1199 |
+
] # output 1 channel prediction map
|
1200 |
+
self.main = nn.Sequential(*layers)
|
1201 |
+
|
1202 |
+
def forward(self, x):
|
1203 |
+
return self.main(x)
|
Text2Human/models/hierarchy_inference_model.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchvision.utils import save_image
|
8 |
+
|
9 |
+
from models.archs.fcn_arch import MultiHeadFCNHead
|
10 |
+
from models.archs.unet_arch import UNet
|
11 |
+
from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
|
12 |
+
VectorQuantizerSpatialTextureAware,
|
13 |
+
VectorQuantizerTexture)
|
14 |
+
from models.losses.accuracy import accuracy
|
15 |
+
from models.losses.cross_entropy_loss import CrossEntropyLoss
|
16 |
+
|
17 |
+
logger = logging.getLogger('base')
|
18 |
+
|
19 |
+
|
20 |
+
class VQGANTextureAwareSpatialHierarchyInferenceModel():
|
21 |
+
|
22 |
+
def __init__(self, opt):
|
23 |
+
self.opt = opt
|
24 |
+
self.device = torch.device('cuda')
|
25 |
+
self.is_train = opt['is_train']
|
26 |
+
|
27 |
+
self.top_encoder = Encoder(
|
28 |
+
ch=opt['top_ch'],
|
29 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
30 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
31 |
+
ch_mult=opt['top_ch_mult'],
|
32 |
+
in_channels=opt['top_in_channels'],
|
33 |
+
resolution=opt['top_resolution'],
|
34 |
+
z_channels=opt['top_z_channels'],
|
35 |
+
double_z=opt['top_double_z'],
|
36 |
+
dropout=opt['top_dropout']).to(self.device)
|
37 |
+
self.decoder = Decoder(
|
38 |
+
in_channels=opt['top_in_channels'],
|
39 |
+
resolution=opt['top_resolution'],
|
40 |
+
z_channels=opt['top_z_channels'],
|
41 |
+
ch=opt['top_ch'],
|
42 |
+
out_ch=opt['top_out_ch'],
|
43 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
44 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
45 |
+
ch_mult=opt['top_ch_mult'],
|
46 |
+
dropout=opt['top_dropout'],
|
47 |
+
resamp_with_conv=True,
|
48 |
+
give_pre_end=False).to(self.device)
|
49 |
+
self.top_quantize = VectorQuantizerTexture(
|
50 |
+
1024, opt['embed_dim'], beta=0.25).to(self.device)
|
51 |
+
self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
|
52 |
+
opt['embed_dim'],
|
53 |
+
1).to(self.device)
|
54 |
+
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
55 |
+
opt["top_z_channels"],
|
56 |
+
1).to(self.device)
|
57 |
+
self.load_top_pretrain_models()
|
58 |
+
|
59 |
+
self.bot_encoder = Encoder(
|
60 |
+
ch=opt['bot_ch'],
|
61 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
62 |
+
attn_resolutions=opt['bot_attn_resolutions'],
|
63 |
+
ch_mult=opt['bot_ch_mult'],
|
64 |
+
in_channels=opt['bot_in_channels'],
|
65 |
+
resolution=opt['bot_resolution'],
|
66 |
+
z_channels=opt['bot_z_channels'],
|
67 |
+
double_z=opt['bot_double_z'],
|
68 |
+
dropout=opt['bot_dropout']).to(self.device)
|
69 |
+
self.bot_decoder_res = DecoderRes(
|
70 |
+
in_channels=opt['bot_in_channels'],
|
71 |
+
resolution=opt['bot_resolution'],
|
72 |
+
z_channels=opt['bot_z_channels'],
|
73 |
+
ch=opt['bot_ch'],
|
74 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
75 |
+
ch_mult=opt['bot_ch_mult'],
|
76 |
+
dropout=opt['bot_dropout'],
|
77 |
+
give_pre_end=False).to(self.device)
|
78 |
+
self.bot_quantize = VectorQuantizerSpatialTextureAware(
|
79 |
+
opt['bot_n_embed'],
|
80 |
+
opt['embed_dim'],
|
81 |
+
beta=0.25,
|
82 |
+
spatial_size=opt['codebook_spatial_size']).to(self.device)
|
83 |
+
self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
|
84 |
+
opt['embed_dim'],
|
85 |
+
1).to(self.device)
|
86 |
+
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
87 |
+
opt["bot_z_channels"],
|
88 |
+
1).to(self.device)
|
89 |
+
|
90 |
+
self.load_bot_pretrain_network()
|
91 |
+
|
92 |
+
self.guidance_encoder = UNet(
|
93 |
+
in_channels=opt['encoder_in_channels']).to(self.device)
|
94 |
+
self.index_decoder = MultiHeadFCNHead(
|
95 |
+
in_channels=opt['fc_in_channels'],
|
96 |
+
in_index=opt['fc_in_index'],
|
97 |
+
channels=opt['fc_channels'],
|
98 |
+
num_convs=opt['fc_num_convs'],
|
99 |
+
concat_input=opt['fc_concat_input'],
|
100 |
+
dropout_ratio=opt['fc_dropout_ratio'],
|
101 |
+
num_classes=opt['fc_num_classes'],
|
102 |
+
align_corners=opt['fc_align_corners'],
|
103 |
+
num_head=18).to(self.device)
|
104 |
+
|
105 |
+
self.init_training_settings()
|
106 |
+
|
107 |
+
def init_training_settings(self):
|
108 |
+
optim_params = []
|
109 |
+
for v in self.guidance_encoder.parameters():
|
110 |
+
if v.requires_grad:
|
111 |
+
optim_params.append(v)
|
112 |
+
for v in self.index_decoder.parameters():
|
113 |
+
if v.requires_grad:
|
114 |
+
optim_params.append(v)
|
115 |
+
# set up optimizers
|
116 |
+
if self.opt['optimizer'] == 'Adam':
|
117 |
+
self.optimizer = torch.optim.Adam(
|
118 |
+
optim_params,
|
119 |
+
self.opt['lr'],
|
120 |
+
weight_decay=self.opt['weight_decay'])
|
121 |
+
elif self.opt['optimizer'] == 'SGD':
|
122 |
+
self.optimizer = torch.optim.SGD(
|
123 |
+
optim_params,
|
124 |
+
self.opt['lr'],
|
125 |
+
momentum=self.opt['momentum'],
|
126 |
+
weight_decay=self.opt['weight_decay'])
|
127 |
+
self.log_dict = OrderedDict()
|
128 |
+
if self.opt['loss_function'] == 'cross_entropy':
|
129 |
+
self.loss_func = CrossEntropyLoss().to(self.device)
|
130 |
+
|
131 |
+
def load_top_pretrain_models(self):
|
132 |
+
# load pretrained vqgan for segmentation mask
|
133 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
134 |
+
self.top_encoder.load_state_dict(
|
135 |
+
top_vae_checkpoint['encoder'], strict=True)
|
136 |
+
self.decoder.load_state_dict(
|
137 |
+
top_vae_checkpoint['decoder'], strict=True)
|
138 |
+
self.top_quantize.load_state_dict(
|
139 |
+
top_vae_checkpoint['quantize'], strict=True)
|
140 |
+
self.top_quant_conv.load_state_dict(
|
141 |
+
top_vae_checkpoint['quant_conv'], strict=True)
|
142 |
+
self.top_post_quant_conv.load_state_dict(
|
143 |
+
top_vae_checkpoint['post_quant_conv'], strict=True)
|
144 |
+
self.top_encoder.eval()
|
145 |
+
self.top_quantize.eval()
|
146 |
+
self.top_quant_conv.eval()
|
147 |
+
self.top_post_quant_conv.eval()
|
148 |
+
|
149 |
+
def load_bot_pretrain_network(self):
|
150 |
+
checkpoint = torch.load(self.opt['bot_vae_path'])
|
151 |
+
self.bot_encoder.load_state_dict(
|
152 |
+
checkpoint['bot_encoder'], strict=True)
|
153 |
+
self.bot_decoder_res.load_state_dict(
|
154 |
+
checkpoint['bot_decoder_res'], strict=True)
|
155 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
156 |
+
self.bot_quantize.load_state_dict(
|
157 |
+
checkpoint['bot_quantize'], strict=True)
|
158 |
+
self.bot_quant_conv.load_state_dict(
|
159 |
+
checkpoint['bot_quant_conv'], strict=True)
|
160 |
+
self.bot_post_quant_conv.load_state_dict(
|
161 |
+
checkpoint['bot_post_quant_conv'], strict=True)
|
162 |
+
|
163 |
+
self.bot_encoder.eval()
|
164 |
+
self.bot_decoder_res.eval()
|
165 |
+
self.decoder.eval()
|
166 |
+
self.bot_quantize.eval()
|
167 |
+
self.bot_quant_conv.eval()
|
168 |
+
self.bot_post_quant_conv.eval()
|
169 |
+
|
170 |
+
def top_encode(self, x, mask):
|
171 |
+
h = self.top_encoder(x)
|
172 |
+
h = self.top_quant_conv(h)
|
173 |
+
quant, _, _ = self.top_quantize(h, mask)
|
174 |
+
quant = self.top_post_quant_conv(quant)
|
175 |
+
|
176 |
+
return quant, quant
|
177 |
+
|
178 |
+
def feed_data(self, data):
|
179 |
+
self.image = data['image'].to(self.device)
|
180 |
+
self.texture_mask = data['texture_mask'].float().to(self.device)
|
181 |
+
self.get_gt_indices()
|
182 |
+
|
183 |
+
self.texture_tokens = F.interpolate(
|
184 |
+
self.texture_mask, size=(32, 16),
|
185 |
+
mode='nearest').view(self.image.size(0), -1).long()
|
186 |
+
|
187 |
+
def bot_encode(self, x, mask):
|
188 |
+
h = self.bot_encoder(x)
|
189 |
+
h = self.bot_quant_conv(h)
|
190 |
+
_, _, (_, _, indices_list) = self.bot_quantize(h, mask)
|
191 |
+
|
192 |
+
return indices_list
|
193 |
+
|
194 |
+
def get_gt_indices(self):
|
195 |
+
self.quant_t, self.feature_t = self.top_encode(self.image,
|
196 |
+
self.texture_mask)
|
197 |
+
self.gt_indices_list = self.bot_encode(self.image, self.texture_mask)
|
198 |
+
|
199 |
+
def index_to_image(self, index_bottom_list, texture_mask):
|
200 |
+
quant_b = self.bot_quantize.get_codebook_entry(
|
201 |
+
index_bottom_list, texture_mask,
|
202 |
+
(index_bottom_list[0].size(0), index_bottom_list[0].size(1),
|
203 |
+
index_bottom_list[0].size(2),
|
204 |
+
self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
|
205 |
+
quant_b = self.bot_post_quant_conv(quant_b)
|
206 |
+
bot_dec_res = self.bot_decoder_res(quant_b)
|
207 |
+
|
208 |
+
dec = self.decoder(self.quant_t, bot_h=bot_dec_res)
|
209 |
+
|
210 |
+
return dec
|
211 |
+
|
212 |
+
def get_vis(self, pred_img_index, rec_img_index, texture_mask, save_path):
|
213 |
+
rec_img = self.index_to_image(rec_img_index, texture_mask)
|
214 |
+
pred_img = self.index_to_image(pred_img_index, texture_mask)
|
215 |
+
|
216 |
+
base_img = self.decoder(self.quant_t)
|
217 |
+
img_cat = torch.cat([
|
218 |
+
self.image,
|
219 |
+
rec_img,
|
220 |
+
base_img,
|
221 |
+
pred_img,
|
222 |
+
], dim=3).detach()
|
223 |
+
img_cat = ((img_cat + 1) / 2)
|
224 |
+
img_cat = img_cat.clamp_(0, 1)
|
225 |
+
save_image(img_cat, save_path, nrow=1, padding=4)
|
226 |
+
|
227 |
+
def optimize_parameters(self):
|
228 |
+
self.guidance_encoder.train()
|
229 |
+
self.index_decoder.train()
|
230 |
+
|
231 |
+
self.feature_enc = self.guidance_encoder(self.feature_t)
|
232 |
+
self.memory_logits_list = self.index_decoder(self.feature_enc)
|
233 |
+
|
234 |
+
loss = 0
|
235 |
+
for i in range(18):
|
236 |
+
loss += self.loss_func(
|
237 |
+
self.memory_logits_list[i],
|
238 |
+
self.gt_indices_list[i],
|
239 |
+
ignore_index=-1)
|
240 |
+
|
241 |
+
self.optimizer.zero_grad()
|
242 |
+
loss.backward()
|
243 |
+
self.optimizer.step()
|
244 |
+
|
245 |
+
self.log_dict['loss_total'] = loss
|
246 |
+
|
247 |
+
def inference(self, data_loader, save_dir):
|
248 |
+
self.guidance_encoder.eval()
|
249 |
+
self.index_decoder.eval()
|
250 |
+
|
251 |
+
acc = 0
|
252 |
+
num = 0
|
253 |
+
|
254 |
+
for _, data in enumerate(data_loader):
|
255 |
+
self.feed_data(data)
|
256 |
+
img_name = data['img_name']
|
257 |
+
|
258 |
+
num += self.image.size(0)
|
259 |
+
|
260 |
+
texture_mask_flatten = self.texture_tokens.view(-1)
|
261 |
+
min_encodings_indices_list = [
|
262 |
+
torch.full(
|
263 |
+
texture_mask_flatten.size(),
|
264 |
+
fill_value=-1,
|
265 |
+
dtype=torch.long,
|
266 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
267 |
+
]
|
268 |
+
with torch.no_grad():
|
269 |
+
self.feature_enc = self.guidance_encoder(self.feature_t)
|
270 |
+
memory_logits_list = self.index_decoder(self.feature_enc)
|
271 |
+
# memory_indices_pred = memory_logits.argmax(dim=1)
|
272 |
+
batch_acc = 0
|
273 |
+
for codebook_idx, memory_logits in enumerate(memory_logits_list):
|
274 |
+
region_of_interest = texture_mask_flatten == codebook_idx
|
275 |
+
if torch.sum(region_of_interest) > 0:
|
276 |
+
memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
|
277 |
+
batch_acc += torch.sum(
|
278 |
+
memory_indices_pred[region_of_interest] ==
|
279 |
+
self.gt_indices_list[codebook_idx].view(
|
280 |
+
-1)[region_of_interest])
|
281 |
+
memory_indices_pred = memory_indices_pred
|
282 |
+
min_encodings_indices_list[codebook_idx][
|
283 |
+
region_of_interest] = memory_indices_pred[
|
284 |
+
region_of_interest]
|
285 |
+
min_encodings_indices_return_list = [
|
286 |
+
min_encodings_indices.view(self.gt_indices_list[0].size())
|
287 |
+
for min_encodings_indices in min_encodings_indices_list
|
288 |
+
]
|
289 |
+
batch_acc = batch_acc / self.gt_indices_list[codebook_idx].numel(
|
290 |
+
) * self.image.size(0)
|
291 |
+
acc += batch_acc
|
292 |
+
self.get_vis(min_encodings_indices_return_list,
|
293 |
+
self.gt_indices_list, self.texture_mask,
|
294 |
+
f'{save_dir}/{img_name[0]}')
|
295 |
+
|
296 |
+
self.guidance_encoder.train()
|
297 |
+
self.index_decoder.train()
|
298 |
+
return (acc / num).item()
|
299 |
+
|
300 |
+
def load_network(self):
|
301 |
+
checkpoint = torch.load(self.opt['pretrained_models'])
|
302 |
+
self.guidance_encoder.load_state_dict(
|
303 |
+
checkpoint['guidance_encoder'], strict=True)
|
304 |
+
self.guidance_encoder.eval()
|
305 |
+
|
306 |
+
self.index_decoder.load_state_dict(
|
307 |
+
checkpoint['index_decoder'], strict=True)
|
308 |
+
self.index_decoder.eval()
|
309 |
+
|
310 |
+
def save_network(self, save_path):
|
311 |
+
"""Save networks.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
net (nn.Module): Network to be saved.
|
315 |
+
net_label (str): Network label.
|
316 |
+
current_iter (int): Current iter number.
|
317 |
+
"""
|
318 |
+
|
319 |
+
save_dict = {}
|
320 |
+
save_dict['guidance_encoder'] = self.guidance_encoder.state_dict()
|
321 |
+
save_dict['index_decoder'] = self.index_decoder.state_dict()
|
322 |
+
|
323 |
+
torch.save(save_dict, save_path)
|
324 |
+
|
325 |
+
def update_learning_rate(self, epoch):
|
326 |
+
"""Update learning rate.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
current_iter (int): Current iteration.
|
330 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
331 |
+
Default: -1.
|
332 |
+
"""
|
333 |
+
lr = self.optimizer.param_groups[0]['lr']
|
334 |
+
|
335 |
+
if self.opt['lr_decay'] == 'step':
|
336 |
+
lr = self.opt['lr'] * (
|
337 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
338 |
+
elif self.opt['lr_decay'] == 'cos':
|
339 |
+
lr = self.opt['lr'] * (
|
340 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
341 |
+
elif self.opt['lr_decay'] == 'linear':
|
342 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
343 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
344 |
+
if epoch < self.opt['turning_point'] + 1:
|
345 |
+
# learning rate decay as 95%
|
346 |
+
# at the turning point (1 / 95% = 1.0526)
|
347 |
+
lr = self.opt['lr'] * (
|
348 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
349 |
+
else:
|
350 |
+
lr *= self.opt['gamma']
|
351 |
+
elif self.opt['lr_decay'] == 'schedule':
|
352 |
+
if epoch in self.opt['schedule']:
|
353 |
+
lr *= self.opt['gamma']
|
354 |
+
else:
|
355 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
356 |
+
# set learning rate
|
357 |
+
for param_group in self.optimizer.param_groups:
|
358 |
+
param_group['lr'] = lr
|
359 |
+
|
360 |
+
return lr
|
361 |
+
|
362 |
+
def get_current_log(self):
|
363 |
+
return self.log_dict
|
Text2Human/models/hierarchy_vqgan_model.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import sys
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
sys.path.append('..')
|
6 |
+
import lpips
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
|
11 |
+
from models.archs.vqgan_arch import (Decoder, DecoderRes, Discriminator,
|
12 |
+
Encoder,
|
13 |
+
VectorQuantizerSpatialTextureAware,
|
14 |
+
VectorQuantizerTexture)
|
15 |
+
from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
|
16 |
+
calculate_adaptive_weight, hinge_d_loss)
|
17 |
+
|
18 |
+
|
19 |
+
class HierarchyVQSpatialTextureAwareModel():
|
20 |
+
|
21 |
+
def __init__(self, opt):
|
22 |
+
self.opt = opt
|
23 |
+
self.device = torch.device('cuda')
|
24 |
+
self.top_encoder = Encoder(
|
25 |
+
ch=opt['top_ch'],
|
26 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
27 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
28 |
+
ch_mult=opt['top_ch_mult'],
|
29 |
+
in_channels=opt['top_in_channels'],
|
30 |
+
resolution=opt['top_resolution'],
|
31 |
+
z_channels=opt['top_z_channels'],
|
32 |
+
double_z=opt['top_double_z'],
|
33 |
+
dropout=opt['top_dropout']).to(self.device)
|
34 |
+
self.decoder = Decoder(
|
35 |
+
in_channels=opt['top_in_channels'],
|
36 |
+
resolution=opt['top_resolution'],
|
37 |
+
z_channels=opt['top_z_channels'],
|
38 |
+
ch=opt['top_ch'],
|
39 |
+
out_ch=opt['top_out_ch'],
|
40 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
41 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
42 |
+
ch_mult=opt['top_ch_mult'],
|
43 |
+
dropout=opt['top_dropout'],
|
44 |
+
resamp_with_conv=True,
|
45 |
+
give_pre_end=False).to(self.device)
|
46 |
+
self.top_quantize = VectorQuantizerTexture(
|
47 |
+
1024, opt['embed_dim'], beta=0.25).to(self.device)
|
48 |
+
self.top_quant_conv = torch.nn.Conv2d(opt["top_z_channels"],
|
49 |
+
opt['embed_dim'],
|
50 |
+
1).to(self.device)
|
51 |
+
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
52 |
+
opt["top_z_channels"],
|
53 |
+
1).to(self.device)
|
54 |
+
self.load_top_pretrain_models()
|
55 |
+
|
56 |
+
self.bot_encoder = Encoder(
|
57 |
+
ch=opt['bot_ch'],
|
58 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
59 |
+
attn_resolutions=opt['bot_attn_resolutions'],
|
60 |
+
ch_mult=opt['bot_ch_mult'],
|
61 |
+
in_channels=opt['bot_in_channels'],
|
62 |
+
resolution=opt['bot_resolution'],
|
63 |
+
z_channels=opt['bot_z_channels'],
|
64 |
+
double_z=opt['bot_double_z'],
|
65 |
+
dropout=opt['bot_dropout']).to(self.device)
|
66 |
+
self.bot_decoder_res = DecoderRes(
|
67 |
+
in_channels=opt['bot_in_channels'],
|
68 |
+
resolution=opt['bot_resolution'],
|
69 |
+
z_channels=opt['bot_z_channels'],
|
70 |
+
ch=opt['bot_ch'],
|
71 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
72 |
+
ch_mult=opt['bot_ch_mult'],
|
73 |
+
dropout=opt['bot_dropout'],
|
74 |
+
give_pre_end=False).to(self.device)
|
75 |
+
self.bot_quantize = VectorQuantizerSpatialTextureAware(
|
76 |
+
opt['bot_n_embed'],
|
77 |
+
opt['embed_dim'],
|
78 |
+
beta=0.25,
|
79 |
+
spatial_size=opt['codebook_spatial_size']).to(self.device)
|
80 |
+
self.bot_quant_conv = torch.nn.Conv2d(opt["bot_z_channels"],
|
81 |
+
opt['embed_dim'],
|
82 |
+
1).to(self.device)
|
83 |
+
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
84 |
+
opt["bot_z_channels"],
|
85 |
+
1).to(self.device)
|
86 |
+
|
87 |
+
self.disc = Discriminator(
|
88 |
+
opt['n_channels'], opt['ndf'],
|
89 |
+
n_layers=opt['disc_layers']).to(self.device)
|
90 |
+
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
|
91 |
+
self.perceptual_weight = opt['perceptual_weight']
|
92 |
+
self.disc_start_step = opt['disc_start_step']
|
93 |
+
self.disc_weight_max = opt['disc_weight_max']
|
94 |
+
self.diff_aug = opt['diff_aug']
|
95 |
+
self.policy = "color,translation"
|
96 |
+
|
97 |
+
self.load_discriminator_models()
|
98 |
+
|
99 |
+
self.disc.train()
|
100 |
+
|
101 |
+
self.fix_decoder = opt['fix_decoder']
|
102 |
+
|
103 |
+
self.init_training_settings()
|
104 |
+
|
105 |
+
def load_top_pretrain_models(self):
|
106 |
+
# load pretrained vqgan for segmentation mask
|
107 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
108 |
+
self.top_encoder.load_state_dict(
|
109 |
+
top_vae_checkpoint['encoder'], strict=True)
|
110 |
+
self.decoder.load_state_dict(
|
111 |
+
top_vae_checkpoint['decoder'], strict=True)
|
112 |
+
self.top_quantize.load_state_dict(
|
113 |
+
top_vae_checkpoint['quantize'], strict=True)
|
114 |
+
self.top_quant_conv.load_state_dict(
|
115 |
+
top_vae_checkpoint['quant_conv'], strict=True)
|
116 |
+
self.top_post_quant_conv.load_state_dict(
|
117 |
+
top_vae_checkpoint['post_quant_conv'], strict=True)
|
118 |
+
self.top_encoder.eval()
|
119 |
+
self.top_quantize.eval()
|
120 |
+
self.top_quant_conv.eval()
|
121 |
+
self.top_post_quant_conv.eval()
|
122 |
+
|
123 |
+
def init_training_settings(self):
|
124 |
+
self.log_dict = OrderedDict()
|
125 |
+
self.configure_optimizers()
|
126 |
+
|
127 |
+
def configure_optimizers(self):
|
128 |
+
optim_params = []
|
129 |
+
for v in self.bot_encoder.parameters():
|
130 |
+
if v.requires_grad:
|
131 |
+
optim_params.append(v)
|
132 |
+
for v in self.bot_decoder_res.parameters():
|
133 |
+
if v.requires_grad:
|
134 |
+
optim_params.append(v)
|
135 |
+
for v in self.bot_quantize.parameters():
|
136 |
+
if v.requires_grad:
|
137 |
+
optim_params.append(v)
|
138 |
+
for v in self.bot_quant_conv.parameters():
|
139 |
+
if v.requires_grad:
|
140 |
+
optim_params.append(v)
|
141 |
+
for v in self.bot_post_quant_conv.parameters():
|
142 |
+
if v.requires_grad:
|
143 |
+
optim_params.append(v)
|
144 |
+
if not self.fix_decoder:
|
145 |
+
for name, v in self.decoder.named_parameters():
|
146 |
+
if v.requires_grad:
|
147 |
+
if 'up.0' in name:
|
148 |
+
optim_params.append(v)
|
149 |
+
if 'up.1' in name:
|
150 |
+
optim_params.append(v)
|
151 |
+
if 'up.2' in name:
|
152 |
+
optim_params.append(v)
|
153 |
+
if 'up.3' in name:
|
154 |
+
optim_params.append(v)
|
155 |
+
|
156 |
+
self.optimizer = torch.optim.Adam(optim_params, lr=self.opt['lr'])
|
157 |
+
|
158 |
+
self.disc_optimizer = torch.optim.Adam(
|
159 |
+
self.disc.parameters(), lr=self.opt['lr'])
|
160 |
+
|
161 |
+
def load_discriminator_models(self):
|
162 |
+
# load pretrained vqgan for segmentation mask
|
163 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'])
|
164 |
+
self.disc.load_state_dict(
|
165 |
+
top_vae_checkpoint['discriminator'], strict=True)
|
166 |
+
|
167 |
+
def save_network(self, save_path):
|
168 |
+
"""Save networks.
|
169 |
+
"""
|
170 |
+
|
171 |
+
save_dict = {}
|
172 |
+
save_dict['bot_encoder'] = self.bot_encoder.state_dict()
|
173 |
+
save_dict['bot_decoder_res'] = self.bot_decoder_res.state_dict()
|
174 |
+
save_dict['decoder'] = self.decoder.state_dict()
|
175 |
+
save_dict['bot_quantize'] = self.bot_quantize.state_dict()
|
176 |
+
save_dict['bot_quant_conv'] = self.bot_quant_conv.state_dict()
|
177 |
+
save_dict['bot_post_quant_conv'] = self.bot_post_quant_conv.state_dict(
|
178 |
+
)
|
179 |
+
save_dict['discriminator'] = self.disc.state_dict()
|
180 |
+
torch.save(save_dict, save_path)
|
181 |
+
|
182 |
+
def load_network(self):
|
183 |
+
checkpoint = torch.load(self.opt['pretrained_models'])
|
184 |
+
self.bot_encoder.load_state_dict(
|
185 |
+
checkpoint['bot_encoder'], strict=True)
|
186 |
+
self.bot_decoder_res.load_state_dict(
|
187 |
+
checkpoint['bot_decoder_res'], strict=True)
|
188 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
189 |
+
self.bot_quantize.load_state_dict(
|
190 |
+
checkpoint['bot_quantize'], strict=True)
|
191 |
+
self.bot_quant_conv.load_state_dict(
|
192 |
+
checkpoint['bot_quant_conv'], strict=True)
|
193 |
+
self.bot_post_quant_conv.load_state_dict(
|
194 |
+
checkpoint['bot_post_quant_conv'], strict=True)
|
195 |
+
|
196 |
+
def optimize_parameters(self, data, step):
|
197 |
+
self.bot_encoder.train()
|
198 |
+
self.bot_decoder_res.train()
|
199 |
+
if not self.fix_decoder:
|
200 |
+
self.decoder.train()
|
201 |
+
self.bot_quantize.train()
|
202 |
+
self.bot_quant_conv.train()
|
203 |
+
self.bot_post_quant_conv.train()
|
204 |
+
|
205 |
+
loss, d_loss = self.training_step(data, step)
|
206 |
+
self.optimizer.zero_grad()
|
207 |
+
loss.backward()
|
208 |
+
self.optimizer.step()
|
209 |
+
|
210 |
+
if step > self.disc_start_step:
|
211 |
+
self.disc_optimizer.zero_grad()
|
212 |
+
d_loss.backward()
|
213 |
+
self.disc_optimizer.step()
|
214 |
+
|
215 |
+
def top_encode(self, x, mask):
|
216 |
+
h = self.top_encoder(x)
|
217 |
+
h = self.top_quant_conv(h)
|
218 |
+
quant, _, _ = self.top_quantize(h, mask)
|
219 |
+
quant = self.top_post_quant_conv(quant)
|
220 |
+
return quant
|
221 |
+
|
222 |
+
def bot_encode(self, x, mask):
|
223 |
+
h = self.bot_encoder(x)
|
224 |
+
h = self.bot_quant_conv(h)
|
225 |
+
quant, emb_loss, info = self.bot_quantize(h, mask)
|
226 |
+
quant = self.bot_post_quant_conv(quant)
|
227 |
+
bot_dec_res = self.bot_decoder_res(quant)
|
228 |
+
return bot_dec_res, emb_loss, info
|
229 |
+
|
230 |
+
def decode(self, quant_top, bot_dec_res):
|
231 |
+
dec = self.decoder(quant_top, bot_h=bot_dec_res)
|
232 |
+
return dec
|
233 |
+
|
234 |
+
def forward_step(self, input, mask):
|
235 |
+
with torch.no_grad():
|
236 |
+
quant_top = self.top_encode(input, mask)
|
237 |
+
bot_dec_res, diff, _ = self.bot_encode(input, mask)
|
238 |
+
dec = self.decode(quant_top, bot_dec_res)
|
239 |
+
return dec, diff
|
240 |
+
|
241 |
+
def feed_data(self, data):
|
242 |
+
x = data['image'].float().to(self.device)
|
243 |
+
mask = data['texture_mask'].float().to(self.device)
|
244 |
+
|
245 |
+
return x, mask
|
246 |
+
|
247 |
+
def training_step(self, data, step):
|
248 |
+
x, mask = self.feed_data(data)
|
249 |
+
xrec, codebook_loss = self.forward_step(x, mask)
|
250 |
+
|
251 |
+
# get recon/perceptual loss
|
252 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
253 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
254 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
255 |
+
nll_loss = torch.mean(nll_loss)
|
256 |
+
|
257 |
+
# augment for input to discriminator
|
258 |
+
if self.diff_aug:
|
259 |
+
xrec = DiffAugment(xrec, policy=self.policy)
|
260 |
+
|
261 |
+
# update generator
|
262 |
+
logits_fake = self.disc(xrec)
|
263 |
+
g_loss = -torch.mean(logits_fake)
|
264 |
+
last_layer = self.decoder.conv_out.weight
|
265 |
+
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
|
266 |
+
self.disc_weight_max)
|
267 |
+
d_weight *= adopt_weight(1, step, self.disc_start_step)
|
268 |
+
loss = nll_loss + d_weight * g_loss + codebook_loss
|
269 |
+
|
270 |
+
self.log_dict["loss"] = loss
|
271 |
+
self.log_dict["l1"] = recon_loss.mean().item()
|
272 |
+
self.log_dict["perceptual"] = p_loss.mean().item()
|
273 |
+
self.log_dict["nll_loss"] = nll_loss.item()
|
274 |
+
self.log_dict["g_loss"] = g_loss.item()
|
275 |
+
self.log_dict["d_weight"] = d_weight
|
276 |
+
self.log_dict["codebook_loss"] = codebook_loss.item()
|
277 |
+
|
278 |
+
if step > self.disc_start_step:
|
279 |
+
if self.diff_aug:
|
280 |
+
logits_real = self.disc(
|
281 |
+
DiffAugment(x.contiguous().detach(), policy=self.policy))
|
282 |
+
else:
|
283 |
+
logits_real = self.disc(x.contiguous().detach())
|
284 |
+
logits_fake = self.disc(xrec.contiguous().detach(
|
285 |
+
)) # detach so that generator isn"t also updated
|
286 |
+
d_loss = hinge_d_loss(logits_real, logits_fake)
|
287 |
+
self.log_dict["d_loss"] = d_loss
|
288 |
+
else:
|
289 |
+
d_loss = None
|
290 |
+
|
291 |
+
return loss, d_loss
|
292 |
+
|
293 |
+
@torch.no_grad()
|
294 |
+
def inference(self, data_loader, save_dir):
|
295 |
+
self.bot_encoder.eval()
|
296 |
+
self.bot_decoder_res.eval()
|
297 |
+
self.decoder.eval()
|
298 |
+
self.bot_quantize.eval()
|
299 |
+
self.bot_quant_conv.eval()
|
300 |
+
self.bot_post_quant_conv.eval()
|
301 |
+
|
302 |
+
loss_total = 0
|
303 |
+
num = 0
|
304 |
+
|
305 |
+
for _, data in enumerate(data_loader):
|
306 |
+
img_name = data['img_name'][0]
|
307 |
+
x, mask = self.feed_data(data)
|
308 |
+
xrec, _ = self.forward_step(x, mask)
|
309 |
+
|
310 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
311 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
312 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
313 |
+
nll_loss = torch.mean(nll_loss)
|
314 |
+
loss_total += nll_loss
|
315 |
+
|
316 |
+
num += x.size(0)
|
317 |
+
|
318 |
+
if x.shape[1] > 3:
|
319 |
+
# colorize with random projection
|
320 |
+
assert xrec.shape[1] > 3
|
321 |
+
# convert logits to indices
|
322 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
323 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
324 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
325 |
+
x = self.to_rgb(x)
|
326 |
+
xrec = self.to_rgb(xrec)
|
327 |
+
|
328 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
329 |
+
img_cat = ((img_cat + 1) / 2)
|
330 |
+
img_cat = img_cat.clamp_(0, 1)
|
331 |
+
save_image(
|
332 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
333 |
+
|
334 |
+
return (loss_total / num).item()
|
335 |
+
|
336 |
+
def get_current_log(self):
|
337 |
+
return self.log_dict
|
338 |
+
|
339 |
+
def update_learning_rate(self, epoch):
|
340 |
+
"""Update learning rate.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
current_iter (int): Current iteration.
|
344 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
345 |
+
Default: -1.
|
346 |
+
"""
|
347 |
+
lr = self.optimizer.param_groups[0]['lr']
|
348 |
+
|
349 |
+
if self.opt['lr_decay'] == 'step':
|
350 |
+
lr = self.opt['lr'] * (
|
351 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
352 |
+
elif self.opt['lr_decay'] == 'cos':
|
353 |
+
lr = self.opt['lr'] * (
|
354 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
355 |
+
elif self.opt['lr_decay'] == 'linear':
|
356 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
357 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
358 |
+
if epoch < self.opt['turning_point'] + 1:
|
359 |
+
# learning rate decay as 95%
|
360 |
+
# at the turning point (1 / 95% = 1.0526)
|
361 |
+
lr = self.opt['lr'] * (
|
362 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
363 |
+
else:
|
364 |
+
lr *= self.opt['gamma']
|
365 |
+
elif self.opt['lr_decay'] == 'schedule':
|
366 |
+
if epoch in self.opt['schedule']:
|
367 |
+
lr *= self.opt['gamma']
|
368 |
+
else:
|
369 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
370 |
+
# set learning rate
|
371 |
+
for param_group in self.optimizer.param_groups:
|
372 |
+
param_group['lr'] = lr
|
373 |
+
|
374 |
+
return lr
|
Text2Human/models/losses/__init__.py
ADDED
File without changes
|
Text2Human/models/losses/accuracy.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def accuracy(pred, target, topk=1, thresh=None):
|
2 |
+
"""Calculate accuracy according to the prediction and target.
|
3 |
+
|
4 |
+
Args:
|
5 |
+
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
|
6 |
+
target (torch.Tensor): The target of each prediction, shape (N, , ...)
|
7 |
+
topk (int | tuple[int], optional): If the predictions in ``topk``
|
8 |
+
matches the target, the predictions will be regarded as
|
9 |
+
correct ones. Defaults to 1.
|
10 |
+
thresh (float, optional): If not None, predictions with scores under
|
11 |
+
this threshold are considered incorrect. Default to None.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
float | tuple[float]: If the input ``topk`` is a single integer,
|
15 |
+
the function will return a single float as accuracy. If
|
16 |
+
``topk`` is a tuple containing multiple integers, the
|
17 |
+
function will return a tuple containing accuracies of
|
18 |
+
each ``topk`` number.
|
19 |
+
"""
|
20 |
+
assert isinstance(topk, (int, tuple))
|
21 |
+
if isinstance(topk, int):
|
22 |
+
topk = (topk, )
|
23 |
+
return_single = True
|
24 |
+
else:
|
25 |
+
return_single = False
|
26 |
+
|
27 |
+
maxk = max(topk)
|
28 |
+
if pred.size(0) == 0:
|
29 |
+
accu = [pred.new_tensor(0.) for i in range(len(topk))]
|
30 |
+
return accu[0] if return_single else accu
|
31 |
+
assert pred.ndim == target.ndim + 1
|
32 |
+
assert pred.size(0) == target.size(0)
|
33 |
+
assert maxk <= pred.size(1), \
|
34 |
+
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
|
35 |
+
pred_value, pred_label = pred.topk(maxk, dim=1)
|
36 |
+
# transpose to shape (maxk, N, ...)
|
37 |
+
pred_label = pred_label.transpose(0, 1)
|
38 |
+
correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
|
39 |
+
if thresh is not None:
|
40 |
+
# Only prediction values larger than thresh are counted as correct
|
41 |
+
correct = correct & (pred_value > thresh).t()
|
42 |
+
res = []
|
43 |
+
for k in topk:
|
44 |
+
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
45 |
+
res.append(correct_k.mul_(100.0 / target.numel()))
|
46 |
+
return res[0] if return_single else res
|
Text2Human/models/losses/cross_entropy_loss.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def reduce_loss(loss, reduction):
|
7 |
+
"""Reduce loss as specified.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
loss (Tensor): Elementwise loss tensor.
|
11 |
+
reduction (str): Options are "none", "mean" and "sum".
|
12 |
+
|
13 |
+
Return:
|
14 |
+
Tensor: Reduced loss tensor.
|
15 |
+
"""
|
16 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
17 |
+
# none: 0, elementwise_mean:1, sum: 2
|
18 |
+
if reduction_enum == 0:
|
19 |
+
return loss
|
20 |
+
elif reduction_enum == 1:
|
21 |
+
return loss.mean()
|
22 |
+
elif reduction_enum == 2:
|
23 |
+
return loss.sum()
|
24 |
+
|
25 |
+
|
26 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
|
27 |
+
"""Apply element-wise weight and reduce loss.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
loss (Tensor): Element-wise loss.
|
31 |
+
weight (Tensor): Element-wise weights.
|
32 |
+
reduction (str): Same as built-in losses of PyTorch.
|
33 |
+
avg_factor (float): Avarage factor when computing the mean of losses.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Tensor: Processed loss values.
|
37 |
+
"""
|
38 |
+
# if weight is specified, apply element-wise weight
|
39 |
+
if weight is not None:
|
40 |
+
assert weight.dim() == loss.dim()
|
41 |
+
if weight.dim() > 1:
|
42 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
43 |
+
loss = loss * weight
|
44 |
+
|
45 |
+
# if avg_factor is not specified, just reduce the loss
|
46 |
+
if avg_factor is None:
|
47 |
+
loss = reduce_loss(loss, reduction)
|
48 |
+
else:
|
49 |
+
# if reduction is mean, then average the loss by avg_factor
|
50 |
+
if reduction == 'mean':
|
51 |
+
loss = loss.sum() / avg_factor
|
52 |
+
# if reduction is 'none', then do nothing, otherwise raise an error
|
53 |
+
elif reduction != 'none':
|
54 |
+
raise ValueError('avg_factor can not be used with reduction="sum"')
|
55 |
+
return loss
|
56 |
+
|
57 |
+
|
58 |
+
def cross_entropy(pred,
|
59 |
+
label,
|
60 |
+
weight=None,
|
61 |
+
class_weight=None,
|
62 |
+
reduction='mean',
|
63 |
+
avg_factor=None,
|
64 |
+
ignore_index=-100):
|
65 |
+
"""The wrapper function for :func:`F.cross_entropy`"""
|
66 |
+
# class_weight is a manual rescaling weight given to each class.
|
67 |
+
# If given, has to be a Tensor of size C element-wise losses
|
68 |
+
loss = F.cross_entropy(
|
69 |
+
pred,
|
70 |
+
label,
|
71 |
+
weight=class_weight,
|
72 |
+
reduction='none',
|
73 |
+
ignore_index=ignore_index)
|
74 |
+
|
75 |
+
# apply weights and do the reduction
|
76 |
+
if weight is not None:
|
77 |
+
weight = weight.float()
|
78 |
+
loss = weight_reduce_loss(
|
79 |
+
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
80 |
+
|
81 |
+
return loss
|
82 |
+
|
83 |
+
|
84 |
+
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
85 |
+
"""Expand onehot labels to match the size of prediction."""
|
86 |
+
bin_labels = labels.new_zeros(target_shape)
|
87 |
+
valid_mask = (labels >= 0) & (labels != ignore_index)
|
88 |
+
inds = torch.nonzero(valid_mask, as_tuple=True)
|
89 |
+
|
90 |
+
if inds[0].numel() > 0:
|
91 |
+
if labels.dim() == 3:
|
92 |
+
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
93 |
+
else:
|
94 |
+
bin_labels[inds[0], labels[valid_mask]] = 1
|
95 |
+
|
96 |
+
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
97 |
+
if label_weights is None:
|
98 |
+
bin_label_weights = valid_mask
|
99 |
+
else:
|
100 |
+
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
101 |
+
bin_label_weights *= valid_mask
|
102 |
+
|
103 |
+
return bin_labels, bin_label_weights
|
104 |
+
|
105 |
+
|
106 |
+
def binary_cross_entropy(pred,
|
107 |
+
label,
|
108 |
+
weight=None,
|
109 |
+
reduction='mean',
|
110 |
+
avg_factor=None,
|
111 |
+
class_weight=None,
|
112 |
+
ignore_index=255):
|
113 |
+
"""Calculate the binary CrossEntropy loss.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
pred (torch.Tensor): The prediction with shape (N, 1).
|
117 |
+
label (torch.Tensor): The learning label of the prediction.
|
118 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
119 |
+
reduction (str, optional): The method used to reduce the loss.
|
120 |
+
Options are "none", "mean" and "sum".
|
121 |
+
avg_factor (int, optional): Average factor that is used to average
|
122 |
+
the loss. Defaults to None.
|
123 |
+
class_weight (list[float], optional): The weight for each class.
|
124 |
+
ignore_index (int | None): The label index to be ignored. Default: 255
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
torch.Tensor: The calculated loss
|
128 |
+
"""
|
129 |
+
if pred.dim() != label.dim():
|
130 |
+
assert (pred.dim() == 2 and label.dim() == 1) or (
|
131 |
+
pred.dim() == 4 and label.dim() == 3), \
|
132 |
+
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
133 |
+
'H, W], label shape [N, H, W] are supported'
|
134 |
+
label, weight = _expand_onehot_labels(label, weight, pred.shape,
|
135 |
+
ignore_index)
|
136 |
+
|
137 |
+
# weighted element-wise losses
|
138 |
+
if weight is not None:
|
139 |
+
weight = weight.float()
|
140 |
+
loss = F.binary_cross_entropy_with_logits(
|
141 |
+
pred, label.float(), pos_weight=class_weight, reduction='none')
|
142 |
+
# do the reduction for the weighted loss
|
143 |
+
loss = weight_reduce_loss(
|
144 |
+
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
145 |
+
|
146 |
+
return loss
|
147 |
+
|
148 |
+
|
149 |
+
def mask_cross_entropy(pred,
|
150 |
+
target,
|
151 |
+
label,
|
152 |
+
reduction='mean',
|
153 |
+
avg_factor=None,
|
154 |
+
class_weight=None,
|
155 |
+
ignore_index=None):
|
156 |
+
"""Calculate the CrossEntropy loss for masks.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
160 |
+
of classes.
|
161 |
+
target (torch.Tensor): The learning label of the prediction.
|
162 |
+
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
163 |
+
corresponding object. This will be used to select the mask in the
|
164 |
+
of the class which the object belongs to when the mask prediction
|
165 |
+
if not class-agnostic.
|
166 |
+
reduction (str, optional): The method used to reduce the loss.
|
167 |
+
Options are "none", "mean" and "sum".
|
168 |
+
avg_factor (int, optional): Average factor that is used to average
|
169 |
+
the loss. Defaults to None.
|
170 |
+
class_weight (list[float], optional): The weight for each class.
|
171 |
+
ignore_index (None): Placeholder, to be consistent with other loss.
|
172 |
+
Default: None.
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
torch.Tensor: The calculated loss
|
176 |
+
"""
|
177 |
+
assert ignore_index is None, 'BCE loss does not support ignore_index'
|
178 |
+
# TODO: handle these two reserved arguments
|
179 |
+
assert reduction == 'mean' and avg_factor is None
|
180 |
+
num_rois = pred.size()[0]
|
181 |
+
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
182 |
+
pred_slice = pred[inds, label].squeeze(1)
|
183 |
+
return F.binary_cross_entropy_with_logits(
|
184 |
+
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
185 |
+
|
186 |
+
|
187 |
+
class CrossEntropyLoss(nn.Module):
|
188 |
+
"""CrossEntropyLoss.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
192 |
+
of softmax. Defaults to False.
|
193 |
+
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
194 |
+
Defaults to False.
|
195 |
+
reduction (str, optional): . Defaults to 'mean'.
|
196 |
+
Options are "none", "mean" and "sum".
|
197 |
+
class_weight (list[float], optional): Weight of each class.
|
198 |
+
Defaults to None.
|
199 |
+
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(self,
|
203 |
+
use_sigmoid=False,
|
204 |
+
use_mask=False,
|
205 |
+
reduction='mean',
|
206 |
+
class_weight=None,
|
207 |
+
loss_weight=1.0):
|
208 |
+
super(CrossEntropyLoss, self).__init__()
|
209 |
+
assert (use_sigmoid is False) or (use_mask is False)
|
210 |
+
self.use_sigmoid = use_sigmoid
|
211 |
+
self.use_mask = use_mask
|
212 |
+
self.reduction = reduction
|
213 |
+
self.loss_weight = loss_weight
|
214 |
+
self.class_weight = class_weight
|
215 |
+
|
216 |
+
if self.use_sigmoid:
|
217 |
+
self.cls_criterion = binary_cross_entropy
|
218 |
+
elif self.use_mask:
|
219 |
+
self.cls_criterion = mask_cross_entropy
|
220 |
+
else:
|
221 |
+
self.cls_criterion = cross_entropy
|
222 |
+
|
223 |
+
def forward(self,
|
224 |
+
cls_score,
|
225 |
+
label,
|
226 |
+
weight=None,
|
227 |
+
avg_factor=None,
|
228 |
+
reduction_override=None,
|
229 |
+
**kwargs):
|
230 |
+
"""Forward function."""
|
231 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
232 |
+
reduction = (
|
233 |
+
reduction_override if reduction_override else self.reduction)
|
234 |
+
if self.class_weight is not None:
|
235 |
+
class_weight = cls_score.new_tensor(self.class_weight)
|
236 |
+
else:
|
237 |
+
class_weight = None
|
238 |
+
loss_cls = self.loss_weight * self.cls_criterion(
|
239 |
+
cls_score,
|
240 |
+
label,
|
241 |
+
weight,
|
242 |
+
class_weight=class_weight,
|
243 |
+
reduction=reduction,
|
244 |
+
avg_factor=avg_factor,
|
245 |
+
**kwargs)
|
246 |
+
return loss_cls
|
Text2Human/models/losses/segmentation_loss.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class BCELoss(nn.Module):
|
6 |
+
|
7 |
+
def forward(self, prediction, target):
|
8 |
+
loss = F.binary_cross_entropy_with_logits(prediction, target)
|
9 |
+
return loss, {}
|
10 |
+
|
11 |
+
|
12 |
+
class BCELossWithQuant(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, codebook_weight=1.):
|
15 |
+
super().__init__()
|
16 |
+
self.codebook_weight = codebook_weight
|
17 |
+
|
18 |
+
def forward(self, qloss, target, prediction, split):
|
19 |
+
bce_loss = F.binary_cross_entropy_with_logits(prediction, target)
|
20 |
+
loss = bce_loss + self.codebook_weight * qloss
|
21 |
+
return loss, {
|
22 |
+
"{}/total_loss".format(split): loss.clone().detach().mean(),
|
23 |
+
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
24 |
+
"{}/quant_loss".format(split): qloss.detach().mean()
|
25 |
+
}
|
Text2Human/models/losses/vqgan_loss.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def calculate_adaptive_weight(recon_loss, g_loss, last_layer, disc_weight_max):
|
6 |
+
recon_grads = torch.autograd.grad(
|
7 |
+
recon_loss, last_layer, retain_graph=True)[0]
|
8 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
9 |
+
|
10 |
+
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
|
11 |
+
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
|
12 |
+
return d_weight
|
13 |
+
|
14 |
+
|
15 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
16 |
+
if global_step < threshold:
|
17 |
+
weight = value
|
18 |
+
return weight
|
19 |
+
|
20 |
+
|
21 |
+
@torch.jit.script
|
22 |
+
def hinge_d_loss(logits_real, logits_fake):
|
23 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
24 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
25 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
26 |
+
return d_loss
|
27 |
+
|
28 |
+
|
29 |
+
def DiffAugment(x, policy='', channels_first=True):
|
30 |
+
if policy:
|
31 |
+
if not channels_first:
|
32 |
+
x = x.permute(0, 3, 1, 2)
|
33 |
+
for p in policy.split(','):
|
34 |
+
for f in AUGMENT_FNS[p]:
|
35 |
+
x = f(x)
|
36 |
+
if not channels_first:
|
37 |
+
x = x.permute(0, 2, 3, 1)
|
38 |
+
x = x.contiguous()
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
def rand_brightness(x):
|
43 |
+
x = x + (
|
44 |
+
torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
def rand_saturation(x):
|
49 |
+
x_mean = x.mean(dim=1, keepdim=True)
|
50 |
+
x = (x - x_mean) * (torch.rand(
|
51 |
+
x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
def rand_contrast(x):
|
56 |
+
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
|
57 |
+
x = (x - x_mean) * (torch.rand(
|
58 |
+
x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
def rand_translation(x, ratio=0.125):
|
63 |
+
shift_x, shift_y = int(x.size(2) * ratio +
|
64 |
+
0.5), int(x.size(3) * ratio + 0.5)
|
65 |
+
translation_x = torch.randint(
|
66 |
+
-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
|
67 |
+
translation_y = torch.randint(
|
68 |
+
-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
|
69 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
70 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
71 |
+
torch.arange(x.size(2), dtype=torch.long, device=x.device),
|
72 |
+
torch.arange(x.size(3), dtype=torch.long, device=x.device),
|
73 |
+
)
|
74 |
+
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
|
75 |
+
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
|
76 |
+
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
|
77 |
+
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x,
|
78 |
+
grid_y].permute(0, 3, 1, 2)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
def rand_cutout(x, ratio=0.5):
|
83 |
+
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
|
84 |
+
offset_x = torch.randint(
|
85 |
+
0,
|
86 |
+
x.size(2) + (1 - cutout_size[0] % 2),
|
87 |
+
size=[x.size(0), 1, 1],
|
88 |
+
device=x.device)
|
89 |
+
offset_y = torch.randint(
|
90 |
+
0,
|
91 |
+
x.size(3) + (1 - cutout_size[1] % 2),
|
92 |
+
size=[x.size(0), 1, 1],
|
93 |
+
device=x.device)
|
94 |
+
grid_batch, grid_x, grid_y = torch.meshgrid(
|
95 |
+
torch.arange(x.size(0), dtype=torch.long, device=x.device),
|
96 |
+
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
|
97 |
+
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
|
98 |
+
)
|
99 |
+
grid_x = torch.clamp(
|
100 |
+
grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
|
101 |
+
grid_y = torch.clamp(
|
102 |
+
grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
|
103 |
+
mask = torch.ones(
|
104 |
+
x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
|
105 |
+
mask[grid_batch, grid_x, grid_y] = 0
|
106 |
+
x = x * mask.unsqueeze(1)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
AUGMENT_FNS = {
|
111 |
+
'color': [rand_brightness, rand_saturation, rand_contrast],
|
112 |
+
'translation': [rand_translation],
|
113 |
+
'cutout': [rand_cutout],
|
114 |
+
}
|
Text2Human/models/parsing_gen_model.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import mmcv
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torchvision.utils import save_image
|
9 |
+
|
10 |
+
from models.archs.fcn_arch import FCNHead
|
11 |
+
from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
|
12 |
+
from models.archs.unet_arch import ShapeUNet
|
13 |
+
from models.losses.accuracy import accuracy
|
14 |
+
from models.losses.cross_entropy_loss import CrossEntropyLoss
|
15 |
+
|
16 |
+
logger = logging.getLogger('base')
|
17 |
+
|
18 |
+
|
19 |
+
class ParsingGenModel():
|
20 |
+
"""Paring Generation model.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
self.opt = opt
|
25 |
+
self.device = torch.device('cuda')
|
26 |
+
self.is_train = opt['is_train']
|
27 |
+
|
28 |
+
self.attr_embedder = ShapeAttrEmbedding(
|
29 |
+
dim=opt['embedder_dim'],
|
30 |
+
out_dim=opt['embedder_out_dim'],
|
31 |
+
cls_num_list=opt['attr_class_num']).to(self.device)
|
32 |
+
self.parsing_encoder = ShapeUNet(
|
33 |
+
in_channels=opt['encoder_in_channels']).to(self.device)
|
34 |
+
self.parsing_decoder = FCNHead(
|
35 |
+
in_channels=opt['fc_in_channels'],
|
36 |
+
in_index=opt['fc_in_index'],
|
37 |
+
channels=opt['fc_channels'],
|
38 |
+
num_convs=opt['fc_num_convs'],
|
39 |
+
concat_input=opt['fc_concat_input'],
|
40 |
+
dropout_ratio=opt['fc_dropout_ratio'],
|
41 |
+
num_classes=opt['fc_num_classes'],
|
42 |
+
align_corners=opt['fc_align_corners'],
|
43 |
+
).to(self.device)
|
44 |
+
|
45 |
+
self.init_training_settings()
|
46 |
+
|
47 |
+
self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
|
48 |
+
[250, 235, 215], [255, 250, 205], [211, 211, 211],
|
49 |
+
[70, 130, 180], [127, 255, 212], [0, 100, 0],
|
50 |
+
[50, 205, 50], [255, 255, 0], [245, 222, 179],
|
51 |
+
[255, 140, 0], [255, 0, 0], [16, 78, 139],
|
52 |
+
[144, 238, 144], [50, 205, 174], [50, 155, 250],
|
53 |
+
[160, 140, 88], [213, 140, 88], [90, 140, 90],
|
54 |
+
[185, 210, 205], [130, 165, 180], [225, 141, 151]]
|
55 |
+
|
56 |
+
def init_training_settings(self):
|
57 |
+
optim_params = []
|
58 |
+
for v in self.attr_embedder.parameters():
|
59 |
+
if v.requires_grad:
|
60 |
+
optim_params.append(v)
|
61 |
+
for v in self.parsing_encoder.parameters():
|
62 |
+
if v.requires_grad:
|
63 |
+
optim_params.append(v)
|
64 |
+
for v in self.parsing_decoder.parameters():
|
65 |
+
if v.requires_grad:
|
66 |
+
optim_params.append(v)
|
67 |
+
# set up optimizers
|
68 |
+
self.optimizer = torch.optim.Adam(
|
69 |
+
optim_params,
|
70 |
+
self.opt['lr'],
|
71 |
+
weight_decay=self.opt['weight_decay'])
|
72 |
+
self.log_dict = OrderedDict()
|
73 |
+
self.entropy_loss = CrossEntropyLoss().to(self.device)
|
74 |
+
|
75 |
+
def feed_data(self, data):
|
76 |
+
self.pose = data['densepose'].to(self.device)
|
77 |
+
self.attr = data['attr'].to(self.device)
|
78 |
+
self.segm = data['segm'].to(self.device)
|
79 |
+
|
80 |
+
def optimize_parameters(self):
|
81 |
+
self.attr_embedder.train()
|
82 |
+
self.parsing_encoder.train()
|
83 |
+
self.parsing_decoder.train()
|
84 |
+
|
85 |
+
self.attr_embedding = self.attr_embedder(self.attr)
|
86 |
+
self.pose_enc = self.parsing_encoder(self.pose, self.attr_embedding)
|
87 |
+
self.seg_logits = self.parsing_decoder(self.pose_enc)
|
88 |
+
|
89 |
+
loss = self.entropy_loss(self.seg_logits, self.segm)
|
90 |
+
|
91 |
+
self.optimizer.zero_grad()
|
92 |
+
loss.backward()
|
93 |
+
self.optimizer.step()
|
94 |
+
|
95 |
+
self.log_dict['loss_total'] = loss
|
96 |
+
|
97 |
+
def get_vis(self, save_path):
|
98 |
+
img_cat = torch.cat([
|
99 |
+
self.pose,
|
100 |
+
self.segm,
|
101 |
+
], dim=3).detach()
|
102 |
+
img_cat = ((img_cat + 1) / 2)
|
103 |
+
|
104 |
+
img_cat = img_cat.clamp_(0, 1)
|
105 |
+
|
106 |
+
save_image(img_cat, save_path, nrow=1, padding=4)
|
107 |
+
|
108 |
+
def inference(self, data_loader, save_dir):
|
109 |
+
self.attr_embedder.eval()
|
110 |
+
self.parsing_encoder.eval()
|
111 |
+
self.parsing_decoder.eval()
|
112 |
+
|
113 |
+
acc = 0
|
114 |
+
num = 0
|
115 |
+
|
116 |
+
for _, data in enumerate(data_loader):
|
117 |
+
pose = data['densepose'].to(self.device)
|
118 |
+
attr = data['attr'].to(self.device)
|
119 |
+
segm = data['segm'].to(self.device)
|
120 |
+
img_name = data['img_name']
|
121 |
+
|
122 |
+
num += pose.size(0)
|
123 |
+
with torch.no_grad():
|
124 |
+
attr_embedding = self.attr_embedder(attr)
|
125 |
+
pose_enc = self.parsing_encoder(pose, attr_embedding)
|
126 |
+
seg_logits = self.parsing_decoder(pose_enc)
|
127 |
+
seg_pred = seg_logits.argmax(dim=1)
|
128 |
+
acc += accuracy(seg_logits, segm)
|
129 |
+
palette_label = self.palette_result(segm.cpu().numpy())
|
130 |
+
palette_pred = self.palette_result(seg_pred.cpu().numpy())
|
131 |
+
pose_numpy = ((pose[0] + 1) / 2. * 255.).expand(
|
132 |
+
3,
|
133 |
+
pose[0].size(1),
|
134 |
+
pose[0].size(2),
|
135 |
+
).cpu().numpy().clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
|
136 |
+
concat_result = np.concatenate(
|
137 |
+
(pose_numpy, palette_pred, palette_label), axis=1)
|
138 |
+
mmcv.imwrite(concat_result, f'{save_dir}/{img_name[0]}')
|
139 |
+
|
140 |
+
self.attr_embedder.train()
|
141 |
+
self.parsing_encoder.train()
|
142 |
+
self.parsing_decoder.train()
|
143 |
+
return (acc / num).item()
|
144 |
+
|
145 |
+
def get_current_log(self):
|
146 |
+
return self.log_dict
|
147 |
+
|
148 |
+
def update_learning_rate(self, epoch):
|
149 |
+
"""Update learning rate.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
current_iter (int): Current iteration.
|
153 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
154 |
+
Default: -1.
|
155 |
+
"""
|
156 |
+
lr = self.optimizer.param_groups[0]['lr']
|
157 |
+
|
158 |
+
if self.opt['lr_decay'] == 'step':
|
159 |
+
lr = self.opt['lr'] * (
|
160 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
161 |
+
elif self.opt['lr_decay'] == 'cos':
|
162 |
+
lr = self.opt['lr'] * (
|
163 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
164 |
+
elif self.opt['lr_decay'] == 'linear':
|
165 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
166 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
167 |
+
if epoch < self.opt['turning_point'] + 1:
|
168 |
+
# learning rate decay as 95%
|
169 |
+
# at the turning point (1 / 95% = 1.0526)
|
170 |
+
lr = self.opt['lr'] * (
|
171 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
172 |
+
else:
|
173 |
+
lr *= self.opt['gamma']
|
174 |
+
elif self.opt['lr_decay'] == 'schedule':
|
175 |
+
if epoch in self.opt['schedule']:
|
176 |
+
lr *= self.opt['gamma']
|
177 |
+
else:
|
178 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
179 |
+
# set learning rate
|
180 |
+
for param_group in self.optimizer.param_groups:
|
181 |
+
param_group['lr'] = lr
|
182 |
+
|
183 |
+
return lr
|
184 |
+
|
185 |
+
def save_network(self, save_path):
|
186 |
+
"""Save networks.
|
187 |
+
"""
|
188 |
+
|
189 |
+
save_dict = {}
|
190 |
+
save_dict['embedder'] = self.attr_embedder.state_dict()
|
191 |
+
save_dict['encoder'] = self.parsing_encoder.state_dict()
|
192 |
+
save_dict['decoder'] = self.parsing_decoder.state_dict()
|
193 |
+
|
194 |
+
torch.save(save_dict, save_path)
|
195 |
+
|
196 |
+
def load_network(self):
|
197 |
+
checkpoint = torch.load(self.opt['pretrained_parsing_gen'])
|
198 |
+
|
199 |
+
self.attr_embedder.load_state_dict(checkpoint['embedder'], strict=True)
|
200 |
+
self.attr_embedder.eval()
|
201 |
+
|
202 |
+
self.parsing_encoder.load_state_dict(
|
203 |
+
checkpoint['encoder'], strict=True)
|
204 |
+
self.parsing_encoder.eval()
|
205 |
+
|
206 |
+
self.parsing_decoder.load_state_dict(
|
207 |
+
checkpoint['decoder'], strict=True)
|
208 |
+
self.parsing_decoder.eval()
|
209 |
+
|
210 |
+
def palette_result(self, result):
|
211 |
+
seg = result[0]
|
212 |
+
palette = np.array(self.palette)
|
213 |
+
assert palette.shape[1] == 3
|
214 |
+
assert len(palette.shape) == 2
|
215 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
216 |
+
for label, color in enumerate(palette):
|
217 |
+
color_seg[seg == label, :] = color
|
218 |
+
# convert to BGR
|
219 |
+
color_seg = color_seg[..., ::-1]
|
220 |
+
return color_seg
|
Text2Human/models/sample_model.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.distributions as dists
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchvision.utils import save_image
|
8 |
+
|
9 |
+
from models.archs.fcn_arch import FCNHead, MultiHeadFCNHead
|
10 |
+
from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding
|
11 |
+
from models.archs.transformer_arch import TransformerMultiHead
|
12 |
+
from models.archs.unet_arch import ShapeUNet, UNet
|
13 |
+
from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder,
|
14 |
+
VectorQuantizer,
|
15 |
+
VectorQuantizerSpatialTextureAware,
|
16 |
+
VectorQuantizerTexture)
|
17 |
+
|
18 |
+
logger = logging.getLogger('base')
|
19 |
+
|
20 |
+
|
21 |
+
class BaseSampleModel():
|
22 |
+
"""Base Model"""
|
23 |
+
|
24 |
+
def __init__(self, opt):
|
25 |
+
self.opt = opt
|
26 |
+
self.device = torch.device(opt['device'])
|
27 |
+
|
28 |
+
# hierarchical VQVAE
|
29 |
+
self.decoder = Decoder(
|
30 |
+
in_channels=opt['top_in_channels'],
|
31 |
+
resolution=opt['top_resolution'],
|
32 |
+
z_channels=opt['top_z_channels'],
|
33 |
+
ch=opt['top_ch'],
|
34 |
+
out_ch=opt['top_out_ch'],
|
35 |
+
num_res_blocks=opt['top_num_res_blocks'],
|
36 |
+
attn_resolutions=opt['top_attn_resolutions'],
|
37 |
+
ch_mult=opt['top_ch_mult'],
|
38 |
+
dropout=opt['top_dropout'],
|
39 |
+
resamp_with_conv=True,
|
40 |
+
give_pre_end=False).to(self.device)
|
41 |
+
self.top_quantize = VectorQuantizerTexture(
|
42 |
+
1024, opt['embed_dim'], beta=0.25).to(self.device)
|
43 |
+
self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
44 |
+
opt["top_z_channels"],
|
45 |
+
1).to(self.device)
|
46 |
+
self.load_top_pretrain_models()
|
47 |
+
|
48 |
+
self.bot_decoder_res = DecoderRes(
|
49 |
+
in_channels=opt['bot_in_channels'],
|
50 |
+
resolution=opt['bot_resolution'],
|
51 |
+
z_channels=opt['bot_z_channels'],
|
52 |
+
ch=opt['bot_ch'],
|
53 |
+
num_res_blocks=opt['bot_num_res_blocks'],
|
54 |
+
ch_mult=opt['bot_ch_mult'],
|
55 |
+
dropout=opt['bot_dropout'],
|
56 |
+
give_pre_end=False).to(self.device)
|
57 |
+
self.bot_quantize = VectorQuantizerSpatialTextureAware(
|
58 |
+
opt['bot_n_embed'],
|
59 |
+
opt['embed_dim'],
|
60 |
+
beta=0.25,
|
61 |
+
spatial_size=opt['bot_codebook_spatial_size']).to(self.device)
|
62 |
+
self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
63 |
+
opt["bot_z_channels"],
|
64 |
+
1).to(self.device)
|
65 |
+
self.load_bot_pretrain_network()
|
66 |
+
|
67 |
+
# top -> bot prediction
|
68 |
+
self.index_pred_guidance_encoder = UNet(
|
69 |
+
in_channels=opt['index_pred_encoder_in_channels']).to(self.device)
|
70 |
+
self.index_pred_decoder = MultiHeadFCNHead(
|
71 |
+
in_channels=opt['index_pred_fc_in_channels'],
|
72 |
+
in_index=opt['index_pred_fc_in_index'],
|
73 |
+
channels=opt['index_pred_fc_channels'],
|
74 |
+
num_convs=opt['index_pred_fc_num_convs'],
|
75 |
+
concat_input=opt['index_pred_fc_concat_input'],
|
76 |
+
dropout_ratio=opt['index_pred_fc_dropout_ratio'],
|
77 |
+
num_classes=opt['index_pred_fc_num_classes'],
|
78 |
+
align_corners=opt['index_pred_fc_align_corners'],
|
79 |
+
num_head=18).to(self.device)
|
80 |
+
self.load_index_pred_network()
|
81 |
+
|
82 |
+
# VAE for segmentation mask
|
83 |
+
self.segm_encoder = Encoder(
|
84 |
+
ch=opt['segm_ch'],
|
85 |
+
num_res_blocks=opt['segm_num_res_blocks'],
|
86 |
+
attn_resolutions=opt['segm_attn_resolutions'],
|
87 |
+
ch_mult=opt['segm_ch_mult'],
|
88 |
+
in_channels=opt['segm_in_channels'],
|
89 |
+
resolution=opt['segm_resolution'],
|
90 |
+
z_channels=opt['segm_z_channels'],
|
91 |
+
double_z=opt['segm_double_z'],
|
92 |
+
dropout=opt['segm_dropout']).to(self.device)
|
93 |
+
self.segm_quantizer = VectorQuantizer(
|
94 |
+
opt['segm_n_embed'],
|
95 |
+
opt['segm_embed_dim'],
|
96 |
+
beta=0.25,
|
97 |
+
sane_index_shape=True).to(self.device)
|
98 |
+
self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
|
99 |
+
opt['segm_embed_dim'],
|
100 |
+
1).to(self.device)
|
101 |
+
self.load_pretrained_segm_token()
|
102 |
+
|
103 |
+
# define sampler
|
104 |
+
self.sampler_fn = TransformerMultiHead(
|
105 |
+
codebook_size=opt['codebook_size'],
|
106 |
+
segm_codebook_size=opt['segm_codebook_size'],
|
107 |
+
texture_codebook_size=opt['texture_codebook_size'],
|
108 |
+
bert_n_emb=opt['bert_n_emb'],
|
109 |
+
bert_n_layers=opt['bert_n_layers'],
|
110 |
+
bert_n_head=opt['bert_n_head'],
|
111 |
+
block_size=opt['block_size'],
|
112 |
+
latent_shape=opt['latent_shape'],
|
113 |
+
embd_pdrop=opt['embd_pdrop'],
|
114 |
+
resid_pdrop=opt['resid_pdrop'],
|
115 |
+
attn_pdrop=opt['attn_pdrop'],
|
116 |
+
num_head=opt['num_head']).to(self.device)
|
117 |
+
self.load_sampler_pretrained_network()
|
118 |
+
|
119 |
+
self.shape = tuple(opt['latent_shape'])
|
120 |
+
|
121 |
+
self.mask_id = opt['codebook_size']
|
122 |
+
self.sample_steps = opt['sample_steps']
|
123 |
+
|
124 |
+
def load_top_pretrain_models(self):
|
125 |
+
# load pretrained vqgan
|
126 |
+
top_vae_checkpoint = torch.load(self.opt['top_vae_path'],map_location=torch.device('cpu'))
|
127 |
+
|
128 |
+
self.decoder.load_state_dict(
|
129 |
+
top_vae_checkpoint['decoder'], strict=True)
|
130 |
+
self.top_quantize.load_state_dict(
|
131 |
+
top_vae_checkpoint['quantize'], strict=True)
|
132 |
+
self.top_post_quant_conv.load_state_dict(
|
133 |
+
top_vae_checkpoint['post_quant_conv'], strict=True)
|
134 |
+
|
135 |
+
self.decoder.eval()
|
136 |
+
self.top_quantize.eval()
|
137 |
+
self.top_post_quant_conv.eval()
|
138 |
+
|
139 |
+
def load_bot_pretrain_network(self):
|
140 |
+
checkpoint = torch.load(self.opt['bot_vae_path'],map_location=torch.device('cpu'))
|
141 |
+
self.bot_decoder_res.load_state_dict(
|
142 |
+
checkpoint['bot_decoder_res'], strict=True)
|
143 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
144 |
+
self.bot_quantize.load_state_dict(
|
145 |
+
checkpoint['bot_quantize'], strict=True)
|
146 |
+
self.bot_post_quant_conv.load_state_dict(
|
147 |
+
checkpoint['bot_post_quant_conv'], strict=True)
|
148 |
+
|
149 |
+
self.bot_decoder_res.eval()
|
150 |
+
self.decoder.eval()
|
151 |
+
self.bot_quantize.eval()
|
152 |
+
self.bot_post_quant_conv.eval()
|
153 |
+
|
154 |
+
def load_pretrained_segm_token(self):
|
155 |
+
# load pretrained vqgan for segmentation mask
|
156 |
+
segm_token_checkpoint = torch.load(self.opt['segm_token_path'],map_location=torch.device('cpu'))
|
157 |
+
self.segm_encoder.load_state_dict(
|
158 |
+
segm_token_checkpoint['encoder'], strict=True)
|
159 |
+
self.segm_quantizer.load_state_dict(
|
160 |
+
segm_token_checkpoint['quantize'], strict=True)
|
161 |
+
self.segm_quant_conv.load_state_dict(
|
162 |
+
segm_token_checkpoint['quant_conv'], strict=True)
|
163 |
+
|
164 |
+
self.segm_encoder.eval()
|
165 |
+
self.segm_quantizer.eval()
|
166 |
+
self.segm_quant_conv.eval()
|
167 |
+
|
168 |
+
def load_index_pred_network(self):
|
169 |
+
checkpoint = torch.load(self.opt['pretrained_index_network'],map_location=torch.device('cpu'))
|
170 |
+
self.index_pred_guidance_encoder.load_state_dict(
|
171 |
+
checkpoint['guidance_encoder'], strict=True)
|
172 |
+
self.index_pred_decoder.load_state_dict(
|
173 |
+
checkpoint['index_decoder'], strict=True)
|
174 |
+
|
175 |
+
self.index_pred_guidance_encoder.eval()
|
176 |
+
self.index_pred_decoder.eval()
|
177 |
+
|
178 |
+
def load_sampler_pretrained_network(self):
|
179 |
+
checkpoint = torch.load(self.opt['pretrained_sampler'],map_location=torch.device('cpu'))
|
180 |
+
self.sampler_fn.load_state_dict(checkpoint, strict=True)
|
181 |
+
self.sampler_fn.eval()
|
182 |
+
|
183 |
+
def bot_index_prediction(self, feature_top, texture_mask):
|
184 |
+
self.index_pred_guidance_encoder.eval()
|
185 |
+
self.index_pred_decoder.eval()
|
186 |
+
|
187 |
+
texture_tokens = F.interpolate(
|
188 |
+
texture_mask, (32, 16), mode='nearest').view(self.batch_size,
|
189 |
+
-1).long()
|
190 |
+
|
191 |
+
texture_mask_flatten = texture_tokens.view(-1)
|
192 |
+
min_encodings_indices_list = [
|
193 |
+
torch.full(
|
194 |
+
texture_mask_flatten.size(),
|
195 |
+
fill_value=-1,
|
196 |
+
dtype=torch.long,
|
197 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
198 |
+
]
|
199 |
+
with torch.no_grad():
|
200 |
+
feature_enc = self.index_pred_guidance_encoder(feature_top)
|
201 |
+
memory_logits_list = self.index_pred_decoder(feature_enc)
|
202 |
+
for codebook_idx, memory_logits in enumerate(memory_logits_list):
|
203 |
+
region_of_interest = texture_mask_flatten == codebook_idx
|
204 |
+
if torch.sum(region_of_interest) > 0:
|
205 |
+
memory_indices_pred = memory_logits.argmax(dim=1).view(-1)
|
206 |
+
memory_indices_pred = memory_indices_pred
|
207 |
+
min_encodings_indices_list[codebook_idx][
|
208 |
+
region_of_interest] = memory_indices_pred[
|
209 |
+
region_of_interest]
|
210 |
+
min_encodings_indices_return_list = [
|
211 |
+
min_encodings_indices.view((1, 32, 16))
|
212 |
+
for min_encodings_indices in min_encodings_indices_list
|
213 |
+
]
|
214 |
+
|
215 |
+
return min_encodings_indices_return_list
|
216 |
+
|
217 |
+
def sample_and_refine(self, save_dir=None, img_name=None):
|
218 |
+
# sample 32x16 features indices
|
219 |
+
sampled_top_indices_list = self.sample_fn(
|
220 |
+
temp=1, sample_steps=self.sample_steps)
|
221 |
+
|
222 |
+
for sample_idx in range(self.batch_size):
|
223 |
+
sample_indices = [
|
224 |
+
sampled_indices_cur[sample_idx:sample_idx + 1]
|
225 |
+
for sampled_indices_cur in sampled_top_indices_list
|
226 |
+
]
|
227 |
+
top_quant = self.top_quantize.get_codebook_entry(
|
228 |
+
sample_indices, self.texture_mask[sample_idx:sample_idx + 1],
|
229 |
+
(sample_indices[0].size(0), self.shape[0], self.shape[1],
|
230 |
+
self.opt["top_z_channels"]))
|
231 |
+
|
232 |
+
top_quant = self.top_post_quant_conv(top_quant)
|
233 |
+
|
234 |
+
bot_indices_list = self.bot_index_prediction(
|
235 |
+
top_quant, self.texture_mask[sample_idx:sample_idx + 1])
|
236 |
+
|
237 |
+
quant_bot = self.bot_quantize.get_codebook_entry(
|
238 |
+
bot_indices_list, self.texture_mask[sample_idx:sample_idx + 1],
|
239 |
+
(bot_indices_list[0].size(0), bot_indices_list[0].size(1),
|
240 |
+
bot_indices_list[0].size(2),
|
241 |
+
self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2)
|
242 |
+
quant_bot = self.bot_post_quant_conv(quant_bot)
|
243 |
+
bot_dec_res = self.bot_decoder_res(quant_bot)
|
244 |
+
|
245 |
+
dec = self.decoder(top_quant, bot_h=bot_dec_res)
|
246 |
+
|
247 |
+
dec = ((dec + 1) / 2)
|
248 |
+
dec = dec.clamp_(0, 1)
|
249 |
+
if save_dir is None and img_name is None:
|
250 |
+
return dec
|
251 |
+
else:
|
252 |
+
save_image(
|
253 |
+
dec,
|
254 |
+
f'{save_dir}/{img_name[sample_idx]}',
|
255 |
+
nrow=1,
|
256 |
+
padding=4)
|
257 |
+
|
258 |
+
def sample_fn(self, temp=1.0, sample_steps=None):
|
259 |
+
self.sampler_fn.eval()
|
260 |
+
|
261 |
+
x_t = torch.ones((self.batch_size, np.prod(self.shape)),
|
262 |
+
device=self.device).long() * self.mask_id
|
263 |
+
unmasked = torch.zeros_like(x_t, device=self.device).bool()
|
264 |
+
sample_steps = list(range(1, sample_steps + 1))
|
265 |
+
|
266 |
+
texture_tokens = F.interpolate(
|
267 |
+
self.texture_mask, (32, 16),
|
268 |
+
mode='nearest').view(self.batch_size, -1).long()
|
269 |
+
|
270 |
+
texture_mask_flatten = texture_tokens.view(-1)
|
271 |
+
|
272 |
+
# min_encodings_indices_list would be used to visualize the image
|
273 |
+
min_encodings_indices_list = [
|
274 |
+
torch.full(
|
275 |
+
texture_mask_flatten.size(),
|
276 |
+
fill_value=-1,
|
277 |
+
dtype=torch.long,
|
278 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
279 |
+
]
|
280 |
+
|
281 |
+
for t in reversed(sample_steps):
|
282 |
+
t = torch.full((self.batch_size, ),
|
283 |
+
t,
|
284 |
+
device=self.device,
|
285 |
+
dtype=torch.long)
|
286 |
+
|
287 |
+
# where to unmask
|
288 |
+
changes = torch.rand(
|
289 |
+
x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1)
|
290 |
+
# don't unmask somewhere already unmasked
|
291 |
+
changes = torch.bitwise_xor(changes,
|
292 |
+
torch.bitwise_and(changes, unmasked))
|
293 |
+
# update mask with changes
|
294 |
+
unmasked = torch.bitwise_or(unmasked, changes)
|
295 |
+
|
296 |
+
x_0_logits_list = self.sampler_fn(
|
297 |
+
x_t, self.segm_tokens, texture_tokens, t=t)
|
298 |
+
|
299 |
+
changes_flatten = changes.view(-1)
|
300 |
+
ori_shape = x_t.shape # [b, h*w]
|
301 |
+
x_t = x_t.view(-1) # [b*h*w]
|
302 |
+
for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
|
303 |
+
if torch.sum(texture_mask_flatten[changes_flatten] ==
|
304 |
+
codebook_idx) > 0:
|
305 |
+
# scale by temperature
|
306 |
+
x_0_logits = x_0_logits / temp
|
307 |
+
x_0_dist = dists.Categorical(logits=x_0_logits)
|
308 |
+
x_0_hat = x_0_dist.sample().long()
|
309 |
+
x_0_hat = x_0_hat.view(-1)
|
310 |
+
|
311 |
+
# only replace the changed indices with corresponding codebook_idx
|
312 |
+
changes_segm = torch.bitwise_and(
|
313 |
+
changes_flatten, texture_mask_flatten == codebook_idx)
|
314 |
+
|
315 |
+
# x_t would be the input to the transformer, so the index range should be continual one
|
316 |
+
x_t[changes_segm] = x_0_hat[
|
317 |
+
changes_segm] + 1024 * codebook_idx
|
318 |
+
min_encodings_indices_list[codebook_idx][
|
319 |
+
changes_segm] = x_0_hat[changes_segm]
|
320 |
+
|
321 |
+
x_t = x_t.view(ori_shape) # [b, h*w]
|
322 |
+
|
323 |
+
min_encodings_indices_return_list = [
|
324 |
+
min_encodings_indices.view(ori_shape)
|
325 |
+
for min_encodings_indices in min_encodings_indices_list
|
326 |
+
]
|
327 |
+
|
328 |
+
self.sampler_fn.train()
|
329 |
+
|
330 |
+
return min_encodings_indices_return_list
|
331 |
+
|
332 |
+
@torch.no_grad()
|
333 |
+
def get_quantized_segm(self, segm):
|
334 |
+
segm_one_hot = F.one_hot(
|
335 |
+
segm.squeeze(1).long(),
|
336 |
+
num_classes=self.opt['segm_num_segm_classes']).permute(
|
337 |
+
0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
338 |
+
encoded_segm_mask = self.segm_encoder(segm_one_hot)
|
339 |
+
encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
|
340 |
+
_, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
|
341 |
+
|
342 |
+
return segm_tokens
|
343 |
+
|
344 |
+
|
345 |
+
class SampleFromParsingModel(BaseSampleModel):
|
346 |
+
"""SampleFromParsing model.
|
347 |
+
"""
|
348 |
+
|
349 |
+
def feed_data(self, data):
|
350 |
+
self.segm = data['segm'].to(self.device)
|
351 |
+
self.texture_mask = data['texture_mask'].to(self.device)
|
352 |
+
self.batch_size = self.segm.size(0)
|
353 |
+
|
354 |
+
self.segm_tokens = self.get_quantized_segm(self.segm)
|
355 |
+
self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
|
356 |
+
|
357 |
+
def inference(self, data_loader, save_dir):
|
358 |
+
for _, data in enumerate(data_loader):
|
359 |
+
img_name = data['img_name']
|
360 |
+
self.feed_data(data)
|
361 |
+
with torch.no_grad():
|
362 |
+
self.sample_and_refine(save_dir, img_name)
|
363 |
+
|
364 |
+
|
365 |
+
class SampleFromPoseModel(BaseSampleModel):
|
366 |
+
"""SampleFromPose model.
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(self, opt):
|
370 |
+
super().__init__(opt)
|
371 |
+
# pose-to-parsing
|
372 |
+
self.shape_attr_embedder = ShapeAttrEmbedding(
|
373 |
+
dim=opt['shape_embedder_dim'],
|
374 |
+
out_dim=opt['shape_embedder_out_dim'],
|
375 |
+
cls_num_list=opt['shape_attr_class_num']).to(self.device)
|
376 |
+
self.shape_parsing_encoder = ShapeUNet(
|
377 |
+
in_channels=opt['shape_encoder_in_channels']).to(self.device)
|
378 |
+
self.shape_parsing_decoder = FCNHead(
|
379 |
+
in_channels=opt['shape_fc_in_channels'],
|
380 |
+
in_index=opt['shape_fc_in_index'],
|
381 |
+
channels=opt['shape_fc_channels'],
|
382 |
+
num_convs=opt['shape_fc_num_convs'],
|
383 |
+
concat_input=opt['shape_fc_concat_input'],
|
384 |
+
dropout_ratio=opt['shape_fc_dropout_ratio'],
|
385 |
+
num_classes=opt['shape_fc_num_classes'],
|
386 |
+
align_corners=opt['shape_fc_align_corners'],
|
387 |
+
).to(self.device)
|
388 |
+
self.load_shape_generation_models()
|
389 |
+
|
390 |
+
self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220],
|
391 |
+
[250, 235, 215], [255, 250, 205], [211, 211, 211],
|
392 |
+
[70, 130, 180], [127, 255, 212], [0, 100, 0],
|
393 |
+
[50, 205, 50], [255, 255, 0], [245, 222, 179],
|
394 |
+
[255, 140, 0], [255, 0, 0], [16, 78, 139],
|
395 |
+
[144, 238, 144], [50, 205, 174], [50, 155, 250],
|
396 |
+
[160, 140, 88], [213, 140, 88], [90, 140, 90],
|
397 |
+
[185, 210, 205], [130, 165, 180], [225, 141, 151]]
|
398 |
+
|
399 |
+
def load_shape_generation_models(self):
|
400 |
+
checkpoint = torch.load(self.opt['pretrained_parsing_gen'],map_location=torch.device('cpu'))
|
401 |
+
|
402 |
+
self.shape_attr_embedder.load_state_dict(
|
403 |
+
checkpoint['embedder'], strict=True)
|
404 |
+
self.shape_attr_embedder.eval()
|
405 |
+
|
406 |
+
self.shape_parsing_encoder.load_state_dict(
|
407 |
+
checkpoint['encoder'], strict=True)
|
408 |
+
self.shape_parsing_encoder.eval()
|
409 |
+
|
410 |
+
self.shape_parsing_decoder.load_state_dict(
|
411 |
+
checkpoint['decoder'], strict=True)
|
412 |
+
self.shape_parsing_decoder.eval()
|
413 |
+
|
414 |
+
def feed_data(self, data):
|
415 |
+
self.pose = data['densepose'].to(self.device)
|
416 |
+
self.batch_size = self.pose.size(0)
|
417 |
+
|
418 |
+
self.shape_attr = data['shape_attr'].to(self.device)
|
419 |
+
self.upper_fused_attr = data['upper_fused_attr'].to(self.device)
|
420 |
+
self.lower_fused_attr = data['lower_fused_attr'].to(self.device)
|
421 |
+
self.outer_fused_attr = data['outer_fused_attr'].to(self.device)
|
422 |
+
|
423 |
+
def inference(self, data_loader, save_dir):
|
424 |
+
for _, data in enumerate(data_loader):
|
425 |
+
img_name = data['img_name']
|
426 |
+
self.feed_data(data)
|
427 |
+
with torch.no_grad():
|
428 |
+
self.generate_parsing_map()
|
429 |
+
self.generate_quantized_segm()
|
430 |
+
self.generate_texture_map()
|
431 |
+
self.sample_and_refine(save_dir, img_name)
|
432 |
+
|
433 |
+
def generate_parsing_map(self):
|
434 |
+
with torch.no_grad():
|
435 |
+
attr_embedding = self.shape_attr_embedder(self.shape_attr)
|
436 |
+
pose_enc = self.shape_parsing_encoder(self.pose, attr_embedding)
|
437 |
+
seg_logits = self.shape_parsing_decoder(pose_enc)
|
438 |
+
self.segm = seg_logits.argmax(dim=1)
|
439 |
+
self.segm = self.segm.unsqueeze(1)
|
440 |
+
|
441 |
+
def generate_quantized_segm(self):
|
442 |
+
self.segm_tokens = self.get_quantized_segm(self.segm)
|
443 |
+
self.segm_tokens = self.segm_tokens.view(self.batch_size, -1)
|
444 |
+
|
445 |
+
def generate_texture_map(self):
|
446 |
+
upper_cls = [1., 4.]
|
447 |
+
lower_cls = [3., 5., 21.]
|
448 |
+
outer_cls = [2.]
|
449 |
+
|
450 |
+
mask_batch = []
|
451 |
+
for idx in range(self.batch_size):
|
452 |
+
mask = torch.zeros_like(self.segm[idx])
|
453 |
+
upper_fused_attr = self.upper_fused_attr[idx]
|
454 |
+
lower_fused_attr = self.lower_fused_attr[idx]
|
455 |
+
outer_fused_attr = self.outer_fused_attr[idx]
|
456 |
+
if upper_fused_attr != 17:
|
457 |
+
for cls in upper_cls:
|
458 |
+
mask[self.segm[idx] == cls] = upper_fused_attr + 1
|
459 |
+
|
460 |
+
if lower_fused_attr != 17:
|
461 |
+
for cls in lower_cls:
|
462 |
+
mask[self.segm[idx] == cls] = lower_fused_attr + 1
|
463 |
+
|
464 |
+
if outer_fused_attr != 17:
|
465 |
+
for cls in outer_cls:
|
466 |
+
mask[self.segm[idx] == cls] = outer_fused_attr + 1
|
467 |
+
|
468 |
+
mask_batch.append(mask)
|
469 |
+
self.texture_mask = torch.stack(mask_batch, dim=0).to(torch.float32)
|
470 |
+
|
471 |
+
def feed_pose_data(self, pose_img):
|
472 |
+
# for ui demo
|
473 |
+
|
474 |
+
self.pose = pose_img.to(self.device)
|
475 |
+
self.batch_size = self.pose.size(0)
|
476 |
+
|
477 |
+
def feed_shape_attributes(self, shape_attr):
|
478 |
+
# for ui demo
|
479 |
+
|
480 |
+
self.shape_attr = shape_attr.to(self.device)
|
481 |
+
|
482 |
+
def feed_texture_attributes(self, texture_attr):
|
483 |
+
# for ui demo
|
484 |
+
|
485 |
+
self.upper_fused_attr = texture_attr[0].unsqueeze(0).to(self.device)
|
486 |
+
self.lower_fused_attr = texture_attr[1].unsqueeze(0).to(self.device)
|
487 |
+
self.outer_fused_attr = texture_attr[2].unsqueeze(0).to(self.device)
|
488 |
+
|
489 |
+
def palette_result(self, result):
|
490 |
+
|
491 |
+
seg = result[0]
|
492 |
+
palette = np.array(self.palette)
|
493 |
+
assert palette.shape[1] == 3
|
494 |
+
assert len(palette.shape) == 2
|
495 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
496 |
+
for label, color in enumerate(palette):
|
497 |
+
color_seg[seg == label, :] = color
|
498 |
+
# convert to BGR
|
499 |
+
# color_seg = color_seg[..., ::-1]
|
500 |
+
return color_seg
|
Text2Human/models/transformer_model.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.distributions as dists
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
|
11 |
+
from models.archs.transformer_arch import TransformerMultiHead
|
12 |
+
from models.archs.vqgan_arch import (Decoder, Encoder, VectorQuantizer,
|
13 |
+
VectorQuantizerTexture)
|
14 |
+
|
15 |
+
logger = logging.getLogger('base')
|
16 |
+
|
17 |
+
|
18 |
+
class TransformerTextureAwareModel():
|
19 |
+
"""Texture-Aware Diffusion based Transformer model.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, opt):
|
23 |
+
self.opt = opt
|
24 |
+
self.device = torch.device('cuda')
|
25 |
+
self.is_train = opt['is_train']
|
26 |
+
|
27 |
+
# VQVAE for image
|
28 |
+
self.img_encoder = Encoder(
|
29 |
+
ch=opt['img_ch'],
|
30 |
+
num_res_blocks=opt['img_num_res_blocks'],
|
31 |
+
attn_resolutions=opt['img_attn_resolutions'],
|
32 |
+
ch_mult=opt['img_ch_mult'],
|
33 |
+
in_channels=opt['img_in_channels'],
|
34 |
+
resolution=opt['img_resolution'],
|
35 |
+
z_channels=opt['img_z_channels'],
|
36 |
+
double_z=opt['img_double_z'],
|
37 |
+
dropout=opt['img_dropout']).to(self.device)
|
38 |
+
self.img_decoder = Decoder(
|
39 |
+
in_channels=opt['img_in_channels'],
|
40 |
+
resolution=opt['img_resolution'],
|
41 |
+
z_channels=opt['img_z_channels'],
|
42 |
+
ch=opt['img_ch'],
|
43 |
+
out_ch=opt['img_out_ch'],
|
44 |
+
num_res_blocks=opt['img_num_res_blocks'],
|
45 |
+
attn_resolutions=opt['img_attn_resolutions'],
|
46 |
+
ch_mult=opt['img_ch_mult'],
|
47 |
+
dropout=opt['img_dropout'],
|
48 |
+
resamp_with_conv=True,
|
49 |
+
give_pre_end=False).to(self.device)
|
50 |
+
self.img_quantizer = VectorQuantizerTexture(
|
51 |
+
opt['img_n_embed'], opt['img_embed_dim'],
|
52 |
+
beta=0.25).to(self.device)
|
53 |
+
self.img_quant_conv = torch.nn.Conv2d(opt["img_z_channels"],
|
54 |
+
opt['img_embed_dim'],
|
55 |
+
1).to(self.device)
|
56 |
+
self.img_post_quant_conv = torch.nn.Conv2d(opt['img_embed_dim'],
|
57 |
+
opt["img_z_channels"],
|
58 |
+
1).to(self.device)
|
59 |
+
self.load_pretrained_image_vae()
|
60 |
+
|
61 |
+
# VAE for segmentation mask
|
62 |
+
self.segm_encoder = Encoder(
|
63 |
+
ch=opt['segm_ch'],
|
64 |
+
num_res_blocks=opt['segm_num_res_blocks'],
|
65 |
+
attn_resolutions=opt['segm_attn_resolutions'],
|
66 |
+
ch_mult=opt['segm_ch_mult'],
|
67 |
+
in_channels=opt['segm_in_channels'],
|
68 |
+
resolution=opt['segm_resolution'],
|
69 |
+
z_channels=opt['segm_z_channels'],
|
70 |
+
double_z=opt['segm_double_z'],
|
71 |
+
dropout=opt['segm_dropout']).to(self.device)
|
72 |
+
self.segm_quantizer = VectorQuantizer(
|
73 |
+
opt['segm_n_embed'],
|
74 |
+
opt['segm_embed_dim'],
|
75 |
+
beta=0.25,
|
76 |
+
sane_index_shape=True).to(self.device)
|
77 |
+
self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"],
|
78 |
+
opt['segm_embed_dim'],
|
79 |
+
1).to(self.device)
|
80 |
+
self.load_pretrained_segm_vae()
|
81 |
+
|
82 |
+
# define sampler
|
83 |
+
self._denoise_fn = TransformerMultiHead(
|
84 |
+
codebook_size=opt['codebook_size'],
|
85 |
+
segm_codebook_size=opt['segm_codebook_size'],
|
86 |
+
texture_codebook_size=opt['texture_codebook_size'],
|
87 |
+
bert_n_emb=opt['bert_n_emb'],
|
88 |
+
bert_n_layers=opt['bert_n_layers'],
|
89 |
+
bert_n_head=opt['bert_n_head'],
|
90 |
+
block_size=opt['block_size'],
|
91 |
+
latent_shape=opt['latent_shape'],
|
92 |
+
embd_pdrop=opt['embd_pdrop'],
|
93 |
+
resid_pdrop=opt['resid_pdrop'],
|
94 |
+
attn_pdrop=opt['attn_pdrop'],
|
95 |
+
num_head=opt['num_head']).to(self.device)
|
96 |
+
|
97 |
+
self.num_classes = opt['codebook_size']
|
98 |
+
self.shape = tuple(opt['latent_shape'])
|
99 |
+
self.num_timesteps = 1000
|
100 |
+
|
101 |
+
self.mask_id = opt['codebook_size']
|
102 |
+
self.loss_type = opt['loss_type']
|
103 |
+
self.mask_schedule = opt['mask_schedule']
|
104 |
+
|
105 |
+
self.sample_steps = opt['sample_steps']
|
106 |
+
|
107 |
+
self.init_training_settings()
|
108 |
+
|
109 |
+
def load_pretrained_image_vae(self):
|
110 |
+
# load pretrained vqgan for segmentation mask
|
111 |
+
img_ae_checkpoint = torch.load(self.opt['img_ae_path'])
|
112 |
+
self.img_encoder.load_state_dict(
|
113 |
+
img_ae_checkpoint['encoder'], strict=True)
|
114 |
+
self.img_decoder.load_state_dict(
|
115 |
+
img_ae_checkpoint['decoder'], strict=True)
|
116 |
+
self.img_quantizer.load_state_dict(
|
117 |
+
img_ae_checkpoint['quantize'], strict=True)
|
118 |
+
self.img_quant_conv.load_state_dict(
|
119 |
+
img_ae_checkpoint['quant_conv'], strict=True)
|
120 |
+
self.img_post_quant_conv.load_state_dict(
|
121 |
+
img_ae_checkpoint['post_quant_conv'], strict=True)
|
122 |
+
self.img_encoder.eval()
|
123 |
+
self.img_decoder.eval()
|
124 |
+
self.img_quantizer.eval()
|
125 |
+
self.img_quant_conv.eval()
|
126 |
+
self.img_post_quant_conv.eval()
|
127 |
+
|
128 |
+
def load_pretrained_segm_vae(self):
|
129 |
+
# load pretrained vqgan for segmentation mask
|
130 |
+
segm_ae_checkpoint = torch.load(self.opt['segm_ae_path'])
|
131 |
+
self.segm_encoder.load_state_dict(
|
132 |
+
segm_ae_checkpoint['encoder'], strict=True)
|
133 |
+
self.segm_quantizer.load_state_dict(
|
134 |
+
segm_ae_checkpoint['quantize'], strict=True)
|
135 |
+
self.segm_quant_conv.load_state_dict(
|
136 |
+
segm_ae_checkpoint['quant_conv'], strict=True)
|
137 |
+
self.segm_encoder.eval()
|
138 |
+
self.segm_quantizer.eval()
|
139 |
+
self.segm_quant_conv.eval()
|
140 |
+
|
141 |
+
def init_training_settings(self):
|
142 |
+
optim_params = []
|
143 |
+
for v in self._denoise_fn.parameters():
|
144 |
+
if v.requires_grad:
|
145 |
+
optim_params.append(v)
|
146 |
+
# set up optimizer
|
147 |
+
self.optimizer = torch.optim.Adam(
|
148 |
+
optim_params,
|
149 |
+
self.opt['lr'],
|
150 |
+
weight_decay=self.opt['weight_decay'])
|
151 |
+
self.log_dict = OrderedDict()
|
152 |
+
|
153 |
+
@torch.no_grad()
|
154 |
+
def get_quantized_img(self, image, texture_mask):
|
155 |
+
encoded_img = self.img_encoder(image)
|
156 |
+
encoded_img = self.img_quant_conv(encoded_img)
|
157 |
+
|
158 |
+
# img_tokens_input is the continual index for the input of transformer
|
159 |
+
# img_tokens_gt_list is the index for 18 texture-aware codebooks respectively
|
160 |
+
_, _, [_, img_tokens_input, img_tokens_gt_list
|
161 |
+
] = self.img_quantizer(encoded_img, texture_mask)
|
162 |
+
|
163 |
+
# reshape the tokens
|
164 |
+
b = image.size(0)
|
165 |
+
img_tokens_input = img_tokens_input.view(b, -1)
|
166 |
+
img_tokens_gt_return_list = [
|
167 |
+
img_tokens_gt.view(b, -1) for img_tokens_gt in img_tokens_gt_list
|
168 |
+
]
|
169 |
+
|
170 |
+
return img_tokens_input, img_tokens_gt_return_list
|
171 |
+
|
172 |
+
@torch.no_grad()
|
173 |
+
def decode(self, quant):
|
174 |
+
quant = self.img_post_quant_conv(quant)
|
175 |
+
dec = self.img_decoder(quant)
|
176 |
+
return dec
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def decode_image_indices(self, indices_list, texture_mask):
|
180 |
+
quant = self.img_quantizer.get_codebook_entry(
|
181 |
+
indices_list, texture_mask,
|
182 |
+
(indices_list[0].size(0), self.shape[0], self.shape[1],
|
183 |
+
self.opt["img_z_channels"]))
|
184 |
+
dec = self.decode(quant)
|
185 |
+
|
186 |
+
return dec
|
187 |
+
|
188 |
+
def sample_time(self, b, device, method='uniform'):
|
189 |
+
if method == 'importance':
|
190 |
+
if not (self.Lt_count > 10).all():
|
191 |
+
return self.sample_time(b, device, method='uniform')
|
192 |
+
|
193 |
+
Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
|
194 |
+
Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
|
195 |
+
pt_all = Lt_sqrt / Lt_sqrt.sum()
|
196 |
+
|
197 |
+
t = torch.multinomial(pt_all, num_samples=b, replacement=True)
|
198 |
+
|
199 |
+
pt = pt_all.gather(dim=0, index=t)
|
200 |
+
|
201 |
+
return t, pt
|
202 |
+
|
203 |
+
elif method == 'uniform':
|
204 |
+
t = torch.randint(
|
205 |
+
1, self.num_timesteps + 1, (b, ), device=device).long()
|
206 |
+
pt = torch.ones_like(t).float() / self.num_timesteps
|
207 |
+
return t, pt
|
208 |
+
|
209 |
+
else:
|
210 |
+
raise ValueError
|
211 |
+
|
212 |
+
def q_sample(self, x_0, x_0_gt_list, t):
|
213 |
+
# samples q(x_t | x_0)
|
214 |
+
# randomly set token to mask with probability t/T
|
215 |
+
# x_t, x_0_ignore = x_0.clone(), x_0.clone()
|
216 |
+
x_t = x_0.clone()
|
217 |
+
|
218 |
+
mask = torch.rand_like(x_t.float()) < (
|
219 |
+
t.float().unsqueeze(-1) / self.num_timesteps)
|
220 |
+
x_t[mask] = self.mask_id
|
221 |
+
# x_0_ignore[torch.bitwise_not(mask)] = -1
|
222 |
+
|
223 |
+
# for every gt token list, we also need to do the mask
|
224 |
+
x_0_gt_ignore_list = []
|
225 |
+
for x_0_gt in x_0_gt_list:
|
226 |
+
x_0_gt_ignore = x_0_gt.clone()
|
227 |
+
x_0_gt_ignore[torch.bitwise_not(mask)] = -1
|
228 |
+
x_0_gt_ignore_list.append(x_0_gt_ignore)
|
229 |
+
|
230 |
+
return x_t, x_0_gt_ignore_list, mask
|
231 |
+
|
232 |
+
def _train_loss(self, x_0, x_0_gt_list):
|
233 |
+
b, device = x_0.size(0), x_0.device
|
234 |
+
|
235 |
+
# choose what time steps to compute loss at
|
236 |
+
t, pt = self.sample_time(b, device, 'uniform')
|
237 |
+
|
238 |
+
# make x noisy and denoise
|
239 |
+
if self.mask_schedule == 'random':
|
240 |
+
x_t, x_0_gt_ignore_list, mask = self.q_sample(
|
241 |
+
x_0=x_0, x_0_gt_list=x_0_gt_list, t=t)
|
242 |
+
else:
|
243 |
+
raise NotImplementedError
|
244 |
+
|
245 |
+
# sample p(x_0 | x_t)
|
246 |
+
x_0_hat_logits_list = self._denoise_fn(
|
247 |
+
x_t, self.segm_tokens, self.texture_tokens, t=t)
|
248 |
+
|
249 |
+
# Always compute ELBO for comparison purposes
|
250 |
+
cross_entropy_loss = 0
|
251 |
+
for x_0_hat_logits, x_0_gt_ignore in zip(x_0_hat_logits_list,
|
252 |
+
x_0_gt_ignore_list):
|
253 |
+
cross_entropy_loss += F.cross_entropy(
|
254 |
+
x_0_hat_logits.permute(0, 2, 1),
|
255 |
+
x_0_gt_ignore,
|
256 |
+
ignore_index=-1,
|
257 |
+
reduction='none').sum(1)
|
258 |
+
vb_loss = cross_entropy_loss / t
|
259 |
+
vb_loss = vb_loss / pt
|
260 |
+
vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel())
|
261 |
+
if self.loss_type == 'elbo':
|
262 |
+
loss = vb_loss
|
263 |
+
elif self.loss_type == 'mlm':
|
264 |
+
denom = mask.float().sum(1)
|
265 |
+
denom[denom == 0] = 1 # prevent divide by 0 errors.
|
266 |
+
loss = cross_entropy_loss / denom
|
267 |
+
elif self.loss_type == 'reweighted_elbo':
|
268 |
+
weight = (1 - (t / self.num_timesteps))
|
269 |
+
loss = weight * cross_entropy_loss
|
270 |
+
loss = loss / (math.log(2) * x_0.shape[1:].numel())
|
271 |
+
else:
|
272 |
+
raise ValueError
|
273 |
+
|
274 |
+
return loss.mean(), vb_loss.mean()
|
275 |
+
|
276 |
+
def feed_data(self, data):
|
277 |
+
self.image = data['image'].to(self.device)
|
278 |
+
self.segm = data['segm'].to(self.device)
|
279 |
+
self.texture_mask = data['texture_mask'].to(self.device)
|
280 |
+
self.input_indices, self.gt_indices_list = self.get_quantized_img(
|
281 |
+
self.image, self.texture_mask)
|
282 |
+
|
283 |
+
self.texture_tokens = F.interpolate(
|
284 |
+
self.texture_mask, size=self.shape,
|
285 |
+
mode='nearest').view(self.image.size(0), -1).long()
|
286 |
+
|
287 |
+
self.segm_tokens = self.get_quantized_segm(self.segm)
|
288 |
+
self.segm_tokens = self.segm_tokens.view(self.image.size(0), -1)
|
289 |
+
|
290 |
+
def optimize_parameters(self):
|
291 |
+
self._denoise_fn.train()
|
292 |
+
|
293 |
+
loss, vb_loss = self._train_loss(self.input_indices,
|
294 |
+
self.gt_indices_list)
|
295 |
+
|
296 |
+
self.optimizer.zero_grad()
|
297 |
+
loss.backward()
|
298 |
+
self.optimizer.step()
|
299 |
+
|
300 |
+
self.log_dict['loss'] = loss
|
301 |
+
self.log_dict['vb_loss'] = vb_loss
|
302 |
+
|
303 |
+
self._denoise_fn.eval()
|
304 |
+
|
305 |
+
@torch.no_grad()
|
306 |
+
def get_quantized_segm(self, segm):
|
307 |
+
segm_one_hot = F.one_hot(
|
308 |
+
segm.squeeze(1).long(),
|
309 |
+
num_classes=self.opt['segm_num_segm_classes']).permute(
|
310 |
+
0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
311 |
+
encoded_segm_mask = self.segm_encoder(segm_one_hot)
|
312 |
+
encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask)
|
313 |
+
_, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask)
|
314 |
+
|
315 |
+
return segm_tokens
|
316 |
+
|
317 |
+
def sample_fn(self, temp=1.0, sample_steps=None):
|
318 |
+
self._denoise_fn.eval()
|
319 |
+
|
320 |
+
b, device = self.image.size(0), 'cuda'
|
321 |
+
x_t = torch.ones(
|
322 |
+
(b, np.prod(self.shape)), device=device).long() * self.mask_id
|
323 |
+
unmasked = torch.zeros_like(x_t, device=device).bool()
|
324 |
+
sample_steps = list(range(1, sample_steps + 1))
|
325 |
+
|
326 |
+
texture_mask_flatten = self.texture_tokens.view(-1)
|
327 |
+
|
328 |
+
# min_encodings_indices_list would be used to visualize the image
|
329 |
+
min_encodings_indices_list = [
|
330 |
+
torch.full(
|
331 |
+
texture_mask_flatten.size(),
|
332 |
+
fill_value=-1,
|
333 |
+
dtype=torch.long,
|
334 |
+
device=texture_mask_flatten.device) for _ in range(18)
|
335 |
+
]
|
336 |
+
|
337 |
+
for t in reversed(sample_steps):
|
338 |
+
print(f'Sample timestep {t:4d}', end='\r')
|
339 |
+
t = torch.full((b, ), t, device=device, dtype=torch.long)
|
340 |
+
|
341 |
+
# where to unmask
|
342 |
+
changes = torch.rand(
|
343 |
+
x_t.shape, device=device) < 1 / t.float().unsqueeze(-1)
|
344 |
+
# don't unmask somewhere already unmasked
|
345 |
+
changes = torch.bitwise_xor(changes,
|
346 |
+
torch.bitwise_and(changes, unmasked))
|
347 |
+
# update mask with changes
|
348 |
+
unmasked = torch.bitwise_or(unmasked, changes)
|
349 |
+
|
350 |
+
x_0_logits_list = self._denoise_fn(
|
351 |
+
x_t, self.segm_tokens, self.texture_tokens, t=t)
|
352 |
+
|
353 |
+
changes_flatten = changes.view(-1)
|
354 |
+
ori_shape = x_t.shape # [b, h*w]
|
355 |
+
x_t = x_t.view(-1) # [b*h*w]
|
356 |
+
for codebook_idx, x_0_logits in enumerate(x_0_logits_list):
|
357 |
+
if torch.sum(texture_mask_flatten[changes_flatten] ==
|
358 |
+
codebook_idx) > 0:
|
359 |
+
# scale by temperature
|
360 |
+
x_0_logits = x_0_logits / temp
|
361 |
+
x_0_dist = dists.Categorical(logits=x_0_logits)
|
362 |
+
x_0_hat = x_0_dist.sample().long()
|
363 |
+
x_0_hat = x_0_hat.view(-1)
|
364 |
+
|
365 |
+
# only replace the changed indices with corresponding codebook_idx
|
366 |
+
changes_segm = torch.bitwise_and(
|
367 |
+
changes_flatten, texture_mask_flatten == codebook_idx)
|
368 |
+
|
369 |
+
# x_t would be the input to the transformer, so the index range should be continual one
|
370 |
+
x_t[changes_segm] = x_0_hat[
|
371 |
+
changes_segm] + 1024 * codebook_idx
|
372 |
+
min_encodings_indices_list[codebook_idx][
|
373 |
+
changes_segm] = x_0_hat[changes_segm]
|
374 |
+
|
375 |
+
x_t = x_t.view(ori_shape) # [b, h*w]
|
376 |
+
|
377 |
+
min_encodings_indices_return_list = [
|
378 |
+
min_encodings_indices.view(ori_shape)
|
379 |
+
for min_encodings_indices in min_encodings_indices_list
|
380 |
+
]
|
381 |
+
|
382 |
+
self._denoise_fn.train()
|
383 |
+
|
384 |
+
return min_encodings_indices_return_list
|
385 |
+
|
386 |
+
def get_vis(self, image, gt_indices, predicted_indices, texture_mask,
|
387 |
+
save_path):
|
388 |
+
# original image
|
389 |
+
ori_img = self.decode_image_indices(gt_indices, texture_mask)
|
390 |
+
# pred image
|
391 |
+
pred_img = self.decode_image_indices(predicted_indices, texture_mask)
|
392 |
+
img_cat = torch.cat([
|
393 |
+
image,
|
394 |
+
ori_img,
|
395 |
+
pred_img,
|
396 |
+
], dim=3).detach()
|
397 |
+
img_cat = ((img_cat + 1) / 2)
|
398 |
+
img_cat = img_cat.clamp_(0, 1)
|
399 |
+
save_image(img_cat, save_path, nrow=1, padding=4)
|
400 |
+
|
401 |
+
def inference(self, data_loader, save_dir):
|
402 |
+
self._denoise_fn.eval()
|
403 |
+
|
404 |
+
for _, data in enumerate(data_loader):
|
405 |
+
img_name = data['img_name']
|
406 |
+
self.feed_data(data)
|
407 |
+
b = self.image.size(0)
|
408 |
+
with torch.no_grad():
|
409 |
+
sampled_indices_list = self.sample_fn(
|
410 |
+
temp=1, sample_steps=self.sample_steps)
|
411 |
+
for idx in range(b):
|
412 |
+
self.get_vis(self.image[idx:idx + 1], [
|
413 |
+
gt_indices[idx:idx + 1]
|
414 |
+
for gt_indices in self.gt_indices_list
|
415 |
+
], [
|
416 |
+
sampled_indices[idx:idx + 1]
|
417 |
+
for sampled_indices in sampled_indices_list
|
418 |
+
], self.texture_mask[idx:idx + 1],
|
419 |
+
f'{save_dir}/{img_name[idx]}')
|
420 |
+
|
421 |
+
self._denoise_fn.train()
|
422 |
+
|
423 |
+
def get_current_log(self):
|
424 |
+
return self.log_dict
|
425 |
+
|
426 |
+
def update_learning_rate(self, epoch, iters=None):
|
427 |
+
"""Update learning rate.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
current_iter (int): Current iteration.
|
431 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
432 |
+
Default: -1.
|
433 |
+
"""
|
434 |
+
lr = self.optimizer.param_groups[0]['lr']
|
435 |
+
|
436 |
+
if self.opt['lr_decay'] == 'step':
|
437 |
+
lr = self.opt['lr'] * (
|
438 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
439 |
+
elif self.opt['lr_decay'] == 'cos':
|
440 |
+
lr = self.opt['lr'] * (
|
441 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
442 |
+
elif self.opt['lr_decay'] == 'linear':
|
443 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
444 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
445 |
+
if epoch < self.opt['turning_point'] + 1:
|
446 |
+
# learning rate decay as 95%
|
447 |
+
# at the turning point (1 / 95% = 1.0526)
|
448 |
+
lr = self.opt['lr'] * (
|
449 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
450 |
+
else:
|
451 |
+
lr *= self.opt['gamma']
|
452 |
+
elif self.opt['lr_decay'] == 'schedule':
|
453 |
+
if epoch in self.opt['schedule']:
|
454 |
+
lr *= self.opt['gamma']
|
455 |
+
elif self.opt['lr_decay'] == 'warm_up':
|
456 |
+
if iters <= self.opt['warmup_iters']:
|
457 |
+
lr = self.opt['lr'] * float(iters) / self.opt['warmup_iters']
|
458 |
+
else:
|
459 |
+
lr = self.opt['lr']
|
460 |
+
else:
|
461 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
462 |
+
# set learning rate
|
463 |
+
for param_group in self.optimizer.param_groups:
|
464 |
+
param_group['lr'] = lr
|
465 |
+
|
466 |
+
return lr
|
467 |
+
|
468 |
+
def save_network(self, net, save_path):
|
469 |
+
"""Save networks.
|
470 |
+
|
471 |
+
Args:
|
472 |
+
net (nn.Module): Network to be saved.
|
473 |
+
net_label (str): Network label.
|
474 |
+
current_iter (int): Current iter number.
|
475 |
+
"""
|
476 |
+
state_dict = net.state_dict()
|
477 |
+
torch.save(state_dict, save_path)
|
478 |
+
|
479 |
+
def load_network(self):
|
480 |
+
checkpoint = torch.load(self.opt['pretrained_sampler'])
|
481 |
+
self._denoise_fn.load_state_dict(checkpoint, strict=True)
|
482 |
+
self._denoise_fn.eval()
|
Text2Human/models/vqgan_model.py
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import sys
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
sys.path.append('..')
|
6 |
+
import lpips
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
|
11 |
+
from models.archs.vqgan_arch import (Decoder, Discriminator, Encoder,
|
12 |
+
VectorQuantizer, VectorQuantizerTexture)
|
13 |
+
from models.losses.segmentation_loss import BCELossWithQuant
|
14 |
+
from models.losses.vqgan_loss import (DiffAugment, adopt_weight,
|
15 |
+
calculate_adaptive_weight, hinge_d_loss)
|
16 |
+
|
17 |
+
|
18 |
+
class VQModel():
|
19 |
+
|
20 |
+
def __init__(self, opt):
|
21 |
+
super().__init__()
|
22 |
+
self.opt = opt
|
23 |
+
self.device = torch.device('cuda')
|
24 |
+
self.encoder = Encoder(
|
25 |
+
ch=opt['ch'],
|
26 |
+
num_res_blocks=opt['num_res_blocks'],
|
27 |
+
attn_resolutions=opt['attn_resolutions'],
|
28 |
+
ch_mult=opt['ch_mult'],
|
29 |
+
in_channels=opt['in_channels'],
|
30 |
+
resolution=opt['resolution'],
|
31 |
+
z_channels=opt['z_channels'],
|
32 |
+
double_z=opt['double_z'],
|
33 |
+
dropout=opt['dropout']).to(self.device)
|
34 |
+
self.decoder = Decoder(
|
35 |
+
in_channels=opt['in_channels'],
|
36 |
+
resolution=opt['resolution'],
|
37 |
+
z_channels=opt['z_channels'],
|
38 |
+
ch=opt['ch'],
|
39 |
+
out_ch=opt['out_ch'],
|
40 |
+
num_res_blocks=opt['num_res_blocks'],
|
41 |
+
attn_resolutions=opt['attn_resolutions'],
|
42 |
+
ch_mult=opt['ch_mult'],
|
43 |
+
dropout=opt['dropout'],
|
44 |
+
resamp_with_conv=True,
|
45 |
+
give_pre_end=False).to(self.device)
|
46 |
+
self.quantize = VectorQuantizer(
|
47 |
+
opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
|
48 |
+
self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
|
49 |
+
1).to(self.device)
|
50 |
+
self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
51 |
+
opt["z_channels"],
|
52 |
+
1).to(self.device)
|
53 |
+
|
54 |
+
def init_training_settings(self):
|
55 |
+
self.loss = BCELossWithQuant()
|
56 |
+
self.log_dict = OrderedDict()
|
57 |
+
self.configure_optimizers()
|
58 |
+
|
59 |
+
def save_network(self, save_path):
|
60 |
+
"""Save networks.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
net (nn.Module): Network to be saved.
|
64 |
+
net_label (str): Network label.
|
65 |
+
current_iter (int): Current iter number.
|
66 |
+
"""
|
67 |
+
|
68 |
+
save_dict = {}
|
69 |
+
save_dict['encoder'] = self.encoder.state_dict()
|
70 |
+
save_dict['decoder'] = self.decoder.state_dict()
|
71 |
+
save_dict['quantize'] = self.quantize.state_dict()
|
72 |
+
save_dict['quant_conv'] = self.quant_conv.state_dict()
|
73 |
+
save_dict['post_quant_conv'] = self.post_quant_conv.state_dict()
|
74 |
+
save_dict['discriminator'] = self.disc.state_dict()
|
75 |
+
torch.save(save_dict, save_path)
|
76 |
+
|
77 |
+
def load_network(self):
|
78 |
+
checkpoint = torch.load(self.opt['pretrained_models'])
|
79 |
+
self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
|
80 |
+
self.decoder.load_state_dict(checkpoint['decoder'], strict=True)
|
81 |
+
self.quantize.load_state_dict(checkpoint['quantize'], strict=True)
|
82 |
+
self.quant_conv.load_state_dict(checkpoint['quant_conv'], strict=True)
|
83 |
+
self.post_quant_conv.load_state_dict(
|
84 |
+
checkpoint['post_quant_conv'], strict=True)
|
85 |
+
|
86 |
+
def optimize_parameters(self, data, current_iter):
|
87 |
+
self.encoder.train()
|
88 |
+
self.decoder.train()
|
89 |
+
self.quantize.train()
|
90 |
+
self.quant_conv.train()
|
91 |
+
self.post_quant_conv.train()
|
92 |
+
|
93 |
+
loss = self.training_step(data)
|
94 |
+
self.optimizer.zero_grad()
|
95 |
+
loss.backward()
|
96 |
+
self.optimizer.step()
|
97 |
+
|
98 |
+
def encode(self, x):
|
99 |
+
h = self.encoder(x)
|
100 |
+
h = self.quant_conv(h)
|
101 |
+
quant, emb_loss, info = self.quantize(h)
|
102 |
+
return quant, emb_loss, info
|
103 |
+
|
104 |
+
def decode(self, quant):
|
105 |
+
quant = self.post_quant_conv(quant)
|
106 |
+
dec = self.decoder(quant)
|
107 |
+
return dec
|
108 |
+
|
109 |
+
def decode_code(self, code_b):
|
110 |
+
quant_b = self.quantize.embed_code(code_b)
|
111 |
+
dec = self.decode(quant_b)
|
112 |
+
return dec
|
113 |
+
|
114 |
+
def forward_step(self, input):
|
115 |
+
quant, diff, _ = self.encode(input)
|
116 |
+
dec = self.decode(quant)
|
117 |
+
return dec, diff
|
118 |
+
|
119 |
+
def feed_data(self, data):
|
120 |
+
x = data['segm']
|
121 |
+
x = F.one_hot(x, num_classes=self.opt['num_segm_classes'])
|
122 |
+
|
123 |
+
if len(x.shape) == 3:
|
124 |
+
x = x[..., None]
|
125 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
126 |
+
return x.float().to(self.device)
|
127 |
+
|
128 |
+
def get_current_log(self):
|
129 |
+
return self.log_dict
|
130 |
+
|
131 |
+
def update_learning_rate(self, epoch):
|
132 |
+
"""Update learning rate.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
current_iter (int): Current iteration.
|
136 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
137 |
+
Default: -1.
|
138 |
+
"""
|
139 |
+
lr = self.optimizer.param_groups[0]['lr']
|
140 |
+
|
141 |
+
if self.opt['lr_decay'] == 'step':
|
142 |
+
lr = self.opt['lr'] * (
|
143 |
+
self.opt['gamma']**(epoch // self.opt['step']))
|
144 |
+
elif self.opt['lr_decay'] == 'cos':
|
145 |
+
lr = self.opt['lr'] * (
|
146 |
+
1 + math.cos(math.pi * epoch / self.opt['num_epochs'])) / 2
|
147 |
+
elif self.opt['lr_decay'] == 'linear':
|
148 |
+
lr = self.opt['lr'] * (1 - epoch / self.opt['num_epochs'])
|
149 |
+
elif self.opt['lr_decay'] == 'linear2exp':
|
150 |
+
if epoch < self.opt['turning_point'] + 1:
|
151 |
+
# learning rate decay as 95%
|
152 |
+
# at the turning point (1 / 95% = 1.0526)
|
153 |
+
lr = self.opt['lr'] * (
|
154 |
+
1 - epoch / int(self.opt['turning_point'] * 1.0526))
|
155 |
+
else:
|
156 |
+
lr *= self.opt['gamma']
|
157 |
+
elif self.opt['lr_decay'] == 'schedule':
|
158 |
+
if epoch in self.opt['schedule']:
|
159 |
+
lr *= self.opt['gamma']
|
160 |
+
else:
|
161 |
+
raise ValueError('Unknown lr mode {}'.format(self.opt['lr_decay']))
|
162 |
+
# set learning rate
|
163 |
+
for param_group in self.optimizer.param_groups:
|
164 |
+
param_group['lr'] = lr
|
165 |
+
|
166 |
+
return lr
|
167 |
+
|
168 |
+
|
169 |
+
class VQSegmentationModel(VQModel):
|
170 |
+
|
171 |
+
def __init__(self, opt):
|
172 |
+
super().__init__(opt)
|
173 |
+
self.colorize = torch.randn(3, opt['num_segm_classes'], 1,
|
174 |
+
1).to(self.device)
|
175 |
+
|
176 |
+
self.init_training_settings()
|
177 |
+
|
178 |
+
def configure_optimizers(self):
|
179 |
+
self.optimizer = torch.optim.Adam(
|
180 |
+
list(self.encoder.parameters()) + list(self.decoder.parameters()) +
|
181 |
+
list(self.quantize.parameters()) +
|
182 |
+
list(self.quant_conv.parameters()) +
|
183 |
+
list(self.post_quant_conv.parameters()),
|
184 |
+
lr=self.opt['lr'],
|
185 |
+
betas=(0.5, 0.9))
|
186 |
+
|
187 |
+
def training_step(self, data):
|
188 |
+
x = self.feed_data(data)
|
189 |
+
xrec, qloss = self.forward_step(x)
|
190 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
|
191 |
+
self.log_dict.update(log_dict_ae)
|
192 |
+
return aeloss
|
193 |
+
|
194 |
+
def to_rgb(self, x):
|
195 |
+
x = F.conv2d(x, weight=self.colorize)
|
196 |
+
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
197 |
+
return x
|
198 |
+
|
199 |
+
@torch.no_grad()
|
200 |
+
def inference(self, data_loader, save_dir):
|
201 |
+
self.encoder.eval()
|
202 |
+
self.decoder.eval()
|
203 |
+
self.quantize.eval()
|
204 |
+
self.quant_conv.eval()
|
205 |
+
self.post_quant_conv.eval()
|
206 |
+
|
207 |
+
loss_total = 0
|
208 |
+
loss_bce = 0
|
209 |
+
loss_quant = 0
|
210 |
+
num = 0
|
211 |
+
|
212 |
+
for _, data in enumerate(data_loader):
|
213 |
+
img_name = data['img_name'][0]
|
214 |
+
x = self.feed_data(data)
|
215 |
+
xrec, qloss = self.forward_step(x)
|
216 |
+
_, log_dict_ae = self.loss(qloss, x, xrec, split="val")
|
217 |
+
|
218 |
+
loss_total += log_dict_ae['val/total_loss']
|
219 |
+
loss_bce += log_dict_ae['val/bce_loss']
|
220 |
+
loss_quant += log_dict_ae['val/quant_loss']
|
221 |
+
|
222 |
+
num += x.size(0)
|
223 |
+
|
224 |
+
if x.shape[1] > 3:
|
225 |
+
# colorize with random projection
|
226 |
+
assert xrec.shape[1] > 3
|
227 |
+
# convert logits to indices
|
228 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
229 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
230 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
231 |
+
x = self.to_rgb(x)
|
232 |
+
xrec = self.to_rgb(xrec)
|
233 |
+
|
234 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
235 |
+
img_cat = ((img_cat + 1) / 2)
|
236 |
+
img_cat = img_cat.clamp_(0, 1)
|
237 |
+
save_image(
|
238 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
239 |
+
|
240 |
+
return (loss_total / num).item(), (loss_bce /
|
241 |
+
num).item(), (loss_quant /
|
242 |
+
num).item()
|
243 |
+
|
244 |
+
|
245 |
+
class VQImageModel(VQModel):
|
246 |
+
|
247 |
+
def __init__(self, opt):
|
248 |
+
super().__init__(opt)
|
249 |
+
self.disc = Discriminator(
|
250 |
+
opt['n_channels'], opt['ndf'],
|
251 |
+
n_layers=opt['disc_layers']).to(self.device)
|
252 |
+
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
|
253 |
+
self.perceptual_weight = opt['perceptual_weight']
|
254 |
+
self.disc_start_step = opt['disc_start_step']
|
255 |
+
self.disc_weight_max = opt['disc_weight_max']
|
256 |
+
self.diff_aug = opt['diff_aug']
|
257 |
+
self.policy = "color,translation"
|
258 |
+
|
259 |
+
self.disc.train()
|
260 |
+
|
261 |
+
self.init_training_settings()
|
262 |
+
|
263 |
+
def feed_data(self, data):
|
264 |
+
x = data['image']
|
265 |
+
|
266 |
+
return x.float().to(self.device)
|
267 |
+
|
268 |
+
def init_training_settings(self):
|
269 |
+
self.log_dict = OrderedDict()
|
270 |
+
self.configure_optimizers()
|
271 |
+
|
272 |
+
def configure_optimizers(self):
|
273 |
+
self.optimizer = torch.optim.Adam(
|
274 |
+
list(self.encoder.parameters()) + list(self.decoder.parameters()) +
|
275 |
+
list(self.quantize.parameters()) +
|
276 |
+
list(self.quant_conv.parameters()) +
|
277 |
+
list(self.post_quant_conv.parameters()),
|
278 |
+
lr=self.opt['lr'])
|
279 |
+
|
280 |
+
self.disc_optimizer = torch.optim.Adam(
|
281 |
+
self.disc.parameters(), lr=self.opt['lr'])
|
282 |
+
|
283 |
+
def training_step(self, data, step):
|
284 |
+
x = self.feed_data(data)
|
285 |
+
xrec, codebook_loss = self.forward_step(x)
|
286 |
+
|
287 |
+
# get recon/perceptual loss
|
288 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
289 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
290 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
291 |
+
nll_loss = torch.mean(nll_loss)
|
292 |
+
|
293 |
+
# augment for input to discriminator
|
294 |
+
if self.diff_aug:
|
295 |
+
xrec = DiffAugment(xrec, policy=self.policy)
|
296 |
+
|
297 |
+
# update generator
|
298 |
+
logits_fake = self.disc(xrec)
|
299 |
+
g_loss = -torch.mean(logits_fake)
|
300 |
+
last_layer = self.decoder.conv_out.weight
|
301 |
+
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
|
302 |
+
self.disc_weight_max)
|
303 |
+
d_weight *= adopt_weight(1, step, self.disc_start_step)
|
304 |
+
loss = nll_loss + d_weight * g_loss + codebook_loss
|
305 |
+
|
306 |
+
self.log_dict["loss"] = loss
|
307 |
+
self.log_dict["l1"] = recon_loss.mean().item()
|
308 |
+
self.log_dict["perceptual"] = p_loss.mean().item()
|
309 |
+
self.log_dict["nll_loss"] = nll_loss.item()
|
310 |
+
self.log_dict["g_loss"] = g_loss.item()
|
311 |
+
self.log_dict["d_weight"] = d_weight
|
312 |
+
self.log_dict["codebook_loss"] = codebook_loss.item()
|
313 |
+
|
314 |
+
if step > self.disc_start_step:
|
315 |
+
if self.diff_aug:
|
316 |
+
logits_real = self.disc(
|
317 |
+
DiffAugment(x.contiguous().detach(), policy=self.policy))
|
318 |
+
else:
|
319 |
+
logits_real = self.disc(x.contiguous().detach())
|
320 |
+
logits_fake = self.disc(xrec.contiguous().detach(
|
321 |
+
)) # detach so that generator isn"t also updated
|
322 |
+
d_loss = hinge_d_loss(logits_real, logits_fake)
|
323 |
+
self.log_dict["d_loss"] = d_loss
|
324 |
+
else:
|
325 |
+
d_loss = None
|
326 |
+
|
327 |
+
return loss, d_loss
|
328 |
+
|
329 |
+
def optimize_parameters(self, data, step):
|
330 |
+
self.encoder.train()
|
331 |
+
self.decoder.train()
|
332 |
+
self.quantize.train()
|
333 |
+
self.quant_conv.train()
|
334 |
+
self.post_quant_conv.train()
|
335 |
+
|
336 |
+
loss, d_loss = self.training_step(data, step)
|
337 |
+
self.optimizer.zero_grad()
|
338 |
+
loss.backward()
|
339 |
+
self.optimizer.step()
|
340 |
+
|
341 |
+
if step > self.disc_start_step:
|
342 |
+
self.disc_optimizer.zero_grad()
|
343 |
+
d_loss.backward()
|
344 |
+
self.disc_optimizer.step()
|
345 |
+
|
346 |
+
@torch.no_grad()
|
347 |
+
def inference(self, data_loader, save_dir):
|
348 |
+
self.encoder.eval()
|
349 |
+
self.decoder.eval()
|
350 |
+
self.quantize.eval()
|
351 |
+
self.quant_conv.eval()
|
352 |
+
self.post_quant_conv.eval()
|
353 |
+
|
354 |
+
loss_total = 0
|
355 |
+
num = 0
|
356 |
+
|
357 |
+
for _, data in enumerate(data_loader):
|
358 |
+
img_name = data['img_name'][0]
|
359 |
+
x = self.feed_data(data)
|
360 |
+
xrec, _ = self.forward_step(x)
|
361 |
+
|
362 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
363 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
364 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
365 |
+
nll_loss = torch.mean(nll_loss)
|
366 |
+
loss_total += nll_loss
|
367 |
+
|
368 |
+
num += x.size(0)
|
369 |
+
|
370 |
+
if x.shape[1] > 3:
|
371 |
+
# colorize with random projection
|
372 |
+
assert xrec.shape[1] > 3
|
373 |
+
# convert logits to indices
|
374 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
375 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
376 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
377 |
+
x = self.to_rgb(x)
|
378 |
+
xrec = self.to_rgb(xrec)
|
379 |
+
|
380 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
381 |
+
img_cat = ((img_cat + 1) / 2)
|
382 |
+
img_cat = img_cat.clamp_(0, 1)
|
383 |
+
save_image(
|
384 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
385 |
+
|
386 |
+
return (loss_total / num).item()
|
387 |
+
|
388 |
+
|
389 |
+
class VQImageSegmTextureModel(VQImageModel):
|
390 |
+
|
391 |
+
def __init__(self, opt):
|
392 |
+
self.opt = opt
|
393 |
+
self.device = torch.device('cuda')
|
394 |
+
self.encoder = Encoder(
|
395 |
+
ch=opt['ch'],
|
396 |
+
num_res_blocks=opt['num_res_blocks'],
|
397 |
+
attn_resolutions=opt['attn_resolutions'],
|
398 |
+
ch_mult=opt['ch_mult'],
|
399 |
+
in_channels=opt['in_channels'],
|
400 |
+
resolution=opt['resolution'],
|
401 |
+
z_channels=opt['z_channels'],
|
402 |
+
double_z=opt['double_z'],
|
403 |
+
dropout=opt['dropout']).to(self.device)
|
404 |
+
self.decoder = Decoder(
|
405 |
+
in_channels=opt['in_channels'],
|
406 |
+
resolution=opt['resolution'],
|
407 |
+
z_channels=opt['z_channels'],
|
408 |
+
ch=opt['ch'],
|
409 |
+
out_ch=opt['out_ch'],
|
410 |
+
num_res_blocks=opt['num_res_blocks'],
|
411 |
+
attn_resolutions=opt['attn_resolutions'],
|
412 |
+
ch_mult=opt['ch_mult'],
|
413 |
+
dropout=opt['dropout'],
|
414 |
+
resamp_with_conv=True,
|
415 |
+
give_pre_end=False).to(self.device)
|
416 |
+
self.quantize = VectorQuantizerTexture(
|
417 |
+
opt['n_embed'], opt['embed_dim'], beta=0.25).to(self.device)
|
418 |
+
self.quant_conv = torch.nn.Conv2d(opt["z_channels"], opt['embed_dim'],
|
419 |
+
1).to(self.device)
|
420 |
+
self.post_quant_conv = torch.nn.Conv2d(opt['embed_dim'],
|
421 |
+
opt["z_channels"],
|
422 |
+
1).to(self.device)
|
423 |
+
|
424 |
+
self.disc = Discriminator(
|
425 |
+
opt['n_channels'], opt['ndf'],
|
426 |
+
n_layers=opt['disc_layers']).to(self.device)
|
427 |
+
self.perceptual = lpips.LPIPS(net="vgg").to(self.device)
|
428 |
+
self.perceptual_weight = opt['perceptual_weight']
|
429 |
+
self.disc_start_step = opt['disc_start_step']
|
430 |
+
self.disc_weight_max = opt['disc_weight_max']
|
431 |
+
self.diff_aug = opt['diff_aug']
|
432 |
+
self.policy = "color,translation"
|
433 |
+
|
434 |
+
self.disc.train()
|
435 |
+
|
436 |
+
self.init_training_settings()
|
437 |
+
|
438 |
+
def feed_data(self, data):
|
439 |
+
x = data['image'].float().to(self.device)
|
440 |
+
mask = data['texture_mask'].float().to(self.device)
|
441 |
+
|
442 |
+
return x, mask
|
443 |
+
|
444 |
+
def training_step(self, data, step):
|
445 |
+
x, mask = self.feed_data(data)
|
446 |
+
xrec, codebook_loss = self.forward_step(x, mask)
|
447 |
+
|
448 |
+
# get recon/perceptual loss
|
449 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
450 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
451 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
452 |
+
nll_loss = torch.mean(nll_loss)
|
453 |
+
|
454 |
+
# augment for input to discriminator
|
455 |
+
if self.diff_aug:
|
456 |
+
xrec = DiffAugment(xrec, policy=self.policy)
|
457 |
+
|
458 |
+
# update generator
|
459 |
+
logits_fake = self.disc(xrec)
|
460 |
+
g_loss = -torch.mean(logits_fake)
|
461 |
+
last_layer = self.decoder.conv_out.weight
|
462 |
+
d_weight = calculate_adaptive_weight(nll_loss, g_loss, last_layer,
|
463 |
+
self.disc_weight_max)
|
464 |
+
d_weight *= adopt_weight(1, step, self.disc_start_step)
|
465 |
+
loss = nll_loss + d_weight * g_loss + codebook_loss
|
466 |
+
|
467 |
+
self.log_dict["loss"] = loss
|
468 |
+
self.log_dict["l1"] = recon_loss.mean().item()
|
469 |
+
self.log_dict["perceptual"] = p_loss.mean().item()
|
470 |
+
self.log_dict["nll_loss"] = nll_loss.item()
|
471 |
+
self.log_dict["g_loss"] = g_loss.item()
|
472 |
+
self.log_dict["d_weight"] = d_weight
|
473 |
+
self.log_dict["codebook_loss"] = codebook_loss.item()
|
474 |
+
|
475 |
+
if step > self.disc_start_step:
|
476 |
+
if self.diff_aug:
|
477 |
+
logits_real = self.disc(
|
478 |
+
DiffAugment(x.contiguous().detach(), policy=self.policy))
|
479 |
+
else:
|
480 |
+
logits_real = self.disc(x.contiguous().detach())
|
481 |
+
logits_fake = self.disc(xrec.contiguous().detach(
|
482 |
+
)) # detach so that generator isn"t also updated
|
483 |
+
d_loss = hinge_d_loss(logits_real, logits_fake)
|
484 |
+
self.log_dict["d_loss"] = d_loss
|
485 |
+
else:
|
486 |
+
d_loss = None
|
487 |
+
|
488 |
+
return loss, d_loss
|
489 |
+
|
490 |
+
@torch.no_grad()
|
491 |
+
def inference(self, data_loader, save_dir):
|
492 |
+
self.encoder.eval()
|
493 |
+
self.decoder.eval()
|
494 |
+
self.quantize.eval()
|
495 |
+
self.quant_conv.eval()
|
496 |
+
self.post_quant_conv.eval()
|
497 |
+
|
498 |
+
loss_total = 0
|
499 |
+
num = 0
|
500 |
+
|
501 |
+
for _, data in enumerate(data_loader):
|
502 |
+
img_name = data['img_name'][0]
|
503 |
+
x, mask = self.feed_data(data)
|
504 |
+
xrec, _ = self.forward_step(x, mask)
|
505 |
+
|
506 |
+
recon_loss = torch.abs(x.contiguous() - xrec.contiguous())
|
507 |
+
p_loss = self.perceptual(x.contiguous(), xrec.contiguous())
|
508 |
+
nll_loss = recon_loss + self.perceptual_weight * p_loss
|
509 |
+
nll_loss = torch.mean(nll_loss)
|
510 |
+
loss_total += nll_loss
|
511 |
+
|
512 |
+
num += x.size(0)
|
513 |
+
|
514 |
+
if x.shape[1] > 3:
|
515 |
+
# colorize with random projection
|
516 |
+
assert xrec.shape[1] > 3
|
517 |
+
# convert logits to indices
|
518 |
+
xrec = torch.argmax(xrec, dim=1, keepdim=True)
|
519 |
+
xrec = F.one_hot(xrec, num_classes=x.shape[1])
|
520 |
+
xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
|
521 |
+
x = self.to_rgb(x)
|
522 |
+
xrec = self.to_rgb(xrec)
|
523 |
+
|
524 |
+
img_cat = torch.cat([x, xrec], dim=3).detach()
|
525 |
+
img_cat = ((img_cat + 1) / 2)
|
526 |
+
img_cat = img_cat.clamp_(0, 1)
|
527 |
+
save_image(
|
528 |
+
img_cat, f'{save_dir}/{img_name}.png', nrow=1, padding=4)
|
529 |
+
|
530 |
+
return (loss_total / num).item()
|
531 |
+
|
532 |
+
def encode(self, x, mask):
|
533 |
+
h = self.encoder(x)
|
534 |
+
h = self.quant_conv(h)
|
535 |
+
quant, emb_loss, info = self.quantize(h, mask)
|
536 |
+
return quant, emb_loss, info
|
537 |
+
|
538 |
+
def decode(self, quant):
|
539 |
+
quant = self.post_quant_conv(quant)
|
540 |
+
dec = self.decoder(quant)
|
541 |
+
return dec
|
542 |
+
|
543 |
+
def decode_code(self, code_b):
|
544 |
+
quant_b = self.quantize.embed_code(code_b)
|
545 |
+
dec = self.decode(quant_b)
|
546 |
+
return dec
|
547 |
+
|
548 |
+
def forward_step(self, input, mask):
|
549 |
+
quant, diff, _ = self.encode(input, mask)
|
550 |
+
dec = self.decode(quant)
|
551 |
+
return dec, diff
|
Text2Human/sample_from_parsing.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os.path as osp
|
4 |
+
import random
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from data.segm_attr_dataset import DeepFashionAttrSegmDataset
|
9 |
+
from models import create_model
|
10 |
+
from utils.logger import get_root_logger
|
11 |
+
from utils.options import dict2str, dict_to_nonedict, parse
|
12 |
+
from utils.util import make_exp_dirs, set_random_seed
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
# options
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
19 |
+
args = parser.parse_args()
|
20 |
+
opt = parse(args.opt, is_train=False)
|
21 |
+
|
22 |
+
# mkdir and loggers
|
23 |
+
make_exp_dirs(opt)
|
24 |
+
log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log")
|
25 |
+
logger = get_root_logger(
|
26 |
+
logger_name='base', log_level=logging.INFO, log_file=log_file)
|
27 |
+
logger.info(dict2str(opt))
|
28 |
+
|
29 |
+
# convert to NoneDict, which returns None for missing keys
|
30 |
+
opt = dict_to_nonedict(opt)
|
31 |
+
|
32 |
+
# random seed
|
33 |
+
seed = opt['manual_seed']
|
34 |
+
if seed is None:
|
35 |
+
seed = random.randint(1, 10000)
|
36 |
+
logger.info(f'Random seed: {seed}')
|
37 |
+
set_random_seed(seed)
|
38 |
+
|
39 |
+
test_dataset = DeepFashionAttrSegmDataset(
|
40 |
+
img_dir=opt['test_img_dir'],
|
41 |
+
segm_dir=opt['segm_dir'],
|
42 |
+
pose_dir=opt['pose_dir'],
|
43 |
+
ann_dir=opt['test_ann_file'])
|
44 |
+
test_loader = torch.utils.data.DataLoader(
|
45 |
+
dataset=test_dataset, batch_size=4, shuffle=False)
|
46 |
+
logger.info(f'Number of test set: {len(test_dataset)}.')
|
47 |
+
|
48 |
+
model = create_model(opt)
|
49 |
+
_ = model.inference(test_loader, opt['path']['results_root'])
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
main()
|
Text2Human/sample_from_pose.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os.path as osp
|
4 |
+
import random
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from data.pose_attr_dataset import DeepFashionAttrPoseDataset
|
9 |
+
from models import create_model
|
10 |
+
from utils.logger import get_root_logger
|
11 |
+
from utils.options import dict2str, dict_to_nonedict, parse
|
12 |
+
from utils.util import make_exp_dirs, set_random_seed
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
# options
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
19 |
+
args = parser.parse_args()
|
20 |
+
opt = parse(args.opt, is_train=False)
|
21 |
+
|
22 |
+
# mkdir and loggers
|
23 |
+
make_exp_dirs(opt)
|
24 |
+
log_file = osp.join(opt['path']['log'], f"test_{opt['name']}.log")
|
25 |
+
logger = get_root_logger(
|
26 |
+
logger_name='base', log_level=logging.INFO, log_file=log_file)
|
27 |
+
logger.info(dict2str(opt))
|
28 |
+
|
29 |
+
# convert to NoneDict, which returns None for missing keys
|
30 |
+
opt = dict_to_nonedict(opt)
|
31 |
+
|
32 |
+
# random seed
|
33 |
+
seed = opt['manual_seed']
|
34 |
+
if seed is None:
|
35 |
+
seed = random.randint(1, 10000)
|
36 |
+
logger.info(f'Random seed: {seed}')
|
37 |
+
set_random_seed(seed)
|
38 |
+
|
39 |
+
test_dataset = DeepFashionAttrPoseDataset(
|
40 |
+
pose_dir=opt['pose_dir'],
|
41 |
+
texture_ann_dir=opt['texture_ann_file'],
|
42 |
+
shape_ann_path=opt['shape_ann_path'])
|
43 |
+
test_loader = torch.utils.data.DataLoader(
|
44 |
+
dataset=test_dataset, batch_size=4, shuffle=False)
|
45 |
+
logger.info(f'Number of test set: {len(test_dataset)}.')
|
46 |
+
|
47 |
+
model = create_model(opt)
|
48 |
+
_ = model.inference(test_loader, opt['path']['results_root'])
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
main()
|
Text2Human/train_index_prediction.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from data.segm_attr_dataset import DeepFashionAttrSegmDataset
|
11 |
+
from models import create_model
|
12 |
+
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
|
13 |
+
from utils.options import dict2str, dict_to_nonedict, parse
|
14 |
+
from utils.util import make_exp_dirs
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
# options
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
21 |
+
args = parser.parse_args()
|
22 |
+
opt = parse(args.opt, is_train=True)
|
23 |
+
|
24 |
+
# mkdir and loggers
|
25 |
+
make_exp_dirs(opt)
|
26 |
+
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
|
27 |
+
logger = get_root_logger(
|
28 |
+
logger_name='base', log_level=logging.INFO, log_file=log_file)
|
29 |
+
logger.info(dict2str(opt))
|
30 |
+
# initialize tensorboard logger
|
31 |
+
tb_logger = None
|
32 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
33 |
+
tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
|
34 |
+
|
35 |
+
# convert to NoneDict, which returns None for missing keys
|
36 |
+
opt = dict_to_nonedict(opt)
|
37 |
+
|
38 |
+
# set up data loader
|
39 |
+
train_dataset = DeepFashionAttrSegmDataset(
|
40 |
+
img_dir=opt['train_img_dir'],
|
41 |
+
segm_dir=opt['segm_dir'],
|
42 |
+
pose_dir=opt['pose_dir'],
|
43 |
+
ann_dir=opt['train_ann_file'],
|
44 |
+
xflip=True)
|
45 |
+
train_loader = torch.utils.data.DataLoader(
|
46 |
+
dataset=train_dataset,
|
47 |
+
batch_size=opt['batch_size'],
|
48 |
+
shuffle=True,
|
49 |
+
num_workers=opt['num_workers'],
|
50 |
+
drop_last=True)
|
51 |
+
logger.info(f'Number of train set: {len(train_dataset)}.')
|
52 |
+
opt['max_iters'] = opt['num_epochs'] * len(
|
53 |
+
train_dataset) // opt['batch_size']
|
54 |
+
|
55 |
+
val_dataset = DeepFashionAttrSegmDataset(
|
56 |
+
img_dir=opt['train_img_dir'],
|
57 |
+
segm_dir=opt['segm_dir'],
|
58 |
+
pose_dir=opt['pose_dir'],
|
59 |
+
ann_dir=opt['val_ann_file'])
|
60 |
+
val_loader = torch.utils.data.DataLoader(
|
61 |
+
dataset=val_dataset, batch_size=1, shuffle=False)
|
62 |
+
logger.info(f'Number of val set: {len(val_dataset)}.')
|
63 |
+
|
64 |
+
test_dataset = DeepFashionAttrSegmDataset(
|
65 |
+
img_dir=opt['test_img_dir'],
|
66 |
+
segm_dir=opt['segm_dir'],
|
67 |
+
pose_dir=opt['pose_dir'],
|
68 |
+
ann_dir=opt['test_ann_file'])
|
69 |
+
test_loader = torch.utils.data.DataLoader(
|
70 |
+
dataset=test_dataset, batch_size=1, shuffle=False)
|
71 |
+
logger.info(f'Number of test set: {len(test_dataset)}.')
|
72 |
+
|
73 |
+
current_iter = 0
|
74 |
+
best_epoch = None
|
75 |
+
best_acc = 0
|
76 |
+
|
77 |
+
model = create_model(opt)
|
78 |
+
|
79 |
+
data_time, iter_time = 0, 0
|
80 |
+
current_iter = 0
|
81 |
+
|
82 |
+
# create message logger (formatted outputs)
|
83 |
+
msg_logger = MessageLogger(opt, current_iter, tb_logger)
|
84 |
+
|
85 |
+
for epoch in range(opt['num_epochs']):
|
86 |
+
lr = model.update_learning_rate(epoch)
|
87 |
+
|
88 |
+
for _, batch_data in enumerate(train_loader):
|
89 |
+
data_time = time.time() - data_time
|
90 |
+
|
91 |
+
current_iter += 1
|
92 |
+
|
93 |
+
model.feed_data(batch_data)
|
94 |
+
model.optimize_parameters()
|
95 |
+
|
96 |
+
iter_time = time.time() - iter_time
|
97 |
+
if current_iter % opt['print_freq'] == 0:
|
98 |
+
log_vars = {'epoch': epoch, 'iter': current_iter}
|
99 |
+
log_vars.update({'lrs': [lr]})
|
100 |
+
log_vars.update({'time': iter_time, 'data_time': data_time})
|
101 |
+
log_vars.update(model.get_current_log())
|
102 |
+
msg_logger(log_vars)
|
103 |
+
|
104 |
+
data_time = time.time()
|
105 |
+
iter_time = time.time()
|
106 |
+
|
107 |
+
if epoch % opt['val_freq'] == 0:
|
108 |
+
save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
|
109 |
+
os.makedirs(save_dir, exist_ok=opt['debug'])
|
110 |
+
val_acc = model.inference(val_loader, save_dir)
|
111 |
+
|
112 |
+
save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
|
113 |
+
os.makedirs(save_dir, exist_ok=opt['debug'])
|
114 |
+
test_acc = model.inference(test_loader, save_dir)
|
115 |
+
|
116 |
+
logger.info(
|
117 |
+
f'Epoch: {epoch}, val_acc: {val_acc: .4f}, test_acc: {test_acc: .4f}.'
|
118 |
+
)
|
119 |
+
|
120 |
+
if test_acc > best_acc:
|
121 |
+
best_epoch = epoch
|
122 |
+
best_acc = test_acc
|
123 |
+
|
124 |
+
logger.info(f'Best epoch: {best_epoch}, '
|
125 |
+
f'Best test acc: {best_acc: .4f}.')
|
126 |
+
|
127 |
+
# save model
|
128 |
+
model.save_network(
|
129 |
+
f'{opt["path"]["models"]}/models_epoch{epoch}.pth')
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
main()
|
Text2Human/train_parsing_gen.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from data.parsing_generation_segm_attr_dataset import \
|
11 |
+
ParsingGenerationDeepFashionAttrSegmDataset
|
12 |
+
from models import create_model
|
13 |
+
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
|
14 |
+
from utils.options import dict2str, dict_to_nonedict, parse
|
15 |
+
from utils.util import make_exp_dirs
|
16 |
+
|
17 |
+
|
18 |
+
def main():
|
19 |
+
# options
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
22 |
+
args = parser.parse_args()
|
23 |
+
opt = parse(args.opt, is_train=True)
|
24 |
+
|
25 |
+
# mkdir and loggers
|
26 |
+
make_exp_dirs(opt)
|
27 |
+
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
|
28 |
+
logger = get_root_logger(
|
29 |
+
logger_name='base', log_level=logging.INFO, log_file=log_file)
|
30 |
+
logger.info(dict2str(opt))
|
31 |
+
# initialize tensorboard logger
|
32 |
+
tb_logger = None
|
33 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
34 |
+
tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
|
35 |
+
|
36 |
+
# convert to NoneDict, which returns None for missing keys
|
37 |
+
opt = dict_to_nonedict(opt)
|
38 |
+
|
39 |
+
# set up data loader
|
40 |
+
train_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
|
41 |
+
segm_dir=opt['segm_dir'],
|
42 |
+
pose_dir=opt['pose_dir'],
|
43 |
+
ann_file=opt['train_ann_file'])
|
44 |
+
train_loader = torch.utils.data.DataLoader(
|
45 |
+
dataset=train_dataset,
|
46 |
+
batch_size=opt['batch_size'],
|
47 |
+
shuffle=True,
|
48 |
+
num_workers=opt['num_workers'],
|
49 |
+
drop_last=True)
|
50 |
+
logger.info(f'Number of train set: {len(train_dataset)}.')
|
51 |
+
opt['max_iters'] = opt['num_epochs'] * len(
|
52 |
+
train_dataset) // opt['batch_size']
|
53 |
+
|
54 |
+
val_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
|
55 |
+
segm_dir=opt['segm_dir'],
|
56 |
+
pose_dir=opt['pose_dir'],
|
57 |
+
ann_file=opt['val_ann_file'])
|
58 |
+
val_loader = torch.utils.data.DataLoader(
|
59 |
+
dataset=val_dataset,
|
60 |
+
batch_size=1,
|
61 |
+
shuffle=False,
|
62 |
+
num_workers=opt['num_workers'])
|
63 |
+
logger.info(f'Number of val set: {len(val_dataset)}.')
|
64 |
+
|
65 |
+
test_dataset = ParsingGenerationDeepFashionAttrSegmDataset(
|
66 |
+
segm_dir=opt['segm_dir'],
|
67 |
+
pose_dir=opt['pose_dir'],
|
68 |
+
ann_file=opt['test_ann_file'])
|
69 |
+
test_loader = torch.utils.data.DataLoader(
|
70 |
+
dataset=test_dataset,
|
71 |
+
batch_size=1,
|
72 |
+
shuffle=False,
|
73 |
+
num_workers=opt['num_workers'])
|
74 |
+
logger.info(f'Number of test set: {len(test_dataset)}.')
|
75 |
+
|
76 |
+
current_iter = 0
|
77 |
+
best_epoch = None
|
78 |
+
best_acc = 0
|
79 |
+
|
80 |
+
model = create_model(opt)
|
81 |
+
|
82 |
+
data_time, iter_time = 0, 0
|
83 |
+
current_iter = 0
|
84 |
+
|
85 |
+
# create message logger (formatted outputs)
|
86 |
+
msg_logger = MessageLogger(opt, current_iter, tb_logger)
|
87 |
+
|
88 |
+
for epoch in range(opt['num_epochs']):
|
89 |
+
lr = model.update_learning_rate(epoch)
|
90 |
+
|
91 |
+
for _, batch_data in enumerate(train_loader):
|
92 |
+
data_time = time.time() - data_time
|
93 |
+
|
94 |
+
current_iter += 1
|
95 |
+
|
96 |
+
model.feed_data(batch_data)
|
97 |
+
model.optimize_parameters()
|
98 |
+
|
99 |
+
iter_time = time.time() - iter_time
|
100 |
+
if current_iter % opt['print_freq'] == 0:
|
101 |
+
log_vars = {'epoch': epoch, 'iter': current_iter}
|
102 |
+
log_vars.update({'lrs': [lr]})
|
103 |
+
log_vars.update({'time': iter_time, 'data_time': data_time})
|
104 |
+
log_vars.update(model.get_current_log())
|
105 |
+
msg_logger(log_vars)
|
106 |
+
|
107 |
+
data_time = time.time()
|
108 |
+
iter_time = time.time()
|
109 |
+
|
110 |
+
if epoch % opt['val_freq'] == 0:
|
111 |
+
save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}'
|
112 |
+
os.makedirs(save_dir, exist_ok=opt['debug'])
|
113 |
+
val_acc = model.inference(val_loader, save_dir)
|
114 |
+
|
115 |
+
save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}'
|
116 |
+
os.makedirs(save_dir, exist_ok=opt['debug'])
|
117 |
+
test_acc = model.inference(test_loader, save_dir)
|
118 |
+
|
119 |
+
logger.info(f'Epoch: {epoch}, '
|
120 |
+
f'val_acc: {val_acc: .4f}, '
|
121 |
+
f'test_acc: {test_acc: .4f}.')
|
122 |
+
|
123 |
+
if test_acc > best_acc:
|
124 |
+
best_epoch = epoch
|
125 |
+
best_acc = test_acc
|
126 |
+
|
127 |
+
logger.info(f'Best epoch: {best_epoch}, '
|
128 |
+
f'Best test acc: {best_acc: .4f}.')
|
129 |
+
|
130 |
+
# save model
|
131 |
+
model.save_network(
|
132 |
+
f'{opt["path"]["models"]}/parsing_generation_epoch{epoch}.pth')
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == '__main__':
|
136 |
+
main()
|
Text2Human/train_parsing_token.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from data.mask_dataset import MaskDataset
|
11 |
+
from models import create_model
|
12 |
+
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
|
13 |
+
from utils.options import dict2str, dict_to_nonedict, parse
|
14 |
+
from utils.util import make_exp_dirs
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
# options
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
21 |
+
args = parser.parse_args()
|
22 |
+
opt = parse(args.opt, is_train=True)
|
23 |
+
|
24 |
+
# mkdir and loggers
|
25 |
+
make_exp_dirs(opt)
|
26 |
+
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
|
27 |
+
logger = get_root_logger(
|
28 |
+
logger_name='base', log_level=logging.INFO, log_file=log_file)
|
29 |
+
logger.info(dict2str(opt))
|
30 |
+
# initialize tensorboard logger
|
31 |
+
tb_logger = None
|
32 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
33 |
+
tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
|
34 |
+
|
35 |
+
# convert to NoneDict, which returns None for missing keys
|
36 |
+
opt = dict_to_nonedict(opt)
|
37 |
+
|
38 |
+
# set up data loader
|
39 |
+
train_dataset = MaskDataset(
|
40 |
+
segm_dir=opt['segm_dir'], ann_dir=opt['train_ann_file'], xflip=True)
|
41 |
+
train_loader = torch.utils.data.DataLoader(
|
42 |
+
dataset=train_dataset,
|
43 |
+
batch_size=opt['batch_size'],
|
44 |
+
shuffle=True,
|
45 |
+
num_workers=opt['num_workers'],
|
46 |
+
persistent_workers=True,
|
47 |
+
drop_last=True)
|
48 |
+
logger.info(f'Number of train set: {len(train_dataset)}.')
|
49 |
+
opt['max_iters'] = opt['num_epochs'] * len(
|
50 |
+
train_dataset) // opt['batch_size']
|
51 |
+
|
52 |
+
val_dataset = MaskDataset(
|
53 |
+
segm_dir=opt['segm_dir'], ann_dir=opt['val_ann_file'])
|
54 |
+
val_loader = torch.utils.data.DataLoader(
|
55 |
+
dataset=val_dataset, batch_size=1, shuffle=False)
|
56 |
+
logger.info(f'Number of val set: {len(val_dataset)}.')
|
57 |
+
|
58 |
+
test_dataset = MaskDataset(
|
59 |
+
segm_dir=opt['segm_dir'], ann_dir=opt['test_ann_file'])
|
60 |
+
test_loader = torch.utils.data.DataLoader(
|
61 |
+
dataset=test_dataset, batch_size=1, shuffle=False)
|
62 |
+
logger.info(f'Number of test set: {len(test_dataset)}.')
|
63 |
+
|
64 |
+
current_iter = 0
|
65 |
+
best_epoch = None
|
66 |
+
best_loss = 100000
|
67 |
+
|
68 |
+
model = create_model(opt)
|
69 |
+
|
70 |
+
data_time, iter_time = 0, 0
|
71 |
+
current_iter = 0
|
72 |
+
|
73 |
+
# create message logger (formatted outputs)
|
74 |
+
msg_logger = MessageLogger(opt, current_iter, tb_logger)
|
75 |
+
|
76 |
+
for epoch in range(opt['num_epochs']):
|
77 |
+
lr = model.update_learning_rate(epoch)
|
78 |
+
|
79 |
+
for _, batch_data in enumerate(train_loader):
|
80 |
+
data_time = time.time() - data_time
|
81 |
+
|
82 |
+
current_iter += 1
|
83 |
+
|
84 |
+
model.optimize_parameters(batch_data, current_iter)
|
85 |
+
|
86 |
+
iter_time = time.time() - iter_time
|
87 |
+
if current_iter % opt['print_freq'] == 0:
|
88 |
+
log_vars = {'epoch': epoch, 'iter': current_iter}
|
89 |
+
log_vars.update({'lrs': [lr]})
|
90 |
+
log_vars.update({'time': iter_time, 'data_time': data_time})
|
91 |
+
log_vars.update(model.get_current_log())
|
92 |
+
msg_logger(log_vars)
|
93 |
+
|
94 |
+
data_time = time.time()
|
95 |
+
iter_time = time.time()
|
96 |
+
|
97 |
+
if epoch % opt['val_freq'] == 0:
|
98 |
+
save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
|
99 |
+
os.makedirs(save_dir, exist_ok=opt['debug'])
|
100 |
+
val_loss_total, _, _ = model.inference(val_loader, save_dir)
|
101 |
+
|
102 |
+
save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
|
103 |
+
os.makedirs(save_dir, exist_ok=opt['debug'])
|
104 |
+
test_loss_total, _, _ = model.inference(test_loader, save_dir)
|
105 |
+
|
106 |
+
logger.info(f'Epoch: {epoch}, '
|
107 |
+
f'val_loss_total: {val_loss_total}, '
|
108 |
+
f'test_loss_total: {test_loss_total}.')
|
109 |
+
|
110 |
+
if test_loss_total < best_loss:
|
111 |
+
best_epoch = epoch
|
112 |
+
best_loss = test_loss_total
|
113 |
+
|
114 |
+
logger.info(f'Best epoch: {best_epoch}, '
|
115 |
+
f'Best test loss: {best_loss: .4f}.')
|
116 |
+
|
117 |
+
# save model
|
118 |
+
model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == '__main__':
|
122 |
+
main()
|
Text2Human/train_sampler.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from data.segm_attr_dataset import DeepFashionAttrSegmDataset
|
11 |
+
from models import create_model
|
12 |
+
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
|
13 |
+
from utils.options import dict2str, dict_to_nonedict, parse
|
14 |
+
from utils.util import make_exp_dirs
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
# options
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
21 |
+
args = parser.parse_args()
|
22 |
+
opt = parse(args.opt, is_train=True)
|
23 |
+
|
24 |
+
# mkdir and loggers
|
25 |
+
make_exp_dirs(opt)
|
26 |
+
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
|
27 |
+
logger = get_root_logger(
|
28 |
+
logger_name='base', log_level=logging.INFO, log_file=log_file)
|
29 |
+
logger.info(dict2str(opt))
|
30 |
+
# initialize tensorboard logger
|
31 |
+
tb_logger = None
|
32 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
33 |
+
tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
|
34 |
+
|
35 |
+
# convert to NoneDict, which returns None for missing keys
|
36 |
+
opt = dict_to_nonedict(opt)
|
37 |
+
|
38 |
+
# set up data loader
|
39 |
+
train_dataset = DeepFashionAttrSegmDataset(
|
40 |
+
img_dir=opt['train_img_dir'],
|
41 |
+
segm_dir=opt['segm_dir'],
|
42 |
+
pose_dir=opt['pose_dir'],
|
43 |
+
ann_dir=opt['train_ann_file'],
|
44 |
+
xflip=True)
|
45 |
+
train_loader = torch.utils.data.DataLoader(
|
46 |
+
dataset=train_dataset,
|
47 |
+
batch_size=opt['batch_size'],
|
48 |
+
shuffle=True,
|
49 |
+
num_workers=opt['num_workers'],
|
50 |
+
persistent_workers=True,
|
51 |
+
drop_last=True)
|
52 |
+
logger.info(f'Number of train set: {len(train_dataset)}.')
|
53 |
+
opt['max_iters'] = opt['num_epochs'] * len(
|
54 |
+
train_dataset) // opt['batch_size']
|
55 |
+
|
56 |
+
val_dataset = DeepFashionAttrSegmDataset(
|
57 |
+
img_dir=opt['train_img_dir'],
|
58 |
+
segm_dir=opt['segm_dir'],
|
59 |
+
pose_dir=opt['pose_dir'],
|
60 |
+
ann_dir=opt['val_ann_file'])
|
61 |
+
val_loader = torch.utils.data.DataLoader(
|
62 |
+
dataset=val_dataset, batch_size=opt['batch_size'], shuffle=False)
|
63 |
+
logger.info(f'Number of val set: {len(val_dataset)}.')
|
64 |
+
|
65 |
+
test_dataset = DeepFashionAttrSegmDataset(
|
66 |
+
img_dir=opt['test_img_dir'],
|
67 |
+
segm_dir=opt['segm_dir'],
|
68 |
+
pose_dir=opt['pose_dir'],
|
69 |
+
ann_dir=opt['test_ann_file'])
|
70 |
+
test_loader = torch.utils.data.DataLoader(
|
71 |
+
dataset=test_dataset, batch_size=opt['batch_size'], shuffle=False)
|
72 |
+
logger.info(f'Number of test set: {len(test_dataset)}.')
|
73 |
+
|
74 |
+
current_iter = 0
|
75 |
+
|
76 |
+
model = create_model(opt)
|
77 |
+
|
78 |
+
data_time, iter_time = 0, 0
|
79 |
+
current_iter = 0
|
80 |
+
|
81 |
+
# create message logger (formatted outputs)
|
82 |
+
msg_logger = MessageLogger(opt, current_iter, tb_logger)
|
83 |
+
|
84 |
+
for epoch in range(opt['num_epochs']):
|
85 |
+
lr = model.update_learning_rate(epoch, current_iter)
|
86 |
+
|
87 |
+
for _, batch_data in enumerate(train_loader):
|
88 |
+
data_time = time.time() - data_time
|
89 |
+
|
90 |
+
current_iter += 1
|
91 |
+
|
92 |
+
model.feed_data(batch_data)
|
93 |
+
model.optimize_parameters()
|
94 |
+
|
95 |
+
iter_time = time.time() - iter_time
|
96 |
+
if current_iter % opt['print_freq'] == 0:
|
97 |
+
log_vars = {'epoch': epoch, 'iter': current_iter}
|
98 |
+
log_vars.update({'lrs': [lr]})
|
99 |
+
log_vars.update({'time': iter_time, 'data_time': data_time})
|
100 |
+
log_vars.update(model.get_current_log())
|
101 |
+
msg_logger(log_vars)
|
102 |
+
|
103 |
+
data_time = time.time()
|
104 |
+
iter_time = time.time()
|
105 |
+
|
106 |
+
if epoch % opt['val_freq'] == 0 and epoch != 0:
|
107 |
+
save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
|
108 |
+
os.makedirs(save_dir, exist_ok=opt['debug'])
|
109 |
+
model.inference(val_loader, save_dir)
|
110 |
+
|
111 |
+
save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
|
112 |
+
os.makedirs(save_dir, exist_ok=opt['debug'])
|
113 |
+
model.inference(test_loader, save_dir)
|
114 |
+
|
115 |
+
# save model
|
116 |
+
model.save_network(
|
117 |
+
model._denoise_fn,
|
118 |
+
f'{opt["path"]["models"]}/sampler_epoch{epoch}.pth')
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == '__main__':
|
122 |
+
main()
|
Text2Human/train_vqvae.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from data.segm_attr_dataset import DeepFashionAttrSegmDataset
|
11 |
+
from models import create_model
|
12 |
+
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
|
13 |
+
from utils.options import dict2str, dict_to_nonedict, parse
|
14 |
+
from utils.util import make_exp_dirs
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
# options
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
|
21 |
+
args = parser.parse_args()
|
22 |
+
opt = parse(args.opt, is_train=True)
|
23 |
+
|
24 |
+
# mkdir and loggers
|
25 |
+
make_exp_dirs(opt)
|
26 |
+
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
|
27 |
+
logger = get_root_logger(
|
28 |
+
logger_name='base', log_level=logging.INFO, log_file=log_file)
|
29 |
+
logger.info(dict2str(opt))
|
30 |
+
# initialize tensorboard logger
|
31 |
+
tb_logger = None
|
32 |
+
if opt['use_tb_logger'] and 'debug' not in opt['name']:
|
33 |
+
tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
|
34 |
+
|
35 |
+
# convert to NoneDict, which returns None for missing keys
|
36 |
+
opt = dict_to_nonedict(opt)
|
37 |
+
|
38 |
+
# set up data loader
|
39 |
+
train_dataset = DeepFashionAttrSegmDataset(
|
40 |
+
img_dir=opt['train_img_dir'],
|
41 |
+
segm_dir=opt['segm_dir'],
|
42 |
+
pose_dir=opt['pose_dir'],
|
43 |
+
ann_dir=opt['train_ann_file'],
|
44 |
+
xflip=True)
|
45 |
+
train_loader = torch.utils.data.DataLoader(
|
46 |
+
dataset=train_dataset,
|
47 |
+
batch_size=opt['batch_size'],
|
48 |
+
shuffle=True,
|
49 |
+
num_workers=opt['num_workers'],
|
50 |
+
persistent_workers=True,
|
51 |
+
drop_last=True)
|
52 |
+
logger.info(f'Number of train set: {len(train_dataset)}.')
|
53 |
+
opt['max_iters'] = opt['num_epochs'] * len(
|
54 |
+
train_dataset) // opt['batch_size']
|
55 |
+
|
56 |
+
val_dataset = DeepFashionAttrSegmDataset(
|
57 |
+
img_dir=opt['train_img_dir'],
|
58 |
+
segm_dir=opt['segm_dir'],
|
59 |
+
pose_dir=opt['pose_dir'],
|
60 |
+
ann_dir=opt['val_ann_file'])
|
61 |
+
val_loader = torch.utils.data.DataLoader(
|
62 |
+
dataset=val_dataset, batch_size=1, shuffle=False)
|
63 |
+
logger.info(f'Number of val set: {len(val_dataset)}.')
|
64 |
+
|
65 |
+
test_dataset = DeepFashionAttrSegmDataset(
|
66 |
+
img_dir=opt['test_img_dir'],
|
67 |
+
segm_dir=opt['segm_dir'],
|
68 |
+
pose_dir=opt['pose_dir'],
|
69 |
+
ann_dir=opt['test_ann_file'])
|
70 |
+
test_loader = torch.utils.data.DataLoader(
|
71 |
+
dataset=test_dataset, batch_size=1, shuffle=False)
|
72 |
+
logger.info(f'Number of test set: {len(test_dataset)}.')
|
73 |
+
|
74 |
+
current_iter = 0
|
75 |
+
best_epoch = None
|
76 |
+
best_loss = 100000
|
77 |
+
|
78 |
+
model = create_model(opt)
|
79 |
+
|
80 |
+
data_time, iter_time = 0, 0
|
81 |
+
current_iter = 0
|
82 |
+
|
83 |
+
# create message logger (formatted outputs)
|
84 |
+
msg_logger = MessageLogger(opt, current_iter, tb_logger)
|
85 |
+
|
86 |
+
for epoch in range(opt['num_epochs']):
|
87 |
+
lr = model.update_learning_rate(epoch)
|
88 |
+
|
89 |
+
for _, batch_data in enumerate(train_loader):
|
90 |
+
data_time = time.time() - data_time
|
91 |
+
|
92 |
+
current_iter += 1
|
93 |
+
|
94 |
+
model.optimize_parameters(batch_data, current_iter)
|
95 |
+
|
96 |
+
iter_time = time.time() - iter_time
|
97 |
+
if current_iter % opt['print_freq'] == 0:
|
98 |
+
log_vars = {'epoch': epoch, 'iter': current_iter}
|
99 |
+
log_vars.update({'lrs': [lr]})
|
100 |
+
log_vars.update({'time': iter_time, 'data_time': data_time})
|
101 |
+
log_vars.update(model.get_current_log())
|
102 |
+
msg_logger(log_vars)
|
103 |
+
|
104 |
+
data_time = time.time()
|
105 |
+
iter_time = time.time()
|
106 |
+
|
107 |
+
if epoch % opt['val_freq'] == 0:
|
108 |
+
save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
|
109 |
+
os.makedirs(save_dir, exist_ok=opt['debug'])
|
110 |
+
val_loss_total = model.inference(val_loader, save_dir)
|
111 |
+
|
112 |
+
save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
|
113 |
+
os.makedirs(save_dir, exist_ok=opt['debug'])
|
114 |
+
test_loss_total = model.inference(test_loader, save_dir)
|
115 |
+
|
116 |
+
logger.info(f'Epoch: {epoch}, '
|
117 |
+
f'val_loss_total: {val_loss_total}, '
|
118 |
+
f'test_loss_total: {test_loss_total}.')
|
119 |
+
|
120 |
+
if test_loss_total < best_loss:
|
121 |
+
best_epoch = epoch
|
122 |
+
best_loss = test_loss_total
|
123 |
+
|
124 |
+
logger.info(f'Best epoch: {best_epoch}, '
|
125 |
+
f'Best test loss: {best_loss: .4f}.')
|
126 |
+
|
127 |
+
# save model
|
128 |
+
model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == '__main__':
|
132 |
+
main()
|
Text2Human/ui/__init__.py
ADDED
File without changes
|
Text2Human/ui/mouse_event.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PyQt5.QtCore import *
|
5 |
+
from PyQt5.QtGui import *
|
6 |
+
from PyQt5.QtWidgets import *
|
7 |
+
|
8 |
+
color_list = [
|
9 |
+
QColor(0, 0, 0),
|
10 |
+
QColor(255, 250, 250),
|
11 |
+
QColor(220, 220, 220),
|
12 |
+
QColor(250, 235, 215),
|
13 |
+
QColor(255, 250, 205),
|
14 |
+
QColor(211, 211, 211),
|
15 |
+
QColor(70, 130, 180),
|
16 |
+
QColor(127, 255, 212),
|
17 |
+
QColor(0, 100, 0),
|
18 |
+
QColor(50, 205, 50),
|
19 |
+
QColor(255, 255, 0),
|
20 |
+
QColor(245, 222, 179),
|
21 |
+
QColor(255, 140, 0),
|
22 |
+
QColor(255, 0, 0),
|
23 |
+
QColor(16, 78, 139),
|
24 |
+
QColor(144, 238, 144),
|
25 |
+
QColor(50, 205, 174),
|
26 |
+
QColor(50, 155, 250),
|
27 |
+
QColor(160, 140, 88),
|
28 |
+
QColor(213, 140, 88),
|
29 |
+
QColor(90, 140, 90),
|
30 |
+
QColor(185, 210, 205),
|
31 |
+
QColor(130, 165, 180),
|
32 |
+
QColor(225, 141, 151)
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
class GraphicsScene(QGraphicsScene):
|
37 |
+
|
38 |
+
def __init__(self, mode, size, parent=None):
|
39 |
+
QGraphicsScene.__init__(self, parent)
|
40 |
+
self.mode = mode
|
41 |
+
self.size = size
|
42 |
+
self.mouse_clicked = False
|
43 |
+
self.prev_pt = None
|
44 |
+
|
45 |
+
# self.masked_image = None
|
46 |
+
|
47 |
+
# save the points
|
48 |
+
self.mask_points = []
|
49 |
+
for i in range(len(color_list)):
|
50 |
+
self.mask_points.append([])
|
51 |
+
|
52 |
+
# save the size of points
|
53 |
+
self.size_points = []
|
54 |
+
for i in range(len(color_list)):
|
55 |
+
self.size_points.append([])
|
56 |
+
|
57 |
+
# save the history of edit
|
58 |
+
self.history = []
|
59 |
+
|
60 |
+
def reset(self):
|
61 |
+
# save the points
|
62 |
+
self.mask_points = []
|
63 |
+
for i in range(len(color_list)):
|
64 |
+
self.mask_points.append([])
|
65 |
+
# save the size of points
|
66 |
+
self.size_points = []
|
67 |
+
for i in range(len(color_list)):
|
68 |
+
self.size_points.append([])
|
69 |
+
# save the history of edit
|
70 |
+
self.history = []
|
71 |
+
|
72 |
+
self.mode = 0
|
73 |
+
self.prev_pt = None
|
74 |
+
|
75 |
+
def mousePressEvent(self, event):
|
76 |
+
self.mouse_clicked = True
|
77 |
+
|
78 |
+
def mouseReleaseEvent(self, event):
|
79 |
+
self.prev_pt = None
|
80 |
+
self.mouse_clicked = False
|
81 |
+
|
82 |
+
def mouseMoveEvent(self, event): # drawing
|
83 |
+
if self.mouse_clicked:
|
84 |
+
if self.prev_pt:
|
85 |
+
self.drawMask(self.prev_pt, event.scenePos(),
|
86 |
+
color_list[self.mode], self.size)
|
87 |
+
pts = {}
|
88 |
+
pts['prev'] = (int(self.prev_pt.x()), int(self.prev_pt.y()))
|
89 |
+
pts['curr'] = (int(event.scenePos().x()),
|
90 |
+
int(event.scenePos().y()))
|
91 |
+
|
92 |
+
self.size_points[self.mode].append(self.size)
|
93 |
+
self.mask_points[self.mode].append(pts)
|
94 |
+
self.history.append(self.mode)
|
95 |
+
self.prev_pt = event.scenePos()
|
96 |
+
else:
|
97 |
+
self.prev_pt = event.scenePos()
|
98 |
+
|
99 |
+
def drawMask(self, prev_pt, curr_pt, color, size):
|
100 |
+
lineItem = QGraphicsLineItem(QLineF(prev_pt, curr_pt))
|
101 |
+
lineItem.setPen(QPen(color, size, Qt.SolidLine)) # rect
|
102 |
+
self.addItem(lineItem)
|
103 |
+
|
104 |
+
def erase_prev_pt(self):
|
105 |
+
self.prev_pt = None
|
106 |
+
|
107 |
+
def reset_items(self):
|
108 |
+
for i in range(len(self.items())):
|
109 |
+
item = self.items()[0]
|
110 |
+
self.removeItem(item)
|
111 |
+
|
112 |
+
def undo(self):
|
113 |
+
if len(self.items()) > 1:
|
114 |
+
if len(self.items()) >= 9:
|
115 |
+
for i in range(8):
|
116 |
+
item = self.items()[0]
|
117 |
+
self.removeItem(item)
|
118 |
+
if self.history[-1] == self.mode:
|
119 |
+
self.mask_points[self.mode].pop()
|
120 |
+
self.size_points[self.mode].pop()
|
121 |
+
self.history.pop()
|
122 |
+
else:
|
123 |
+
for i in range(len(self.items()) - 1):
|
124 |
+
item = self.items()[0]
|
125 |
+
self.removeItem(item)
|
126 |
+
if self.history[-1] == self.mode:
|
127 |
+
self.mask_points[self.mode].pop()
|
128 |
+
self.size_points[self.mode].pop()
|
129 |
+
self.history.pop()
|
Text2Human/ui/ui.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PyQt5 import QtCore, QtGui, QtWidgets
|
2 |
+
from PyQt5.QtCore import *
|
3 |
+
from PyQt5.QtGui import *
|
4 |
+
from PyQt5.QtWidgets import *
|
5 |
+
|
6 |
+
|
7 |
+
class Ui_Form(object):
|
8 |
+
|
9 |
+
def setupUi(self, Form):
|
10 |
+
Form.setObjectName("Form")
|
11 |
+
Form.resize(1250, 670)
|
12 |
+
|
13 |
+
self.pushButton_2 = QtWidgets.QPushButton(Form)
|
14 |
+
self.pushButton_2.setGeometry(QtCore.QRect(20, 60, 97, 27))
|
15 |
+
self.pushButton_2.setObjectName("pushButton_2")
|
16 |
+
|
17 |
+
self.pushButton_6 = QtWidgets.QPushButton(Form)
|
18 |
+
self.pushButton_6.setGeometry(QtCore.QRect(20, 100, 97, 27))
|
19 |
+
self.pushButton_6.setObjectName("pushButton_6")
|
20 |
+
|
21 |
+
# Generate Parsing
|
22 |
+
self.pushButton_0 = QtWidgets.QPushButton(Form)
|
23 |
+
self.pushButton_0.setGeometry(QtCore.QRect(126, 60, 150, 27))
|
24 |
+
self.pushButton_0.setObjectName("pushButton_0")
|
25 |
+
|
26 |
+
# Generate Human
|
27 |
+
self.pushButton_1 = QtWidgets.QPushButton(Form)
|
28 |
+
self.pushButton_1.setGeometry(QtCore.QRect(126, 100, 150, 27))
|
29 |
+
self.pushButton_1.setObjectName("pushButton_1")
|
30 |
+
|
31 |
+
# shape text box
|
32 |
+
self.label_heading_1 = QtWidgets.QLabel(Form)
|
33 |
+
self.label_heading_1.setText('Describe the shape.')
|
34 |
+
self.label_heading_1.setObjectName("label_heading_1")
|
35 |
+
self.label_heading_1.setGeometry(QtCore.QRect(320, 20, 200, 20))
|
36 |
+
|
37 |
+
self.message_box_1 = QtWidgets.QLineEdit(Form)
|
38 |
+
self.message_box_1.setGeometry(QtCore.QRect(320, 50, 256, 80))
|
39 |
+
self.message_box_1.setObjectName("message_box_1")
|
40 |
+
self.message_box_1.setAlignment(Qt.AlignTop)
|
41 |
+
|
42 |
+
# texture text box
|
43 |
+
self.label_heading_2 = QtWidgets.QLabel(Form)
|
44 |
+
self.label_heading_2.setText('Describe the textures.')
|
45 |
+
self.label_heading_2.setObjectName("label_heading_2")
|
46 |
+
self.label_heading_2.setGeometry(QtCore.QRect(620, 20, 200, 20))
|
47 |
+
|
48 |
+
self.message_box_2 = QtWidgets.QLineEdit(Form)
|
49 |
+
self.message_box_2.setGeometry(QtCore.QRect(620, 50, 256, 80))
|
50 |
+
self.message_box_2.setObjectName("message_box_2")
|
51 |
+
self.message_box_2.setAlignment(Qt.AlignTop)
|
52 |
+
|
53 |
+
# title icon
|
54 |
+
self.title_icon = QtWidgets.QLabel(Form)
|
55 |
+
self.title_icon.setGeometry(QtCore.QRect(30, 10, 200, 50))
|
56 |
+
self.title_icon.setPixmap(
|
57 |
+
QtGui.QPixmap('./ui/icons/icon_title.png').scaledToWidth(200))
|
58 |
+
|
59 |
+
# palette icon
|
60 |
+
self.palette_icon = QtWidgets.QLabel(Form)
|
61 |
+
self.palette_icon.setGeometry(QtCore.QRect(950, 10, 256, 128))
|
62 |
+
self.palette_icon.setPixmap(
|
63 |
+
QtGui.QPixmap('./ui/icons/icon_palette.png').scaledToWidth(256))
|
64 |
+
|
65 |
+
# top
|
66 |
+
self.pushButton_8 = QtWidgets.QPushButton(' top', Form)
|
67 |
+
self.pushButton_8.setGeometry(QtCore.QRect(940, 120, 120, 27))
|
68 |
+
self.pushButton_8.setObjectName("pushButton_8")
|
69 |
+
self.pushButton_8.setStyleSheet(
|
70 |
+
"text-align: left; padding-left: 10px;")
|
71 |
+
self.pushButton_8.setIcon(QIcon('./ui/color_blocks/class_top.png'))
|
72 |
+
# skin
|
73 |
+
self.pushButton_9 = QtWidgets.QPushButton(' skin', Form)
|
74 |
+
self.pushButton_9.setGeometry(QtCore.QRect(940, 165, 120, 27))
|
75 |
+
self.pushButton_9.setObjectName("pushButton_9")
|
76 |
+
self.pushButton_9.setStyleSheet(
|
77 |
+
"text-align: left; padding-left: 10px;")
|
78 |
+
self.pushButton_9.setIcon(QIcon('./ui/color_blocks/class_skin.png'))
|
79 |
+
# outer
|
80 |
+
self.pushButton_10 = QtWidgets.QPushButton(' outer', Form)
|
81 |
+
self.pushButton_10.setGeometry(QtCore.QRect(940, 210, 120, 27))
|
82 |
+
self.pushButton_10.setObjectName("pushButton_10")
|
83 |
+
self.pushButton_10.setStyleSheet(
|
84 |
+
"text-align: left; padding-left: 10px;")
|
85 |
+
self.pushButton_10.setIcon(QIcon('./ui/color_blocks/class_outer.png'))
|
86 |
+
# face
|
87 |
+
self.pushButton_11 = QtWidgets.QPushButton(' face', Form)
|
88 |
+
self.pushButton_11.setGeometry(QtCore.QRect(940, 255, 120, 27))
|
89 |
+
self.pushButton_11.setObjectName("pushButton_11")
|
90 |
+
self.pushButton_11.setStyleSheet(
|
91 |
+
"text-align: left; padding-left: 10px;")
|
92 |
+
self.pushButton_11.setIcon(QIcon('./ui/color_blocks/class_face.png'))
|
93 |
+
# skirt
|
94 |
+
self.pushButton_12 = QtWidgets.QPushButton(' skirt', Form)
|
95 |
+
self.pushButton_12.setGeometry(QtCore.QRect(940, 300, 120, 27))
|
96 |
+
self.pushButton_12.setObjectName("pushButton_12")
|
97 |
+
self.pushButton_12.setStyleSheet(
|
98 |
+
"text-align: left; padding-left: 10px;")
|
99 |
+
self.pushButton_12.setIcon(QIcon('./ui/color_blocks/class_skirt.png'))
|
100 |
+
# hair
|
101 |
+
self.pushButton_13 = QtWidgets.QPushButton(' hair', Form)
|
102 |
+
self.pushButton_13.setGeometry(QtCore.QRect(940, 345, 120, 27))
|
103 |
+
self.pushButton_13.setObjectName("pushButton_13")
|
104 |
+
self.pushButton_13.setStyleSheet(
|
105 |
+
"text-align: left; padding-left: 10px;")
|
106 |
+
self.pushButton_13.setIcon(QIcon('./ui/color_blocks/class_hair.png'))
|
107 |
+
# dress
|
108 |
+
self.pushButton_14 = QtWidgets.QPushButton(' dress', Form)
|
109 |
+
self.pushButton_14.setGeometry(QtCore.QRect(940, 390, 120, 27))
|
110 |
+
self.pushButton_14.setObjectName("pushButton_14")
|
111 |
+
self.pushButton_14.setStyleSheet(
|
112 |
+
"text-align: left; padding-left: 10px;")
|
113 |
+
self.pushButton_14.setIcon(QIcon('./ui/color_blocks/class_dress.png'))
|
114 |
+
# headwear
|
115 |
+
self.pushButton_15 = QtWidgets.QPushButton(' headwear', Form)
|
116 |
+
self.pushButton_15.setGeometry(QtCore.QRect(940, 435, 120, 27))
|
117 |
+
self.pushButton_15.setObjectName("pushButton_15")
|
118 |
+
self.pushButton_15.setStyleSheet(
|
119 |
+
"text-align: left; padding-left: 10px;")
|
120 |
+
self.pushButton_15.setIcon(
|
121 |
+
QIcon('./ui/color_blocks/class_headwear.png'))
|
122 |
+
# pants
|
123 |
+
self.pushButton_16 = QtWidgets.QPushButton(' pants', Form)
|
124 |
+
self.pushButton_16.setGeometry(QtCore.QRect(940, 480, 120, 27))
|
125 |
+
self.pushButton_16.setObjectName("pushButton_16")
|
126 |
+
self.pushButton_16.setStyleSheet(
|
127 |
+
"text-align: left; padding-left: 10px;")
|
128 |
+
self.pushButton_16.setIcon(QIcon('./ui/color_blocks/class_pants.png'))
|
129 |
+
# eyeglasses
|
130 |
+
self.pushButton_17 = QtWidgets.QPushButton(' eyeglass', Form)
|
131 |
+
self.pushButton_17.setGeometry(QtCore.QRect(940, 525, 120, 27))
|
132 |
+
self.pushButton_17.setObjectName("pushButton_17")
|
133 |
+
self.pushButton_17.setStyleSheet(
|
134 |
+
"text-align: left; padding-left: 10px;")
|
135 |
+
self.pushButton_17.setIcon(
|
136 |
+
QIcon('./ui/color_blocks/class_eyeglass.png'))
|
137 |
+
# rompers
|
138 |
+
self.pushButton_18 = QtWidgets.QPushButton(' rompers', Form)
|
139 |
+
self.pushButton_18.setGeometry(QtCore.QRect(940, 570, 120, 27))
|
140 |
+
self.pushButton_18.setObjectName("pushButton_18")
|
141 |
+
self.pushButton_18.setStyleSheet(
|
142 |
+
"text-align: left; padding-left: 10px;")
|
143 |
+
self.pushButton_18.setIcon(
|
144 |
+
QIcon('./ui/color_blocks/class_rompers.png'))
|
145 |
+
# footwear
|
146 |
+
self.pushButton_19 = QtWidgets.QPushButton(' footwear', Form)
|
147 |
+
self.pushButton_19.setGeometry(QtCore.QRect(940, 615, 120, 27))
|
148 |
+
self.pushButton_19.setObjectName("pushButton_19")
|
149 |
+
self.pushButton_19.setStyleSheet(
|
150 |
+
"text-align: left; padding-left: 10px;")
|
151 |
+
self.pushButton_19.setIcon(
|
152 |
+
QIcon('./ui/color_blocks/class_footwear.png'))
|
153 |
+
|
154 |
+
# leggings
|
155 |
+
self.pushButton_20 = QtWidgets.QPushButton(' leggings', Form)
|
156 |
+
self.pushButton_20.setGeometry(QtCore.QRect(1100, 120, 120, 27))
|
157 |
+
self.pushButton_20.setObjectName("pushButton_10")
|
158 |
+
self.pushButton_20.setStyleSheet(
|
159 |
+
"text-align: left; padding-left: 10px;")
|
160 |
+
self.pushButton_20.setIcon(
|
161 |
+
QIcon('./ui/color_blocks/class_leggings.png'))
|
162 |
+
|
163 |
+
# ring
|
164 |
+
self.pushButton_21 = QtWidgets.QPushButton(' ring', Form)
|
165 |
+
self.pushButton_21.setGeometry(QtCore.QRect(1100, 165, 120, 27))
|
166 |
+
self.pushButton_21.setObjectName("pushButton_2`0`")
|
167 |
+
self.pushButton_21.setStyleSheet(
|
168 |
+
"text-align: left; padding-left: 10px;")
|
169 |
+
self.pushButton_21.setIcon(QIcon('./ui/color_blocks/class_ring.png'))
|
170 |
+
|
171 |
+
# belt
|
172 |
+
self.pushButton_22 = QtWidgets.QPushButton(' belt', Form)
|
173 |
+
self.pushButton_22.setGeometry(QtCore.QRect(1100, 210, 120, 27))
|
174 |
+
self.pushButton_22.setObjectName("pushButton_2`0`")
|
175 |
+
self.pushButton_22.setStyleSheet(
|
176 |
+
"text-align: left; padding-left: 10px;")
|
177 |
+
self.pushButton_22.setIcon(QIcon('./ui/color_blocks/class_belt.png'))
|
178 |
+
|
179 |
+
# neckwear
|
180 |
+
self.pushButton_23 = QtWidgets.QPushButton(' neckwear', Form)
|
181 |
+
self.pushButton_23.setGeometry(QtCore.QRect(1100, 255, 120, 27))
|
182 |
+
self.pushButton_23.setObjectName("pushButton_2`0`")
|
183 |
+
self.pushButton_23.setStyleSheet(
|
184 |
+
"text-align: left; padding-left: 10px;")
|
185 |
+
self.pushButton_23.setIcon(
|
186 |
+
QIcon('./ui/color_blocks/class_neckwear.png'))
|
187 |
+
|
188 |
+
# wrist
|
189 |
+
self.pushButton_24 = QtWidgets.QPushButton(' wrist', Form)
|
190 |
+
self.pushButton_24.setGeometry(QtCore.QRect(1100, 300, 120, 27))
|
191 |
+
self.pushButton_24.setObjectName("pushButton_2`0`")
|
192 |
+
self.pushButton_24.setStyleSheet(
|
193 |
+
"text-align: left; padding-left: 10px;")
|
194 |
+
self.pushButton_24.setIcon(QIcon('./ui/color_blocks/class_wrist.png'))
|
195 |
+
|
196 |
+
# socks
|
197 |
+
self.pushButton_25 = QtWidgets.QPushButton(' socks', Form)
|
198 |
+
self.pushButton_25.setGeometry(QtCore.QRect(1100, 345, 120, 27))
|
199 |
+
self.pushButton_25.setObjectName("pushButton_2`0`")
|
200 |
+
self.pushButton_25.setStyleSheet(
|
201 |
+
"text-align: left; padding-left: 10px;")
|
202 |
+
self.pushButton_25.setIcon(QIcon('./ui/color_blocks/class_socks.png'))
|
203 |
+
|
204 |
+
# tie
|
205 |
+
self.pushButton_26 = QtWidgets.QPushButton(' tie', Form)
|
206 |
+
self.pushButton_26.setGeometry(QtCore.QRect(1100, 390, 120, 27))
|
207 |
+
self.pushButton_26.setObjectName("pushButton_2`0`")
|
208 |
+
self.pushButton_26.setStyleSheet(
|
209 |
+
"text-align: left; padding-left: 10px;")
|
210 |
+
self.pushButton_26.setIcon(QIcon('./ui/color_blocks/class_tie.png'))
|
211 |
+
|
212 |
+
# earstuds
|
213 |
+
self.pushButton_27 = QtWidgets.QPushButton(' necklace', Form)
|
214 |
+
self.pushButton_27.setGeometry(QtCore.QRect(1100, 435, 120, 27))
|
215 |
+
self.pushButton_27.setObjectName("pushButton_2`0`")
|
216 |
+
self.pushButton_27.setStyleSheet(
|
217 |
+
"text-align: left; padding-left: 10px;")
|
218 |
+
self.pushButton_27.setIcon(
|
219 |
+
QIcon('./ui/color_blocks/class_necklace.png'))
|
220 |
+
|
221 |
+
# necklace
|
222 |
+
self.pushButton_28 = QtWidgets.QPushButton(' earstuds', Form)
|
223 |
+
self.pushButton_28.setGeometry(QtCore.QRect(1100, 480, 120, 27))
|
224 |
+
self.pushButton_28.setObjectName("pushButton_2`0`")
|
225 |
+
self.pushButton_28.setStyleSheet(
|
226 |
+
"text-align: left; padding-left: 10px;")
|
227 |
+
self.pushButton_28.setIcon(
|
228 |
+
QIcon('./ui/color_blocks/class_earstuds.png'))
|
229 |
+
|
230 |
+
# bag
|
231 |
+
self.pushButton_29 = QtWidgets.QPushButton(' bag', Form)
|
232 |
+
self.pushButton_29.setGeometry(QtCore.QRect(1100, 525, 120, 27))
|
233 |
+
self.pushButton_29.setObjectName("pushButton_2`0`")
|
234 |
+
self.pushButton_29.setStyleSheet(
|
235 |
+
"text-align: left; padding-left: 10px;")
|
236 |
+
self.pushButton_29.setIcon(QIcon('./ui/color_blocks/class_bag.png'))
|
237 |
+
|
238 |
+
# glove
|
239 |
+
self.pushButton_30 = QtWidgets.QPushButton(' glove', Form)
|
240 |
+
self.pushButton_30.setGeometry(QtCore.QRect(1100, 570, 120, 27))
|
241 |
+
self.pushButton_30.setObjectName("pushButton_2`0`")
|
242 |
+
self.pushButton_30.setStyleSheet(
|
243 |
+
"text-align: left; padding-left: 10px;")
|
244 |
+
self.pushButton_30.setIcon(QIcon('./ui/color_blocks/class_glove.png'))
|
245 |
+
|
246 |
+
# background
|
247 |
+
self.pushButton_31 = QtWidgets.QPushButton(' background', Form)
|
248 |
+
self.pushButton_31.setGeometry(QtCore.QRect(1100, 615, 120, 27))
|
249 |
+
self.pushButton_31.setObjectName("pushButton_2`0`")
|
250 |
+
self.pushButton_31.setStyleSheet(
|
251 |
+
"text-align: left; padding-left: 10px;")
|
252 |
+
self.pushButton_31.setIcon(QIcon('./ui/color_blocks/class_bg.png'))
|
253 |
+
|
254 |
+
self.graphicsView = QtWidgets.QGraphicsView(Form)
|
255 |
+
self.graphicsView.setGeometry(QtCore.QRect(20, 140, 256, 512))
|
256 |
+
self.graphicsView.setObjectName("graphicsView")
|
257 |
+
self.graphicsView_2 = QtWidgets.QGraphicsView(Form)
|
258 |
+
self.graphicsView_2.setGeometry(QtCore.QRect(320, 140, 256, 512))
|
259 |
+
self.graphicsView_2.setObjectName("graphicsView_2")
|
260 |
+
self.graphicsView_3 = QtWidgets.QGraphicsView(Form)
|
261 |
+
self.graphicsView_3.setGeometry(QtCore.QRect(620, 140, 256, 512))
|
262 |
+
self.graphicsView_3.setObjectName("graphicsView_3")
|
263 |
+
|
264 |
+
self.retranslateUi(Form)
|
265 |
+
self.pushButton_2.clicked.connect(Form.open_densepose)
|
266 |
+
self.pushButton_6.clicked.connect(Form.save_img)
|
267 |
+
self.pushButton_8.clicked.connect(Form.top_mode)
|
268 |
+
self.pushButton_9.clicked.connect(Form.skin_mode)
|
269 |
+
self.pushButton_10.clicked.connect(Form.outer_mode)
|
270 |
+
self.pushButton_11.clicked.connect(Form.face_mode)
|
271 |
+
self.pushButton_12.clicked.connect(Form.skirt_mode)
|
272 |
+
self.pushButton_13.clicked.connect(Form.hair_mode)
|
273 |
+
self.pushButton_14.clicked.connect(Form.dress_mode)
|
274 |
+
self.pushButton_15.clicked.connect(Form.headwear_mode)
|
275 |
+
self.pushButton_16.clicked.connect(Form.pants_mode)
|
276 |
+
self.pushButton_17.clicked.connect(Form.eyeglass_mode)
|
277 |
+
self.pushButton_18.clicked.connect(Form.rompers_mode)
|
278 |
+
self.pushButton_19.clicked.connect(Form.footwear_mode)
|
279 |
+
self.pushButton_20.clicked.connect(Form.leggings_mode)
|
280 |
+
self.pushButton_21.clicked.connect(Form.ring_mode)
|
281 |
+
self.pushButton_22.clicked.connect(Form.belt_mode)
|
282 |
+
self.pushButton_23.clicked.connect(Form.neckwear_mode)
|
283 |
+
self.pushButton_24.clicked.connect(Form.wrist_mode)
|
284 |
+
self.pushButton_25.clicked.connect(Form.socks_mode)
|
285 |
+
self.pushButton_26.clicked.connect(Form.tie_mode)
|
286 |
+
self.pushButton_27.clicked.connect(Form.earstuds_mode)
|
287 |
+
self.pushButton_28.clicked.connect(Form.necklace_mode)
|
288 |
+
self.pushButton_29.clicked.connect(Form.bag_mode)
|
289 |
+
self.pushButton_30.clicked.connect(Form.glove_mode)
|
290 |
+
self.pushButton_31.clicked.connect(Form.background_mode)
|
291 |
+
self.pushButton_0.clicked.connect(Form.generate_parsing)
|
292 |
+
self.pushButton_1.clicked.connect(Form.generate_human)
|
293 |
+
|
294 |
+
QtCore.QMetaObject.connectSlotsByName(Form)
|
295 |
+
|
296 |
+
def retranslateUi(self, Form):
|
297 |
+
_translate = QtCore.QCoreApplication.translate
|
298 |
+
Form.setWindowTitle(_translate("Form", "Text2Human"))
|
299 |
+
self.pushButton_2.setText(_translate("Form", "Load Pose"))
|
300 |
+
self.pushButton_6.setText(_translate("Form", "Save Image"))
|
301 |
+
|
302 |
+
self.pushButton_0.setText(_translate("Form", "Generate Parsing"))
|
303 |
+
self.pushButton_1.setText(_translate("Form", "Generate Human"))
|
304 |
+
|
305 |
+
|
306 |
+
if __name__ == "__main__":
|
307 |
+
import sys
|
308 |
+
app = QtWidgets.QApplication(sys.argv)
|
309 |
+
Form = QtWidgets.QWidget()
|
310 |
+
ui = Ui_Form()
|
311 |
+
ui.setupUi(Form)
|
312 |
+
Form.show()
|
313 |
+
sys.exit(app.exec_())
|
Text2Human/ui_demo.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from PyQt5.QtCore import *
|
8 |
+
from PyQt5.QtGui import *
|
9 |
+
from PyQt5.QtWidgets import *
|
10 |
+
|
11 |
+
from models.sample_model import SampleFromPoseModel
|
12 |
+
from ui.mouse_event import GraphicsScene
|
13 |
+
from ui.ui import Ui_Form
|
14 |
+
from utils.language_utils import (generate_shape_attributes,
|
15 |
+
generate_texture_attributes)
|
16 |
+
from utils.options import dict_to_nonedict, parse
|
17 |
+
|
18 |
+
color_list = [(0, 0, 0), (255, 250, 250), (220, 220, 220), (250, 235, 215),
|
19 |
+
(255, 250, 205), (211, 211, 211), (70, 130, 180),
|
20 |
+
(127, 255, 212), (0, 100, 0), (50, 205, 50), (255, 255, 0),
|
21 |
+
(245, 222, 179), (255, 140, 0), (255, 0, 0), (16, 78, 139),
|
22 |
+
(144, 238, 144), (50, 205, 174), (50, 155, 250), (160, 140, 88),
|
23 |
+
(213, 140, 88), (90, 140, 90), (185, 210, 205), (130, 165, 180),
|
24 |
+
(225, 141, 151)]
|
25 |
+
|
26 |
+
|
27 |
+
class Ex(QWidget, Ui_Form):
|
28 |
+
|
29 |
+
def __init__(self, opt):
|
30 |
+
super(Ex, self).__init__()
|
31 |
+
self.setupUi(self)
|
32 |
+
self.show()
|
33 |
+
|
34 |
+
self.output_img = None
|
35 |
+
|
36 |
+
self.mat_img = None
|
37 |
+
|
38 |
+
self.mode = 0
|
39 |
+
self.size = 6
|
40 |
+
self.mask = None
|
41 |
+
self.mask_m = None
|
42 |
+
self.img = None
|
43 |
+
|
44 |
+
# about UI
|
45 |
+
self.mouse_clicked = False
|
46 |
+
self.scene = QGraphicsScene()
|
47 |
+
self.graphicsView.setScene(self.scene)
|
48 |
+
self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
|
49 |
+
self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
50 |
+
self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
51 |
+
|
52 |
+
self.ref_scene = GraphicsScene(self.mode, self.size)
|
53 |
+
self.graphicsView_2.setScene(self.ref_scene)
|
54 |
+
self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft)
|
55 |
+
self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
56 |
+
self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
57 |
+
|
58 |
+
self.result_scene = QGraphicsScene()
|
59 |
+
self.graphicsView_3.setScene(self.result_scene)
|
60 |
+
self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft)
|
61 |
+
self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
62 |
+
self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
63 |
+
|
64 |
+
self.dlg = QColorDialog(self.graphicsView)
|
65 |
+
self.color = None
|
66 |
+
|
67 |
+
self.sample_model = SampleFromPoseModel(opt)
|
68 |
+
|
69 |
+
def open_densepose(self):
|
70 |
+
fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
|
71 |
+
QDir.currentPath())
|
72 |
+
if fileName:
|
73 |
+
image = QPixmap(fileName)
|
74 |
+
mat_img = Image.open(fileName)
|
75 |
+
self.pose_img = mat_img.copy()
|
76 |
+
if image.isNull():
|
77 |
+
QMessageBox.information(self, "Image Viewer",
|
78 |
+
"Cannot load %s." % fileName)
|
79 |
+
return
|
80 |
+
image = image.scaled(self.graphicsView.size(),
|
81 |
+
Qt.IgnoreAspectRatio)
|
82 |
+
|
83 |
+
if len(self.scene.items()) > 0:
|
84 |
+
self.scene.removeItem(self.scene.items()[-1])
|
85 |
+
self.scene.addPixmap(image)
|
86 |
+
|
87 |
+
self.ref_scene.clear()
|
88 |
+
self.result_scene.clear()
|
89 |
+
|
90 |
+
# load pose to model
|
91 |
+
self.pose_img = np.array(
|
92 |
+
self.pose_img.resize(
|
93 |
+
size=(256, 512),
|
94 |
+
resample=Image.LANCZOS))[:, :, 2:].transpose(
|
95 |
+
2, 0, 1).astype(np.float32)
|
96 |
+
self.pose_img = self.pose_img / 12. - 1
|
97 |
+
|
98 |
+
self.pose_img = torch.from_numpy(self.pose_img).unsqueeze(1)
|
99 |
+
|
100 |
+
self.sample_model.feed_pose_data(self.pose_img)
|
101 |
+
|
102 |
+
def generate_parsing(self):
|
103 |
+
self.ref_scene.reset_items()
|
104 |
+
self.ref_scene.reset()
|
105 |
+
|
106 |
+
shape_texts = self.message_box_1.text()
|
107 |
+
|
108 |
+
shape_attributes = generate_shape_attributes(shape_texts)
|
109 |
+
shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
|
110 |
+
self.sample_model.feed_shape_attributes(shape_attributes)
|
111 |
+
|
112 |
+
self.sample_model.generate_parsing_map()
|
113 |
+
self.sample_model.generate_quantized_segm()
|
114 |
+
|
115 |
+
self.colored_segm = self.sample_model.palette_result(
|
116 |
+
self.sample_model.segm[0].cpu())
|
117 |
+
|
118 |
+
self.mask_m = cv2.cvtColor(
|
119 |
+
cv2.cvtColor(self.colored_segm, cv2.COLOR_RGB2BGR),
|
120 |
+
cv2.COLOR_BGR2RGB)
|
121 |
+
|
122 |
+
qim = QImage(self.colored_segm.data.tobytes(),
|
123 |
+
self.colored_segm.shape[1], self.colored_segm.shape[0],
|
124 |
+
QImage.Format_RGB888)
|
125 |
+
|
126 |
+
image = QPixmap.fromImage(qim)
|
127 |
+
|
128 |
+
image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
|
129 |
+
|
130 |
+
if len(self.ref_scene.items()) > 0:
|
131 |
+
self.ref_scene.removeItem(self.ref_scene.items()[-1])
|
132 |
+
self.ref_scene.addPixmap(image)
|
133 |
+
|
134 |
+
self.result_scene.clear()
|
135 |
+
|
136 |
+
def generate_human(self):
|
137 |
+
for i in range(24):
|
138 |
+
self.mask_m = self.make_mask(self.mask_m,
|
139 |
+
self.ref_scene.mask_points[i],
|
140 |
+
self.ref_scene.size_points[i],
|
141 |
+
color_list[i])
|
142 |
+
|
143 |
+
seg_map = np.full(self.mask_m.shape[:-1], -1)
|
144 |
+
|
145 |
+
# convert rgb to num
|
146 |
+
for index, color in enumerate(color_list):
|
147 |
+
seg_map[np.sum(self.mask_m == color, axis=2) == 3] = index
|
148 |
+
assert (seg_map != -1).all()
|
149 |
+
|
150 |
+
self.sample_model.segm = torch.from_numpy(seg_map).unsqueeze(
|
151 |
+
0).unsqueeze(0).to(self.sample_model.device)
|
152 |
+
self.sample_model.generate_quantized_segm()
|
153 |
+
|
154 |
+
texture_texts = self.message_box_2.text()
|
155 |
+
texture_attributes = generate_texture_attributes(texture_texts)
|
156 |
+
|
157 |
+
texture_attributes = torch.LongTensor(texture_attributes)
|
158 |
+
|
159 |
+
self.sample_model.feed_texture_attributes(texture_attributes)
|
160 |
+
|
161 |
+
self.sample_model.generate_texture_map()
|
162 |
+
result = self.sample_model.sample_and_refine()
|
163 |
+
result = result.permute(0, 2, 3, 1)
|
164 |
+
result = result.detach().cpu().numpy()
|
165 |
+
result = result * 255
|
166 |
+
|
167 |
+
result = np.asarray(result[0, :, :, :], dtype=np.uint8)
|
168 |
+
|
169 |
+
self.output_img = result
|
170 |
+
|
171 |
+
qim = QImage(result.data.tobytes(), result.shape[1], result.shape[0],
|
172 |
+
QImage.Format_RGB888)
|
173 |
+
image = QPixmap.fromImage(qim)
|
174 |
+
|
175 |
+
image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
|
176 |
+
|
177 |
+
if len(self.result_scene.items()) > 0:
|
178 |
+
self.result_scene.removeItem(self.result_scene.items()[-1])
|
179 |
+
self.result_scene.addPixmap(image)
|
180 |
+
|
181 |
+
def top_mode(self):
|
182 |
+
self.ref_scene.mode = 1
|
183 |
+
|
184 |
+
def skin_mode(self):
|
185 |
+
self.ref_scene.mode = 15
|
186 |
+
|
187 |
+
def outer_mode(self):
|
188 |
+
self.ref_scene.mode = 2
|
189 |
+
|
190 |
+
def face_mode(self):
|
191 |
+
self.ref_scene.mode = 14
|
192 |
+
|
193 |
+
def skirt_mode(self):
|
194 |
+
self.ref_scene.mode = 3
|
195 |
+
|
196 |
+
def hair_mode(self):
|
197 |
+
self.ref_scene.mode = 13
|
198 |
+
|
199 |
+
def dress_mode(self):
|
200 |
+
self.ref_scene.mode = 4
|
201 |
+
|
202 |
+
def headwear_mode(self):
|
203 |
+
self.ref_scene.mode = 7
|
204 |
+
|
205 |
+
def pants_mode(self):
|
206 |
+
self.ref_scene.mode = 5
|
207 |
+
|
208 |
+
def eyeglass_mode(self):
|
209 |
+
self.ref_scene.mode = 8
|
210 |
+
|
211 |
+
def rompers_mode(self):
|
212 |
+
self.ref_scene.mode = 21
|
213 |
+
|
214 |
+
def footwear_mode(self):
|
215 |
+
self.ref_scene.mode = 11
|
216 |
+
|
217 |
+
def leggings_mode(self):
|
218 |
+
self.ref_scene.mode = 6
|
219 |
+
|
220 |
+
def ring_mode(self):
|
221 |
+
self.ref_scene.mode = 16
|
222 |
+
|
223 |
+
def belt_mode(self):
|
224 |
+
self.ref_scene.mode = 10
|
225 |
+
|
226 |
+
def neckwear_mode(self):
|
227 |
+
self.ref_scene.mode = 9
|
228 |
+
|
229 |
+
def wrist_mode(self):
|
230 |
+
self.ref_scene.mode = 17
|
231 |
+
|
232 |
+
def socks_mode(self):
|
233 |
+
self.ref_scene.mode = 18
|
234 |
+
|
235 |
+
def tie_mode(self):
|
236 |
+
self.ref_scene.mode = 23
|
237 |
+
|
238 |
+
def earstuds_mode(self):
|
239 |
+
self.ref_scene.mode = 22
|
240 |
+
|
241 |
+
def necklace_mode(self):
|
242 |
+
self.ref_scene.mode = 20
|
243 |
+
|
244 |
+
def bag_mode(self):
|
245 |
+
self.ref_scene.mode = 12
|
246 |
+
|
247 |
+
def glove_mode(self):
|
248 |
+
self.ref_scene.mode = 19
|
249 |
+
|
250 |
+
def background_mode(self):
|
251 |
+
self.ref_scene.mode = 0
|
252 |
+
|
253 |
+
def make_mask(self, mask, pts, sizes, color):
|
254 |
+
if len(pts) > 0:
|
255 |
+
for idx, pt in enumerate(pts):
|
256 |
+
cv2.line(mask, pt['prev'], pt['curr'], color, sizes[idx])
|
257 |
+
return mask
|
258 |
+
|
259 |
+
def save_img(self):
|
260 |
+
if type(self.output_img):
|
261 |
+
fileName, _ = QFileDialog.getSaveFileName(self, "Save File",
|
262 |
+
QDir.currentPath())
|
263 |
+
cv2.imwrite(fileName + '.png', self.output_img[:, :, ::-1])
|
264 |
+
|
265 |
+
def undo(self):
|
266 |
+
self.scene.undo()
|
267 |
+
|
268 |
+
def clear(self):
|
269 |
+
|
270 |
+
self.ref_scene.reset_items()
|
271 |
+
self.ref_scene.reset()
|
272 |
+
|
273 |
+
self.ref_scene.clear()
|
274 |
+
|
275 |
+
self.result_scene.clear()
|
276 |
+
|
277 |
+
|
278 |
+
if __name__ == '__main__':
|
279 |
+
|
280 |
+
app = QApplication(sys.argv)
|
281 |
+
opt = './configs/sample_from_pose.yml'
|
282 |
+
opt = parse(opt, is_train=False)
|
283 |
+
opt = dict_to_nonedict(opt)
|
284 |
+
ex = Ex(opt)
|
285 |
+
sys.exit(app.exec_())
|
Text2Human/ui_util/__init__.py
ADDED
File without changes
|
Text2Human/ui_util/config.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
|
5 |
+
import yaml
|
6 |
+
|
7 |
+
logger = logging.getLogger()
|
8 |
+
|
9 |
+
class Config(object):
|
10 |
+
def __init__(self, filename=None):
|
11 |
+
assert os.path.exists(filename), "ERROR: Config File doesn't exist."
|
12 |
+
try:
|
13 |
+
with open(filename, 'r') as f:
|
14 |
+
self._cfg_dict = yaml.load(f)
|
15 |
+
# parent of IOError, OSError *and* WindowsError where available
|
16 |
+
except EnvironmentError:
|
17 |
+
logger.error('Please check the file with name of "%s"', filename)
|
18 |
+
logger.info(' APP CONFIG '.center(80, '-'))
|
19 |
+
logger.info(''.center(80, '-'))
|
20 |
+
|
21 |
+
def __getattr__(self, name):
|
22 |
+
value = self._cfg_dict[name]
|
23 |
+
if isinstance(value, dict):
|
24 |
+
value = DictAsMember(value)
|
25 |
+
return value
|
Text2Human/utils/__init__.py
ADDED
File without changes
|
Text2Human/utils/language_utils.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from curses import A_ATTRIBUTES
|
2 |
+
|
3 |
+
import numpy
|
4 |
+
import torch
|
5 |
+
from pip import main
|
6 |
+
from sentence_transformers import SentenceTransformer, util
|
7 |
+
|
8 |
+
# predefined shape text
|
9 |
+
upper_length_text = [
|
10 |
+
'sleeveless', 'without sleeves', 'sleeves have been cut off', 'tank top',
|
11 |
+
'tank shirt', 'muscle shirt', 'short-sleeve', 'short sleeves',
|
12 |
+
'with short sleeves', 'medium-sleeve', 'medium sleeves',
|
13 |
+
'with medium sleeves', 'sleeves reach elbow', 'long-sleeve',
|
14 |
+
'long sleeves', 'with long sleeves'
|
15 |
+
]
|
16 |
+
upper_length_attr = {
|
17 |
+
'sleeveless': 0,
|
18 |
+
'without sleeves': 0,
|
19 |
+
'sleeves have been cut off': 0,
|
20 |
+
'tank top': 0,
|
21 |
+
'tank shirt': 0,
|
22 |
+
'muscle shirt': 0,
|
23 |
+
'short-sleeve': 1,
|
24 |
+
'with short sleeves': 1,
|
25 |
+
'short sleeves': 1,
|
26 |
+
'medium-sleeve': 2,
|
27 |
+
'with medium sleeves': 2,
|
28 |
+
'medium sleeves': 2,
|
29 |
+
'sleeves reach elbow': 2,
|
30 |
+
'long-sleeve': 3,
|
31 |
+
'long sleeves': 3,
|
32 |
+
'with long sleeves': 3
|
33 |
+
}
|
34 |
+
lower_length_text = [
|
35 |
+
'three-point', 'medium', 'short', 'covering knee', 'cropped',
|
36 |
+
'three-quarter', 'long', 'slack', 'of long length'
|
37 |
+
]
|
38 |
+
lower_length_attr = {
|
39 |
+
'three-point': 0,
|
40 |
+
'medium': 1,
|
41 |
+
'covering knee': 1,
|
42 |
+
'short': 1,
|
43 |
+
'cropped': 2,
|
44 |
+
'three-quarter': 2,
|
45 |
+
'long': 3,
|
46 |
+
'slack': 3,
|
47 |
+
'of long length': 3
|
48 |
+
}
|
49 |
+
socks_length_text = [
|
50 |
+
'socks', 'stocking', 'pantyhose', 'leggings', 'sheer hosiery'
|
51 |
+
]
|
52 |
+
socks_length_attr = {
|
53 |
+
'socks': 0,
|
54 |
+
'stocking': 1,
|
55 |
+
'pantyhose': 1,
|
56 |
+
'leggings': 1,
|
57 |
+
'sheer hosiery': 1
|
58 |
+
}
|
59 |
+
hat_text = ['hat', 'cap', 'chapeau']
|
60 |
+
eyeglasses_text = ['sunglasses']
|
61 |
+
belt_text = ['belt', 'with a dress tied around the waist']
|
62 |
+
outer_shape_text = [
|
63 |
+
'with outer clothing open', 'with outer clothing unzipped',
|
64 |
+
'covering inner clothes', 'with outer clothing zipped'
|
65 |
+
]
|
66 |
+
outer_shape_attr = {
|
67 |
+
'with outer clothing open': 0,
|
68 |
+
'with outer clothing unzipped': 0,
|
69 |
+
'covering inner clothes': 1,
|
70 |
+
'with outer clothing zipped': 1
|
71 |
+
}
|
72 |
+
|
73 |
+
upper_types = [
|
74 |
+
'T-shirt', 'shirt', 'sweater', 'hoodie', 'tops', 'blouse', 'Basic Tee'
|
75 |
+
]
|
76 |
+
outer_types = [
|
77 |
+
'jacket', 'outer clothing', 'coat', 'overcoat', 'blazer', 'outerwear',
|
78 |
+
'duffle', 'cardigan'
|
79 |
+
]
|
80 |
+
skirt_types = ['skirt']
|
81 |
+
dress_types = ['dress']
|
82 |
+
pant_types = ['jeans', 'pants', 'trousers']
|
83 |
+
rompers_types = ['rompers', 'bodysuit', 'jumpsuit']
|
84 |
+
|
85 |
+
attr_names_list = [
|
86 |
+
'gender', 'hair length', '0 upper clothing length',
|
87 |
+
'1 lower clothing length', '2 socks', '3 hat', '4 eyeglasses', '5 belt',
|
88 |
+
'6 opening of outer clothing', '7 upper clothes', '8 outer clothing',
|
89 |
+
'9 skirt', '10 dress', '11 pants', '12 rompers'
|
90 |
+
]
|
91 |
+
|
92 |
+
|
93 |
+
def generate_shape_attributes(user_shape_texts):
|
94 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
95 |
+
parsed_texts = user_shape_texts.split(',')
|
96 |
+
|
97 |
+
text_num = len(parsed_texts)
|
98 |
+
|
99 |
+
human_attr = [0, 0]
|
100 |
+
attr = [1, 3, 0, 0, 0, 3, 1, 1, 0, 0, 0, 0, 0]
|
101 |
+
|
102 |
+
changed = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
103 |
+
for text_id, text in enumerate(parsed_texts):
|
104 |
+
user_embeddings = model.encode(text)
|
105 |
+
if ('man' in text) and (text_id == 0):
|
106 |
+
human_attr[0] = 0
|
107 |
+
human_attr[1] = 0
|
108 |
+
|
109 |
+
if ('woman' in text or 'lady' in text) and (text_id == 0):
|
110 |
+
human_attr[0] = 1
|
111 |
+
human_attr[1] = 2
|
112 |
+
|
113 |
+
if (not changed[0]) and (text_id == 1):
|
114 |
+
# upper length
|
115 |
+
predefined_embeddings = model.encode(upper_length_text)
|
116 |
+
similarities = util.dot_score(user_embeddings,
|
117 |
+
predefined_embeddings)
|
118 |
+
arg_idx = torch.argmax(similarities).item()
|
119 |
+
attr[0] = upper_length_attr[upper_length_text[arg_idx]]
|
120 |
+
changed[0] = 1
|
121 |
+
|
122 |
+
if (not changed[1]) and ((text_num == 2 and text_id == 1) or
|
123 |
+
(text_num > 2 and text_id == 2)):
|
124 |
+
# lower length
|
125 |
+
predefined_embeddings = model.encode(lower_length_text)
|
126 |
+
similarities = util.dot_score(user_embeddings,
|
127 |
+
predefined_embeddings)
|
128 |
+
arg_idx = torch.argmax(similarities).item()
|
129 |
+
attr[1] = lower_length_attr[lower_length_text[arg_idx]]
|
130 |
+
changed[1] = 1
|
131 |
+
|
132 |
+
if (not changed[2]) and (text_id > 2):
|
133 |
+
# socks length
|
134 |
+
predefined_embeddings = model.encode(socks_length_text)
|
135 |
+
similarities = util.dot_score(user_embeddings,
|
136 |
+
predefined_embeddings)
|
137 |
+
arg_idx = torch.argmax(similarities).item()
|
138 |
+
if similarities[0][arg_idx] > 0.7:
|
139 |
+
attr[2] = arg_idx + 1
|
140 |
+
changed[2] = 1
|
141 |
+
|
142 |
+
if (not changed[3]) and (text_id > 2):
|
143 |
+
# hat
|
144 |
+
predefined_embeddings = model.encode(hat_text)
|
145 |
+
similarities = util.dot_score(user_embeddings,
|
146 |
+
predefined_embeddings)
|
147 |
+
if similarities[0][0] > 0.7:
|
148 |
+
attr[3] = 1
|
149 |
+
changed[3] = 1
|
150 |
+
|
151 |
+
if (not changed[4]) and (text_id > 2):
|
152 |
+
# glasses
|
153 |
+
predefined_embeddings = model.encode(eyeglasses_text)
|
154 |
+
similarities = util.dot_score(user_embeddings,
|
155 |
+
predefined_embeddings)
|
156 |
+
arg_idx = torch.argmax(similarities).item()
|
157 |
+
if similarities[0][arg_idx] > 0.7:
|
158 |
+
attr[4] = arg_idx + 1
|
159 |
+
changed[4] = 1
|
160 |
+
|
161 |
+
if (not changed[5]) and (text_id > 2):
|
162 |
+
# belt
|
163 |
+
predefined_embeddings = model.encode(belt_text)
|
164 |
+
similarities = util.dot_score(user_embeddings,
|
165 |
+
predefined_embeddings)
|
166 |
+
arg_idx = torch.argmax(similarities).item()
|
167 |
+
if similarities[0][arg_idx] > 0.7:
|
168 |
+
attr[5] = arg_idx + 1
|
169 |
+
changed[5] = 1
|
170 |
+
|
171 |
+
if (not changed[6]) and (text_id == 3):
|
172 |
+
# outer coverage
|
173 |
+
predefined_embeddings = model.encode(outer_shape_text)
|
174 |
+
similarities = util.dot_score(user_embeddings,
|
175 |
+
predefined_embeddings)
|
176 |
+
arg_idx = torch.argmax(similarities).item()
|
177 |
+
if similarities[0][arg_idx] > 0.7:
|
178 |
+
attr[6] = arg_idx
|
179 |
+
changed[6] = 1
|
180 |
+
|
181 |
+
if (not changed[10]) and (text_num == 2 and text_id == 1):
|
182 |
+
# dress_types
|
183 |
+
predefined_embeddings = model.encode(dress_types)
|
184 |
+
similarities = util.dot_score(user_embeddings,
|
185 |
+
predefined_embeddings)
|
186 |
+
similarity_skirt = util.dot_score(user_embeddings,
|
187 |
+
model.encode(skirt_types))
|
188 |
+
if similarities[0][0] > 0.5 and similarities[0][
|
189 |
+
0] > similarity_skirt[0][0]:
|
190 |
+
attr[10] = 1
|
191 |
+
attr[7] = 0
|
192 |
+
attr[8] = 0
|
193 |
+
attr[9] = 0
|
194 |
+
attr[11] = 0
|
195 |
+
attr[12] = 0
|
196 |
+
|
197 |
+
changed[0] = 1
|
198 |
+
changed[10] = 1
|
199 |
+
changed[7] = 1
|
200 |
+
changed[8] = 1
|
201 |
+
changed[9] = 1
|
202 |
+
changed[11] = 1
|
203 |
+
changed[12] = 1
|
204 |
+
|
205 |
+
if (not changed[12]) and (text_num == 2 and text_id == 1):
|
206 |
+
# rompers_types
|
207 |
+
predefined_embeddings = model.encode(rompers_types)
|
208 |
+
similarities = util.dot_score(user_embeddings,
|
209 |
+
predefined_embeddings)
|
210 |
+
max_similarity = torch.max(similarities).item()
|
211 |
+
if max_similarity > 0.6:
|
212 |
+
attr[12] = 1
|
213 |
+
attr[7] = 0
|
214 |
+
attr[8] = 0
|
215 |
+
attr[9] = 0
|
216 |
+
attr[10] = 0
|
217 |
+
attr[11] = 0
|
218 |
+
|
219 |
+
changed[12] = 1
|
220 |
+
changed[7] = 1
|
221 |
+
changed[8] = 1
|
222 |
+
changed[9] = 1
|
223 |
+
changed[10] = 1
|
224 |
+
changed[11] = 1
|
225 |
+
|
226 |
+
if (not changed[7]) and (text_num > 2 and text_id == 1):
|
227 |
+
# upper_types
|
228 |
+
predefined_embeddings = model.encode(upper_types)
|
229 |
+
similarities = util.dot_score(user_embeddings,
|
230 |
+
predefined_embeddings)
|
231 |
+
max_similarity = torch.max(similarities).item()
|
232 |
+
if max_similarity > 0.6:
|
233 |
+
attr[7] = 1
|
234 |
+
changed[7] = 1
|
235 |
+
|
236 |
+
if (not changed[8]) and (text_id == 3):
|
237 |
+
# outer_types
|
238 |
+
predefined_embeddings = model.encode(outer_types)
|
239 |
+
similarities = util.dot_score(user_embeddings,
|
240 |
+
predefined_embeddings)
|
241 |
+
arg_idx = torch.argmax(similarities).item()
|
242 |
+
if similarities[0][arg_idx] > 0.7:
|
243 |
+
attr[6] = outer_shape_attr[outer_shape_text[arg_idx]]
|
244 |
+
attr[8] = 1
|
245 |
+
changed[8] = 1
|
246 |
+
|
247 |
+
if (not changed[9]) and (text_num > 2 and text_id == 2):
|
248 |
+
# skirt_types
|
249 |
+
predefined_embeddings = model.encode(skirt_types)
|
250 |
+
similarity_skirt = util.dot_score(user_embeddings,
|
251 |
+
predefined_embeddings)
|
252 |
+
similarity_dress = util.dot_score(user_embeddings,
|
253 |
+
model.encode(dress_types))
|
254 |
+
if similarity_skirt[0][0] > 0.7 and similarity_skirt[0][
|
255 |
+
0] > similarity_dress[0][0]:
|
256 |
+
attr[9] = 1
|
257 |
+
attr[10] = 0
|
258 |
+
changed[9] = 1
|
259 |
+
changed[10] = 1
|
260 |
+
|
261 |
+
if (not changed[11]) and (text_num > 2 and text_id == 2):
|
262 |
+
# pant_types
|
263 |
+
predefined_embeddings = model.encode(pant_types)
|
264 |
+
similarities = util.dot_score(user_embeddings,
|
265 |
+
predefined_embeddings)
|
266 |
+
max_similarity = torch.max(similarities).item()
|
267 |
+
if max_similarity > 0.6:
|
268 |
+
attr[11] = 1
|
269 |
+
attr[9] = 0
|
270 |
+
attr[10] = 0
|
271 |
+
attr[12] = 0
|
272 |
+
changed[11] = 1
|
273 |
+
changed[9] = 1
|
274 |
+
changed[10] = 1
|
275 |
+
changed[12] = 1
|
276 |
+
|
277 |
+
return human_attr + attr
|
278 |
+
|
279 |
+
|
280 |
+
def generate_texture_attributes(user_text):
|
281 |
+
parsed_texts = user_text.split(',')
|
282 |
+
|
283 |
+
attr = []
|
284 |
+
for text in parsed_texts:
|
285 |
+
if ('pure color' in text) or ('solid color' in text):
|
286 |
+
attr.append(4)
|
287 |
+
elif ('spline' in text) or ('stripe' in text):
|
288 |
+
attr.append(3)
|
289 |
+
elif ('plaid' in text) or ('lattice' in text):
|
290 |
+
attr.append(5)
|
291 |
+
elif 'floral' in text:
|
292 |
+
attr.append(1)
|
293 |
+
elif 'denim' in text:
|
294 |
+
attr.append(0)
|
295 |
+
else:
|
296 |
+
attr.append(17)
|
297 |
+
|
298 |
+
if len(attr) == 1:
|
299 |
+
attr.append(attr[0])
|
300 |
+
attr.append(17)
|
301 |
+
|
302 |
+
if len(attr) == 2:
|
303 |
+
attr.append(17)
|
304 |
+
|
305 |
+
return attr
|
306 |
+
|
307 |
+
|
308 |
+
if __name__ == "__main__":
|
309 |
+
user_request = input('Enter your request: ')
|
310 |
+
while user_request != '\\q':
|
311 |
+
attr = generate_shape_attributes(user_request)
|
312 |
+
print(attr)
|
313 |
+
for attr_name, attr_value in zip(attr_names_list, attr):
|
314 |
+
print(attr_name, attr_value)
|
315 |
+
user_request = input('Enter your request: ')
|