Upload 84 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- .gitignore +120 -0
- License +58 -0
- README.md +67 -3
- assets/.DS_Store +0 -0
- assets/arch.png +3 -0
- assets/arch_2.pdf +3 -0
- assets/comparison_3.pdf +3 -0
- assets/new_ablation.pdf +3 -0
- assets/show_3.png +3 -0
- assets/table.pdf +0 -0
- configs/unit_gta2city_folder.yaml +54 -0
- data/__init__.py +0 -0
- data/aligned_dataset.py +56 -0
- data/base_data_loader.py +14 -0
- data/base_dataset.py +50 -0
- data/custom_dataset_data_loader.py +50 -0
- data/data_loader.py +7 -0
- data/image_folder.py +83 -0
- data/pair_dataset.py +95 -0
- data/single_dataset.py +36 -0
- data/syn_dataset.py +91 -0
- data/unaligned_dataset.py +141 -0
- data/unaligned_random_crop.py +85 -0
- datasets/.DS_Store +0 -0
- datasets/bibtex/cityscapes.tex +6 -0
- datasets/bibtex/facades.tex +7 -0
- datasets/bibtex/handbags.tex +13 -0
- datasets/bibtex/shoes.tex +14 -0
- datasets/combine_A_and_B.py +49 -0
- datasets/download_cyclegan_dataset.sh +14 -0
- datasets/download_pix2pix_dataset.sh +8 -0
- imgs/edges2cats.jpg +0 -0
- imgs/horse2zebra.gif +3 -0
- lib/nn/__init__.py +2 -0
- lib/nn/modules/__init__.py +12 -0
- lib/nn/modules/batchnorm.py +329 -0
- lib/nn/modules/comm.py +131 -0
- lib/nn/modules/replicate.py +94 -0
- lib/nn/modules/tests/test_numeric_batchnorm.py +56 -0
- lib/nn/modules/tests/test_sync_batchnorm.py +111 -0
- lib/nn/modules/unittest.py +29 -0
- lib/nn/parallel/__init__.py +1 -0
- lib/nn/parallel/data_parallel.py +112 -0
- lib/utils/__init__.py +1 -0
- lib/utils/data/__init__.py +3 -0
- lib/utils/data/dataloader.py +422 -0
- lib/utils/data/dataset.py +118 -0
- lib/utils/data/distributed.py +58 -0
- lib/utils/data/sampler.py +131 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/arch_2.pdf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/arch.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/comparison_3.pdf filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/new_ablation.pdf filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/show_3.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
imgs/horse2zebra.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
checkpoints/
|
11 |
+
.DS_Store
|
12 |
+
._.DS_Store
|
13 |
+
.vscode
|
14 |
+
predict/
|
15 |
+
results/
|
16 |
+
model/
|
17 |
+
.pth
|
18 |
+
.png
|
19 |
+
.jpg
|
20 |
+
.Python
|
21 |
+
build/
|
22 |
+
develop-eggs/
|
23 |
+
dist/
|
24 |
+
downloads/
|
25 |
+
eggs/
|
26 |
+
.eggs/
|
27 |
+
lib64/
|
28 |
+
parts/
|
29 |
+
sdist/
|
30 |
+
var/
|
31 |
+
wheels/
|
32 |
+
*.egg-info/
|
33 |
+
.installed.cfg
|
34 |
+
*.egg
|
35 |
+
MANIFEST
|
36 |
+
|
37 |
+
# PyInstaller
|
38 |
+
# Usually these files are written by a python script from a template
|
39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
40 |
+
*.manifest
|
41 |
+
*.spec
|
42 |
+
|
43 |
+
# Installer logs
|
44 |
+
pip-log.txt
|
45 |
+
pip-delete-this-directory.txt
|
46 |
+
|
47 |
+
# Unit test / coverage reports
|
48 |
+
htmlcov/
|
49 |
+
.tox/
|
50 |
+
.nox/
|
51 |
+
.coverage
|
52 |
+
.coverage.*
|
53 |
+
.cache
|
54 |
+
nosetests.xml
|
55 |
+
coverage.xml
|
56 |
+
*.cover
|
57 |
+
.hypothesis/
|
58 |
+
.pytest_cache/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
.python-version
|
91 |
+
|
92 |
+
# celery beat schedule file
|
93 |
+
celerybeat-schedule
|
94 |
+
|
95 |
+
# SageMath parsed files
|
96 |
+
*.sage.py
|
97 |
+
|
98 |
+
# Environments
|
99 |
+
.env
|
100 |
+
.venv
|
101 |
+
env/
|
102 |
+
venv/
|
103 |
+
ENV/
|
104 |
+
env.bak/
|
105 |
+
venv.bak/
|
106 |
+
|
107 |
+
# Spyder project settings
|
108 |
+
.spyderproject
|
109 |
+
.spyproject
|
110 |
+
|
111 |
+
# Rope project settings
|
112 |
+
.ropeproject
|
113 |
+
|
114 |
+
# mkdocs documentation
|
115 |
+
/site
|
116 |
+
|
117 |
+
# mypy
|
118 |
+
.mypy_cache/
|
119 |
+
.dmypy.json
|
120 |
+
dmypy.json
|
License
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2019, Yifan Jiang and Zhangyang Wang
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
|
7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
8 |
+
list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer in the documentation
|
12 |
+
and/or other materials provided with the distribution.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
24 |
+
|
25 |
+
|
26 |
+
--------------------------- LICENSE FOR EnlightenGAN --------------------------------
|
27 |
+
BSD License
|
28 |
+
|
29 |
+
For EnlightenGAN software
|
30 |
+
Copyright (c) 2019, Yifan Jiang and Zhangyang Wang
|
31 |
+
All rights reserved.
|
32 |
+
|
33 |
+
Redistribution and use in source and binary forms, with or without
|
34 |
+
modification, are permitted provided that the following conditions are met:
|
35 |
+
|
36 |
+
* Redistributions of source code must retain the above copyright notice, this
|
37 |
+
list of conditions and the following disclaimer.
|
38 |
+
|
39 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
40 |
+
this list of conditions and the following disclaimer in the documentation
|
41 |
+
and/or other materials provided with the distribution.
|
42 |
+
|
43 |
+
----------------------------- LICENSE FOR DCGAN --------------------------------
|
44 |
+
BSD License
|
45 |
+
|
46 |
+
For dcgan.torch software
|
47 |
+
|
48 |
+
Copyright (c) 2015, Facebook, Inc. All rights reserved.
|
49 |
+
|
50 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
51 |
+
|
52 |
+
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
53 |
+
|
54 |
+
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
55 |
+
|
56 |
+
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
57 |
+
|
58 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
CHANGED
@@ -1,3 +1,67 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EnlightenGAN: Deep Light Enhancement without Paired Supervision
|
2 |
+
[Yifan Jiang](https://yifanjiang19.github.io/), Xinyu Gong, Ding Liu, Yu Cheng, Chen Fang, Xiaohui Shen, Jianchao Yang, Pan Zhou, Zhangyang Wang
|
3 |
+
|
4 |
+
[[Paper]](https://arxiv.org/abs/1906.06972) [[Supplementary Materials]](https://yifanjiang.net/files/EnlightenGAN_Supplementary.pdf)
|
5 |
+
|
6 |
+
|
7 |
+
### Representitive Results
|
8 |
+
![representive_results](/assets/show_3.png)
|
9 |
+
|
10 |
+
### Overal Architecture
|
11 |
+
![architecture](/assets/arch.png)
|
12 |
+
|
13 |
+
## Environment Preparing
|
14 |
+
```
|
15 |
+
python3.5
|
16 |
+
```
|
17 |
+
You should prepare at least 3 1080ti gpus or change the batch size.
|
18 |
+
|
19 |
+
|
20 |
+
```pip install -r requirement.txt``` </br>
|
21 |
+
```mkdir model``` </br>
|
22 |
+
Download VGG pretrained model from [[Google Drive 1]](https://drive.google.com/file/d/1IfCeihmPqGWJ0KHmH-mTMi_pn3z3Zo-P/view?usp=sharing), and then put it into the directory `model`.
|
23 |
+
|
24 |
+
### Training process
|
25 |
+
Before starting training process, you should launch the `visdom.server` for visualizing.
|
26 |
+
|
27 |
+
```nohup python -m visdom.server -port=8097```
|
28 |
+
|
29 |
+
then run the following command
|
30 |
+
|
31 |
+
```python scripts/script.py --train```
|
32 |
+
|
33 |
+
### Testing process
|
34 |
+
|
35 |
+
Download [pretrained model](https://drive.google.com/file/d/1AkV-n2MdyfuZTFvcon8Z4leyVb0i7x63/view?usp=sharing) and put it into `./checkpoints/enlightening`
|
36 |
+
|
37 |
+
Create directories `../test_dataset/testA` and `../test_dataset/testB`. Put your test images on `../test_dataset/testA` (And you should keep whatever one image in `../test_dataset/testB` to make sure program can start.)
|
38 |
+
|
39 |
+
Run
|
40 |
+
|
41 |
+
```python scripts/script.py --predict ```
|
42 |
+
|
43 |
+
### Dataset preparing
|
44 |
+
|
45 |
+
Training data [[Google Drive]](https://drive.google.com/drive/folders/1fwqz8-RnTfxgIIkebFG2Ej3jQFsYECh0?usp=sharing) (unpaired images collected from multiple datasets)
|
46 |
+
|
47 |
+
Testing data [[Google Drive]](https://drive.google.com/open?id=1PrvL8jShZ7zj2IC3fVdDxBY1oJR72iDf) (including LIME, MEF, NPE, VV, DICP)
|
48 |
+
|
49 |
+
And [[BaiduYun]](https://github.com/TAMU-VITA/EnlightenGAN/issues/28) is available now thanks to @YHLelaine!
|
50 |
+
|
51 |
+
### Faster Inference
|
52 |
+
https://github.com/arsenyinfo/EnlightenGAN-inference from @arsenyinfo
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
If you find this work useful for you, please cite
|
57 |
+
```
|
58 |
+
@article{jiang2021enlightengan,
|
59 |
+
title={Enlightengan: Deep light enhancement without paired supervision},
|
60 |
+
author={Jiang, Yifan and Gong, Xinyu and Liu, Ding and Cheng, Yu and Fang, Chen and Shen, Xiaohui and Yang, Jianchao and Zhou, Pan and Wang, Zhangyang},
|
61 |
+
journal={IEEE Transactions on Image Processing},
|
62 |
+
volume={30},
|
63 |
+
pages={2340--2349},
|
64 |
+
year={2021},
|
65 |
+
publisher={IEEE}
|
66 |
+
}
|
67 |
+
```
|
assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/arch.png
ADDED
Git LFS Details
|
assets/arch_2.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df02a7f2b894d6230a1f120aa7c112962abe39232c648a441879d9dc8cc71756
|
3 |
+
size 1738396
|
assets/comparison_3.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75a9820f1a978d9f0b6230dbd163efad5f8ca4100afe06bbed90cbe780a341d5
|
3 |
+
size 1753489
|
assets/new_ablation.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38b13a6ad0682986d9535ad2d908a5124f65d321240d0cca881f84f6ef033892
|
3 |
+
size 1407150
|
assets/show_3.png
ADDED
Git LFS Details
|
assets/table.pdf
ADDED
Binary file (96 kB). View file
|
|
configs/unit_gta2city_folder.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
|
2 |
+
# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
3 |
+
|
4 |
+
# logger options
|
5 |
+
image_save_iter: 1000 # How often do you want to save output images during training
|
6 |
+
image_display_iter: 10 # How often do you want to display output images during training
|
7 |
+
display_size: 8 # How many images do you want to display each time
|
8 |
+
snapshot_save_iter: 10000 # How often do you want to save trained models
|
9 |
+
log_iter: 1 # How often do you want to log the training stats
|
10 |
+
|
11 |
+
# optimization options
|
12 |
+
max_iter: 1000000 # maximum number of training iterations
|
13 |
+
batch_size: 1 # batch size
|
14 |
+
weight_decay: 0.0001 # weight decay
|
15 |
+
beta1: 0.5 # Adam parameter
|
16 |
+
beta2: 0.999 # Adam parameter
|
17 |
+
init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
|
18 |
+
lr: 0.0001 # initial learning rate
|
19 |
+
lr_policy: step # learning rate scheduler
|
20 |
+
step_size: 100000 # how often to decay learning rate
|
21 |
+
gamma: 0.5 # how much to decay learning rate
|
22 |
+
gan_w: 1 # weight of adversarial loss
|
23 |
+
recon_x_w: 10 # weight of image reconstruction loss
|
24 |
+
recon_h_w: 0 # weight of hidden reconstruction loss
|
25 |
+
recon_kl_w: 0.01 # weight of KL loss for reconstruction
|
26 |
+
recon_x_cyc_w: 10 # weight of cycle consistency loss
|
27 |
+
recon_kl_cyc_w: 0.01 # weight of KL loss for cycle consistency
|
28 |
+
vgg_w: 0 # weight of domain-invariant perceptual loss
|
29 |
+
|
30 |
+
# model options
|
31 |
+
gen:
|
32 |
+
dim: 64 # number of filters in the bottommost layer
|
33 |
+
activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
|
34 |
+
n_downsample: 2 # number of downsampling layers in content encoder
|
35 |
+
n_res: 4 # number of residual blocks in content encoder/decoder
|
36 |
+
pad_type: reflect # padding type [zero/reflect]
|
37 |
+
dis:
|
38 |
+
dim: 64 # number of filters in the bottommost layer
|
39 |
+
norm: none # normalization layer [none/bn/in/ln]
|
40 |
+
activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
|
41 |
+
n_layer: 4 # number of layers in D
|
42 |
+
gan_type: lsgan # GAN loss [lsgan/nsgan]
|
43 |
+
num_scales: 3 # number of scales
|
44 |
+
pad_type: reflect # padding type [zero/reflect]
|
45 |
+
|
46 |
+
# data options
|
47 |
+
input_dim_a: 3 # number of image channels [1/3]
|
48 |
+
input_dim_b: 3 # number of image channels [1/3]
|
49 |
+
num_workers: 8 # number of data loading threads
|
50 |
+
new_size: 256 # first resize the shortest image side to this size
|
51 |
+
crop_image_height: 256 # random crop image of this height
|
52 |
+
crop_image_width: 256 # random crop image of this width
|
53 |
+
|
54 |
+
data_root: ./datasets/lol/ # dataset folder location
|
data/__init__.py
ADDED
File without changes
|
data/aligned_dataset.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import random
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import torch
|
5 |
+
from data.base_dataset import BaseDataset
|
6 |
+
from data.image_folder import make_dataset
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
class AlignedDataset(BaseDataset):
|
11 |
+
def initialize(self, opt):
|
12 |
+
self.opt = opt
|
13 |
+
self.root = opt.dataroot
|
14 |
+
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
|
15 |
+
|
16 |
+
self.AB_paths = sorted(make_dataset(self.dir_AB))
|
17 |
+
|
18 |
+
assert(opt.resize_or_crop == 'resize_and_crop')
|
19 |
+
|
20 |
+
transform_list = [transforms.ToTensor(),
|
21 |
+
transforms.Normalize((0.5, 0.5, 0.5),
|
22 |
+
(0.5, 0.5, 0.5))]
|
23 |
+
|
24 |
+
self.transform = transforms.Compose(transform_list)
|
25 |
+
|
26 |
+
def __getitem__(self, index):
|
27 |
+
AB_path = self.AB_paths[index]
|
28 |
+
AB = Image.open(AB_path).convert('RGB')
|
29 |
+
AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC)
|
30 |
+
AB = self.transform(AB)
|
31 |
+
|
32 |
+
w_total = AB.size(2)
|
33 |
+
w = int(w_total / 2)
|
34 |
+
h = AB.size(1)
|
35 |
+
w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
|
36 |
+
h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
|
37 |
+
|
38 |
+
A = AB[:, h_offset:h_offset + self.opt.fineSize,
|
39 |
+
w_offset:w_offset + self.opt.fineSize]
|
40 |
+
B = AB[:, h_offset:h_offset + self.opt.fineSize,
|
41 |
+
w + w_offset:w + w_offset + self.opt.fineSize]
|
42 |
+
|
43 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
44 |
+
idx = [i for i in range(A.size(2) - 1, -1, -1)]
|
45 |
+
idx = torch.LongTensor(idx)
|
46 |
+
A = A.index_select(2, idx)
|
47 |
+
B = B.index_select(2, idx)
|
48 |
+
|
49 |
+
return {'A': A, 'B': B,
|
50 |
+
'A_paths': AB_path, 'B_paths': AB_path}
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.AB_paths)
|
54 |
+
|
55 |
+
def name(self):
|
56 |
+
return 'AlignedDataset'
|
data/base_data_loader.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class BaseDataLoader():
|
3 |
+
def __init__(self):
|
4 |
+
pass
|
5 |
+
|
6 |
+
def initialize(self, opt):
|
7 |
+
self.opt = opt
|
8 |
+
pass
|
9 |
+
|
10 |
+
def load_data():
|
11 |
+
return None
|
12 |
+
|
13 |
+
|
14 |
+
|
data/base_dataset.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
from PIL import Image
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import random
|
5 |
+
|
6 |
+
class BaseDataset(data.Dataset):
|
7 |
+
def __init__(self):
|
8 |
+
super(BaseDataset, self).__init__()
|
9 |
+
|
10 |
+
def name(self):
|
11 |
+
return 'BaseDataset'
|
12 |
+
|
13 |
+
def initialize(self, opt):
|
14 |
+
pass
|
15 |
+
|
16 |
+
def get_transform(opt):
|
17 |
+
transform_list = []
|
18 |
+
if opt.resize_or_crop == 'resize_and_crop':
|
19 |
+
zoom = 1 + 0.1*radom.randint(0,4)
|
20 |
+
osize = [int(400*zoom), int(600*zoom)]
|
21 |
+
transform_list.append(transforms.Scale(osize, Image.BICUBIC))
|
22 |
+
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
23 |
+
elif opt.resize_or_crop == 'crop':
|
24 |
+
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
25 |
+
elif opt.resize_or_crop == 'scale_width':
|
26 |
+
transform_list.append(transforms.Lambda(
|
27 |
+
lambda img: __scale_width(img, opt.fineSize)))
|
28 |
+
elif opt.resize_or_crop == 'scale_width_and_crop':
|
29 |
+
transform_list.append(transforms.Lambda(
|
30 |
+
lambda img: __scale_width(img, opt.loadSize)))
|
31 |
+
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
32 |
+
# elif opt.resize_or_crop == 'no':
|
33 |
+
# osize = [384, 512]
|
34 |
+
# transform_list.append(transforms.Scale(osize, Image.BICUBIC))
|
35 |
+
|
36 |
+
if opt.isTrain and not opt.no_flip:
|
37 |
+
transform_list.append(transforms.RandomHorizontalFlip())
|
38 |
+
|
39 |
+
transform_list += [transforms.ToTensor(),
|
40 |
+
transforms.Normalize((0.5, 0.5, 0.5),
|
41 |
+
(0.5, 0.5, 0.5))]
|
42 |
+
return transforms.Compose(transform_list)
|
43 |
+
|
44 |
+
def __scale_width(img, target_width):
|
45 |
+
ow, oh = img.size
|
46 |
+
if (ow == target_width):
|
47 |
+
return img
|
48 |
+
w = target_width
|
49 |
+
h = int(target_width * oh / ow)
|
50 |
+
return img.resize((w, h), Image.BICUBIC)
|
data/custom_dataset_data_loader.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data
|
2 |
+
from data.base_data_loader import BaseDataLoader
|
3 |
+
|
4 |
+
|
5 |
+
def CreateDataset(opt):
|
6 |
+
dataset = None
|
7 |
+
if opt.dataset_mode == 'aligned':
|
8 |
+
from data.aligned_dataset import AlignedDataset
|
9 |
+
dataset = AlignedDataset()
|
10 |
+
elif opt.dataset_mode == 'unaligned':
|
11 |
+
from data.unaligned_dataset import UnalignedDataset
|
12 |
+
dataset = UnalignedDataset()
|
13 |
+
elif opt.dataset_mode == 'unaligned_random_crop':
|
14 |
+
from data.unaligned_random_crop import UnalignedDataset
|
15 |
+
dataset = UnalignedDataset()
|
16 |
+
elif opt.dataset_mode == 'pair':
|
17 |
+
from data.pair_dataset import PairDataset
|
18 |
+
dataset = PairDataset()
|
19 |
+
elif opt.dataset_mode == 'syn':
|
20 |
+
from data.syn_dataset import PairDataset
|
21 |
+
dataset = PairDataset()
|
22 |
+
elif opt.dataset_mode == 'single':
|
23 |
+
from data.single_dataset import SingleDataset
|
24 |
+
dataset = SingleDataset()
|
25 |
+
else:
|
26 |
+
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
|
27 |
+
|
28 |
+
print("dataset [%s] was created" % (dataset.name()))
|
29 |
+
dataset.initialize(opt)
|
30 |
+
return dataset
|
31 |
+
|
32 |
+
|
33 |
+
class CustomDatasetDataLoader(BaseDataLoader):
|
34 |
+
def name(self):
|
35 |
+
return 'CustomDatasetDataLoader'
|
36 |
+
|
37 |
+
def initialize(self, opt):
|
38 |
+
BaseDataLoader.initialize(self, opt)
|
39 |
+
self.dataset = CreateDataset(opt)
|
40 |
+
self.dataloader = torch.utils.data.DataLoader(
|
41 |
+
self.dataset,
|
42 |
+
batch_size=opt.batchSize,
|
43 |
+
shuffle=not opt.serial_batches,
|
44 |
+
num_workers=int(opt.nThreads))
|
45 |
+
|
46 |
+
def load_data(self):
|
47 |
+
return self.dataloader
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
data/data_loader.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def CreateDataLoader(opt):
|
3 |
+
from data.custom_dataset_data_loader import CustomDatasetDataLoader
|
4 |
+
data_loader = CustomDatasetDataLoader()
|
5 |
+
print(data_loader.name())
|
6 |
+
data_loader.initialize(opt)
|
7 |
+
return data_loader
|
data/image_folder.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###############################################################################
|
2 |
+
# Code from
|
3 |
+
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
|
4 |
+
# Modified the original code so that it also loads images from the current
|
5 |
+
# directory as well as the subdirectories
|
6 |
+
###############################################################################
|
7 |
+
|
8 |
+
import torch.utils.data as data
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
import os
|
12 |
+
import os.path
|
13 |
+
|
14 |
+
IMG_EXTENSIONS = [
|
15 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
16 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
def is_image_file(filename):
|
21 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
22 |
+
|
23 |
+
|
24 |
+
def make_dataset(dir):
|
25 |
+
images = []
|
26 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
27 |
+
|
28 |
+
for root, _, fnames in sorted(os.walk(dir)):
|
29 |
+
for fname in fnames:
|
30 |
+
if is_image_file(fname):
|
31 |
+
path = os.path.join(root, fname)
|
32 |
+
images.append(path)
|
33 |
+
|
34 |
+
return images
|
35 |
+
|
36 |
+
def store_dataset(dir):
|
37 |
+
images = []
|
38 |
+
all_path = []
|
39 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
40 |
+
|
41 |
+
for root, _, fnames in sorted(os.walk(dir)):
|
42 |
+
for fname in fnames:
|
43 |
+
if is_image_file(fname):
|
44 |
+
path = os.path.join(root, fname)
|
45 |
+
img = Image.open(path).convert('RGB')
|
46 |
+
images.append(img)
|
47 |
+
all_path.append(path)
|
48 |
+
|
49 |
+
return images, all_path
|
50 |
+
|
51 |
+
|
52 |
+
def default_loader(path):
|
53 |
+
return Image.open(path).convert('RGB')
|
54 |
+
|
55 |
+
|
56 |
+
class ImageFolder(data.Dataset):
|
57 |
+
|
58 |
+
def __init__(self, root, transform=None, return_paths=False,
|
59 |
+
loader=default_loader):
|
60 |
+
imgs = make_dataset(root)
|
61 |
+
if len(imgs) == 0:
|
62 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
63 |
+
"Supported image extensions are: " +
|
64 |
+
",".join(IMG_EXTENSIONS)))
|
65 |
+
|
66 |
+
self.root = root
|
67 |
+
self.imgs = imgs
|
68 |
+
self.transform = transform
|
69 |
+
self.return_paths = return_paths
|
70 |
+
self.loader = loader
|
71 |
+
|
72 |
+
def __getitem__(self, index):
|
73 |
+
path = self.imgs[index]
|
74 |
+
img = self.loader(path)
|
75 |
+
if self.transform is not None:
|
76 |
+
img = self.transform(img)
|
77 |
+
if self.return_paths:
|
78 |
+
return img, path
|
79 |
+
else:
|
80 |
+
return img
|
81 |
+
|
82 |
+
def __len__(self):
|
83 |
+
return len(self.imgs)
|
data/pair_dataset.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from data.base_dataset import BaseDataset, get_transform
|
4 |
+
from data.image_folder import make_dataset
|
5 |
+
from PIL import Image
|
6 |
+
import PIL
|
7 |
+
import random
|
8 |
+
import torch
|
9 |
+
from pdb import set_trace as st
|
10 |
+
|
11 |
+
|
12 |
+
class PairDataset(BaseDataset):
|
13 |
+
def initialize(self, opt):
|
14 |
+
self.opt = opt
|
15 |
+
self.root = opt.dataroot
|
16 |
+
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
|
17 |
+
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
|
18 |
+
|
19 |
+
self.A_paths = make_dataset(self.dir_A)
|
20 |
+
self.B_paths = make_dataset(self.dir_B)
|
21 |
+
|
22 |
+
self.A_paths = sorted(self.A_paths)
|
23 |
+
self.B_paths = sorted(self.B_paths)
|
24 |
+
self.A_size = len(self.A_paths)
|
25 |
+
self.B_size = len(self.B_paths)
|
26 |
+
|
27 |
+
transform_list = []
|
28 |
+
|
29 |
+
transform_list += [transforms.ToTensor(),
|
30 |
+
transforms.Normalize((0.5, 0.5, 0.5),
|
31 |
+
(0.5, 0.5, 0.5))]
|
32 |
+
# transform_list = [transforms.ToTensor()]
|
33 |
+
|
34 |
+
self.transform = transforms.Compose(transform_list)
|
35 |
+
# self.transform = get_transform(opt)
|
36 |
+
|
37 |
+
def __getitem__(self, index):
|
38 |
+
A_path = self.A_paths[index % self.A_size]
|
39 |
+
B_path = self.B_paths[index % self.B_size]
|
40 |
+
|
41 |
+
A_img = Image.open(A_path).convert('RGB')
|
42 |
+
B_img = Image.open(A_path.replace("low", "normal").replace("A", "B")).convert('RGB')
|
43 |
+
|
44 |
+
|
45 |
+
A_img = self.transform(A_img)
|
46 |
+
B_img = self.transform(B_img)
|
47 |
+
|
48 |
+
w = A_img.size(2)
|
49 |
+
h = A_img.size(1)
|
50 |
+
w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
|
51 |
+
h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
|
52 |
+
|
53 |
+
A_img = A_img[:, h_offset:h_offset + self.opt.fineSize,
|
54 |
+
w_offset:w_offset + self.opt.fineSize]
|
55 |
+
B_img = B_img[:, h_offset:h_offset + self.opt.fineSize,
|
56 |
+
w_offset:w_offset + self.opt.fineSize]
|
57 |
+
|
58 |
+
|
59 |
+
if self.opt.resize_or_crop == 'no':
|
60 |
+
r,g,b = A_img[0]+1, A_img[1]+1, A_img[2]+1
|
61 |
+
A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
|
62 |
+
A_gray = torch.unsqueeze(A_gray, 0)
|
63 |
+
input_img = A_img
|
64 |
+
# A_gray = (1./A_gray)/255.
|
65 |
+
else:
|
66 |
+
|
67 |
+
|
68 |
+
# A_gray = (1./A_gray)/255.
|
69 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
70 |
+
idx = [i for i in range(A_img.size(2) - 1, -1, -1)]
|
71 |
+
idx = torch.LongTensor(idx)
|
72 |
+
A_img = A_img.index_select(2, idx)
|
73 |
+
B_img = B_img.index_select(2, idx)
|
74 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
75 |
+
idx = [i for i in range(A_img.size(1) - 1, -1, -1)]
|
76 |
+
idx = torch.LongTensor(idx)
|
77 |
+
A_img = A_img.index_select(1, idx)
|
78 |
+
B_img = B_img.index_select(1, idx)
|
79 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
80 |
+
times = random.randint(self.opt.low_times,self.opt.high_times)/100.
|
81 |
+
input_img = (A_img+1)/2./times
|
82 |
+
input_img = input_img*2-1
|
83 |
+
else:
|
84 |
+
input_img = A_img
|
85 |
+
r,g,b = input_img[0]+1, input_img[1]+1, input_img[2]+1
|
86 |
+
A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
|
87 |
+
A_gray = torch.unsqueeze(A_gray, 0)
|
88 |
+
return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img':input_img,
|
89 |
+
'A_paths': A_path, 'B_paths': B_path}
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
return self.A_size
|
93 |
+
|
94 |
+
def name(self):
|
95 |
+
return 'PairDataset'
|
data/single_dataset.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from data.base_dataset import BaseDataset, get_transform
|
4 |
+
from data.image_folder import make_dataset
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class SingleDataset(BaseDataset):
|
9 |
+
def initialize(self, opt):
|
10 |
+
self.opt = opt
|
11 |
+
self.root = opt.dataroot
|
12 |
+
self.dir_A = os.path.join(opt.dataroot)
|
13 |
+
|
14 |
+
self.A_paths = make_dataset(self.dir_A)
|
15 |
+
|
16 |
+
self.A_paths = sorted(self.A_paths)
|
17 |
+
|
18 |
+
self.transform = get_transform(opt)
|
19 |
+
|
20 |
+
def __getitem__(self, index):
|
21 |
+
A_path = self.A_paths[index]
|
22 |
+
|
23 |
+
A_img = Image.open(A_path).convert('RGB')
|
24 |
+
A_size = A_img.size
|
25 |
+
A_size = A_size = (A_size[0]//16*16, A_size[1]//16*16)
|
26 |
+
A_img = A_img.resize(A_size, Image.BICUBIC)
|
27 |
+
|
28 |
+
A_img = self.transform(A_img)
|
29 |
+
|
30 |
+
return {'A': A_img, 'A_paths': A_path}
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.A_paths)
|
34 |
+
|
35 |
+
def name(self):
|
36 |
+
return 'SingleImageDataset'
|
data/syn_dataset.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from data.base_dataset import BaseDataset, get_transform
|
4 |
+
from data.image_folder import make_dataset
|
5 |
+
from PIL import Image
|
6 |
+
import PIL
|
7 |
+
import random
|
8 |
+
import torch
|
9 |
+
from pdb import set_trace as st
|
10 |
+
|
11 |
+
|
12 |
+
class PairDataset(BaseDataset):
|
13 |
+
def initialize(self, opt):
|
14 |
+
self.opt = opt
|
15 |
+
self.root = opt.dataroot
|
16 |
+
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
|
17 |
+
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
|
18 |
+
|
19 |
+
self.A_paths = make_dataset(self.dir_A)
|
20 |
+
self.B_paths = make_dataset(self.dir_B)
|
21 |
+
|
22 |
+
self.A_paths = sorted(self.A_paths)
|
23 |
+
self.B_paths = sorted(self.B_paths)
|
24 |
+
self.A_size = len(self.A_paths)
|
25 |
+
self.B_size = len(self.B_paths)
|
26 |
+
|
27 |
+
transform_list = []
|
28 |
+
|
29 |
+
transform_list += [transforms.ToTensor(),
|
30 |
+
transforms.Normalize((0.5, 0.5, 0.5),
|
31 |
+
(0.5, 0.5, 0.5))]
|
32 |
+
# transform_list = [transforms.ToTensor()]
|
33 |
+
|
34 |
+
self.transform = transforms.Compose(transform_list)
|
35 |
+
# self.transform = get_transform(opt)
|
36 |
+
|
37 |
+
def __getitem__(self, index):
|
38 |
+
A_path = self.A_paths[index % self.A_size]
|
39 |
+
B_path = self.B_paths[index % self.B_size]
|
40 |
+
|
41 |
+
B_img = Image.open(B_path).convert('RGB')
|
42 |
+
# B_img = Image.open(A_path.replace("low", "normal").replace("A", "B")).convert('RGB')
|
43 |
+
|
44 |
+
|
45 |
+
# A_img = self.transform(A_img)
|
46 |
+
B_img = self.transform(B_img)
|
47 |
+
|
48 |
+
w = B_img.size(2)
|
49 |
+
h = B_img.size(1)
|
50 |
+
w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
|
51 |
+
h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
|
52 |
+
|
53 |
+
B_img = B_img[:, h_offset:h_offset + self.opt.fineSize,
|
54 |
+
w_offset:w_offset + self.opt.fineSize]
|
55 |
+
|
56 |
+
|
57 |
+
if self.opt.resize_or_crop == 'no':
|
58 |
+
pass
|
59 |
+
# r,g,b = A_img[0]+1, A_img[1]+1, A_img[2]+1
|
60 |
+
# A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
|
61 |
+
# A_gray = torch.unsqueeze(A_gray, 0)
|
62 |
+
# input_img = A_img
|
63 |
+
# A_gray = (1./A_gray)/255.
|
64 |
+
else:
|
65 |
+
|
66 |
+
|
67 |
+
# A_gray = (1./A_gray)/255.
|
68 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
69 |
+
idx = [i for i in range(B_img.size(2) - 1, -1, -1)]
|
70 |
+
idx = torch.LongTensor(idx)
|
71 |
+
B_img = B_img.index_select(2, idx)
|
72 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
73 |
+
idx = [i for i in range(B_img.size(1) - 1, -1, -1)]
|
74 |
+
idx = torch.LongTensor(idx)
|
75 |
+
B_img = B_img.index_select(1, idx)
|
76 |
+
|
77 |
+
times = random.randint(self.opt.low_times,self.opt.high_times)/100.
|
78 |
+
input_img = (B_img+1)/2./times
|
79 |
+
input_img = input_img*2-1
|
80 |
+
A_img = input_img
|
81 |
+
r,g,b = input_img[0]+1, input_img[1]+1, input_img[2]+1
|
82 |
+
A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
|
83 |
+
A_gray = torch.unsqueeze(A_gray, 0)
|
84 |
+
return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img':input_img,
|
85 |
+
'A_paths': A_path, 'B_paths': B_path}
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return self.A_size
|
89 |
+
|
90 |
+
def name(self):
|
91 |
+
return 'PairDataset'
|
data/unaligned_dataset.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import os.path
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
from data.base_dataset import BaseDataset, get_transform
|
6 |
+
from data.image_folder import make_dataset, store_dataset
|
7 |
+
import random
|
8 |
+
from PIL import Image
|
9 |
+
import PIL
|
10 |
+
from pdb import set_trace as st
|
11 |
+
|
12 |
+
def pad_tensor(input):
|
13 |
+
|
14 |
+
height_org, width_org = input.shape[2], input.shape[3]
|
15 |
+
divide = 16
|
16 |
+
|
17 |
+
if width_org % divide != 0 or height_org % divide != 0:
|
18 |
+
|
19 |
+
width_res = width_org % divide
|
20 |
+
height_res = height_org % divide
|
21 |
+
if width_res != 0:
|
22 |
+
width_div = divide - width_res
|
23 |
+
pad_left = int(width_div / 2)
|
24 |
+
pad_right = int(width_div - pad_left)
|
25 |
+
else:
|
26 |
+
pad_left = 0
|
27 |
+
pad_right = 0
|
28 |
+
|
29 |
+
if height_res != 0:
|
30 |
+
height_div = divide - height_res
|
31 |
+
pad_top = int(height_div / 2)
|
32 |
+
pad_bottom = int(height_div - pad_top)
|
33 |
+
else:
|
34 |
+
pad_top = 0
|
35 |
+
pad_bottom = 0
|
36 |
+
|
37 |
+
padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom))
|
38 |
+
input = padding(input).data
|
39 |
+
else:
|
40 |
+
pad_left = 0
|
41 |
+
pad_right = 0
|
42 |
+
pad_top = 0
|
43 |
+
pad_bottom = 0
|
44 |
+
|
45 |
+
height, width = input.shape[2], input.shape[3]
|
46 |
+
assert width % divide == 0, 'width cant divided by stride'
|
47 |
+
assert height % divide == 0, 'height cant divided by stride'
|
48 |
+
|
49 |
+
return input, pad_left, pad_right, pad_top, pad_bottom
|
50 |
+
|
51 |
+
def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom):
|
52 |
+
height, width = input.shape[2], input.shape[3]
|
53 |
+
return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right]
|
54 |
+
|
55 |
+
|
56 |
+
class UnalignedDataset(BaseDataset):
|
57 |
+
def initialize(self, opt):
|
58 |
+
self.opt = opt
|
59 |
+
self.root = opt.dataroot
|
60 |
+
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
|
61 |
+
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
|
62 |
+
|
63 |
+
# self.A_paths = make_dataset(self.dir_A)
|
64 |
+
# self.B_paths = make_dataset(self.dir_B)
|
65 |
+
self.A_imgs, self.A_paths = store_dataset(self.dir_A)
|
66 |
+
self.B_imgs, self.B_paths = store_dataset(self.dir_B)
|
67 |
+
|
68 |
+
# self.A_paths = sorted(self.A_paths)
|
69 |
+
# self.B_paths = sorted(self.B_paths)
|
70 |
+
self.A_size = len(self.A_paths)
|
71 |
+
self.B_size = len(self.B_paths)
|
72 |
+
|
73 |
+
self.transform = get_transform(opt)
|
74 |
+
|
75 |
+
def __getitem__(self, index):
|
76 |
+
# A_path = self.A_paths[index % self.A_size]
|
77 |
+
# B_path = self.B_paths[index % self.B_size]
|
78 |
+
|
79 |
+
# A_img = Image.open(A_path).convert('RGB')
|
80 |
+
# B_img = Image.open(B_path).convert('RGB')
|
81 |
+
A_img = self.A_imgs[index % self.A_size]
|
82 |
+
B_img = self.B_imgs[index % self.B_size]
|
83 |
+
A_path = self.A_paths[index % self.A_size]
|
84 |
+
B_path = self.B_paths[index % self.B_size]
|
85 |
+
# A_size = A_img.size
|
86 |
+
# B_size = B_img.size
|
87 |
+
# A_size = A_size = (A_size[0]//16*16, A_size[1]//16*16)
|
88 |
+
# B_size = B_size = (B_size[0]//16*16, B_size[1]//16*16)
|
89 |
+
# A_img = A_img.resize(A_size, Image.BICUBIC)
|
90 |
+
# B_img = B_img.resize(B_size, Image.BICUBIC)
|
91 |
+
# A_gray = A_img.convert('LA')
|
92 |
+
# A_gray = 255.0-A_gray
|
93 |
+
|
94 |
+
A_img = self.transform(A_img)
|
95 |
+
B_img = self.transform(B_img)
|
96 |
+
|
97 |
+
|
98 |
+
if self.opt.resize_or_crop == 'no':
|
99 |
+
r,g,b = A_img[0]+1, A_img[1]+1, A_img[2]+1
|
100 |
+
A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
|
101 |
+
A_gray = torch.unsqueeze(A_gray, 0)
|
102 |
+
input_img = A_img
|
103 |
+
# A_gray = (1./A_gray)/255.
|
104 |
+
else:
|
105 |
+
w = A_img.size(2)
|
106 |
+
h = A_img.size(1)
|
107 |
+
|
108 |
+
# A_gray = (1./A_gray)/255.
|
109 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
110 |
+
idx = [i for i in range(A_img.size(2) - 1, -1, -1)]
|
111 |
+
idx = torch.LongTensor(idx)
|
112 |
+
A_img = A_img.index_select(2, idx)
|
113 |
+
B_img = B_img.index_select(2, idx)
|
114 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
115 |
+
idx = [i for i in range(A_img.size(1) - 1, -1, -1)]
|
116 |
+
idx = torch.LongTensor(idx)
|
117 |
+
A_img = A_img.index_select(1, idx)
|
118 |
+
B_img = B_img.index_select(1, idx)
|
119 |
+
if self.opt.vary == 1 and (not self.opt.no_flip) and random.random() < 0.5:
|
120 |
+
times = random.randint(self.opt.low_times,self.opt.high_times)/100.
|
121 |
+
input_img = (A_img+1)/2./times
|
122 |
+
input_img = input_img*2-1
|
123 |
+
else:
|
124 |
+
input_img = A_img
|
125 |
+
if self.opt.lighten:
|
126 |
+
B_img = (B_img + 1)/2.
|
127 |
+
B_img = (B_img - torch.min(B_img))/(torch.max(B_img) - torch.min(B_img))
|
128 |
+
B_img = B_img*2. -1
|
129 |
+
r,g,b = input_img[0]+1, input_img[1]+1, input_img[2]+1
|
130 |
+
A_gray = 1. - (0.299*r+0.587*g+0.114*b)/2.
|
131 |
+
A_gray = torch.unsqueeze(A_gray, 0)
|
132 |
+
return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img': input_img,
|
133 |
+
'A_paths': A_path, 'B_paths': B_path}
|
134 |
+
|
135 |
+
def __len__(self):
|
136 |
+
return max(self.A_size, self.B_size)
|
137 |
+
|
138 |
+
def name(self):
|
139 |
+
return 'UnalignedDataset'
|
140 |
+
|
141 |
+
|
data/unaligned_random_crop.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os.path
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from data.base_dataset import BaseDataset, get_transform
|
5 |
+
from data.image_folder import make_dataset
|
6 |
+
import random
|
7 |
+
from PIL import Image
|
8 |
+
import PIL
|
9 |
+
from pdb import set_trace as st
|
10 |
+
|
11 |
+
|
12 |
+
class UnalignedDataset(BaseDataset):
|
13 |
+
def initialize(self, opt):
|
14 |
+
self.opt = opt
|
15 |
+
self.root = opt.dataroot
|
16 |
+
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
|
17 |
+
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
|
18 |
+
|
19 |
+
self.A_paths = make_dataset(self.dir_A)
|
20 |
+
self.B_paths = make_dataset(self.dir_B)
|
21 |
+
|
22 |
+
self.A_paths = sorted(self.A_paths)
|
23 |
+
self.B_paths = sorted(self.B_paths)
|
24 |
+
self.A_size = len(self.A_paths)
|
25 |
+
self.B_size = len(self.B_paths)
|
26 |
+
|
27 |
+
transform_list = [transforms.ToTensor(),
|
28 |
+
transforms.Normalize((0.5, 0.5, 0.5),
|
29 |
+
(0.5, 0.5, 0.5))]
|
30 |
+
|
31 |
+
self.transform = transforms.Compose(transform_list)
|
32 |
+
# self.transform = get_transform(opt)
|
33 |
+
|
34 |
+
def __getitem__(self, index):
|
35 |
+
A_path = self.A_paths[index % self.A_size]
|
36 |
+
B_path = self.B_paths[index % self.B_size]
|
37 |
+
|
38 |
+
A_img = Image.open(A_path).convert('RGB')
|
39 |
+
B_img = Image.open(B_path).convert('RGB')
|
40 |
+
A_size = A_img.size
|
41 |
+
B_size = B_img.size
|
42 |
+
A_size = A_size = (A_size[0]//16*16, A_size[1]//16*16)
|
43 |
+
B_size = B_size = (B_size[0]//16*16, B_size[1]//16*16)
|
44 |
+
A_img = A_img.resize(A_size, Image.BICUBIC)
|
45 |
+
B_img = B_img.resize(B_size, Image.BICUBIC)
|
46 |
+
|
47 |
+
|
48 |
+
A_img = self.transform(A_img)
|
49 |
+
B_img = self.transform(B_img)
|
50 |
+
|
51 |
+
if self.opt.resize_or_crop == 'no':
|
52 |
+
pass
|
53 |
+
else:
|
54 |
+
w = A_img.size(2)
|
55 |
+
h = A_img.size(1)
|
56 |
+
size = [8,16,22]
|
57 |
+
from random import randint
|
58 |
+
size_index = randint(0,2)
|
59 |
+
Cropsize = size[size_index]*16
|
60 |
+
|
61 |
+
w_offset = random.randint(0, max(0, w - Cropsize - 1))
|
62 |
+
h_offset = random.randint(0, max(0, h - Cropsize - 1))
|
63 |
+
|
64 |
+
A_img = A_img[:, h_offset:h_offset + Cropsize,
|
65 |
+
w_offset:w_offset + Cropsize]
|
66 |
+
|
67 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
68 |
+
idx = [i for i in range(A_img.size(2) - 1, -1, -1)]
|
69 |
+
idx = torch.LongTensor(idx)
|
70 |
+
A_img = A_img.index_select(2, idx)
|
71 |
+
B_img = B_img.index_select(2, idx)
|
72 |
+
if (not self.opt.no_flip) and random.random() < 0.5:
|
73 |
+
idx = [i for i in range(A_img.size(1) - 1, -1, -1)]
|
74 |
+
idx = torch.LongTensor(idx)
|
75 |
+
A_img = A_img.index_select(1, idx)
|
76 |
+
B_img = B_img.index_select(1, idx)
|
77 |
+
|
78 |
+
return {'A': A_img, 'B': B_img,
|
79 |
+
'A_paths': A_path, 'B_paths': B_path}
|
80 |
+
|
81 |
+
def __len__(self):
|
82 |
+
return max(self.A_size, self.B_size)
|
83 |
+
|
84 |
+
def name(self):
|
85 |
+
return 'UnalignedDataset'
|
datasets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
datasets/bibtex/cityscapes.tex
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@inproceedings{Cordts2016Cityscapes,
|
2 |
+
title={The Cityscapes Dataset for Semantic Urban Scene Understanding},
|
3 |
+
author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
|
4 |
+
booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
5 |
+
year={2016}
|
6 |
+
}
|
datasets/bibtex/facades.tex
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@INPROCEEDINGS{Tylecek13,
|
2 |
+
author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra},
|
3 |
+
title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure},
|
4 |
+
booktitle = {Proc. GCPR},
|
5 |
+
year = {2013},
|
6 |
+
address = {Saarbrucken, Germany},
|
7 |
+
}
|
datasets/bibtex/handbags.tex
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@inproceedings{zhu2016generative,
|
2 |
+
title={Generative Visual Manipulation on the Natural Image Manifold},
|
3 |
+
author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.},
|
4 |
+
booktitle={Proceedings of European Conference on Computer Vision (ECCV)},
|
5 |
+
year={2016}
|
6 |
+
}
|
7 |
+
|
8 |
+
@InProceedings{xie15hed,
|
9 |
+
author = {"Xie, Saining and Tu, Zhuowen"},
|
10 |
+
Title = {Holistically-Nested Edge Detection},
|
11 |
+
Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
|
12 |
+
Year = {2015},
|
13 |
+
}
|
datasets/bibtex/shoes.tex
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@InProceedings{fine-grained,
|
2 |
+
author = {A. Yu and K. Grauman},
|
3 |
+
title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning},
|
4 |
+
booktitle = {Computer Vision and Pattern Recognition (CVPR)},
|
5 |
+
month = {June},
|
6 |
+
year = {2014}
|
7 |
+
}
|
8 |
+
|
9 |
+
@InProceedings{xie15hed,
|
10 |
+
author = {"Xie, Saining and Tu, Zhuowen"},
|
11 |
+
Title = {Holistically-Nested Edge Detection},
|
12 |
+
Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
|
13 |
+
Year = {2015},
|
14 |
+
}
|
datasets/combine_A_and_B.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pdb import set_trace as st
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser('create image pairs')
|
8 |
+
parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
|
9 |
+
parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
|
10 |
+
parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
|
11 |
+
parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
|
12 |
+
parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true')
|
13 |
+
args = parser.parse_args()
|
14 |
+
|
15 |
+
for arg in vars(args):
|
16 |
+
print('[%s] = ' % arg, getattr(args, arg))
|
17 |
+
|
18 |
+
splits = os.listdir(args.fold_A)
|
19 |
+
|
20 |
+
for sp in splits:
|
21 |
+
img_fold_A = os.path.join(args.fold_A, sp)
|
22 |
+
img_fold_B = os.path.join(args.fold_B, sp)
|
23 |
+
img_list = os.listdir(img_fold_A)
|
24 |
+
if args.use_AB:
|
25 |
+
img_list = [img_path for img_path in img_list if '_A.' in img_path]
|
26 |
+
|
27 |
+
num_imgs = min(args.num_imgs, len(img_list))
|
28 |
+
print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
|
29 |
+
img_fold_AB = os.path.join(args.fold_AB, sp)
|
30 |
+
if not os.path.isdir(img_fold_AB):
|
31 |
+
os.makedirs(img_fold_AB)
|
32 |
+
print('split = %s, number of images = %d' % (sp, num_imgs))
|
33 |
+
for n in range(num_imgs):
|
34 |
+
name_A = img_list[n]
|
35 |
+
path_A = os.path.join(img_fold_A, name_A)
|
36 |
+
if args.use_AB:
|
37 |
+
name_B = name_A.replace('_A.', '_B.')
|
38 |
+
else:
|
39 |
+
name_B = name_A
|
40 |
+
path_B = os.path.join(img_fold_B, name_B)
|
41 |
+
if os.path.isfile(path_A) and os.path.isfile(path_B):
|
42 |
+
name_AB = name_A
|
43 |
+
if args.use_AB:
|
44 |
+
name_AB = name_AB.replace('_A.', '.') # remove _A
|
45 |
+
path_AB = os.path.join(img_fold_AB, name_AB)
|
46 |
+
im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR)
|
47 |
+
im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR)
|
48 |
+
im_AB = np.concatenate([im_A, im_B], 1)
|
49 |
+
cv2.imwrite(path_AB, im_AB)
|
datasets/download_cyclegan_dataset.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FILE=$1
|
2 |
+
|
3 |
+
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
|
4 |
+
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
|
5 |
+
exit 1
|
6 |
+
fi
|
7 |
+
|
8 |
+
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
|
9 |
+
ZIP_FILE=./datasets/$FILE.zip
|
10 |
+
TARGET_DIR=./datasets/$FILE/
|
11 |
+
wget -N $URL -O $ZIP_FILE
|
12 |
+
mkdir $TARGET_DIR
|
13 |
+
unzip $ZIP_FILE -d ./datasets/
|
14 |
+
rm $ZIP_FILE
|
datasets/download_pix2pix_dataset.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FILE=$1
|
2 |
+
URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
|
3 |
+
TAR_FILE=./datasets/$FILE.tar.gz
|
4 |
+
TARGET_DIR=./datasets/$FILE/
|
5 |
+
wget -N $URL -O $TAR_FILE
|
6 |
+
mkdir $TARGET_DIR
|
7 |
+
tar -zxvf $TAR_FILE -C ./datasets/
|
8 |
+
rm $TAR_FILE
|
imgs/edges2cats.jpg
ADDED
imgs/horse2zebra.gif
ADDED
Git LFS Details
|
lib/nn/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .modules import *
|
2 |
+
from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
|
lib/nn/modules/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : maojiayuan@gmail.com
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
lib/nn/modules/batchnorm.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : maojiayuan@gmail.com
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
18 |
+
|
19 |
+
from .comm import SyncMaster
|
20 |
+
|
21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
22 |
+
|
23 |
+
|
24 |
+
def _sum_ft(tensor):
|
25 |
+
"""sum over the first and last dimention"""
|
26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
27 |
+
|
28 |
+
|
29 |
+
def _unsqueeze_ft(tensor):
|
30 |
+
"""add new dementions at the front and the tail"""
|
31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
32 |
+
|
33 |
+
|
34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
36 |
+
|
37 |
+
|
38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
|
40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
41 |
+
|
42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
43 |
+
|
44 |
+
self._is_parallel = False
|
45 |
+
self._parallel_id = None
|
46 |
+
self._slave_pipe = None
|
47 |
+
|
48 |
+
# customed batch norm statistics
|
49 |
+
self._moving_average_fraction = 1. - momentum
|
50 |
+
self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
|
51 |
+
self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
|
52 |
+
self.register_buffer('_running_iter', torch.ones(1))
|
53 |
+
self._tmp_running_mean = self.running_mean.clone() * self._running_iter
|
54 |
+
self._tmp_running_var = self.running_var.clone() * self._running_iter
|
55 |
+
|
56 |
+
def forward(self, input):
|
57 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
58 |
+
if not (self._is_parallel and self.training):
|
59 |
+
return F.batch_norm(
|
60 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
61 |
+
self.training, self.momentum, self.eps)
|
62 |
+
|
63 |
+
# Resize the input to (B, C, -1).
|
64 |
+
input_shape = input.size()
|
65 |
+
input = input.view(input.size(0), self.num_features, -1)
|
66 |
+
|
67 |
+
# Compute the sum and square-sum.
|
68 |
+
sum_size = input.size(0) * input.size(2)
|
69 |
+
input_sum = _sum_ft(input)
|
70 |
+
input_ssum = _sum_ft(input ** 2)
|
71 |
+
|
72 |
+
# Reduce-and-broadcast the statistics.
|
73 |
+
if self._parallel_id == 0:
|
74 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
75 |
+
else:
|
76 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
77 |
+
|
78 |
+
# Compute the output.
|
79 |
+
if self.affine:
|
80 |
+
# MJY:: Fuse the multiplication for speed.
|
81 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
82 |
+
else:
|
83 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
84 |
+
|
85 |
+
# Reshape it.
|
86 |
+
return output.view(input_shape)
|
87 |
+
|
88 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
89 |
+
self._is_parallel = True
|
90 |
+
self._parallel_id = copy_id
|
91 |
+
|
92 |
+
# parallel_id == 0 means master device.
|
93 |
+
if self._parallel_id == 0:
|
94 |
+
ctx.sync_master = self._sync_master
|
95 |
+
else:
|
96 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
97 |
+
|
98 |
+
def _data_parallel_master(self, intermediates):
|
99 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
100 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
101 |
+
|
102 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
103 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
104 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
105 |
+
|
106 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
107 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
108 |
+
|
109 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
110 |
+
|
111 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
112 |
+
|
113 |
+
outputs = []
|
114 |
+
for i, rec in enumerate(intermediates):
|
115 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
116 |
+
|
117 |
+
return outputs
|
118 |
+
|
119 |
+
def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
|
120 |
+
"""return *dest* by `dest := dest*alpha + delta*beta + bias`"""
|
121 |
+
return dest * alpha + delta * beta + bias
|
122 |
+
|
123 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
124 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
125 |
+
also maintains the moving average on the master device."""
|
126 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
127 |
+
mean = sum_ / size
|
128 |
+
sumvar = ssum - sum_ * mean
|
129 |
+
unbias_var = sumvar / (size - 1)
|
130 |
+
bias_var = sumvar / size
|
131 |
+
|
132 |
+
self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
|
133 |
+
self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
|
134 |
+
self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
|
135 |
+
|
136 |
+
self.running_mean = self._tmp_running_mean / self._running_iter
|
137 |
+
self.running_var = self._tmp_running_var / self._running_iter
|
138 |
+
|
139 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
140 |
+
|
141 |
+
|
142 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
143 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
144 |
+
mini-batch.
|
145 |
+
|
146 |
+
.. math::
|
147 |
+
|
148 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
149 |
+
|
150 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
151 |
+
standard-deviation are reduced across all devices during training.
|
152 |
+
|
153 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
154 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
155 |
+
the statistics only on that device, which accelerated the computation and
|
156 |
+
is also easy to implement, but the statistics might be inaccurate.
|
157 |
+
Instead, in this synchronized version, the statistics will be computed
|
158 |
+
over all training samples distributed on multiple devices.
|
159 |
+
|
160 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
161 |
+
as the built-in PyTorch implementation.
|
162 |
+
|
163 |
+
The mean and standard-deviation are calculated per-dimension over
|
164 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
165 |
+
of size C (where C is the input size).
|
166 |
+
|
167 |
+
During training, this layer keeps a running estimate of its computed mean
|
168 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
169 |
+
|
170 |
+
During evaluation, this running mean/variance is used for normalization.
|
171 |
+
|
172 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
173 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
174 |
+
|
175 |
+
Args:
|
176 |
+
num_features: num_features from an expected input of size
|
177 |
+
`batch_size x num_features [x width]`
|
178 |
+
eps: a value added to the denominator for numerical stability.
|
179 |
+
Default: 1e-5
|
180 |
+
momentum: the value used for the running_mean and running_var
|
181 |
+
computation. Default: 0.1
|
182 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
183 |
+
affine parameters. Default: ``True``
|
184 |
+
|
185 |
+
Shape:
|
186 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
187 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
188 |
+
|
189 |
+
Examples:
|
190 |
+
>>> # With Learnable Parameters
|
191 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
192 |
+
>>> # Without Learnable Parameters
|
193 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
194 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
195 |
+
>>> output = m(input)
|
196 |
+
"""
|
197 |
+
|
198 |
+
def _check_input_dim(self, input):
|
199 |
+
if input.dim() != 2 and input.dim() != 3:
|
200 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
201 |
+
.format(input.dim()))
|
202 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
203 |
+
|
204 |
+
|
205 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
206 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
207 |
+
of 3d inputs
|
208 |
+
|
209 |
+
.. math::
|
210 |
+
|
211 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
212 |
+
|
213 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
214 |
+
standard-deviation are reduced across all devices during training.
|
215 |
+
|
216 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
217 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
218 |
+
the statistics only on that device, which accelerated the computation and
|
219 |
+
is also easy to implement, but the statistics might be inaccurate.
|
220 |
+
Instead, in this synchronized version, the statistics will be computed
|
221 |
+
over all training samples distributed on multiple devices.
|
222 |
+
|
223 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
224 |
+
as the built-in PyTorch implementation.
|
225 |
+
|
226 |
+
The mean and standard-deviation are calculated per-dimension over
|
227 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
228 |
+
of size C (where C is the input size).
|
229 |
+
|
230 |
+
During training, this layer keeps a running estimate of its computed mean
|
231 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
232 |
+
|
233 |
+
During evaluation, this running mean/variance is used for normalization.
|
234 |
+
|
235 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
236 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
237 |
+
|
238 |
+
Args:
|
239 |
+
num_features: num_features from an expected input of
|
240 |
+
size batch_size x num_features x height x width
|
241 |
+
eps: a value added to the denominator for numerical stability.
|
242 |
+
Default: 1e-5
|
243 |
+
momentum: the value used for the running_mean and running_var
|
244 |
+
computation. Default: 0.1
|
245 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
246 |
+
affine parameters. Default: ``True``
|
247 |
+
|
248 |
+
Shape:
|
249 |
+
- Input: :math:`(N, C, H, W)`
|
250 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
251 |
+
|
252 |
+
Examples:
|
253 |
+
>>> # With Learnable Parameters
|
254 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
255 |
+
>>> # Without Learnable Parameters
|
256 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
257 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
258 |
+
>>> output = m(input)
|
259 |
+
"""
|
260 |
+
|
261 |
+
def _check_input_dim(self, input):
|
262 |
+
if input.dim() != 4:
|
263 |
+
raise ValueError('expected 4D input (got {}D input)'
|
264 |
+
.format(input.dim()))
|
265 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
266 |
+
|
267 |
+
|
268 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
269 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
270 |
+
of 4d inputs
|
271 |
+
|
272 |
+
.. math::
|
273 |
+
|
274 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
275 |
+
|
276 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
277 |
+
standard-deviation are reduced across all devices during training.
|
278 |
+
|
279 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
280 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
281 |
+
the statistics only on that device, which accelerated the computation and
|
282 |
+
is also easy to implement, but the statistics might be inaccurate.
|
283 |
+
Instead, in this synchronized version, the statistics will be computed
|
284 |
+
over all training samples distributed on multiple devices.
|
285 |
+
|
286 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
287 |
+
as the built-in PyTorch implementation.
|
288 |
+
|
289 |
+
The mean and standard-deviation are calculated per-dimension over
|
290 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
291 |
+
of size C (where C is the input size).
|
292 |
+
|
293 |
+
During training, this layer keeps a running estimate of its computed mean
|
294 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
295 |
+
|
296 |
+
During evaluation, this running mean/variance is used for normalization.
|
297 |
+
|
298 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
299 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
300 |
+
or Spatio-temporal BatchNorm
|
301 |
+
|
302 |
+
Args:
|
303 |
+
num_features: num_features from an expected input of
|
304 |
+
size batch_size x num_features x depth x height x width
|
305 |
+
eps: a value added to the denominator for numerical stability.
|
306 |
+
Default: 1e-5
|
307 |
+
momentum: the value used for the running_mean and running_var
|
308 |
+
computation. Default: 0.1
|
309 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
310 |
+
affine parameters. Default: ``True``
|
311 |
+
|
312 |
+
Shape:
|
313 |
+
- Input: :math:`(N, C, D, H, W)`
|
314 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
315 |
+
|
316 |
+
Examples:
|
317 |
+
>>> # With Learnable Parameters
|
318 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
319 |
+
>>> # Without Learnable Parameters
|
320 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
321 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
322 |
+
>>> output = m(input)
|
323 |
+
"""
|
324 |
+
|
325 |
+
def _check_input_dim(self, input):
|
326 |
+
if input.dim() != 5:
|
327 |
+
raise ValueError('expected 5D input (got {}D input)'
|
328 |
+
.format(input.dim()))
|
329 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
lib/nn/modules/comm.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : maojiayuan@gmail.com
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def register_slave(self, identifier):
|
79 |
+
"""
|
80 |
+
Register an slave device.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
identifier: an identifier, usually is the device id.
|
84 |
+
|
85 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
86 |
+
|
87 |
+
"""
|
88 |
+
if self._activated:
|
89 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
90 |
+
self._activated = False
|
91 |
+
self._registry.clear()
|
92 |
+
future = FutureResult()
|
93 |
+
self._registry[identifier] = _MasterRegistry(future)
|
94 |
+
return SlavePipe(identifier, self._queue, future)
|
95 |
+
|
96 |
+
def run_master(self, master_msg):
|
97 |
+
"""
|
98 |
+
Main entry for the master device in each forward pass.
|
99 |
+
The messages were first collected from each devices (including the master device), and then
|
100 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
101 |
+
(including the master device).
|
102 |
+
|
103 |
+
Args:
|
104 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
105 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
106 |
+
|
107 |
+
Returns: the message to be sent back to the master device.
|
108 |
+
|
109 |
+
"""
|
110 |
+
self._activated = True
|
111 |
+
|
112 |
+
intermediates = [(0, master_msg)]
|
113 |
+
for i in range(self.nr_slaves):
|
114 |
+
intermediates.append(self._queue.get())
|
115 |
+
|
116 |
+
results = self._master_callback(intermediates)
|
117 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
118 |
+
|
119 |
+
for i, res in results:
|
120 |
+
if i == 0:
|
121 |
+
continue
|
122 |
+
self._registry[i].result.put(res)
|
123 |
+
|
124 |
+
for i in range(self.nr_slaves):
|
125 |
+
assert self._queue.get() is True
|
126 |
+
|
127 |
+
return results[0][1]
|
128 |
+
|
129 |
+
@property
|
130 |
+
def nr_slaves(self):
|
131 |
+
return len(self._registry)
|
lib/nn/modules/replicate.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : maojiayuan@gmail.com
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
|
31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
32 |
+
|
33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
34 |
+
(shared among multiple copies of this module on different devices).
|
35 |
+
Through this context, different copies can share some information.
|
36 |
+
|
37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
38 |
+
of any slave copies.
|
39 |
+
"""
|
40 |
+
master_copy = modules[0]
|
41 |
+
nr_modules = len(list(master_copy.modules()))
|
42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
43 |
+
|
44 |
+
for i, module in enumerate(modules):
|
45 |
+
for j, m in enumerate(module.modules()):
|
46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
48 |
+
|
49 |
+
|
50 |
+
class DataParallelWithCallback(DataParallel):
|
51 |
+
"""
|
52 |
+
Data Parallel with a replication callback.
|
53 |
+
|
54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
55 |
+
original `replicate` function.
|
56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def replicate(self, module, device_ids):
|
65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
66 |
+
execute_replication_callbacks(modules)
|
67 |
+
return modules
|
68 |
+
|
69 |
+
|
70 |
+
def patch_replication_callback(data_parallel):
|
71 |
+
"""
|
72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
73 |
+
Useful when you have customized `DataParallel` implementation.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
78 |
+
> patch_replication_callback(sync_bn)
|
79 |
+
# this is equivalent to
|
80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert isinstance(data_parallel, DataParallel)
|
85 |
+
|
86 |
+
old_replicate = data_parallel.replicate
|
87 |
+
|
88 |
+
@functools.wraps(old_replicate)
|
89 |
+
def new_replicate(module, device_ids):
|
90 |
+
modules = old_replicate(module, device_ids)
|
91 |
+
execute_replication_callbacks(modules)
|
92 |
+
return modules
|
93 |
+
|
94 |
+
data_parallel.replicate = new_replicate
|
lib/nn/modules/tests/test_numeric_batchnorm.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : test_numeric_batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : maojiayuan@gmail.com
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.autograd import Variable
|
14 |
+
|
15 |
+
from sync_batchnorm.unittest import TorchTestCase
|
16 |
+
|
17 |
+
|
18 |
+
def handy_var(a, unbias=True):
|
19 |
+
n = a.size(0)
|
20 |
+
asum = a.sum(dim=0)
|
21 |
+
as_sum = (a ** 2).sum(dim=0) # a square sum
|
22 |
+
sumvar = as_sum - asum * asum / n
|
23 |
+
if unbias:
|
24 |
+
return sumvar / (n - 1)
|
25 |
+
else:
|
26 |
+
return sumvar / n
|
27 |
+
|
28 |
+
|
29 |
+
class NumericTestCase(TorchTestCase):
|
30 |
+
def testNumericBatchNorm(self):
|
31 |
+
a = torch.rand(16, 10)
|
32 |
+
bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
|
33 |
+
bn.train()
|
34 |
+
|
35 |
+
a_var1 = Variable(a, requires_grad=True)
|
36 |
+
b_var1 = bn(a_var1)
|
37 |
+
loss1 = b_var1.sum()
|
38 |
+
loss1.backward()
|
39 |
+
|
40 |
+
a_var2 = Variable(a, requires_grad=True)
|
41 |
+
a_mean2 = a_var2.mean(dim=0, keepdim=True)
|
42 |
+
a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
|
43 |
+
# a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
|
44 |
+
b_var2 = (a_var2 - a_mean2) / a_std2
|
45 |
+
loss2 = b_var2.sum()
|
46 |
+
loss2.backward()
|
47 |
+
|
48 |
+
self.assertTensorClose(bn.running_mean, a.mean(dim=0))
|
49 |
+
self.assertTensorClose(bn.running_var, handy_var(a))
|
50 |
+
self.assertTensorClose(a_var1.data, a_var2.data)
|
51 |
+
self.assertTensorClose(b_var1.data, b_var2.data)
|
52 |
+
self.assertTensorClose(a_var1.grad, a_var2.grad)
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
unittest.main()
|
lib/nn/modules/tests/test_sync_batchnorm.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : test_sync_batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : maojiayuan@gmail.com
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.autograd import Variable
|
14 |
+
|
15 |
+
from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
|
16 |
+
from sync_batchnorm.unittest import TorchTestCase
|
17 |
+
|
18 |
+
|
19 |
+
def handy_var(a, unbias=True):
|
20 |
+
n = a.size(0)
|
21 |
+
asum = a.sum(dim=0)
|
22 |
+
as_sum = (a ** 2).sum(dim=0) # a square sum
|
23 |
+
sumvar = as_sum - asum * asum / n
|
24 |
+
if unbias:
|
25 |
+
return sumvar / (n - 1)
|
26 |
+
else:
|
27 |
+
return sumvar / n
|
28 |
+
|
29 |
+
|
30 |
+
def _find_bn(module):
|
31 |
+
for m in module.modules():
|
32 |
+
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
|
33 |
+
return m
|
34 |
+
|
35 |
+
|
36 |
+
class SyncTestCase(TorchTestCase):
|
37 |
+
def _syncParameters(self, bn1, bn2):
|
38 |
+
bn1.reset_parameters()
|
39 |
+
bn2.reset_parameters()
|
40 |
+
if bn1.affine and bn2.affine:
|
41 |
+
bn2.weight.data.copy_(bn1.weight.data)
|
42 |
+
bn2.bias.data.copy_(bn1.bias.data)
|
43 |
+
|
44 |
+
def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
|
45 |
+
"""Check the forward and backward for the customized batch normalization."""
|
46 |
+
bn1.train(mode=is_train)
|
47 |
+
bn2.train(mode=is_train)
|
48 |
+
|
49 |
+
if cuda:
|
50 |
+
input = input.cuda()
|
51 |
+
|
52 |
+
self._syncParameters(_find_bn(bn1), _find_bn(bn2))
|
53 |
+
|
54 |
+
input1 = Variable(input, requires_grad=True)
|
55 |
+
output1 = bn1(input1)
|
56 |
+
output1.sum().backward()
|
57 |
+
input2 = Variable(input, requires_grad=True)
|
58 |
+
output2 = bn2(input2)
|
59 |
+
output2.sum().backward()
|
60 |
+
|
61 |
+
self.assertTensorClose(input1.data, input2.data)
|
62 |
+
self.assertTensorClose(output1.data, output2.data)
|
63 |
+
self.assertTensorClose(input1.grad, input2.grad)
|
64 |
+
self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
|
65 |
+
self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
|
66 |
+
|
67 |
+
def testSyncBatchNormNormalTrain(self):
|
68 |
+
bn = nn.BatchNorm1d(10)
|
69 |
+
sync_bn = SynchronizedBatchNorm1d(10)
|
70 |
+
|
71 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
|
72 |
+
|
73 |
+
def testSyncBatchNormNormalEval(self):
|
74 |
+
bn = nn.BatchNorm1d(10)
|
75 |
+
sync_bn = SynchronizedBatchNorm1d(10)
|
76 |
+
|
77 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
|
78 |
+
|
79 |
+
def testSyncBatchNormSyncTrain(self):
|
80 |
+
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
82 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
83 |
+
|
84 |
+
bn.cuda()
|
85 |
+
sync_bn.cuda()
|
86 |
+
|
87 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
|
88 |
+
|
89 |
+
def testSyncBatchNormSyncEval(self):
|
90 |
+
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
|
91 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
92 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
93 |
+
|
94 |
+
bn.cuda()
|
95 |
+
sync_bn.cuda()
|
96 |
+
|
97 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
|
98 |
+
|
99 |
+
def testSyncBatchNorm2DSyncTrain(self):
|
100 |
+
bn = nn.BatchNorm2d(10)
|
101 |
+
sync_bn = SynchronizedBatchNorm2d(10)
|
102 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
103 |
+
|
104 |
+
bn.cuda()
|
105 |
+
sync_bn.cuda()
|
106 |
+
|
107 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == '__main__':
|
111 |
+
unittest.main()
|
lib/nn/modules/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : maojiayuan@gmail.com
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from torch.autograd import Variable
|
15 |
+
|
16 |
+
|
17 |
+
def as_numpy(v):
|
18 |
+
if isinstance(v, Variable):
|
19 |
+
v = v.data
|
20 |
+
return v.cpu().numpy()
|
21 |
+
|
22 |
+
|
23 |
+
class TorchTestCase(unittest.TestCase):
|
24 |
+
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
25 |
+
npa, npb = as_numpy(a), as_numpy(b)
|
26 |
+
self.assertTrue(
|
27 |
+
np.allclose(npa, npb, atol=atol),
|
28 |
+
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
29 |
+
)
|
lib/nn/parallel/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
|
lib/nn/parallel/data_parallel.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf8 -*-
|
2 |
+
|
3 |
+
import torch.cuda as cuda
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch
|
6 |
+
import collections
|
7 |
+
from torch.nn.parallel._functions import Gather
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
|
11 |
+
|
12 |
+
|
13 |
+
def async_copy_to(obj, dev, main_stream=None):
|
14 |
+
if torch.is_tensor(obj):
|
15 |
+
v = obj.cuda(dev, non_blocking=True)
|
16 |
+
if main_stream is not None:
|
17 |
+
v.data.record_stream(main_stream)
|
18 |
+
return v
|
19 |
+
elif isinstance(obj, collections.Mapping):
|
20 |
+
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
|
21 |
+
elif isinstance(obj, collections.Sequence):
|
22 |
+
return [async_copy_to(o, dev, main_stream) for o in obj]
|
23 |
+
else:
|
24 |
+
return obj
|
25 |
+
|
26 |
+
|
27 |
+
def dict_gather(outputs, target_device, dim=0):
|
28 |
+
"""
|
29 |
+
Gathers variables from different GPUs on a specified device
|
30 |
+
(-1 means the CPU), with dictionary support.
|
31 |
+
"""
|
32 |
+
def gather_map(outputs):
|
33 |
+
out = outputs[0]
|
34 |
+
if torch.is_tensor(out):
|
35 |
+
# MJY(20180330) HACK:: force nr_dims > 0
|
36 |
+
if out.dim() == 0:
|
37 |
+
outputs = [o.unsqueeze(0) for o in outputs]
|
38 |
+
return Gather.apply(target_device, dim, *outputs)
|
39 |
+
elif out is None:
|
40 |
+
return None
|
41 |
+
elif isinstance(out, collections.Mapping):
|
42 |
+
return {k: gather_map([o[k] for o in outputs]) for k in out}
|
43 |
+
elif isinstance(out, collections.Sequence):
|
44 |
+
return type(out)(map(gather_map, zip(*outputs)))
|
45 |
+
return gather_map(outputs)
|
46 |
+
|
47 |
+
|
48 |
+
class DictGatherDataParallel(nn.DataParallel):
|
49 |
+
def gather(self, outputs, output_device):
|
50 |
+
return dict_gather(outputs, output_device, dim=self.dim)
|
51 |
+
|
52 |
+
|
53 |
+
class UserScatteredDataParallel(DictGatherDataParallel):
|
54 |
+
def scatter(self, inputs, kwargs, device_ids):
|
55 |
+
assert len(inputs) == 1
|
56 |
+
inputs = inputs[0]
|
57 |
+
inputs = _async_copy_stream(inputs, device_ids)
|
58 |
+
inputs = [[i] for i in inputs]
|
59 |
+
assert len(kwargs) == 0
|
60 |
+
kwargs = [{} for _ in range(len(inputs))]
|
61 |
+
|
62 |
+
return inputs, kwargs
|
63 |
+
|
64 |
+
|
65 |
+
def user_scattered_collate(batch):
|
66 |
+
return batch
|
67 |
+
|
68 |
+
|
69 |
+
def _async_copy(inputs, device_ids):
|
70 |
+
nr_devs = len(device_ids)
|
71 |
+
assert type(inputs) in (tuple, list)
|
72 |
+
assert len(inputs) == nr_devs
|
73 |
+
|
74 |
+
outputs = []
|
75 |
+
for i, dev in zip(inputs, device_ids):
|
76 |
+
with cuda.device(dev):
|
77 |
+
outputs.append(async_copy_to(i, dev))
|
78 |
+
|
79 |
+
return tuple(outputs)
|
80 |
+
|
81 |
+
|
82 |
+
def _async_copy_stream(inputs, device_ids):
|
83 |
+
nr_devs = len(device_ids)
|
84 |
+
assert type(inputs) in (tuple, list)
|
85 |
+
assert len(inputs) == nr_devs
|
86 |
+
|
87 |
+
outputs = []
|
88 |
+
streams = [_get_stream(d) for d in device_ids]
|
89 |
+
for i, dev, stream in zip(inputs, device_ids, streams):
|
90 |
+
with cuda.device(dev):
|
91 |
+
main_stream = cuda.current_stream()
|
92 |
+
with cuda.stream(stream):
|
93 |
+
outputs.append(async_copy_to(i, dev, main_stream=main_stream))
|
94 |
+
main_stream.wait_stream(stream)
|
95 |
+
|
96 |
+
return outputs
|
97 |
+
|
98 |
+
|
99 |
+
"""Adapted from: torch/nn/parallel/_functions.py"""
|
100 |
+
# background streams used for copying
|
101 |
+
_streams = None
|
102 |
+
|
103 |
+
|
104 |
+
def _get_stream(device):
|
105 |
+
"""Gets a background stream for copying between CPU and GPU"""
|
106 |
+
global _streams
|
107 |
+
if device == -1:
|
108 |
+
return None
|
109 |
+
if _streams is None:
|
110 |
+
_streams = [None] * cuda.device_count()
|
111 |
+
if _streams[device] is None: _streams[device] = cuda.Stream(device)
|
112 |
+
return _streams[device]
|
lib/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .th import *
|
lib/utils/data/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .dataset import Dataset, TensorDataset, ConcatDataset
|
3 |
+
from .dataloader import DataLoader
|
lib/utils/data/dataloader.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.multiprocessing as multiprocessing
|
3 |
+
from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
|
4 |
+
_remove_worker_pids, _error_if_any_worker_fails
|
5 |
+
from .sampler import SequentialSampler, RandomSampler, BatchSampler
|
6 |
+
import signal
|
7 |
+
import functools
|
8 |
+
import collections
|
9 |
+
import re
|
10 |
+
import sys
|
11 |
+
import threading
|
12 |
+
import traceback
|
13 |
+
from torch._six import string_classes, int_classes
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
if sys.version_info[0] == 2:
|
17 |
+
import Queue as queue
|
18 |
+
else:
|
19 |
+
import queue
|
20 |
+
|
21 |
+
|
22 |
+
class ExceptionWrapper(object):
|
23 |
+
r"Wraps an exception plus traceback to communicate across threads"
|
24 |
+
|
25 |
+
def __init__(self, exc_info):
|
26 |
+
self.exc_type = exc_info[0]
|
27 |
+
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
28 |
+
|
29 |
+
|
30 |
+
_use_shared_memory = False
|
31 |
+
"""Whether to use shared memory in default_collate"""
|
32 |
+
|
33 |
+
|
34 |
+
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
|
35 |
+
global _use_shared_memory
|
36 |
+
_use_shared_memory = True
|
37 |
+
|
38 |
+
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
39 |
+
# module's handlers are executed after Python returns from C low-level
|
40 |
+
# handlers, likely when the same fatal signal happened again already.
|
41 |
+
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
|
42 |
+
_set_worker_signal_handlers()
|
43 |
+
|
44 |
+
torch.set_num_threads(1)
|
45 |
+
torch.manual_seed(seed)
|
46 |
+
np.random.seed(seed)
|
47 |
+
|
48 |
+
if init_fn is not None:
|
49 |
+
init_fn(worker_id)
|
50 |
+
|
51 |
+
while True:
|
52 |
+
r = index_queue.get()
|
53 |
+
if r is None:
|
54 |
+
break
|
55 |
+
idx, batch_indices = r
|
56 |
+
try:
|
57 |
+
samples = collate_fn([dataset[i] for i in batch_indices])
|
58 |
+
except Exception:
|
59 |
+
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
60 |
+
else:
|
61 |
+
data_queue.put((idx, samples))
|
62 |
+
|
63 |
+
|
64 |
+
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
|
65 |
+
if pin_memory:
|
66 |
+
torch.cuda.set_device(device_id)
|
67 |
+
|
68 |
+
while True:
|
69 |
+
try:
|
70 |
+
r = in_queue.get()
|
71 |
+
except Exception:
|
72 |
+
if done_event.is_set():
|
73 |
+
return
|
74 |
+
raise
|
75 |
+
if r is None:
|
76 |
+
break
|
77 |
+
if isinstance(r[1], ExceptionWrapper):
|
78 |
+
out_queue.put(r)
|
79 |
+
continue
|
80 |
+
idx, batch = r
|
81 |
+
try:
|
82 |
+
if pin_memory:
|
83 |
+
batch = pin_memory_batch(batch)
|
84 |
+
except Exception:
|
85 |
+
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
86 |
+
else:
|
87 |
+
out_queue.put((idx, batch))
|
88 |
+
|
89 |
+
numpy_type_map = {
|
90 |
+
'float64': torch.DoubleTensor,
|
91 |
+
'float32': torch.FloatTensor,
|
92 |
+
'float16': torch.HalfTensor,
|
93 |
+
'int64': torch.LongTensor,
|
94 |
+
'int32': torch.IntTensor,
|
95 |
+
'int16': torch.ShortTensor,
|
96 |
+
'int8': torch.CharTensor,
|
97 |
+
'uint8': torch.ByteTensor,
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
def default_collate(batch):
|
102 |
+
"Puts each data field into a tensor with outer dimension batch size"
|
103 |
+
|
104 |
+
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
105 |
+
elem_type = type(batch[0])
|
106 |
+
if torch.is_tensor(batch[0]):
|
107 |
+
out = None
|
108 |
+
if _use_shared_memory:
|
109 |
+
# If we're in a background process, concatenate directly into a
|
110 |
+
# shared memory tensor to avoid an extra copy
|
111 |
+
numel = sum([x.numel() for x in batch])
|
112 |
+
storage = batch[0].storage()._new_shared(numel)
|
113 |
+
out = batch[0].new(storage)
|
114 |
+
return torch.stack(batch, 0, out=out)
|
115 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
116 |
+
and elem_type.__name__ != 'string_':
|
117 |
+
elem = batch[0]
|
118 |
+
if elem_type.__name__ == 'ndarray':
|
119 |
+
# array of string classes and object
|
120 |
+
if re.search('[SaUO]', elem.dtype.str) is not None:
|
121 |
+
raise TypeError(error_msg.format(elem.dtype))
|
122 |
+
|
123 |
+
return torch.stack([torch.from_numpy(b) for b in batch], 0)
|
124 |
+
if elem.shape == (): # scalars
|
125 |
+
py_type = float if elem.dtype.name.startswith('float') else int
|
126 |
+
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
127 |
+
elif isinstance(batch[0], int_classes):
|
128 |
+
return torch.LongTensor(batch)
|
129 |
+
elif isinstance(batch[0], float):
|
130 |
+
return torch.DoubleTensor(batch)
|
131 |
+
elif isinstance(batch[0], string_classes):
|
132 |
+
return batch
|
133 |
+
elif isinstance(batch[0], collections.Mapping):
|
134 |
+
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
|
135 |
+
elif isinstance(batch[0], collections.Sequence):
|
136 |
+
transposed = zip(*batch)
|
137 |
+
return [default_collate(samples) for samples in transposed]
|
138 |
+
|
139 |
+
raise TypeError((error_msg.format(type(batch[0]))))
|
140 |
+
|
141 |
+
|
142 |
+
def pin_memory_batch(batch):
|
143 |
+
if torch.is_tensor(batch):
|
144 |
+
return batch.pin_memory()
|
145 |
+
elif isinstance(batch, string_classes):
|
146 |
+
return batch
|
147 |
+
elif isinstance(batch, collections.Mapping):
|
148 |
+
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
|
149 |
+
elif isinstance(batch, collections.Sequence):
|
150 |
+
return [pin_memory_batch(sample) for sample in batch]
|
151 |
+
else:
|
152 |
+
return batch
|
153 |
+
|
154 |
+
|
155 |
+
_SIGCHLD_handler_set = False
|
156 |
+
"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
|
157 |
+
handler needs to be set for all DataLoaders in a process."""
|
158 |
+
|
159 |
+
|
160 |
+
def _set_SIGCHLD_handler():
|
161 |
+
# Windows doesn't support SIGCHLD handler
|
162 |
+
if sys.platform == 'win32':
|
163 |
+
return
|
164 |
+
# can't set signal in child threads
|
165 |
+
if not isinstance(threading.current_thread(), threading._MainThread):
|
166 |
+
return
|
167 |
+
global _SIGCHLD_handler_set
|
168 |
+
if _SIGCHLD_handler_set:
|
169 |
+
return
|
170 |
+
previous_handler = signal.getsignal(signal.SIGCHLD)
|
171 |
+
if not callable(previous_handler):
|
172 |
+
previous_handler = None
|
173 |
+
|
174 |
+
def handler(signum, frame):
|
175 |
+
# This following call uses `waitid` with WNOHANG from C side. Therefore,
|
176 |
+
# Python can still get and update the process status successfully.
|
177 |
+
_error_if_any_worker_fails()
|
178 |
+
if previous_handler is not None:
|
179 |
+
previous_handler(signum, frame)
|
180 |
+
|
181 |
+
signal.signal(signal.SIGCHLD, handler)
|
182 |
+
_SIGCHLD_handler_set = True
|
183 |
+
|
184 |
+
|
185 |
+
class DataLoaderIter(object):
|
186 |
+
"Iterates once over the DataLoader's dataset, as specified by the sampler"
|
187 |
+
|
188 |
+
def __init__(self, loader):
|
189 |
+
self.dataset = loader.dataset
|
190 |
+
self.collate_fn = loader.collate_fn
|
191 |
+
self.batch_sampler = loader.batch_sampler
|
192 |
+
self.num_workers = loader.num_workers
|
193 |
+
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
|
194 |
+
self.timeout = loader.timeout
|
195 |
+
self.done_event = threading.Event()
|
196 |
+
|
197 |
+
self.sample_iter = iter(self.batch_sampler)
|
198 |
+
|
199 |
+
if self.num_workers > 0:
|
200 |
+
self.worker_init_fn = loader.worker_init_fn
|
201 |
+
self.index_queue = multiprocessing.SimpleQueue()
|
202 |
+
self.worker_result_queue = multiprocessing.SimpleQueue()
|
203 |
+
self.batches_outstanding = 0
|
204 |
+
self.worker_pids_set = False
|
205 |
+
self.shutdown = False
|
206 |
+
self.send_idx = 0
|
207 |
+
self.rcvd_idx = 0
|
208 |
+
self.reorder_dict = {}
|
209 |
+
|
210 |
+
base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0]
|
211 |
+
self.workers = [
|
212 |
+
multiprocessing.Process(
|
213 |
+
target=_worker_loop,
|
214 |
+
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
|
215 |
+
base_seed + i, self.worker_init_fn, i))
|
216 |
+
for i in range(self.num_workers)]
|
217 |
+
|
218 |
+
if self.pin_memory or self.timeout > 0:
|
219 |
+
self.data_queue = queue.Queue()
|
220 |
+
if self.pin_memory:
|
221 |
+
maybe_device_id = torch.cuda.current_device()
|
222 |
+
else:
|
223 |
+
# do not initialize cuda context if not necessary
|
224 |
+
maybe_device_id = None
|
225 |
+
self.worker_manager_thread = threading.Thread(
|
226 |
+
target=_worker_manager_loop,
|
227 |
+
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
|
228 |
+
maybe_device_id))
|
229 |
+
self.worker_manager_thread.daemon = True
|
230 |
+
self.worker_manager_thread.start()
|
231 |
+
else:
|
232 |
+
self.data_queue = self.worker_result_queue
|
233 |
+
|
234 |
+
for w in self.workers:
|
235 |
+
w.daemon = True # ensure that the worker exits on process exit
|
236 |
+
w.start()
|
237 |
+
|
238 |
+
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
|
239 |
+
_set_SIGCHLD_handler()
|
240 |
+
self.worker_pids_set = True
|
241 |
+
|
242 |
+
# prime the prefetch loop
|
243 |
+
for _ in range(2 * self.num_workers):
|
244 |
+
self._put_indices()
|
245 |
+
|
246 |
+
def __len__(self):
|
247 |
+
return len(self.batch_sampler)
|
248 |
+
|
249 |
+
def _get_batch(self):
|
250 |
+
if self.timeout > 0:
|
251 |
+
try:
|
252 |
+
return self.data_queue.get(timeout=self.timeout)
|
253 |
+
except queue.Empty:
|
254 |
+
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
|
255 |
+
else:
|
256 |
+
return self.data_queue.get()
|
257 |
+
|
258 |
+
def __next__(self):
|
259 |
+
if self.num_workers == 0: # same-process loading
|
260 |
+
indices = next(self.sample_iter) # may raise StopIteration
|
261 |
+
batch = self.collate_fn([self.dataset[i] for i in indices])
|
262 |
+
if self.pin_memory:
|
263 |
+
batch = pin_memory_batch(batch)
|
264 |
+
return batch
|
265 |
+
|
266 |
+
# check if the next sample has already been generated
|
267 |
+
if self.rcvd_idx in self.reorder_dict:
|
268 |
+
batch = self.reorder_dict.pop(self.rcvd_idx)
|
269 |
+
return self._process_next_batch(batch)
|
270 |
+
|
271 |
+
if self.batches_outstanding == 0:
|
272 |
+
self._shutdown_workers()
|
273 |
+
raise StopIteration
|
274 |
+
|
275 |
+
while True:
|
276 |
+
assert (not self.shutdown and self.batches_outstanding > 0)
|
277 |
+
idx, batch = self._get_batch()
|
278 |
+
self.batches_outstanding -= 1
|
279 |
+
if idx != self.rcvd_idx:
|
280 |
+
# store out-of-order samples
|
281 |
+
self.reorder_dict[idx] = batch
|
282 |
+
continue
|
283 |
+
return self._process_next_batch(batch)
|
284 |
+
|
285 |
+
next = __next__ # Python 2 compatibility
|
286 |
+
|
287 |
+
def __iter__(self):
|
288 |
+
return self
|
289 |
+
|
290 |
+
def _put_indices(self):
|
291 |
+
assert self.batches_outstanding < 2 * self.num_workers
|
292 |
+
indices = next(self.sample_iter, None)
|
293 |
+
if indices is None:
|
294 |
+
return
|
295 |
+
self.index_queue.put((self.send_idx, indices))
|
296 |
+
self.batches_outstanding += 1
|
297 |
+
self.send_idx += 1
|
298 |
+
|
299 |
+
def _process_next_batch(self, batch):
|
300 |
+
self.rcvd_idx += 1
|
301 |
+
self._put_indices()
|
302 |
+
if isinstance(batch, ExceptionWrapper):
|
303 |
+
raise batch.exc_type(batch.exc_msg)
|
304 |
+
return batch
|
305 |
+
|
306 |
+
def __getstate__(self):
|
307 |
+
# TODO: add limited pickling support for sharing an iterator
|
308 |
+
# across multiple threads for HOGWILD.
|
309 |
+
# Probably the best way to do this is by moving the sample pushing
|
310 |
+
# to a separate thread and then just sharing the data queue
|
311 |
+
# but signalling the end is tricky without a non-blocking API
|
312 |
+
raise NotImplementedError("DataLoaderIterator cannot be pickled")
|
313 |
+
|
314 |
+
def _shutdown_workers(self):
|
315 |
+
try:
|
316 |
+
if not self.shutdown:
|
317 |
+
self.shutdown = True
|
318 |
+
self.done_event.set()
|
319 |
+
# if worker_manager_thread is waiting to put
|
320 |
+
while not self.data_queue.empty():
|
321 |
+
self.data_queue.get()
|
322 |
+
for _ in self.workers:
|
323 |
+
self.index_queue.put(None)
|
324 |
+
# done_event should be sufficient to exit worker_manager_thread,
|
325 |
+
# but be safe here and put another None
|
326 |
+
self.worker_result_queue.put(None)
|
327 |
+
finally:
|
328 |
+
# removes pids no matter what
|
329 |
+
if self.worker_pids_set:
|
330 |
+
_remove_worker_pids(id(self))
|
331 |
+
self.worker_pids_set = False
|
332 |
+
|
333 |
+
def __del__(self):
|
334 |
+
if self.num_workers > 0:
|
335 |
+
self._shutdown_workers()
|
336 |
+
|
337 |
+
|
338 |
+
class DataLoader(object):
|
339 |
+
"""
|
340 |
+
Data loader. Combines a dataset and a sampler, and provides
|
341 |
+
single- or multi-process iterators over the dataset.
|
342 |
+
|
343 |
+
Arguments:
|
344 |
+
dataset (Dataset): dataset from which to load the data.
|
345 |
+
batch_size (int, optional): how many samples per batch to load
|
346 |
+
(default: 1).
|
347 |
+
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
348 |
+
at every epoch (default: False).
|
349 |
+
sampler (Sampler, optional): defines the strategy to draw samples from
|
350 |
+
the dataset. If specified, ``shuffle`` must be False.
|
351 |
+
batch_sampler (Sampler, optional): like sampler, but returns a batch of
|
352 |
+
indices at a time. Mutually exclusive with batch_size, shuffle,
|
353 |
+
sampler, and drop_last.
|
354 |
+
num_workers (int, optional): how many subprocesses to use for data
|
355 |
+
loading. 0 means that the data will be loaded in the main process.
|
356 |
+
(default: 0)
|
357 |
+
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
|
358 |
+
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
|
359 |
+
into CUDA pinned memory before returning them.
|
360 |
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
361 |
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
362 |
+
the size of dataset is not divisible by the batch size, then the last batch
|
363 |
+
will be smaller. (default: False)
|
364 |
+
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
365 |
+
from workers. Should always be non-negative. (default: 0)
|
366 |
+
worker_init_fn (callable, optional): If not None, this will be called on each
|
367 |
+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
368 |
+
input, after seeding and before data loading. (default: None)
|
369 |
+
|
370 |
+
.. note:: By default, each worker will have its PyTorch seed set to
|
371 |
+
``base_seed + worker_id``, where ``base_seed`` is a long generated
|
372 |
+
by main process using its RNG. You may use ``torch.initial_seed()`` to access
|
373 |
+
this value in :attr:`worker_init_fn`, which can be used to set other seeds
|
374 |
+
(e.g. NumPy) before data loading.
|
375 |
+
|
376 |
+
.. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
|
377 |
+
unpicklable object, e.g., a lambda function.
|
378 |
+
"""
|
379 |
+
|
380 |
+
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
381 |
+
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
|
382 |
+
timeout=0, worker_init_fn=None):
|
383 |
+
self.dataset = dataset
|
384 |
+
self.batch_size = batch_size
|
385 |
+
self.num_workers = num_workers
|
386 |
+
self.collate_fn = collate_fn
|
387 |
+
self.pin_memory = pin_memory
|
388 |
+
self.drop_last = drop_last
|
389 |
+
self.timeout = timeout
|
390 |
+
self.worker_init_fn = worker_init_fn
|
391 |
+
|
392 |
+
if timeout < 0:
|
393 |
+
raise ValueError('timeout option should be non-negative')
|
394 |
+
|
395 |
+
if batch_sampler is not None:
|
396 |
+
if batch_size > 1 or shuffle or sampler is not None or drop_last:
|
397 |
+
raise ValueError('batch_sampler is mutually exclusive with '
|
398 |
+
'batch_size, shuffle, sampler, and drop_last')
|
399 |
+
|
400 |
+
if sampler is not None and shuffle:
|
401 |
+
raise ValueError('sampler is mutually exclusive with shuffle')
|
402 |
+
|
403 |
+
if self.num_workers < 0:
|
404 |
+
raise ValueError('num_workers cannot be negative; '
|
405 |
+
'use num_workers=0 to disable multiprocessing.')
|
406 |
+
|
407 |
+
if batch_sampler is None:
|
408 |
+
if sampler is None:
|
409 |
+
if shuffle:
|
410 |
+
sampler = RandomSampler(dataset)
|
411 |
+
else:
|
412 |
+
sampler = SequentialSampler(dataset)
|
413 |
+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
414 |
+
|
415 |
+
self.sampler = sampler
|
416 |
+
self.batch_sampler = batch_sampler
|
417 |
+
|
418 |
+
def __iter__(self):
|
419 |
+
return DataLoaderIter(self)
|
420 |
+
|
421 |
+
def __len__(self):
|
422 |
+
return len(self.batch_sampler)
|
lib/utils/data/dataset.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
from torch._utils import _accumulate
|
5 |
+
from torch import randperm
|
6 |
+
|
7 |
+
|
8 |
+
class Dataset(object):
|
9 |
+
"""An abstract class representing a Dataset.
|
10 |
+
|
11 |
+
All other datasets should subclass it. All subclasses should override
|
12 |
+
``__len__``, that provides the size of the dataset, and ``__getitem__``,
|
13 |
+
supporting integer indexing in range from 0 to len(self) exclusive.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __getitem__(self, index):
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
def __add__(self, other):
|
23 |
+
return ConcatDataset([self, other])
|
24 |
+
|
25 |
+
|
26 |
+
class TensorDataset(Dataset):
|
27 |
+
"""Dataset wrapping data and target tensors.
|
28 |
+
|
29 |
+
Each sample will be retrieved by indexing both tensors along the first
|
30 |
+
dimension.
|
31 |
+
|
32 |
+
Arguments:
|
33 |
+
data_tensor (Tensor): contains sample data.
|
34 |
+
target_tensor (Tensor): contains sample targets (labels).
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, data_tensor, target_tensor):
|
38 |
+
assert data_tensor.size(0) == target_tensor.size(0)
|
39 |
+
self.data_tensor = data_tensor
|
40 |
+
self.target_tensor = target_tensor
|
41 |
+
|
42 |
+
def __getitem__(self, index):
|
43 |
+
return self.data_tensor[index], self.target_tensor[index]
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return self.data_tensor.size(0)
|
47 |
+
|
48 |
+
|
49 |
+
class ConcatDataset(Dataset):
|
50 |
+
"""
|
51 |
+
Dataset to concatenate multiple datasets.
|
52 |
+
Purpose: useful to assemble different existing datasets, possibly
|
53 |
+
large-scale datasets as the concatenation operation is done in an
|
54 |
+
on-the-fly manner.
|
55 |
+
|
56 |
+
Arguments:
|
57 |
+
datasets (iterable): List of datasets to be concatenated
|
58 |
+
"""
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def cumsum(sequence):
|
62 |
+
r, s = [], 0
|
63 |
+
for e in sequence:
|
64 |
+
l = len(e)
|
65 |
+
r.append(l + s)
|
66 |
+
s += l
|
67 |
+
return r
|
68 |
+
|
69 |
+
def __init__(self, datasets):
|
70 |
+
super(ConcatDataset, self).__init__()
|
71 |
+
assert len(datasets) > 0, 'datasets should not be an empty iterable'
|
72 |
+
self.datasets = list(datasets)
|
73 |
+
self.cumulative_sizes = self.cumsum(self.datasets)
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return self.cumulative_sizes[-1]
|
77 |
+
|
78 |
+
def __getitem__(self, idx):
|
79 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
80 |
+
if dataset_idx == 0:
|
81 |
+
sample_idx = idx
|
82 |
+
else:
|
83 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
84 |
+
return self.datasets[dataset_idx][sample_idx]
|
85 |
+
|
86 |
+
@property
|
87 |
+
def cummulative_sizes(self):
|
88 |
+
warnings.warn("cummulative_sizes attribute is renamed to "
|
89 |
+
"cumulative_sizes", DeprecationWarning, stacklevel=2)
|
90 |
+
return self.cumulative_sizes
|
91 |
+
|
92 |
+
|
93 |
+
class Subset(Dataset):
|
94 |
+
def __init__(self, dataset, indices):
|
95 |
+
self.dataset = dataset
|
96 |
+
self.indices = indices
|
97 |
+
|
98 |
+
def __getitem__(self, idx):
|
99 |
+
return self.dataset[self.indices[idx]]
|
100 |
+
|
101 |
+
def __len__(self):
|
102 |
+
return len(self.indices)
|
103 |
+
|
104 |
+
|
105 |
+
def random_split(dataset, lengths):
|
106 |
+
"""
|
107 |
+
Randomly split a dataset into non-overlapping new datasets of given lengths
|
108 |
+
ds
|
109 |
+
|
110 |
+
Arguments:
|
111 |
+
dataset (Dataset): Dataset to be split
|
112 |
+
lengths (iterable): lengths of splits to be produced
|
113 |
+
"""
|
114 |
+
if sum(lengths) != len(dataset):
|
115 |
+
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
116 |
+
|
117 |
+
indices = randperm(sum(lengths))
|
118 |
+
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|
lib/utils/data/distributed.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from .sampler import Sampler
|
4 |
+
from torch.distributed import get_world_size, get_rank
|
5 |
+
|
6 |
+
|
7 |
+
class DistributedSampler(Sampler):
|
8 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
9 |
+
|
10 |
+
It is especially useful in conjunction with
|
11 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
12 |
+
process can pass a DistributedSampler instance as a DataLoader sampler,
|
13 |
+
and load a subset of the original dataset that is exclusive to it.
|
14 |
+
|
15 |
+
.. note::
|
16 |
+
Dataset is assumed to be of constant size.
|
17 |
+
|
18 |
+
Arguments:
|
19 |
+
dataset: Dataset used for sampling.
|
20 |
+
num_replicas (optional): Number of processes participating in
|
21 |
+
distributed training.
|
22 |
+
rank (optional): Rank of the current process within num_replicas.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, dataset, num_replicas=None, rank=None):
|
26 |
+
if num_replicas is None:
|
27 |
+
num_replicas = get_world_size()
|
28 |
+
if rank is None:
|
29 |
+
rank = get_rank()
|
30 |
+
self.dataset = dataset
|
31 |
+
self.num_replicas = num_replicas
|
32 |
+
self.rank = rank
|
33 |
+
self.epoch = 0
|
34 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
35 |
+
self.total_size = self.num_samples * self.num_replicas
|
36 |
+
|
37 |
+
def __iter__(self):
|
38 |
+
# deterministically shuffle based on epoch
|
39 |
+
g = torch.Generator()
|
40 |
+
g.manual_seed(self.epoch)
|
41 |
+
indices = list(torch.randperm(len(self.dataset), generator=g))
|
42 |
+
|
43 |
+
# add extra samples to make it evenly divisible
|
44 |
+
indices += indices[:(self.total_size - len(indices))]
|
45 |
+
assert len(indices) == self.total_size
|
46 |
+
|
47 |
+
# subsample
|
48 |
+
offset = self.num_samples * self.rank
|
49 |
+
indices = indices[offset:offset + self.num_samples]
|
50 |
+
assert len(indices) == self.num_samples
|
51 |
+
|
52 |
+
return iter(indices)
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return self.num_samples
|
56 |
+
|
57 |
+
def set_epoch(self, epoch):
|
58 |
+
self.epoch = epoch
|
lib/utils/data/sampler.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class Sampler(object):
|
5 |
+
"""Base class for all Samplers.
|
6 |
+
|
7 |
+
Every Sampler subclass has to provide an __iter__ method, providing a way
|
8 |
+
to iterate over indices of dataset elements, and a __len__ method that
|
9 |
+
returns the length of the returned iterators.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, data_source):
|
13 |
+
pass
|
14 |
+
|
15 |
+
def __iter__(self):
|
16 |
+
raise NotImplementedError
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
raise NotImplementedError
|
20 |
+
|
21 |
+
|
22 |
+
class SequentialSampler(Sampler):
|
23 |
+
"""Samples elements sequentially, always in the same order.
|
24 |
+
|
25 |
+
Arguments:
|
26 |
+
data_source (Dataset): dataset to sample from
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, data_source):
|
30 |
+
self.data_source = data_source
|
31 |
+
|
32 |
+
def __iter__(self):
|
33 |
+
return iter(range(len(self.data_source)))
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return len(self.data_source)
|
37 |
+
|
38 |
+
|
39 |
+
class RandomSampler(Sampler):
|
40 |
+
"""Samples elements randomly, without replacement.
|
41 |
+
|
42 |
+
Arguments:
|
43 |
+
data_source (Dataset): dataset to sample from
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, data_source):
|
47 |
+
self.data_source = data_source
|
48 |
+
|
49 |
+
def __iter__(self):
|
50 |
+
return iter(torch.randperm(len(self.data_source)).long())
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.data_source)
|
54 |
+
|
55 |
+
|
56 |
+
class SubsetRandomSampler(Sampler):
|
57 |
+
"""Samples elements randomly from a given list of indices, without replacement.
|
58 |
+
|
59 |
+
Arguments:
|
60 |
+
indices (list): a list of indices
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self, indices):
|
64 |
+
self.indices = indices
|
65 |
+
|
66 |
+
def __iter__(self):
|
67 |
+
return (self.indices[i] for i in torch.randperm(len(self.indices)))
|
68 |
+
|
69 |
+
def __len__(self):
|
70 |
+
return len(self.indices)
|
71 |
+
|
72 |
+
|
73 |
+
class WeightedRandomSampler(Sampler):
|
74 |
+
"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
|
75 |
+
|
76 |
+
Arguments:
|
77 |
+
weights (list) : a list of weights, not necessary summing up to one
|
78 |
+
num_samples (int): number of samples to draw
|
79 |
+
replacement (bool): if ``True``, samples are drawn with replacement.
|
80 |
+
If not, they are drawn without replacement, which means that when a
|
81 |
+
sample index is drawn for a row, it cannot be drawn again for that row.
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, weights, num_samples, replacement=True):
|
85 |
+
self.weights = torch.DoubleTensor(weights)
|
86 |
+
self.num_samples = num_samples
|
87 |
+
self.replacement = replacement
|
88 |
+
|
89 |
+
def __iter__(self):
|
90 |
+
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
|
91 |
+
|
92 |
+
def __len__(self):
|
93 |
+
return self.num_samples
|
94 |
+
|
95 |
+
|
96 |
+
class BatchSampler(object):
|
97 |
+
"""Wraps another sampler to yield a mini-batch of indices.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
sampler (Sampler): Base sampler.
|
101 |
+
batch_size (int): Size of mini-batch.
|
102 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
103 |
+
its size would be less than ``batch_size``
|
104 |
+
|
105 |
+
Example:
|
106 |
+
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
|
107 |
+
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
|
108 |
+
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
|
109 |
+
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, sampler, batch_size, drop_last):
|
113 |
+
self.sampler = sampler
|
114 |
+
self.batch_size = batch_size
|
115 |
+
self.drop_last = drop_last
|
116 |
+
|
117 |
+
def __iter__(self):
|
118 |
+
batch = []
|
119 |
+
for idx in self.sampler:
|
120 |
+
batch.append(idx)
|
121 |
+
if len(batch) == self.batch_size:
|
122 |
+
yield batch
|
123 |
+
batch = []
|
124 |
+
if len(batch) > 0 and not self.drop_last:
|
125 |
+
yield batch
|
126 |
+
|
127 |
+
def __len__(self):
|
128 |
+
if self.drop_last:
|
129 |
+
return len(self.sampler) // self.batch_size
|
130 |
+
else:
|
131 |
+
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|