guma radames commited on
Commit
c7d7131
0 Parent(s):

Duplicate from radames/PIFu-Clothed-Human-Digitization

Browse files

Co-authored-by: Radamés Ajna <radames@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +28 -0
  2. .gitignore +47 -0
  3. PIFu/.gitignore +1 -0
  4. PIFu/LICENSE.txt +48 -0
  5. PIFu/README.md +167 -0
  6. PIFu/apps/__init__.py +0 -0
  7. PIFu/apps/crop_img.py +75 -0
  8. PIFu/apps/eval.py +153 -0
  9. PIFu/apps/eval_spaces.py +138 -0
  10. PIFu/apps/prt_util.py +142 -0
  11. PIFu/apps/render_data.py +290 -0
  12. PIFu/apps/train_color.py +191 -0
  13. PIFu/apps/train_shape.py +183 -0
  14. PIFu/env_sh.npy +0 -0
  15. PIFu/environment.yml +19 -0
  16. PIFu/lib/__init__.py +0 -0
  17. PIFu/lib/colab_util.py +114 -0
  18. PIFu/lib/data/BaseDataset.py +46 -0
  19. PIFu/lib/data/EvalDataset.py +166 -0
  20. PIFu/lib/data/TrainDataset.py +390 -0
  21. PIFu/lib/data/__init__.py +2 -0
  22. PIFu/lib/ext_transform.py +78 -0
  23. PIFu/lib/geometry.py +55 -0
  24. PIFu/lib/mesh_util.py +91 -0
  25. PIFu/lib/model/BasePIFuNet.py +76 -0
  26. PIFu/lib/model/ConvFilters.py +112 -0
  27. PIFu/lib/model/ConvPIFuNet.py +99 -0
  28. PIFu/lib/model/DepthNormalizer.py +18 -0
  29. PIFu/lib/model/HGFilters.py +146 -0
  30. PIFu/lib/model/HGPIFuNet.py +142 -0
  31. PIFu/lib/model/ResBlkPIFuNet.py +201 -0
  32. PIFu/lib/model/SurfaceClassifier.py +71 -0
  33. PIFu/lib/model/VhullPIFuNet.py +70 -0
  34. PIFu/lib/model/__init__.py +5 -0
  35. PIFu/lib/net_util.py +396 -0
  36. PIFu/lib/options.py +161 -0
  37. PIFu/lib/renderer/__init__.py +0 -0
  38. PIFu/lib/renderer/camera.py +207 -0
  39. PIFu/lib/renderer/gl/__init__.py +0 -0
  40. PIFu/lib/renderer/gl/cam_render.py +48 -0
  41. PIFu/lib/renderer/gl/data/prt.fs +153 -0
  42. PIFu/lib/renderer/gl/data/prt.vs +167 -0
  43. PIFu/lib/renderer/gl/data/prt_uv.fs +141 -0
  44. PIFu/lib/renderer/gl/data/prt_uv.vs +168 -0
  45. PIFu/lib/renderer/gl/data/quad.fs +11 -0
  46. PIFu/lib/renderer/gl/data/quad.vs +11 -0
  47. PIFu/lib/renderer/gl/framework.py +90 -0
  48. PIFu/lib/renderer/gl/glcontext.py +142 -0
  49. PIFu/lib/renderer/gl/init_gl.py +24 -0
  50. PIFu/lib/renderer/gl/prt_render.py +350 -0
.gitattributes ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.glb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ results/
2
+ # Python build
3
+ .eggs/
4
+ gradio.egg-info/*
5
+ !gradio.egg-info/requires.txt
6
+ !gradio.egg-info/PKG-INFO
7
+ dist/
8
+ *.pyc
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+ build/
13
+
14
+ # JS build
15
+ gradio/templates/frontend
16
+ # Secrets
17
+ .env
18
+
19
+ # Gradio run artifacts
20
+ *.db
21
+ *.sqlite3
22
+ gradio/launches.json
23
+ flagged/
24
+ # gradio_cached_examples/
25
+
26
+ # Tests
27
+ .coverage
28
+ coverage.xml
29
+ test.txt
30
+
31
+ # Demos
32
+ demo/tmp.zip
33
+ demo/files/*.avi
34
+ demo/files/*.mp4
35
+
36
+ # Etc
37
+ .idea/*
38
+ .DS_Store
39
+ *.bak
40
+ workspace.code-workspace
41
+ *.h5
42
+ .vscode/
43
+
44
+ # log files
45
+ .pnpm-debug.log
46
+ venv/
47
+ *.db-journal
PIFu/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ checkpoints/*
PIFu/LICENSE.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Shunsuke Saito, Zeng Huang, and Ryota Natsume
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ anyabagomo
24
+
25
+ -------------------- LICENSE FOR ResBlk Image Encoder -----------------------
26
+ Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
27
+ All rights reserved.
28
+
29
+ Redistribution and use in source and binary forms, with or without
30
+ modification, are permitted provided that the following conditions are met:
31
+
32
+ * Redistributions of source code must retain the above copyright notice, this
33
+ list of conditions and the following disclaimer.
34
+
35
+ * Redistributions in binary form must reproduce the above copyright notice,
36
+ this list of conditions and the following disclaimer in the documentation
37
+ and/or other materials provided with the distribution.
38
+
39
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
40
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
41
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
42
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
43
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
44
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
45
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
46
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
47
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
48
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
PIFu/README.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization
2
+
3
+ [![report](https://img.shields.io/badge/arxiv-report-red)](https://arxiv.org/abs/1905.05172) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1GFSsqP2BWz4gtq0e-nki00ZHSirXwFyY)
4
+
5
+ News:
6
+ * \[2020/05/04\] Added EGL rendering option for training data generation. Now you can create your own training data with headless machines!
7
+ * \[2020/04/13\] Demo with Google Colab (incl. visualization) is available. Special thanks to [@nanopoteto](https://github.com/nanopoteto)!!!
8
+ * \[2020/02/26\] License is updated to MIT license! Enjoy!
9
+
10
+ This repository contains a pytorch implementation of "[PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization](https://arxiv.org/abs/1905.05172)".
11
+
12
+ [Project Page](https://shunsukesaito.github.io/PIFu/)
13
+ ![Teaser Image](https://shunsukesaito.github.io/PIFu/resources/images/teaser.png)
14
+
15
+ If you find the code useful in your research, please consider citing the paper.
16
+
17
+ ```
18
+ @InProceedings{saito2019pifu,
19
+ author = {Saito, Shunsuke and Huang, Zeng and Natsume, Ryota and Morishima, Shigeo and Kanazawa, Angjoo and Li, Hao},
20
+ title = {PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization},
21
+ booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
22
+ month = {October},
23
+ year = {2019}
24
+ }
25
+ ```
26
+
27
+
28
+ This codebase provides:
29
+ - test code
30
+ - training code
31
+ - data generation code
32
+
33
+ ## Requirements
34
+ - Python 3
35
+ - [PyTorch](https://pytorch.org/) tested on 1.4.0
36
+ - json
37
+ - PIL
38
+ - skimage
39
+ - tqdm
40
+ - numpy
41
+ - cv2
42
+
43
+ for training and data generation
44
+ - [trimesh](https://trimsh.org/) with [pyembree](https://github.com/scopatz/pyembree)
45
+ - [pyexr](https://github.com/tvogels/pyexr)
46
+ - PyOpenGL
47
+ - freeglut (use `sudo apt-get install freeglut3-dev` for ubuntu users)
48
+ - (optional) egl related packages for rendering with headless machines. (use `apt install libgl1-mesa-dri libegl1-mesa libgbm1` for ubuntu users)
49
+
50
+ Warning: I found that outdated NVIDIA drivers may cause errors with EGL. If you want to try out the EGL version, please update your NVIDIA driver to the latest!!
51
+
52
+ ## Windows demo installation instuction
53
+
54
+ - Install [miniconda](https://docs.conda.io/en/latest/miniconda.html)
55
+ - Add `conda` to PATH
56
+ - Install [git bash](https://git-scm.com/downloads)
57
+ - Launch `Git\bin\bash.exe`
58
+ - `eval "$(conda shell.bash hook)"` then `conda activate my_env` because of [this](https://github.com/conda/conda-build/issues/3371)
59
+ - Automatic `env create -f environment.yml` (look [this](https://github.com/conda/conda/issues/3417))
60
+ - OR manually setup [environment](https://towardsdatascience.com/a-guide-to-conda-environments-bc6180fc533)
61
+ - `conda create —name pifu python` where `pifu` is name of your environment
62
+ - `conda activate`
63
+ - `conda install pytorch torchvision cudatoolkit=10.1 -c pytorch`
64
+ - `conda install pillow`
65
+ - `conda install scikit-image`
66
+ - `conda install tqdm`
67
+ - `conda install -c menpo opencv`
68
+ - Download [wget.exe](https://eternallybored.org/misc/wget/)
69
+ - Place it into `Git\mingw64\bin`
70
+ - `sh ./scripts/download_trained_model.sh`
71
+ - Remove background from your image ([this](https://www.remove.bg/), for example)
72
+ - Create black-white mask .png
73
+ - Replace original from sample_images/
74
+ - Try it out - `sh ./scripts/test.sh`
75
+ - Download [Meshlab](http://www.meshlab.net/) because of [this](https://github.com/shunsukesaito/PIFu/issues/1)
76
+ - Open .obj file in Meshlab
77
+
78
+
79
+ ## Demo
80
+ Warning: The released model is trained with mostly upright standing scans with weak perspectie projection and the pitch angle of 0 degree. Reconstruction quality may degrade for images highly deviated from trainining data.
81
+ 1. run the following script to download the pretrained models from the following link and copy them under `./PIFu/checkpoints/`.
82
+ ```
83
+ sh ./scripts/download_trained_model.sh
84
+ ```
85
+
86
+ 2. run the following script. the script creates a textured `.obj` file under `./PIFu/eval_results/`. You may need to use `./apps/crop_img.py` to roughly align an input image and the corresponding mask to the training data for better performance. For background removal, you can use any off-the-shelf tools such as [removebg](https://www.remove.bg/).
87
+ ```
88
+ sh ./scripts/test.sh
89
+ ```
90
+
91
+ ## Demo on Google Colab
92
+ If you do not have a setup to run PIFu, we offer Google Colab version to give it a try, allowing you to run PIFu in the cloud, free of charge. Try our Colab demo using the following notebook:
93
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1GFSsqP2BWz4gtq0e-nki00ZHSirXwFyY)
94
+
95
+ ## Data Generation (Linux Only)
96
+ While we are unable to release the full training data due to the restriction of commertial scans, we provide rendering code using free models in [RenderPeople](https://renderpeople.com/free-3d-people/).
97
+ This tutorial uses `rp_dennis_posed_004` model. Please download the model from [this link](https://renderpeople.com/sample/free/rp_dennis_posed_004_OBJ.zip) and unzip the content under a folder named `rp_dennis_posed_004_OBJ`. The same process can be applied to other RenderPeople data.
98
+
99
+ Warning: the following code becomes extremely slow without [pyembree](https://github.com/scopatz/pyembree). Please make sure you install pyembree.
100
+
101
+ 1. run the following script to compute spherical harmonics coefficients for [precomputed radiance transfer (PRT)](https://sites.fas.harvard.edu/~cs278/papers/prt.pdf). In a nutshell, PRT is used to account for accurate light transport including ambient occlusion without compromising online rendering time, which significantly improves the photorealism compared with [a common sperical harmonics rendering using surface normals](https://cseweb.ucsd.edu/~ravir/papers/envmap/envmap.pdf). This process has to be done once for each obj file.
102
+ ```
103
+ python -m apps.prt_util -i {path_to_rp_dennis_posed_004_OBJ}
104
+ ```
105
+
106
+ 2. run the following script. Under the specified data path, the code creates folders named `GEO`, `RENDER`, `MASK`, `PARAM`, `UV_RENDER`, `UV_MASK`, `UV_NORMAL`, and `UV_POS`. Note that you may need to list validation subjects to exclude from training in `{path_to_training_data}/val.txt` (this tutorial has only one subject and leave it empty). If you wish to render images with headless servers equipped with NVIDIA GPU, add -e to enable EGL rendering.
107
+ ```
108
+ python -m apps.render_data -i {path_to_rp_dennis_posed_004_OBJ} -o {path_to_training_data} [-e]
109
+ ```
110
+
111
+ ## Training (Linux Only)
112
+
113
+ Warning: the following code becomes extremely slow without [pyembree](https://github.com/scopatz/pyembree). Please make sure you install pyembree.
114
+
115
+ 1. run the following script to train the shape module. The intermediate results and checkpoints are saved under `./results` and `./checkpoints` respectively. You can add `--batch_size` and `--num_sample_input` flags to adjust the batch size and the number of sampled points based on available GPU memory.
116
+ ```
117
+ python -m apps.train_shape --dataroot {path_to_training_data} --random_flip --random_scale --random_trans
118
+ ```
119
+
120
+ 2. run the following script to train the color module.
121
+ ```
122
+ python -m apps.train_color --dataroot {path_to_training_data} --num_sample_inout 0 --num_sample_color 5000 --sigma 0.1 --random_flip --random_scale --random_trans
123
+ ```
124
+
125
+ ## Related Research
126
+ **[Monocular Real-Time Volumetric Performance Capture (ECCV 2020)](https://project-splinter.github.io/)**
127
+ *Ruilong Li\*, Yuliang Xiu\*, Shunsuke Saito, Zeng Huang, Kyle Olszewski, Hao Li*
128
+
129
+ The first real-time PIFu by accelerating reconstruction and rendering!!
130
+
131
+ **[PIFuHD: Multi-Level Pixel-Aligned Implicit Function for High-Resolution 3D Human Digitization (CVPR 2020)](https://shunsukesaito.github.io/PIFuHD/)**
132
+ *Shunsuke Saito, Tomas Simon, Jason Saragih, Hanbyul Joo*
133
+
134
+ We further improve the quality of reconstruction by leveraging multi-level approach!
135
+
136
+ **[ARCH: Animatable Reconstruction of Clothed Humans (CVPR 2020)](https://arxiv.org/pdf/2004.04572.pdf)**
137
+ *Zeng Huang, Yuanlu Xu, Christoph Lassner, Hao Li, Tony Tung*
138
+
139
+ Learning PIFu in canonical space for animatable avatar generation!
140
+
141
+ **[Robust 3D Self-portraits in Seconds (CVPR 2020)](http://www.liuyebin.com/portrait/portrait.html)**
142
+ *Zhe Li, Tao Yu, Chuanyu Pan, Zerong Zheng, Yebin Liu*
143
+
144
+ They extend PIFu to RGBD + introduce "PIFusion" utilizing PIFu reconstruction for non-rigid fusion.
145
+
146
+ **[Learning to Infer Implicit Surfaces without 3d Supervision (NeurIPS 2019)](http://papers.nips.cc/paper/9039-learning-to-infer-implicit-surfaces-without-3d-supervision.pdf)**
147
+ *Shichen Liu, Shunsuke Saito, Weikai Chen, Hao Li*
148
+
149
+ We answer to the question of "how can we learn implicit function if we don't have 3D ground truth?"
150
+
151
+ **[SiCloPe: Silhouette-Based Clothed People (CVPR 2019, best paper finalist)](https://arxiv.org/pdf/1901.00049.pdf)**
152
+ *Ryota Natsume\*, Shunsuke Saito\*, Zeng Huang, Weikai Chen, Chongyang Ma, Hao Li, Shigeo Morishima*
153
+
154
+ Our first attempt to reconstruct 3D clothed human body with texture from a single image!
155
+
156
+ **[Deep Volumetric Video from Very Sparse Multi-view Performance Capture (ECCV 2018)](http://openaccess.thecvf.com/content_ECCV_2018/papers/Zeng_Huang_Deep_Volumetric_Video_ECCV_2018_paper.pdf)**
157
+ *Zeng Huang, Tianye Li, Weikai Chen, Yajie Zhao, Jun Xing, Chloe LeGendre, Linjie Luo, Chongyang Ma, Hao Li*
158
+
159
+ Implict surface learning for sparse view human performance capture!
160
+
161
+ ------
162
+
163
+
164
+
165
+ For commercial queries, please contact:
166
+
167
+ Hao Li: hao@hao-li.com ccto: saitos@usc.edu Baker!!
PIFu/apps/__init__.py ADDED
File without changes
PIFu/apps/crop_img.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+
5
+ from pathlib import Path
6
+ import argparse
7
+
8
+ def get_bbox(msk):
9
+ rows = np.any(msk, axis=1)
10
+ cols = np.any(msk, axis=0)
11
+ rmin, rmax = np.where(rows)[0][[0,-1]]
12
+ cmin, cmax = np.where(cols)[0][[0,-1]]
13
+
14
+ return rmin, rmax, cmin, cmax
15
+
16
+ def process_img(img, msk, bbox=None):
17
+ if bbox is None:
18
+ bbox = get_bbox(msk > 100)
19
+ cx = (bbox[3] + bbox[2])//2
20
+ cy = (bbox[1] + bbox[0])//2
21
+
22
+ w = img.shape[1]
23
+ h = img.shape[0]
24
+ height = int(1.138*(bbox[1] - bbox[0]))
25
+ hh = height//2
26
+
27
+ # crop
28
+ dw = min(cx, w-cx, hh)
29
+ if cy-hh < 0:
30
+ img = cv2.copyMakeBorder(img,hh-cy,0,0,0,cv2.BORDER_CONSTANT,value=[0,0,0])
31
+ msk = cv2.copyMakeBorder(msk,hh-cy,0,0,0,cv2.BORDER_CONSTANT,value=0)
32
+ cy = hh
33
+ if cy+hh > h:
34
+ img = cv2.copyMakeBorder(img,0,cy+hh-h,0,0,cv2.BORDER_CONSTANT,value=[0,0,0])
35
+ msk = cv2.copyMakeBorder(msk,0,cy+hh-h,0,0,cv2.BORDER_CONSTANT,value=0)
36
+ img = img[cy-hh:(cy+hh),cx-dw:cx+dw,:]
37
+ msk = msk[cy-hh:(cy+hh),cx-dw:cx+dw]
38
+ dw = img.shape[0] - img.shape[1]
39
+ if dw != 0:
40
+ img = cv2.copyMakeBorder(img,0,0,dw//2,dw//2,cv2.BORDER_CONSTANT,value=[0,0,0])
41
+ msk = cv2.copyMakeBorder(msk,0,0,dw//2,dw//2,cv2.BORDER_CONSTANT,value=0)
42
+ img = cv2.resize(img, (512, 512))
43
+ msk = cv2.resize(msk, (512, 512))
44
+
45
+ kernel = np.ones((3,3),np.uint8)
46
+ msk = cv2.erode((255*(msk > 100)).astype(np.uint8), kernel, iterations = 1)
47
+
48
+ return img, msk
49
+
50
+ def main():
51
+ '''
52
+ given foreground mask, this script crops and resizes an input image and mask for processing.
53
+ '''
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument('-i', '--input_image', type=str, help='if the image has alpha channel, it will be used as mask')
56
+ parser.add_argument('-m', '--input_mask', type=str)
57
+ parser.add_argument('-o', '--out_path', type=str, default='./sample_images')
58
+ args = parser.parse_args()
59
+
60
+ img = cv2.imread(args.input_image, cv2.IMREAD_UNCHANGED)
61
+ if img.shape[2] == 4:
62
+ msk = img[:,:,3:]
63
+ img = img[:,:,:3]
64
+ else:
65
+ msk = cv2.imread(args.input_mask, cv2.IMREAD_GRAYSCALE)
66
+
67
+ img_new, msk_new = process_img(img, msk)
68
+
69
+ img_name = Path(args.input_image).stem
70
+
71
+ cv2.imwrite(os.path.join(args.out_path, img_name + '.png'), img_new)
72
+ cv2.imwrite(os.path.join(args.out_path, img_name + '_mask.png'), msk_new)
73
+
74
+ if __name__ == "__main__":
75
+ main()
PIFu/apps/eval.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import glob
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from lib.model import *
6
+ from lib.train_util import *
7
+ from lib.sample_util import *
8
+ from lib.mesh_util import *
9
+ # from lib.options import BaseOptions
10
+ from torch.utils.data import DataLoader
11
+ import torch
12
+ import numpy as np
13
+ import json
14
+ import time
15
+ import sys
16
+ import os
17
+
18
+ sys.path.insert(0, os.path.abspath(
19
+ os.path.join(os.path.dirname(__file__), '..')))
20
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
21
+
22
+
23
+ # # get options
24
+ # opt = BaseOptions().parse()
25
+
26
+ class Evaluator:
27
+ def __init__(self, opt, projection_mode='orthogonal'):
28
+ self.opt = opt
29
+ self.load_size = self.opt.loadSize
30
+ self.to_tensor = transforms.Compose([
31
+ transforms.Resize(self.load_size),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
34
+ ])
35
+ # set cuda
36
+ cuda = torch.device(
37
+ 'cuda:%d' % opt.gpu_id) if torch.cuda.is_available() else torch.device('cpu')
38
+
39
+ # create net
40
+ netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
41
+ print('Using Network: ', netG.name)
42
+
43
+ if opt.load_netG_checkpoint_path:
44
+ netG.load_state_dict(torch.load(
45
+ opt.load_netG_checkpoint_path, map_location=cuda))
46
+
47
+ if opt.load_netC_checkpoint_path is not None:
48
+ print('loading for net C ...', opt.load_netC_checkpoint_path)
49
+ netC = ResBlkPIFuNet(opt).to(device=cuda)
50
+ netC.load_state_dict(torch.load(
51
+ opt.load_netC_checkpoint_path, map_location=cuda))
52
+ else:
53
+ netC = None
54
+
55
+ os.makedirs(opt.results_path, exist_ok=True)
56
+ os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
57
+
58
+ opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
59
+ with open(opt_log, 'w') as outfile:
60
+ outfile.write(json.dumps(vars(opt), indent=2))
61
+
62
+ self.cuda = cuda
63
+ self.netG = netG
64
+ self.netC = netC
65
+
66
+ def load_image(self, image_path, mask_path):
67
+ # Name
68
+ img_name = os.path.splitext(os.path.basename(image_path))[0]
69
+ # Calib
70
+ B_MIN = np.array([-1, -1, -1])
71
+ B_MAX = np.array([1, 1, 1])
72
+ projection_matrix = np.identity(4)
73
+ projection_matrix[1, 1] = -1
74
+ calib = torch.Tensor(projection_matrix).float()
75
+ # Mask
76
+ mask = Image.open(mask_path).convert('L')
77
+ mask = transforms.Resize(self.load_size)(mask)
78
+ mask = transforms.ToTensor()(mask).float()
79
+ # image
80
+ image = Image.open(image_path).convert('RGB')
81
+ image = self.to_tensor(image)
82
+ image = mask.expand_as(image) * image
83
+ return {
84
+ 'name': img_name,
85
+ 'img': image.unsqueeze(0),
86
+ 'calib': calib.unsqueeze(0),
87
+ 'mask': mask.unsqueeze(0),
88
+ 'b_min': B_MIN,
89
+ 'b_max': B_MAX,
90
+ }
91
+
92
+ def load_image_from_memory(self, image_path, mask_path, img_name):
93
+ # Calib
94
+ B_MIN = np.array([-1, -1, -1])
95
+ B_MAX = np.array([1, 1, 1])
96
+ projection_matrix = np.identity(4)
97
+ projection_matrix[1, 1] = -1
98
+ calib = torch.Tensor(projection_matrix).float()
99
+ # Mask
100
+ mask = Image.fromarray(mask_path).convert('L')
101
+ mask = transforms.Resize(self.load_size)(mask)
102
+ mask = transforms.ToTensor()(mask).float()
103
+ # image
104
+ image = Image.fromarray(image_path).convert('RGB')
105
+ image = self.to_tensor(image)
106
+ image = mask.expand_as(image) * image
107
+ return {
108
+ 'name': img_name,
109
+ 'img': image.unsqueeze(0),
110
+ 'calib': calib.unsqueeze(0),
111
+ 'mask': mask.unsqueeze(0),
112
+ 'b_min': B_MIN,
113
+ 'b_max': B_MAX,
114
+ }
115
+
116
+ def eval(self, data, use_octree=False):
117
+ '''
118
+ Evaluate a data point
119
+ :param data: a dict containing at least ['name'], ['image'], ['calib'], ['b_min'] and ['b_max'] tensors.
120
+ :return:
121
+ '''
122
+ opt = self.opt
123
+ with torch.no_grad():
124
+ self.netG.eval()
125
+ if self.netC:
126
+ self.netC.eval()
127
+ save_path = '%s/%s/result_%s.obj' % (
128
+ opt.results_path, opt.name, data['name'])
129
+ if self.netC:
130
+ gen_mesh_color(opt, self.netG, self.netC, self.cuda,
131
+ data, save_path, use_octree=use_octree)
132
+ else:
133
+ gen_mesh(opt, self.netG, self.cuda, data,
134
+ save_path, use_octree=use_octree)
135
+
136
+
137
+ if __name__ == '__main__':
138
+ evaluator = Evaluator(opt)
139
+
140
+ test_images = glob.glob(os.path.join(opt.test_folder_path, '*'))
141
+ test_images = [f for f in test_images if (
142
+ 'png' in f or 'jpg' in f) and (not 'mask' in f)]
143
+ test_masks = [f[:-4]+'_mask.png' for f in test_images]
144
+
145
+ print("num; ", len(test_masks))
146
+
147
+ for image_path, mask_path in tqdm.tqdm(zip(test_images, test_masks)):
148
+ try:
149
+ print(image_path, mask_path)
150
+ data = evaluator.load_image(image_path, mask_path)
151
+ evaluator.eval(data, True)
152
+ except Exception as e:
153
+ print("error:", e.args)
PIFu/apps/eval_spaces.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
+
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import DataLoader
12
+
13
+ from lib.options import BaseOptions
14
+ from lib.mesh_util import *
15
+ from lib.sample_util import *
16
+ from lib.train_util import *
17
+ from lib.model import *
18
+
19
+ from PIL import Image
20
+ import torchvision.transforms as transforms
21
+
22
+ import trimesh
23
+ from datetime import datetime
24
+
25
+ # get options
26
+ opt = BaseOptions().parse()
27
+
28
+ class Evaluator:
29
+ def __init__(self, opt, projection_mode='orthogonal'):
30
+ self.opt = opt
31
+ self.load_size = self.opt.loadSize
32
+ self.to_tensor = transforms.Compose([
33
+ transforms.Resize(self.load_size),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
36
+ ])
37
+ # set cuda
38
+ cuda = torch.device('cuda:%d' % opt.gpu_id) if torch.cuda.is_available() else torch.device('cpu')
39
+ print("CUDDAAAAA ???", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NO ONLY CPU")
40
+
41
+ # create net
42
+ netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
43
+ print('Using Network: ', netG.name)
44
+
45
+ if opt.load_netG_checkpoint_path:
46
+ netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
47
+
48
+ if opt.load_netC_checkpoint_path is not None:
49
+ print('loading for net C ...', opt.load_netC_checkpoint_path)
50
+ netC = ResBlkPIFuNet(opt).to(device=cuda)
51
+ netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda))
52
+ else:
53
+ netC = None
54
+
55
+ os.makedirs(opt.results_path, exist_ok=True)
56
+ os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
57
+
58
+ opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
59
+ with open(opt_log, 'w') as outfile:
60
+ outfile.write(json.dumps(vars(opt), indent=2))
61
+
62
+ self.cuda = cuda
63
+ self.netG = netG
64
+ self.netC = netC
65
+
66
+ def load_image(self, image_path, mask_path):
67
+ # Name
68
+ img_name = os.path.splitext(os.path.basename(image_path))[0]
69
+ # Calib
70
+ B_MIN = np.array([-1, -1, -1])
71
+ B_MAX = np.array([1, 1, 1])
72
+ projection_matrix = np.identity(4)
73
+ projection_matrix[1, 1] = -1
74
+ calib = torch.Tensor(projection_matrix).float()
75
+ # Mask
76
+ mask = Image.open(mask_path).convert('L')
77
+ mask = transforms.Resize(self.load_size)(mask)
78
+ mask = transforms.ToTensor()(mask).float()
79
+ # image
80
+ image = Image.open(image_path).convert('RGB')
81
+ image = self.to_tensor(image)
82
+ image = mask.expand_as(image) * image
83
+ return {
84
+ 'name': img_name,
85
+ 'img': image.unsqueeze(0),
86
+ 'calib': calib.unsqueeze(0),
87
+ 'mask': mask.unsqueeze(0),
88
+ 'b_min': B_MIN,
89
+ 'b_max': B_MAX,
90
+ }
91
+
92
+ def eval(self, data, use_octree=False):
93
+ '''
94
+ Evaluate a data point
95
+ :param data: a dict containing at least ['name'], ['image'], ['calib'], ['b_min'] and ['b_max'] tensors.
96
+ :return:
97
+ '''
98
+ opt = self.opt
99
+ with torch.no_grad():
100
+ self.netG.eval()
101
+ if self.netC:
102
+ self.netC.eval()
103
+ save_path = '%s/%s/result_%s.obj' % (opt.results_path, opt.name, data['name'])
104
+ if self.netC:
105
+ gen_mesh_color(opt, self.netG, self.netC, self.cuda, data, save_path, use_octree=use_octree)
106
+ else:
107
+ gen_mesh(opt, self.netG, self.cuda, data, save_path, use_octree=use_octree)
108
+
109
+
110
+ if __name__ == '__main__':
111
+ evaluator = Evaluator(opt)
112
+
113
+ results_path = opt.results_path
114
+ name = opt.name
115
+ test_image_path = opt.img_path
116
+ test_mask_path = test_image_path[:-4] +'_mask.png'
117
+ test_img_name = os.path.splitext(os.path.basename(test_image_path))[0]
118
+ print("test_image: ", test_image_path)
119
+ print("test_mask: ", test_mask_path)
120
+
121
+ try:
122
+ time = datetime.now()
123
+ print("evaluating" , time)
124
+ data = evaluator.load_image(test_image_path, test_mask_path)
125
+ evaluator.eval(data, False)
126
+ print("done evaluating" , datetime.now() - time)
127
+ except Exception as e:
128
+ print("error:", e.args)
129
+
130
+ try:
131
+ mesh = trimesh.load(f'{results_path}/{name}/result_{test_img_name}.obj')
132
+ mesh.apply_transform([[1, 0, 0, 0],
133
+ [0, 1, 0, 0],
134
+ [0, 0, -1, 0],
135
+ [0, 0, 0, 1]])
136
+ mesh.export(file_obj=f'{results_path}/{name}/result_{test_img_name}.glb')
137
+ except Exception as e:
138
+ print("error generating MESH", e)
PIFu/apps/prt_util.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import trimesh
3
+ import numpy as np
4
+ import math
5
+ from scipy.special import sph_harm
6
+ import argparse
7
+ from tqdm import tqdm
8
+
9
+ def factratio(N, D):
10
+ if N >= D:
11
+ prod = 1.0
12
+ for i in range(D+1, N+1):
13
+ prod *= i
14
+ return prod
15
+ else:
16
+ prod = 1.0
17
+ for i in range(N+1, D+1):
18
+ prod *= i
19
+ return 1.0 / prod
20
+
21
+ def KVal(M, L):
22
+ return math.sqrt(((2 * L + 1) / (4 * math.pi)) * (factratio(L - M, L + M)))
23
+
24
+ def AssociatedLegendre(M, L, x):
25
+ if M < 0 or M > L or np.max(np.abs(x)) > 1.0:
26
+ return np.zeros_like(x)
27
+
28
+ pmm = np.ones_like(x)
29
+ if M > 0:
30
+ somx2 = np.sqrt((1.0 + x) * (1.0 - x))
31
+ fact = 1.0
32
+ for i in range(1, M+1):
33
+ pmm = -pmm * fact * somx2
34
+ fact = fact + 2
35
+
36
+ if L == M:
37
+ return pmm
38
+ else:
39
+ pmmp1 = x * (2 * M + 1) * pmm
40
+ if L == M+1:
41
+ return pmmp1
42
+ else:
43
+ pll = np.zeros_like(x)
44
+ for i in range(M+2, L+1):
45
+ pll = (x * (2 * i - 1) * pmmp1 - (i + M - 1) * pmm) / (i - M)
46
+ pmm = pmmp1
47
+ pmmp1 = pll
48
+ return pll
49
+
50
+ def SphericalHarmonic(M, L, theta, phi):
51
+ if M > 0:
52
+ return math.sqrt(2.0) * KVal(M, L) * np.cos(M * phi) * AssociatedLegendre(M, L, np.cos(theta))
53
+ elif M < 0:
54
+ return math.sqrt(2.0) * KVal(-M, L) * np.sin(-M * phi) * AssociatedLegendre(-M, L, np.cos(theta))
55
+ else:
56
+ return KVal(0, L) * AssociatedLegendre(0, L, np.cos(theta))
57
+
58
+ def save_obj(mesh_path, verts):
59
+ file = open(mesh_path, 'w')
60
+ for v in verts:
61
+ file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
62
+ file.close()
63
+
64
+ def sampleSphericalDirections(n):
65
+ xv = np.random.rand(n,n)
66
+ yv = np.random.rand(n,n)
67
+ theta = np.arccos(1-2 * xv)
68
+ phi = 2.0 * math.pi * yv
69
+
70
+ phi = phi.reshape(-1)
71
+ theta = theta.reshape(-1)
72
+
73
+ vx = -np.sin(theta) * np.cos(phi)
74
+ vy = -np.sin(theta) * np.sin(phi)
75
+ vz = np.cos(theta)
76
+ return np.stack([vx, vy, vz], 1), phi, theta
77
+
78
+ def getSHCoeffs(order, phi, theta):
79
+ shs = []
80
+ for n in range(0, order+1):
81
+ for m in range(-n,n+1):
82
+ s = SphericalHarmonic(m, n, theta, phi)
83
+ shs.append(s)
84
+
85
+ return np.stack(shs, 1)
86
+
87
+ def computePRT(mesh_path, n, order):
88
+ mesh = trimesh.load(mesh_path, process=False)
89
+ vectors_orig, phi, theta = sampleSphericalDirections(n)
90
+ SH_orig = getSHCoeffs(order, phi, theta)
91
+
92
+ w = 4.0 * math.pi / (n*n)
93
+
94
+ origins = mesh.vertices
95
+ normals = mesh.vertex_normals
96
+ n_v = origins.shape[0]
97
+
98
+ origins = np.repeat(origins[:,None], n, axis=1).reshape(-1,3)
99
+ normals = np.repeat(normals[:,None], n, axis=1).reshape(-1,3)
100
+ PRT_all = None
101
+ for i in tqdm(range(n)):
102
+ SH = np.repeat(SH_orig[None,(i*n):((i+1)*n)], n_v, axis=0).reshape(-1,SH_orig.shape[1])
103
+ vectors = np.repeat(vectors_orig[None,(i*n):((i+1)*n)], n_v, axis=0).reshape(-1,3)
104
+
105
+ dots = (vectors * normals).sum(1)
106
+ front = (dots > 0.0)
107
+
108
+ delta = 1e-3*min(mesh.bounding_box.extents)
109
+ hits = mesh.ray.intersects_any(origins + delta * normals, vectors)
110
+ nohits = np.logical_and(front, np.logical_not(hits))
111
+
112
+ PRT = (nohits.astype(np.float) * dots)[:,None] * SH
113
+
114
+ if PRT_all is not None:
115
+ PRT_all += (PRT.reshape(-1, n, SH.shape[1]).sum(1))
116
+ else:
117
+ PRT_all = (PRT.reshape(-1, n, SH.shape[1]).sum(1))
118
+
119
+ PRT = w * PRT_all
120
+
121
+ # NOTE: trimesh sometimes break the original vertex order, but topology will not change.
122
+ # when loading PRT in other program, use the triangle list from trimesh.
123
+ return PRT, mesh.faces
124
+
125
+ def testPRT(dir_path, n=40):
126
+ if dir_path[-1] == '/':
127
+ dir_path = dir_path[:-1]
128
+ sub_name = dir_path.split('/')[-1][:-4]
129
+ obj_path = os.path.join(dir_path, sub_name + '_100k.obj')
130
+ os.makedirs(os.path.join(dir_path, 'bounce'), exist_ok=True)
131
+
132
+ PRT, F = computePRT(obj_path, n, 2)
133
+ np.savetxt(os.path.join(dir_path, 'bounce', 'bounce0.txt'), PRT, fmt='%.8f')
134
+ np.save(os.path.join(dir_path, 'bounce', 'face.npy'), F)
135
+
136
+ if __name__ == '__main__':
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument('-i', '--input', type=str, default='/home/shunsuke/Downloads/rp_dennis_posed_004_OBJ')
139
+ parser.add_argument('-n', '--n_sample', type=int, default=40, help='squared root of number of sampling. the higher, the more accurate, but slower')
140
+ args = parser.parse_args()
141
+
142
+ testPRT(args.input)
PIFu/apps/render_data.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from data.config import raw_dataset, render_dataset, archive_dataset, model_list, zip_path
2
+
3
+ from lib.renderer.camera import Camera
4
+ import numpy as np
5
+ from lib.renderer.mesh import load_obj_mesh, compute_tangent, compute_normal, load_obj_mesh_mtl
6
+ from lib.renderer.camera import Camera
7
+ import os
8
+ import cv2
9
+ import time
10
+ import math
11
+ import random
12
+ import pyexr
13
+ import argparse
14
+ from tqdm import tqdm
15
+
16
+
17
+ def make_rotate(rx, ry, rz):
18
+ sinX = np.sin(rx)
19
+ sinY = np.sin(ry)
20
+ sinZ = np.sin(rz)
21
+
22
+ cosX = np.cos(rx)
23
+ cosY = np.cos(ry)
24
+ cosZ = np.cos(rz)
25
+
26
+ Rx = np.zeros((3,3))
27
+ Rx[0, 0] = 1.0
28
+ Rx[1, 1] = cosX
29
+ Rx[1, 2] = -sinX
30
+ Rx[2, 1] = sinX
31
+ Rx[2, 2] = cosX
32
+
33
+ Ry = np.zeros((3,3))
34
+ Ry[0, 0] = cosY
35
+ Ry[0, 2] = sinY
36
+ Ry[1, 1] = 1.0
37
+ Ry[2, 0] = -sinY
38
+ Ry[2, 2] = cosY
39
+
40
+ Rz = np.zeros((3,3))
41
+ Rz[0, 0] = cosZ
42
+ Rz[0, 1] = -sinZ
43
+ Rz[1, 0] = sinZ
44
+ Rz[1, 1] = cosZ
45
+ Rz[2, 2] = 1.0
46
+
47
+ R = np.matmul(np.matmul(Rz,Ry),Rx)
48
+ return R
49
+
50
+ def rotateSH(SH, R):
51
+ SHn = SH
52
+
53
+ # 1st order
54
+ SHn[1] = R[1,1]*SH[1] - R[1,2]*SH[2] + R[1,0]*SH[3]
55
+ SHn[2] = -R[2,1]*SH[1] + R[2,2]*SH[2] - R[2,0]*SH[3]
56
+ SHn[3] = R[0,1]*SH[1] - R[0,2]*SH[2] + R[0,0]*SH[3]
57
+
58
+ # 2nd order
59
+ SHn[4:,0] = rotateBand2(SH[4:,0],R)
60
+ SHn[4:,1] = rotateBand2(SH[4:,1],R)
61
+ SHn[4:,2] = rotateBand2(SH[4:,2],R)
62
+
63
+ return SHn
64
+
65
+ def rotateBand2(x, R):
66
+ s_c3 = 0.94617469575
67
+ s_c4 = -0.31539156525
68
+ s_c5 = 0.54627421529
69
+
70
+ s_c_scale = 1.0/0.91529123286551084
71
+ s_c_scale_inv = 0.91529123286551084
72
+
73
+ s_rc2 = 1.5853309190550713*s_c_scale
74
+ s_c4_div_c3 = s_c4/s_c3
75
+ s_c4_div_c3_x2 = (s_c4/s_c3)*2.0
76
+
77
+ s_scale_dst2 = s_c3 * s_c_scale_inv
78
+ s_scale_dst4 = s_c5 * s_c_scale_inv
79
+
80
+ sh0 = x[3] + x[4] + x[4] - x[1]
81
+ sh1 = x[0] + s_rc2*x[2] + x[3] + x[4]
82
+ sh2 = x[0]
83
+ sh3 = -x[3]
84
+ sh4 = -x[1]
85
+
86
+ r2x = R[0][0] + R[0][1]
87
+ r2y = R[1][0] + R[1][1]
88
+ r2z = R[2][0] + R[2][1]
89
+
90
+ r3x = R[0][0] + R[0][2]
91
+ r3y = R[1][0] + R[1][2]
92
+ r3z = R[2][0] + R[2][2]
93
+
94
+ r4x = R[0][1] + R[0][2]
95
+ r4y = R[1][1] + R[1][2]
96
+ r4z = R[2][1] + R[2][2]
97
+
98
+ sh0_x = sh0 * R[0][0]
99
+ sh0_y = sh0 * R[1][0]
100
+ d0 = sh0_x * R[1][0]
101
+ d1 = sh0_y * R[2][0]
102
+ d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3)
103
+ d3 = sh0_x * R[2][0]
104
+ d4 = sh0_x * R[0][0] - sh0_y * R[1][0]
105
+
106
+ sh1_x = sh1 * R[0][2]
107
+ sh1_y = sh1 * R[1][2]
108
+ d0 += sh1_x * R[1][2]
109
+ d1 += sh1_y * R[2][2]
110
+ d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3)
111
+ d3 += sh1_x * R[2][2]
112
+ d4 += sh1_x * R[0][2] - sh1_y * R[1][2]
113
+
114
+ sh2_x = sh2 * r2x
115
+ sh2_y = sh2 * r2y
116
+ d0 += sh2_x * r2y
117
+ d1 += sh2_y * r2z
118
+ d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2)
119
+ d3 += sh2_x * r2z
120
+ d4 += sh2_x * r2x - sh2_y * r2y
121
+
122
+ sh3_x = sh3 * r3x
123
+ sh3_y = sh3 * r3y
124
+ d0 += sh3_x * r3y
125
+ d1 += sh3_y * r3z
126
+ d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2)
127
+ d3 += sh3_x * r3z
128
+ d4 += sh3_x * r3x - sh3_y * r3y
129
+
130
+ sh4_x = sh4 * r4x
131
+ sh4_y = sh4 * r4y
132
+ d0 += sh4_x * r4y
133
+ d1 += sh4_y * r4z
134
+ d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2)
135
+ d3 += sh4_x * r4z
136
+ d4 += sh4_x * r4x - sh4_y * r4y
137
+
138
+ dst = x
139
+ dst[0] = d0
140
+ dst[1] = -d1
141
+ dst[2] = d2 * s_scale_dst2
142
+ dst[3] = -d3
143
+ dst[4] = d4 * s_scale_dst4
144
+
145
+ return dst
146
+
147
+ def render_prt_ortho(out_path, folder_name, subject_name, shs, rndr, rndr_uv, im_size, angl_step=4, n_light=1, pitch=[0]):
148
+ cam = Camera(width=im_size, height=im_size)
149
+ cam.ortho_ratio = 0.4 * (512 / im_size)
150
+ cam.near = -100
151
+ cam.far = 100
152
+ cam.sanity_check()
153
+
154
+ # set path for obj, prt
155
+ mesh_file = os.path.join(folder_name, subject_name + '_100k.obj')
156
+ if not os.path.exists(mesh_file):
157
+ print('ERROR: obj file does not exist!!', mesh_file)
158
+ return
159
+ prt_file = os.path.join(folder_name, 'bounce', 'bounce0.txt')
160
+ if not os.path.exists(prt_file):
161
+ print('ERROR: prt file does not exist!!!', prt_file)
162
+ return
163
+ face_prt_file = os.path.join(folder_name, 'bounce', 'face.npy')
164
+ if not os.path.exists(face_prt_file):
165
+ print('ERROR: face prt file does not exist!!!', prt_file)
166
+ return
167
+ text_file = os.path.join(folder_name, 'tex', subject_name + '_dif_2k.jpg')
168
+ if not os.path.exists(text_file):
169
+ print('ERROR: dif file does not exist!!', text_file)
170
+ return
171
+
172
+ texture_image = cv2.imread(text_file)
173
+ texture_image = cv2.cvtColor(texture_image, cv2.COLOR_BGR2RGB)
174
+
175
+ vertices, faces, normals, faces_normals, textures, face_textures = load_obj_mesh(mesh_file, with_normal=True, with_texture=True)
176
+ vmin = vertices.min(0)
177
+ vmax = vertices.max(0)
178
+ up_axis = 1 if (vmax-vmin).argmax() == 1 else 2
179
+
180
+ vmed = np.median(vertices, 0)
181
+ vmed[up_axis] = 0.5*(vmax[up_axis]+vmin[up_axis])
182
+ y_scale = 180/(vmax[up_axis] - vmin[up_axis])
183
+
184
+ rndr.set_norm_mat(y_scale, vmed)
185
+ rndr_uv.set_norm_mat(y_scale, vmed)
186
+
187
+ tan, bitan = compute_tangent(vertices, faces, normals, textures, face_textures)
188
+ prt = np.loadtxt(prt_file)
189
+ face_prt = np.load(face_prt_file)
190
+ rndr.set_mesh(vertices, faces, normals, faces_normals, textures, face_textures, prt, face_prt, tan, bitan)
191
+ rndr.set_albedo(texture_image)
192
+
193
+ rndr_uv.set_mesh(vertices, faces, normals, faces_normals, textures, face_textures, prt, face_prt, tan, bitan)
194
+ rndr_uv.set_albedo(texture_image)
195
+
196
+ os.makedirs(os.path.join(out_path, 'GEO', 'OBJ', subject_name),exist_ok=True)
197
+ os.makedirs(os.path.join(out_path, 'PARAM', subject_name),exist_ok=True)
198
+ os.makedirs(os.path.join(out_path, 'RENDER', subject_name),exist_ok=True)
199
+ os.makedirs(os.path.join(out_path, 'MASK', subject_name),exist_ok=True)
200
+ os.makedirs(os.path.join(out_path, 'UV_RENDER', subject_name),exist_ok=True)
201
+ os.makedirs(os.path.join(out_path, 'UV_MASK', subject_name),exist_ok=True)
202
+ os.makedirs(os.path.join(out_path, 'UV_POS', subject_name),exist_ok=True)
203
+ os.makedirs(os.path.join(out_path, 'UV_NORMAL', subject_name),exist_ok=True)
204
+
205
+ if not os.path.exists(os.path.join(out_path, 'val.txt')):
206
+ f = open(os.path.join(out_path, 'val.txt'), 'w')
207
+ f.close()
208
+
209
+ # copy obj file
210
+ cmd = 'cp %s %s' % (mesh_file, os.path.join(out_path, 'GEO', 'OBJ', subject_name))
211
+ print(cmd)
212
+ os.system(cmd)
213
+
214
+ for p in pitch:
215
+ for y in tqdm(range(0, 360, angl_step)):
216
+ R = np.matmul(make_rotate(math.radians(p), 0, 0), make_rotate(0, math.radians(y), 0))
217
+ if up_axis == 2:
218
+ R = np.matmul(R, make_rotate(math.radians(90),0,0))
219
+
220
+ rndr.rot_matrix = R
221
+ rndr_uv.rot_matrix = R
222
+ rndr.set_camera(cam)
223
+ rndr_uv.set_camera(cam)
224
+
225
+ for j in range(n_light):
226
+ sh_id = random.randint(0,shs.shape[0]-1)
227
+ sh = shs[sh_id]
228
+ sh_angle = 0.2*np.pi*(random.random()-0.5)
229
+ sh = rotateSH(sh, make_rotate(0, sh_angle, 0).T)
230
+
231
+ dic = {'sh': sh, 'ortho_ratio': cam.ortho_ratio, 'scale': y_scale, 'center': vmed, 'R': R}
232
+
233
+ rndr.set_sh(sh)
234
+ rndr.analytic = False
235
+ rndr.use_inverse_depth = False
236
+ rndr.display()
237
+
238
+ out_all_f = rndr.get_color(0)
239
+ out_mask = out_all_f[:,:,3]
240
+ out_all_f = cv2.cvtColor(out_all_f, cv2.COLOR_RGBA2BGR)
241
+
242
+ np.save(os.path.join(out_path, 'PARAM', subject_name, '%d_%d_%02d.npy'%(y,p,j)),dic)
243
+ cv2.imwrite(os.path.join(out_path, 'RENDER', subject_name, '%d_%d_%02d.jpg'%(y,p,j)),255.0*out_all_f)
244
+ cv2.imwrite(os.path.join(out_path, 'MASK', subject_name, '%d_%d_%02d.png'%(y,p,j)),255.0*out_mask)
245
+
246
+ rndr_uv.set_sh(sh)
247
+ rndr_uv.analytic = False
248
+ rndr_uv.use_inverse_depth = False
249
+ rndr_uv.display()
250
+
251
+ uv_color = rndr_uv.get_color(0)
252
+ uv_color = cv2.cvtColor(uv_color, cv2.COLOR_RGBA2BGR)
253
+ cv2.imwrite(os.path.join(out_path, 'UV_RENDER', subject_name, '%d_%d_%02d.jpg'%(y,p,j)),255.0*uv_color)
254
+
255
+ if y == 0 and j == 0 and p == pitch[0]:
256
+ uv_pos = rndr_uv.get_color(1)
257
+ uv_mask = uv_pos[:,:,3]
258
+ cv2.imwrite(os.path.join(out_path, 'UV_MASK', subject_name, '00.png'),255.0*uv_mask)
259
+
260
+ data = {'default': uv_pos[:,:,:3]} # default is a reserved name
261
+ pyexr.write(os.path.join(out_path, 'UV_POS', subject_name, '00.exr'), data)
262
+
263
+ uv_nml = rndr_uv.get_color(2)
264
+ uv_nml = cv2.cvtColor(uv_nml, cv2.COLOR_RGBA2BGR)
265
+ cv2.imwrite(os.path.join(out_path, 'UV_NORMAL', subject_name, '00.png'),255.0*uv_nml)
266
+
267
+
268
+ if __name__ == '__main__':
269
+ shs = np.load('./env_sh.npy')
270
+
271
+ parser = argparse.ArgumentParser()
272
+ parser.add_argument('-i', '--input', type=str, default='/home/shunsuke/Downloads/rp_dennis_posed_004_OBJ')
273
+ parser.add_argument('-o', '--out_dir', type=str, default='/home/shunsuke/Documents/hf_human')
274
+ parser.add_argument('-m', '--ms_rate', type=int, default=1, help='higher ms rate results in less aliased output. MESA renderer only supports ms_rate=1.')
275
+ parser.add_argument('-e', '--egl', action='store_true', help='egl rendering option. use this when rendering with headless server with NVIDIA GPU')
276
+ parser.add_argument('-s', '--size', type=int, default=512, help='rendering image size')
277
+ args = parser.parse_args()
278
+
279
+ # NOTE: GL context has to be created before any other OpenGL function loads.
280
+ from lib.renderer.gl.init_gl import initialize_GL_context
281
+ initialize_GL_context(width=args.size, height=args.size, egl=args.egl)
282
+
283
+ from lib.renderer.gl.prt_render import PRTRender
284
+ rndr = PRTRender(width=args.size, height=args.size, ms_rate=args.ms_rate, egl=args.egl)
285
+ rndr_uv = PRTRender(width=args.size, height=args.size, uv_mode=True, egl=args.egl)
286
+
287
+ if args.input[-1] == '/':
288
+ args.input = args.input[:-1]
289
+ subject_name = args.input.split('/')[-1][:-4]
290
+ render_prt_ortho(args.out_dir, args.input, subject_name, shs, rndr, rndr_uv, args.size, 1, 1, pitch=[0])
PIFu/apps/train_color.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
+
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ import cv2
11
+ import random
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.utils.data import DataLoader
15
+ from tqdm import tqdm
16
+
17
+ from lib.options import BaseOptions
18
+ from lib.mesh_util import *
19
+ from lib.sample_util import *
20
+ from lib.train_util import *
21
+ from lib.data import *
22
+ from lib.model import *
23
+ from lib.geometry import index
24
+
25
+ # get options
26
+ opt = BaseOptions().parse()
27
+
28
+ def train_color(opt):
29
+ # set cuda
30
+ cuda = torch.device('cuda:%d' % opt.gpu_id)
31
+
32
+ train_dataset = TrainDataset(opt, phase='train')
33
+ test_dataset = TrainDataset(opt, phase='test')
34
+
35
+ projection_mode = train_dataset.projection_mode
36
+
37
+ # create data loader
38
+ train_data_loader = DataLoader(train_dataset,
39
+ batch_size=opt.batch_size, shuffle=not opt.serial_batches,
40
+ num_workers=opt.num_threads, pin_memory=opt.pin_memory)
41
+
42
+ print('train data size: ', len(train_data_loader))
43
+
44
+ # NOTE: batch size should be 1 and use all the points for evaluation
45
+ test_data_loader = DataLoader(test_dataset,
46
+ batch_size=1, shuffle=False,
47
+ num_workers=opt.num_threads, pin_memory=opt.pin_memory)
48
+ print('test data size: ', len(test_data_loader))
49
+
50
+ # create net
51
+ netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
52
+
53
+ lr = opt.learning_rate
54
+
55
+ # Always use resnet for color regression
56
+ netC = ResBlkPIFuNet(opt).to(device=cuda)
57
+ optimizerC = torch.optim.Adam(netC.parameters(), lr=opt.learning_rate)
58
+
59
+ def set_train():
60
+ netG.eval()
61
+ netC.train()
62
+
63
+ def set_eval():
64
+ netG.eval()
65
+ netC.eval()
66
+
67
+ print('Using NetworkG: ', netG.name, 'networkC: ', netC.name)
68
+
69
+ # load checkpoints
70
+ if opt.load_netG_checkpoint_path is not None:
71
+ print('loading for net G ...', opt.load_netG_checkpoint_path)
72
+ netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
73
+ else:
74
+ model_path_G = '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name)
75
+ print('loading for net G ...', model_path_G)
76
+ netG.load_state_dict(torch.load(model_path_G, map_location=cuda))
77
+
78
+ if opt.load_netC_checkpoint_path is not None:
79
+ print('loading for net C ...', opt.load_netC_checkpoint_path)
80
+ netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda))
81
+
82
+ if opt.continue_train:
83
+ if opt.resume_epoch < 0:
84
+ model_path_C = '%s/%s/netC_latest' % (opt.checkpoints_path, opt.name)
85
+ else:
86
+ model_path_C = '%s/%s/netC_epoch_%d' % (opt.checkpoints_path, opt.name, opt.resume_epoch)
87
+
88
+ print('Resuming from ', model_path_C)
89
+ netC.load_state_dict(torch.load(model_path_C, map_location=cuda))
90
+
91
+ os.makedirs(opt.checkpoints_path, exist_ok=True)
92
+ os.makedirs(opt.results_path, exist_ok=True)
93
+ os.makedirs('%s/%s' % (opt.checkpoints_path, opt.name), exist_ok=True)
94
+ os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
95
+
96
+ opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
97
+ with open(opt_log, 'w') as outfile:
98
+ outfile.write(json.dumps(vars(opt), indent=2))
99
+
100
+ # training
101
+ start_epoch = 0 if not opt.continue_train else max(opt.resume_epoch,0)
102
+ for epoch in range(start_epoch, opt.num_epoch):
103
+ epoch_start_time = time.time()
104
+
105
+ set_train()
106
+ iter_data_time = time.time()
107
+ for train_idx, train_data in enumerate(train_data_loader):
108
+ iter_start_time = time.time()
109
+ # retrieve the data
110
+ image_tensor = train_data['img'].to(device=cuda)
111
+ calib_tensor = train_data['calib'].to(device=cuda)
112
+ color_sample_tensor = train_data['color_samples'].to(device=cuda)
113
+
114
+ image_tensor, calib_tensor = reshape_multiview_tensors(image_tensor, calib_tensor)
115
+
116
+ if opt.num_views > 1:
117
+ color_sample_tensor = reshape_sample_tensor(color_sample_tensor, opt.num_views)
118
+
119
+ rgb_tensor = train_data['rgbs'].to(device=cuda)
120
+
121
+ with torch.no_grad():
122
+ netG.filter(image_tensor)
123
+ resC, error = netC.forward(image_tensor, netG.get_im_feat(), color_sample_tensor, calib_tensor, labels=rgb_tensor)
124
+
125
+ optimizerC.zero_grad()
126
+ error.backward()
127
+ optimizerC.step()
128
+
129
+ iter_net_time = time.time()
130
+ eta = ((iter_net_time - epoch_start_time) / (train_idx + 1)) * len(train_data_loader) - (
131
+ iter_net_time - epoch_start_time)
132
+
133
+ if train_idx % opt.freq_plot == 0:
134
+ print(
135
+ 'Name: {0} | Epoch: {1} | {2}/{3} | Err: {4:.06f} | LR: {5:.06f} | dataT: {6:.05f} | netT: {7:.05f} | ETA: {8:02d}:{9:02d}'.format(
136
+ opt.name, epoch, train_idx, len(train_data_loader),
137
+ error.item(),
138
+ lr,
139
+ iter_start_time - iter_data_time,
140
+ iter_net_time - iter_start_time, int(eta // 60),
141
+ int(eta - 60 * (eta // 60))))
142
+
143
+ if train_idx % opt.freq_save == 0 and train_idx != 0:
144
+ torch.save(netC.state_dict(), '%s/%s/netC_latest' % (opt.checkpoints_path, opt.name))
145
+ torch.save(netC.state_dict(), '%s/%s/netC_epoch_%d' % (opt.checkpoints_path, opt.name, epoch))
146
+
147
+ if train_idx % opt.freq_save_ply == 0:
148
+ save_path = '%s/%s/pred_col.ply' % (opt.results_path, opt.name)
149
+ rgb = resC[0].transpose(0, 1).cpu() * 0.5 + 0.5
150
+ points = color_sample_tensor[0].transpose(0, 1).cpu()
151
+ save_samples_rgb(save_path, points.detach().numpy(), rgb.detach().numpy())
152
+
153
+ iter_data_time = time.time()
154
+
155
+ #### test
156
+ with torch.no_grad():
157
+ set_eval()
158
+
159
+ if not opt.no_num_eval:
160
+ test_losses = {}
161
+ print('calc error (test) ...')
162
+ test_color_error = calc_error_color(opt, netG, netC, cuda, test_dataset, 100)
163
+ print('eval test | color error:', test_color_error)
164
+ test_losses['test_color'] = test_color_error
165
+
166
+ print('calc error (train) ...')
167
+ train_dataset.is_train = False
168
+ train_color_error = calc_error_color(opt, netG, netC, cuda, train_dataset, 100)
169
+ train_dataset.is_train = True
170
+ print('eval train | color error:', train_color_error)
171
+ test_losses['train_color'] = train_color_error
172
+
173
+ if not opt.no_gen_mesh:
174
+ print('generate mesh (test) ...')
175
+ for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
176
+ test_data = random.choice(test_dataset)
177
+ save_path = '%s/%s/test_eval_epoch%d_%s.obj' % (
178
+ opt.results_path, opt.name, epoch, test_data['name'])
179
+ gen_mesh_color(opt, netG, netC, cuda, test_data, save_path)
180
+
181
+ print('generate mesh (train) ...')
182
+ train_dataset.is_train = False
183
+ for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
184
+ train_data = random.choice(train_dataset)
185
+ save_path = '%s/%s/train_eval_epoch%d_%s.obj' % (
186
+ opt.results_path, opt.name, epoch, train_data['name'])
187
+ gen_mesh_color(opt, netG, netC, cuda, train_data, save_path)
188
+ train_dataset.is_train = True
189
+
190
+ if __name__ == '__main__':
191
+ train_color(opt)
PIFu/apps/train_shape.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5
+ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
+
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ import cv2
11
+ import random
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from tqdm import tqdm
15
+
16
+ from lib.options import BaseOptions
17
+ from lib.mesh_util import *
18
+ from lib.sample_util import *
19
+ from lib.train_util import *
20
+ from lib.data import *
21
+ from lib.model import *
22
+ from lib.geometry import index
23
+
24
+ # get options
25
+ opt = BaseOptions().parse()
26
+
27
+ def train(opt):
28
+ # set cuda
29
+ cuda = torch.device('cuda:%d' % opt.gpu_id)
30
+
31
+ train_dataset = TrainDataset(opt, phase='train')
32
+ test_dataset = TrainDataset(opt, phase='test')
33
+
34
+ projection_mode = train_dataset.projection_mode
35
+
36
+ # create data loader
37
+ train_data_loader = DataLoader(train_dataset,
38
+ batch_size=opt.batch_size, shuffle=not opt.serial_batches,
39
+ num_workers=opt.num_threads, pin_memory=opt.pin_memory)
40
+
41
+ print('train data size: ', len(train_data_loader))
42
+
43
+ # NOTE: batch size should be 1 and use all the points for evaluation
44
+ test_data_loader = DataLoader(test_dataset,
45
+ batch_size=1, shuffle=False,
46
+ num_workers=opt.num_threads, pin_memory=opt.pin_memory)
47
+ print('test data size: ', len(test_data_loader))
48
+
49
+ # create net
50
+ netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
51
+ optimizerG = torch.optim.RMSprop(netG.parameters(), lr=opt.learning_rate, momentum=0, weight_decay=0)
52
+ lr = opt.learning_rate
53
+ print('Using Network: ', netG.name)
54
+
55
+ def set_train():
56
+ netG.train()
57
+
58
+ def set_eval():
59
+ netG.eval()
60
+
61
+ # load checkpoints
62
+ if opt.load_netG_checkpoint_path is not None:
63
+ print('loading for net G ...', opt.load_netG_checkpoint_path)
64
+ netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
65
+
66
+ if opt.continue_train:
67
+ if opt.resume_epoch < 0:
68
+ model_path = '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name)
69
+ else:
70
+ model_path = '%s/%s/netG_epoch_%d' % (opt.checkpoints_path, opt.name, opt.resume_epoch)
71
+ print('Resuming from ', model_path)
72
+ netG.load_state_dict(torch.load(model_path, map_location=cuda))
73
+
74
+ os.makedirs(opt.checkpoints_path, exist_ok=True)
75
+ os.makedirs(opt.results_path, exist_ok=True)
76
+ os.makedirs('%s/%s' % (opt.checkpoints_path, opt.name), exist_ok=True)
77
+ os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
78
+
79
+ opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
80
+ with open(opt_log, 'w') as outfile:
81
+ outfile.write(json.dumps(vars(opt), indent=2))
82
+
83
+ # training
84
+ start_epoch = 0 if not opt.continue_train else max(opt.resume_epoch,0)
85
+ for epoch in range(start_epoch, opt.num_epoch):
86
+ epoch_start_time = time.time()
87
+
88
+ set_train()
89
+ iter_data_time = time.time()
90
+ for train_idx, train_data in enumerate(train_data_loader):
91
+ iter_start_time = time.time()
92
+
93
+ # retrieve the data
94
+ image_tensor = train_data['img'].to(device=cuda)
95
+ calib_tensor = train_data['calib'].to(device=cuda)
96
+ sample_tensor = train_data['samples'].to(device=cuda)
97
+
98
+ image_tensor, calib_tensor = reshape_multiview_tensors(image_tensor, calib_tensor)
99
+
100
+ if opt.num_views > 1:
101
+ sample_tensor = reshape_sample_tensor(sample_tensor, opt.num_views)
102
+
103
+ label_tensor = train_data['labels'].to(device=cuda)
104
+
105
+ res, error = netG.forward(image_tensor, sample_tensor, calib_tensor, labels=label_tensor)
106
+
107
+ optimizerG.zero_grad()
108
+ error.backward()
109
+ optimizerG.step()
110
+
111
+ iter_net_time = time.time()
112
+ eta = ((iter_net_time - epoch_start_time) / (train_idx + 1)) * len(train_data_loader) - (
113
+ iter_net_time - epoch_start_time)
114
+
115
+ if train_idx % opt.freq_plot == 0:
116
+ print(
117
+ 'Name: {0} | Epoch: {1} | {2}/{3} | Err: {4:.06f} | LR: {5:.06f} | Sigma: {6:.02f} | dataT: {7:.05f} | netT: {8:.05f} | ETA: {9:02d}:{10:02d}'.format(
118
+ opt.name, epoch, train_idx, len(train_data_loader), error.item(), lr, opt.sigma,
119
+ iter_start_time - iter_data_time,
120
+ iter_net_time - iter_start_time, int(eta // 60),
121
+ int(eta - 60 * (eta // 60))))
122
+
123
+ if train_idx % opt.freq_save == 0 and train_idx != 0:
124
+ torch.save(netG.state_dict(), '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name))
125
+ torch.save(netG.state_dict(), '%s/%s/netG_epoch_%d' % (opt.checkpoints_path, opt.name, epoch))
126
+
127
+ if train_idx % opt.freq_save_ply == 0:
128
+ save_path = '%s/%s/pred.ply' % (opt.results_path, opt.name)
129
+ r = res[0].cpu()
130
+ points = sample_tensor[0].transpose(0, 1).cpu()
131
+ save_samples_truncted_prob(save_path, points.detach().numpy(), r.detach().numpy())
132
+
133
+ iter_data_time = time.time()
134
+
135
+ # update learning rate
136
+ lr = adjust_learning_rate(optimizerG, epoch, lr, opt.schedule, opt.gamma)
137
+
138
+ #### test
139
+ with torch.no_grad():
140
+ set_eval()
141
+
142
+ if not opt.no_num_eval:
143
+ test_losses = {}
144
+ print('calc error (test) ...')
145
+ test_errors = calc_error(opt, netG, cuda, test_dataset, 100)
146
+ print('eval test MSE: {0:06f} IOU: {1:06f} prec: {2:06f} recall: {3:06f}'.format(*test_errors))
147
+ MSE, IOU, prec, recall = test_errors
148
+ test_losses['MSE(test)'] = MSE
149
+ test_losses['IOU(test)'] = IOU
150
+ test_losses['prec(test)'] = prec
151
+ test_losses['recall(test)'] = recall
152
+
153
+ print('calc error (train) ...')
154
+ train_dataset.is_train = False
155
+ train_errors = calc_error(opt, netG, cuda, train_dataset, 100)
156
+ train_dataset.is_train = True
157
+ print('eval train MSE: {0:06f} IOU: {1:06f} prec: {2:06f} recall: {3:06f}'.format(*train_errors))
158
+ MSE, IOU, prec, recall = train_errors
159
+ test_losses['MSE(train)'] = MSE
160
+ test_losses['IOU(train)'] = IOU
161
+ test_losses['prec(train)'] = prec
162
+ test_losses['recall(train)'] = recall
163
+
164
+ if not opt.no_gen_mesh:
165
+ print('generate mesh (test) ...')
166
+ for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
167
+ test_data = random.choice(test_dataset)
168
+ save_path = '%s/%s/test_eval_epoch%d_%s.obj' % (
169
+ opt.results_path, opt.name, epoch, test_data['name'])
170
+ gen_mesh(opt, netG, cuda, test_data, save_path)
171
+
172
+ print('generate mesh (train) ...')
173
+ train_dataset.is_train = False
174
+ for gen_idx in tqdm(range(opt.num_gen_mesh_test)):
175
+ train_data = random.choice(train_dataset)
176
+ save_path = '%s/%s/train_eval_epoch%d_%s.obj' % (
177
+ opt.results_path, opt.name, epoch, train_data['name'])
178
+ gen_mesh(opt, netG, cuda, train_data, save_path)
179
+ train_dataset.is_train = True
180
+
181
+
182
+ if __name__ == '__main__':
183
+ train(opt)
PIFu/env_sh.npy ADDED
Binary file (52 kB). View file
 
PIFu/environment.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PIFu
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - opencv
7
+ - pytorch
8
+ - json
9
+ - pyexr
10
+ - cv2
11
+ - PIL
12
+ - skimage
13
+ - tqdm
14
+ - pyembree
15
+ - shapely
16
+ - rtree
17
+ - xxhash
18
+ - trimesh
19
+ - PyOpenGL
PIFu/lib/__init__.py ADDED
File without changes
PIFu/lib/colab_util.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import torch
4
+ from skimage.io import imread
5
+ import numpy as np
6
+ import cv2
7
+ from tqdm import tqdm_notebook as tqdm
8
+ import base64
9
+ from IPython.display import HTML
10
+
11
+ # Util function for loading meshes
12
+ from pytorch3d.io import load_objs_as_meshes
13
+
14
+ from IPython.display import HTML
15
+ from base64 import b64encode
16
+
17
+ # Data structures and functions for rendering
18
+ from pytorch3d.structures import Meshes
19
+ from pytorch3d.renderer import (
20
+ look_at_view_transform,
21
+ OpenGLOrthographicCameras,
22
+ PointLights,
23
+ DirectionalLights,
24
+ Materials,
25
+ RasterizationSettings,
26
+ MeshRenderer,
27
+ MeshRasterizer,
28
+ SoftPhongShader,
29
+ HardPhongShader,
30
+ TexturesVertex
31
+ )
32
+
33
+ def set_renderer():
34
+ # Setup
35
+ device = torch.device("cuda:0")
36
+ torch.cuda.set_device(device)
37
+
38
+ # Initialize an OpenGL perspective camera.
39
+ R, T = look_at_view_transform(2.0, 0, 180)
40
+ cameras = OpenGLOrthographicCameras(device=device, R=R, T=T)
41
+
42
+ raster_settings = RasterizationSettings(
43
+ image_size=512,
44
+ blur_radius=0.0,
45
+ faces_per_pixel=1,
46
+ bin_size = None,
47
+ max_faces_per_bin = None
48
+ )
49
+
50
+ lights = PointLights(device=device, location=((2.0, 2.0, 2.0),))
51
+
52
+ renderer = MeshRenderer(
53
+ rasterizer=MeshRasterizer(
54
+ cameras=cameras,
55
+ raster_settings=raster_settings
56
+ ),
57
+ shader=HardPhongShader(
58
+ device=device,
59
+ cameras=cameras,
60
+ lights=lights
61
+ )
62
+ )
63
+ return renderer
64
+
65
+ def get_verts_rgb_colors(obj_path):
66
+ rgb_colors = []
67
+
68
+ f = open(obj_path)
69
+ lines = f.readlines()
70
+ for line in lines:
71
+ ls = line.split(' ')
72
+ if len(ls) == 7:
73
+ rgb_colors.append(ls[-3:])
74
+
75
+ return np.array(rgb_colors, dtype='float32')[None, :, :]
76
+
77
+ def generate_video_from_obj(obj_path, video_path, renderer):
78
+ # Setup
79
+ device = torch.device("cuda:0")
80
+ torch.cuda.set_device(device)
81
+
82
+ # Load obj file
83
+ verts_rgb_colors = get_verts_rgb_colors(obj_path)
84
+ verts_rgb_colors = torch.from_numpy(verts_rgb_colors).to(device)
85
+ textures = TexturesVertex(verts_features=verts_rgb_colors)
86
+ wo_textures = TexturesVertex(verts_features=torch.ones_like(verts_rgb_colors)*0.75)
87
+
88
+ # Load obj
89
+ mesh = load_objs_as_meshes([obj_path], device=device)
90
+
91
+ # Set mesh
92
+ vers = mesh._verts_list
93
+ faces = mesh._faces_list
94
+ mesh_w_tex = Meshes(vers, faces, textures)
95
+ mesh_wo_tex = Meshes(vers, faces, wo_textures)
96
+
97
+ # create VideoWriter
98
+ fourcc = cv2. VideoWriter_fourcc(*'MP4V')
99
+ out = cv2.VideoWriter(video_path, fourcc, 20.0, (1024,512))
100
+
101
+ for i in tqdm(range(90)):
102
+ R, T = look_at_view_transform(1.8, 0, i*4, device=device)
103
+ images_w_tex = renderer(mesh_w_tex, R=R, T=T)
104
+ images_w_tex = np.clip(images_w_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255
105
+ images_wo_tex = renderer(mesh_wo_tex, R=R, T=T)
106
+ images_wo_tex = np.clip(images_wo_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255
107
+ image = np.concatenate([images_w_tex, images_wo_tex], axis=1)
108
+ out.write(image.astype('uint8'))
109
+ out.release()
110
+
111
+ def video(path):
112
+ mp4 = open(path,'rb').read()
113
+ data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
114
+ return HTML('<video width=500 controls loop> <source src="%s" type="video/mp4"></video>' % data_url)
PIFu/lib/data/BaseDataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import random
3
+
4
+
5
+ class BaseDataset(Dataset):
6
+ '''
7
+ This is the Base Datasets.
8
+ Itself does nothing and is not runnable.
9
+ Check self.get_item function to see what it should return.
10
+ '''
11
+
12
+ @staticmethod
13
+ def modify_commandline_options(parser, is_train):
14
+ return parser
15
+
16
+ def __init__(self, opt, phase='train'):
17
+ self.opt = opt
18
+ self.is_train = self.phase == 'train'
19
+ self.projection_mode = 'orthogonal' # Declare projection mode here
20
+
21
+ def __len__(self):
22
+ return 0
23
+
24
+ def get_item(self, index):
25
+ # In case of a missing file or IO error, switch to a random sample instead
26
+ try:
27
+ res = {
28
+ 'name': None, # name of this subject
29
+ 'b_min': None, # Bounding box (x_min, y_min, z_min) of target space
30
+ 'b_max': None, # Bounding box (x_max, y_max, z_max) of target space
31
+
32
+ 'samples': None, # [3, N] samples
33
+ 'labels': None, # [1, N] labels
34
+
35
+ 'img': None, # [num_views, C, H, W] input images
36
+ 'calib': None, # [num_views, 4, 4] calibration matrix
37
+ 'extrinsic': None, # [num_views, 4, 4] extrinsic matrix
38
+ 'mask': None, # [num_views, 1, H, W] segmentation masks
39
+ }
40
+ return res
41
+ except:
42
+ print("Requested index %s has missing files. Using a random sample instead." % index)
43
+ return self.get_item(index=random.randint(0, self.__len__() - 1))
44
+
45
+ def __getitem__(self, index):
46
+ return self.get_item(index)
PIFu/lib/data/EvalDataset.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import numpy as np
3
+ import os
4
+ import random
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image, ImageOps
7
+ import cv2
8
+ import torch
9
+ from PIL.ImageFilter import GaussianBlur
10
+ import trimesh
11
+ import cv2
12
+
13
+
14
+ class EvalDataset(Dataset):
15
+ @staticmethod
16
+ def modify_commandline_options(parser):
17
+ return parser
18
+
19
+ def __init__(self, opt, root=None):
20
+ self.opt = opt
21
+ self.projection_mode = 'orthogonal'
22
+
23
+ # Path setup
24
+ self.root = self.opt.dataroot
25
+ if root is not None:
26
+ self.root = root
27
+ self.RENDER = os.path.join(self.root, 'RENDER')
28
+ self.MASK = os.path.join(self.root, 'MASK')
29
+ self.PARAM = os.path.join(self.root, 'PARAM')
30
+ self.OBJ = os.path.join(self.root, 'GEO', 'OBJ')
31
+
32
+ self.phase = 'val'
33
+ self.load_size = self.opt.loadSize
34
+
35
+ self.num_views = self.opt.num_views
36
+
37
+ self.max_view_angle = 360
38
+ self.interval = 1
39
+ self.subjects = self.get_subjects()
40
+
41
+ # PIL to tensor
42
+ self.to_tensor = transforms.Compose([
43
+ transforms.Resize(self.load_size),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
46
+ ])
47
+
48
+ def get_subjects(self):
49
+ var_file = os.path.join(self.root, 'val.txt')
50
+ if os.path.exists(var_file):
51
+ var_subjects = np.loadtxt(var_file, dtype=str)
52
+ return sorted(list(var_subjects))
53
+ all_subjects = os.listdir(self.RENDER)
54
+ return sorted(list(all_subjects))
55
+
56
+ def __len__(self):
57
+ return len(self.subjects) * self.max_view_angle // self.interval
58
+
59
+ def get_render(self, subject, num_views, view_id=None, random_sample=False):
60
+ '''
61
+ Return the render data
62
+ :param subject: subject name
63
+ :param num_views: how many views to return
64
+ :param view_id: the first view_id. If None, select a random one.
65
+ :return:
66
+ 'img': [num_views, C, W, H] images
67
+ 'calib': [num_views, 4, 4] calibration matrix
68
+ 'extrinsic': [num_views, 4, 4] extrinsic matrix
69
+ 'mask': [num_views, 1, W, H] masks
70
+ '''
71
+ # For now we only have pitch = 00. Hard code it here
72
+ pitch = 0
73
+ # Select a random view_id from self.max_view_angle if not given
74
+ if view_id is None:
75
+ view_id = np.random.randint(self.max_view_angle)
76
+ # The ids are an even distribution of num_views around view_id
77
+ view_ids = [(view_id + self.max_view_angle // num_views * offset) % self.max_view_angle
78
+ for offset in range(num_views)]
79
+ if random_sample:
80
+ view_ids = np.random.choice(self.max_view_angle, num_views, replace=False)
81
+
82
+ calib_list = []
83
+ render_list = []
84
+ mask_list = []
85
+ extrinsic_list = []
86
+
87
+ for vid in view_ids:
88
+ param_path = os.path.join(self.PARAM, subject, '%d_%02d.npy' % (vid, pitch))
89
+ render_path = os.path.join(self.RENDER, subject, '%d_%02d.jpg' % (vid, pitch))
90
+ mask_path = os.path.join(self.MASK, subject, '%d_%02d.png' % (vid, pitch))
91
+
92
+ # loading calibration data
93
+ param = np.load(param_path)
94
+ # pixel unit / world unit
95
+ ortho_ratio = param.item().get('ortho_ratio')
96
+ # world unit / model unit
97
+ scale = param.item().get('scale')
98
+ # camera center world coordinate
99
+ center = param.item().get('center')
100
+ # model rotation
101
+ R = param.item().get('R')
102
+
103
+ translate = -np.matmul(R, center).reshape(3, 1)
104
+ extrinsic = np.concatenate([R, translate], axis=1)
105
+ extrinsic = np.concatenate([extrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0)
106
+ # Match camera space to image pixel space
107
+ scale_intrinsic = np.identity(4)
108
+ scale_intrinsic[0, 0] = scale / ortho_ratio
109
+ scale_intrinsic[1, 1] = -scale / ortho_ratio
110
+ scale_intrinsic[2, 2] = -scale / ortho_ratio
111
+ # Match image pixel space to image uv space
112
+ uv_intrinsic = np.identity(4)
113
+ uv_intrinsic[0, 0] = 1.0 / float(self.opt.loadSize // 2)
114
+ uv_intrinsic[1, 1] = 1.0 / float(self.opt.loadSize // 2)
115
+ uv_intrinsic[2, 2] = 1.0 / float(self.opt.loadSize // 2)
116
+ # Transform under image pixel space
117
+ trans_intrinsic = np.identity(4)
118
+
119
+ mask = Image.open(mask_path).convert('L')
120
+ render = Image.open(render_path).convert('RGB')
121
+
122
+ intrinsic = np.matmul(trans_intrinsic, np.matmul(uv_intrinsic, scale_intrinsic))
123
+ calib = torch.Tensor(np.matmul(intrinsic, extrinsic)).float()
124
+ extrinsic = torch.Tensor(extrinsic).float()
125
+
126
+ mask = transforms.Resize(self.load_size)(mask)
127
+ mask = transforms.ToTensor()(mask).float()
128
+ mask_list.append(mask)
129
+
130
+ render = self.to_tensor(render)
131
+ render = mask.expand_as(render) * render
132
+
133
+ render_list.append(render)
134
+ calib_list.append(calib)
135
+ extrinsic_list.append(extrinsic)
136
+
137
+ return {
138
+ 'img': torch.stack(render_list, dim=0),
139
+ 'calib': torch.stack(calib_list, dim=0),
140
+ 'extrinsic': torch.stack(extrinsic_list, dim=0),
141
+ 'mask': torch.stack(mask_list, dim=0)
142
+ }
143
+
144
+ def get_item(self, index):
145
+ # In case of a missing file or IO error, switch to a random sample instead
146
+ try:
147
+ sid = index % len(self.subjects)
148
+ vid = (index // len(self.subjects)) * self.interval
149
+ # name of the subject 'rp_xxxx_xxx'
150
+ subject = self.subjects[sid]
151
+ res = {
152
+ 'name': subject,
153
+ 'mesh_path': os.path.join(self.OBJ, subject + '.obj'),
154
+ 'sid': sid,
155
+ 'vid': vid,
156
+ }
157
+ render_data = self.get_render(subject, num_views=self.num_views, view_id=vid,
158
+ random_sample=self.opt.random_multiview)
159
+ res.update(render_data)
160
+ return res
161
+ except Exception as e:
162
+ print(e)
163
+ return self.get_item(index=random.randint(0, self.__len__() - 1))
164
+
165
+ def __getitem__(self, index):
166
+ return self.get_item(index)
PIFu/lib/data/TrainDataset.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import numpy as np
3
+ import os
4
+ import random
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image, ImageOps
7
+ import cv2
8
+ import torch
9
+ from PIL.ImageFilter import GaussianBlur
10
+ import trimesh
11
+ import logging
12
+
13
+ log = logging.getLogger('trimesh')
14
+ log.setLevel(40)
15
+
16
+ def load_trimesh(root_dir):
17
+ folders = os.listdir(root_dir)
18
+ meshs = {}
19
+ for i, f in enumerate(folders):
20
+ sub_name = f
21
+ meshs[sub_name] = trimesh.load(os.path.join(root_dir, f, '%s_100k.obj' % sub_name))
22
+
23
+ return meshs
24
+
25
+ def save_samples_truncted_prob(fname, points, prob):
26
+ '''
27
+ Save the visualization of sampling to a ply file.
28
+ Red points represent positive predictions.
29
+ Green points represent negative predictions.
30
+ :param fname: File name to save
31
+ :param points: [N, 3] array of points
32
+ :param prob: [N, 1] array of predictions in the range [0~1]
33
+ :return:
34
+ '''
35
+ r = (prob > 0.5).reshape([-1, 1]) * 255
36
+ g = (prob < 0.5).reshape([-1, 1]) * 255
37
+ b = np.zeros(r.shape)
38
+
39
+ to_save = np.concatenate([points, r, g, b], axis=-1)
40
+ return np.savetxt(fname,
41
+ to_save,
42
+ fmt='%.6f %.6f %.6f %d %d %d',
43
+ comments='',
44
+ header=(
45
+ 'ply\nformat ascii 1.0\nelement vertex {:d}\nproperty float x\nproperty float y\nproperty float z\nproperty uchar red\nproperty uchar green\nproperty uchar blue\nend_header').format(
46
+ points.shape[0])
47
+ )
48
+
49
+
50
+ class TrainDataset(Dataset):
51
+ @staticmethod
52
+ def modify_commandline_options(parser, is_train):
53
+ return parser
54
+
55
+ def __init__(self, opt, phase='train'):
56
+ self.opt = opt
57
+ self.projection_mode = 'orthogonal'
58
+
59
+ # Path setup
60
+ self.root = self.opt.dataroot
61
+ self.RENDER = os.path.join(self.root, 'RENDER')
62
+ self.MASK = os.path.join(self.root, 'MASK')
63
+ self.PARAM = os.path.join(self.root, 'PARAM')
64
+ self.UV_MASK = os.path.join(self.root, 'UV_MASK')
65
+ self.UV_NORMAL = os.path.join(self.root, 'UV_NORMAL')
66
+ self.UV_RENDER = os.path.join(self.root, 'UV_RENDER')
67
+ self.UV_POS = os.path.join(self.root, 'UV_POS')
68
+ self.OBJ = os.path.join(self.root, 'GEO', 'OBJ')
69
+
70
+ self.B_MIN = np.array([-128, -28, -128])
71
+ self.B_MAX = np.array([128, 228, 128])
72
+
73
+ self.is_train = (phase == 'train')
74
+ self.load_size = self.opt.loadSize
75
+
76
+ self.num_views = self.opt.num_views
77
+
78
+ self.num_sample_inout = self.opt.num_sample_inout
79
+ self.num_sample_color = self.opt.num_sample_color
80
+
81
+ self.yaw_list = list(range(0,360,1))
82
+ self.pitch_list = [0]
83
+ self.subjects = self.get_subjects()
84
+
85
+ # PIL to tensor
86
+ self.to_tensor = transforms.Compose([
87
+ transforms.Resize(self.load_size),
88
+ transforms.ToTensor(),
89
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
90
+ ])
91
+
92
+ # augmentation
93
+ self.aug_trans = transforms.Compose([
94
+ transforms.ColorJitter(brightness=opt.aug_bri, contrast=opt.aug_con, saturation=opt.aug_sat,
95
+ hue=opt.aug_hue)
96
+ ])
97
+
98
+ self.mesh_dic = load_trimesh(self.OBJ)
99
+
100
+ def get_subjects(self):
101
+ all_subjects = os.listdir(self.RENDER)
102
+ var_subjects = np.loadtxt(os.path.join(self.root, 'val.txt'), dtype=str)
103
+ if len(var_subjects) == 0:
104
+ return all_subjects
105
+
106
+ if self.is_train:
107
+ return sorted(list(set(all_subjects) - set(var_subjects)))
108
+ else:
109
+ return sorted(list(var_subjects))
110
+
111
+ def __len__(self):
112
+ return len(self.subjects) * len(self.yaw_list) * len(self.pitch_list)
113
+
114
+ def get_render(self, subject, num_views, yid=0, pid=0, random_sample=False):
115
+ '''
116
+ Return the render data
117
+ :param subject: subject name
118
+ :param num_views: how many views to return
119
+ :param view_id: the first view_id. If None, select a random one.
120
+ :return:
121
+ 'img': [num_views, C, W, H] images
122
+ 'calib': [num_views, 4, 4] calibration matrix
123
+ 'extrinsic': [num_views, 4, 4] extrinsic matrix
124
+ 'mask': [num_views, 1, W, H] masks
125
+ '''
126
+ pitch = self.pitch_list[pid]
127
+
128
+ # The ids are an even distribution of num_views around view_id
129
+ view_ids = [self.yaw_list[(yid + len(self.yaw_list) // num_views * offset) % len(self.yaw_list)]
130
+ for offset in range(num_views)]
131
+ if random_sample:
132
+ view_ids = np.random.choice(self.yaw_list, num_views, replace=False)
133
+
134
+ calib_list = []
135
+ render_list = []
136
+ mask_list = []
137
+ extrinsic_list = []
138
+
139
+ for vid in view_ids:
140
+ param_path = os.path.join(self.PARAM, subject, '%d_%d_%02d.npy' % (vid, pitch, 0))
141
+ render_path = os.path.join(self.RENDER, subject, '%d_%d_%02d.jpg' % (vid, pitch, 0))
142
+ mask_path = os.path.join(self.MASK, subject, '%d_%d_%02d.png' % (vid, pitch, 0))
143
+
144
+ # loading calibration data
145
+ param = np.load(param_path, allow_pickle=True)
146
+ # pixel unit / world unit
147
+ ortho_ratio = param.item().get('ortho_ratio')
148
+ # world unit / model unit
149
+ scale = param.item().get('scale')
150
+ # camera center world coordinate
151
+ center = param.item().get('center')
152
+ # model rotation
153
+ R = param.item().get('R')
154
+
155
+ translate = -np.matmul(R, center).reshape(3, 1)
156
+ extrinsic = np.concatenate([R, translate], axis=1)
157
+ extrinsic = np.concatenate([extrinsic, np.array([0, 0, 0, 1]).reshape(1, 4)], 0)
158
+ # Match camera space to image pixel space
159
+ scale_intrinsic = np.identity(4)
160
+ scale_intrinsic[0, 0] = scale / ortho_ratio
161
+ scale_intrinsic[1, 1] = -scale / ortho_ratio
162
+ scale_intrinsic[2, 2] = scale / ortho_ratio
163
+ # Match image pixel space to image uv space
164
+ uv_intrinsic = np.identity(4)
165
+ uv_intrinsic[0, 0] = 1.0 / float(self.opt.loadSize // 2)
166
+ uv_intrinsic[1, 1] = 1.0 / float(self.opt.loadSize // 2)
167
+ uv_intrinsic[2, 2] = 1.0 / float(self.opt.loadSize // 2)
168
+ # Transform under image pixel space
169
+ trans_intrinsic = np.identity(4)
170
+
171
+ mask = Image.open(mask_path).convert('L')
172
+ render = Image.open(render_path).convert('RGB')
173
+
174
+ if self.is_train:
175
+ # Pad images
176
+ pad_size = int(0.1 * self.load_size)
177
+ render = ImageOps.expand(render, pad_size, fill=0)
178
+ mask = ImageOps.expand(mask, pad_size, fill=0)
179
+
180
+ w, h = render.size
181
+ th, tw = self.load_size, self.load_size
182
+
183
+ # random flip
184
+ if self.opt.random_flip and np.random.rand() > 0.5:
185
+ scale_intrinsic[0, 0] *= -1
186
+ render = transforms.RandomHorizontalFlip(p=1.0)(render)
187
+ mask = transforms.RandomHorizontalFlip(p=1.0)(mask)
188
+
189
+ # random scale
190
+ if self.opt.random_scale:
191
+ rand_scale = random.uniform(0.9, 1.1)
192
+ w = int(rand_scale * w)
193
+ h = int(rand_scale * h)
194
+ render = render.resize((w, h), Image.BILINEAR)
195
+ mask = mask.resize((w, h), Image.NEAREST)
196
+ scale_intrinsic *= rand_scale
197
+ scale_intrinsic[3, 3] = 1
198
+
199
+ # random translate in the pixel space
200
+ if self.opt.random_trans:
201
+ dx = random.randint(-int(round((w - tw) / 10.)),
202
+ int(round((w - tw) / 10.)))
203
+ dy = random.randint(-int(round((h - th) / 10.)),
204
+ int(round((h - th) / 10.)))
205
+ else:
206
+ dx = 0
207
+ dy = 0
208
+
209
+ trans_intrinsic[0, 3] = -dx / float(self.opt.loadSize // 2)
210
+ trans_intrinsic[1, 3] = -dy / float(self.opt.loadSize // 2)
211
+
212
+ x1 = int(round((w - tw) / 2.)) + dx
213
+ y1 = int(round((h - th) / 2.)) + dy
214
+
215
+ render = render.crop((x1, y1, x1 + tw, y1 + th))
216
+ mask = mask.crop((x1, y1, x1 + tw, y1 + th))
217
+
218
+ render = self.aug_trans(render)
219
+
220
+ # random blur
221
+ if self.opt.aug_blur > 0.00001:
222
+ blur = GaussianBlur(np.random.uniform(0, self.opt.aug_blur))
223
+ render = render.filter(blur)
224
+
225
+ intrinsic = np.matmul(trans_intrinsic, np.matmul(uv_intrinsic, scale_intrinsic))
226
+ calib = torch.Tensor(np.matmul(intrinsic, extrinsic)).float()
227
+ extrinsic = torch.Tensor(extrinsic).float()
228
+
229
+ mask = transforms.Resize(self.load_size)(mask)
230
+ mask = transforms.ToTensor()(mask).float()
231
+ mask_list.append(mask)
232
+
233
+ render = self.to_tensor(render)
234
+ render = mask.expand_as(render) * render
235
+
236
+ render_list.append(render)
237
+ calib_list.append(calib)
238
+ extrinsic_list.append(extrinsic)
239
+
240
+ return {
241
+ 'img': torch.stack(render_list, dim=0),
242
+ 'calib': torch.stack(calib_list, dim=0),
243
+ 'extrinsic': torch.stack(extrinsic_list, dim=0),
244
+ 'mask': torch.stack(mask_list, dim=0)
245
+ }
246
+
247
+ def select_sampling_method(self, subject):
248
+ if not self.is_train:
249
+ random.seed(1991)
250
+ np.random.seed(1991)
251
+ torch.manual_seed(1991)
252
+ mesh = self.mesh_dic[subject]
253
+ surface_points, _ = trimesh.sample.sample_surface(mesh, 4 * self.num_sample_inout)
254
+ sample_points = surface_points + np.random.normal(scale=self.opt.sigma, size=surface_points.shape)
255
+
256
+ # add random points within image space
257
+ length = self.B_MAX - self.B_MIN
258
+ random_points = np.random.rand(self.num_sample_inout // 4, 3) * length + self.B_MIN
259
+ sample_points = np.concatenate([sample_points, random_points], 0)
260
+ np.random.shuffle(sample_points)
261
+
262
+ inside = mesh.contains(sample_points)
263
+ inside_points = sample_points[inside]
264
+ outside_points = sample_points[np.logical_not(inside)]
265
+
266
+ nin = inside_points.shape[0]
267
+ inside_points = inside_points[
268
+ :self.num_sample_inout // 2] if nin > self.num_sample_inout // 2 else inside_points
269
+ outside_points = outside_points[
270
+ :self.num_sample_inout // 2] if nin > self.num_sample_inout // 2 else outside_points[
271
+ :(self.num_sample_inout - nin)]
272
+
273
+ samples = np.concatenate([inside_points, outside_points], 0).T
274
+ labels = np.concatenate([np.ones((1, inside_points.shape[0])), np.zeros((1, outside_points.shape[0]))], 1)
275
+
276
+ # save_samples_truncted_prob('out.ply', samples.T, labels.T)
277
+ # exit()
278
+
279
+ samples = torch.Tensor(samples).float()
280
+ labels = torch.Tensor(labels).float()
281
+
282
+ del mesh
283
+
284
+ return {
285
+ 'samples': samples,
286
+ 'labels': labels
287
+ }
288
+
289
+
290
+ def get_color_sampling(self, subject, yid, pid=0):
291
+ yaw = self.yaw_list[yid]
292
+ pitch = self.pitch_list[pid]
293
+ uv_render_path = os.path.join(self.UV_RENDER, subject, '%d_%d_%02d.jpg' % (yaw, pitch, 0))
294
+ uv_mask_path = os.path.join(self.UV_MASK, subject, '%02d.png' % (0))
295
+ uv_pos_path = os.path.join(self.UV_POS, subject, '%02d.exr' % (0))
296
+ uv_normal_path = os.path.join(self.UV_NORMAL, subject, '%02d.png' % (0))
297
+
298
+ # Segmentation mask for the uv render.
299
+ # [H, W] bool
300
+ uv_mask = cv2.imread(uv_mask_path)
301
+ uv_mask = uv_mask[:, :, 0] != 0
302
+ # UV render. each pixel is the color of the point.
303
+ # [H, W, 3] 0 ~ 1 float
304
+ uv_render = cv2.imread(uv_render_path)
305
+ uv_render = cv2.cvtColor(uv_render, cv2.COLOR_BGR2RGB) / 255.0
306
+
307
+ # Normal render. each pixel is the surface normal of the point.
308
+ # [H, W, 3] -1 ~ 1 float
309
+ uv_normal = cv2.imread(uv_normal_path)
310
+ uv_normal = cv2.cvtColor(uv_normal, cv2.COLOR_BGR2RGB) / 255.0
311
+ uv_normal = 2.0 * uv_normal - 1.0
312
+ # Position render. each pixel is the xyz coordinates of the point
313
+ uv_pos = cv2.imread(uv_pos_path, 2 | 4)[:, :, ::-1]
314
+
315
+ ### In these few lines we flattern the masks, positions, and normals
316
+ uv_mask = uv_mask.reshape((-1))
317
+ uv_pos = uv_pos.reshape((-1, 3))
318
+ uv_render = uv_render.reshape((-1, 3))
319
+ uv_normal = uv_normal.reshape((-1, 3))
320
+
321
+ surface_points = uv_pos[uv_mask]
322
+ surface_colors = uv_render[uv_mask]
323
+ surface_normal = uv_normal[uv_mask]
324
+
325
+ if self.num_sample_color:
326
+ sample_list = random.sample(range(0, surface_points.shape[0] - 1), self.num_sample_color)
327
+ surface_points = surface_points[sample_list].T
328
+ surface_colors = surface_colors[sample_list].T
329
+ surface_normal = surface_normal[sample_list].T
330
+
331
+ # Samples are around the true surface with an offset
332
+ normal = torch.Tensor(surface_normal).float()
333
+ samples = torch.Tensor(surface_points).float() \
334
+ + torch.normal(mean=torch.zeros((1, normal.size(1))), std=self.opt.sigma).expand_as(normal) * normal
335
+
336
+ # Normalized to [-1, 1]
337
+ rgbs_color = 2.0 * torch.Tensor(surface_colors).float() - 1.0
338
+
339
+ return {
340
+ 'color_samples': samples,
341
+ 'rgbs': rgbs_color
342
+ }
343
+
344
+ def get_item(self, index):
345
+ # In case of a missing file or IO error, switch to a random sample instead
346
+ # try:
347
+ sid = index % len(self.subjects)
348
+ tmp = index // len(self.subjects)
349
+ yid = tmp % len(self.yaw_list)
350
+ pid = tmp // len(self.yaw_list)
351
+
352
+ # name of the subject 'rp_xxxx_xxx'
353
+ subject = self.subjects[sid]
354
+ res = {
355
+ 'name': subject,
356
+ 'mesh_path': os.path.join(self.OBJ, subject + '.obj'),
357
+ 'sid': sid,
358
+ 'yid': yid,
359
+ 'pid': pid,
360
+ 'b_min': self.B_MIN,
361
+ 'b_max': self.B_MAX,
362
+ }
363
+ render_data = self.get_render(subject, num_views=self.num_views, yid=yid, pid=pid,
364
+ random_sample=self.opt.random_multiview)
365
+ res.update(render_data)
366
+
367
+ if self.opt.num_sample_inout:
368
+ sample_data = self.select_sampling_method(subject)
369
+ res.update(sample_data)
370
+
371
+ # img = np.uint8((np.transpose(render_data['img'][0].numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0)
372
+ # rot = render_data['calib'][0,:3, :3]
373
+ # trans = render_data['calib'][0,:3, 3:4]
374
+ # pts = torch.addmm(trans, rot, sample_data['samples'][:, sample_data['labels'][0] > 0.5]) # [3, N]
375
+ # pts = 0.5 * (pts.numpy().T + 1.0) * render_data['img'].size(2)
376
+ # for p in pts:
377
+ # img = cv2.circle(img, (p[0], p[1]), 2, (0,255,0), -1)
378
+ # cv2.imshow('test', img)
379
+ # cv2.waitKey(1)
380
+
381
+ if self.num_sample_color:
382
+ color_data = self.get_color_sampling(subject, yid=yid, pid=pid)
383
+ res.update(color_data)
384
+ return res
385
+ # except Exception as e:
386
+ # print(e)
387
+ # return self.get_item(index=random.randint(0, self.__len__() - 1))
388
+
389
+ def __getitem__(self, index):
390
+ return self.get_item(index)
PIFu/lib/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .EvalDataset import EvalDataset
2
+ from .TrainDataset import TrainDataset
PIFu/lib/ext_transform.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ from skimage.filters import gaussian
5
+ import torch
6
+ from PIL import Image, ImageFilter
7
+
8
+
9
+ class RandomVerticalFlip(object):
10
+ def __call__(self, img):
11
+ if random.random() < 0.5:
12
+ return img.transpose(Image.FLIP_TOP_BOTTOM)
13
+ return img
14
+
15
+
16
+ class DeNormalize(object):
17
+ def __init__(self, mean, std):
18
+ self.mean = mean
19
+ self.std = std
20
+
21
+ def __call__(self, tensor):
22
+ for t, m, s in zip(tensor, self.mean, self.std):
23
+ t.mul_(s).add_(m)
24
+ return tensor
25
+
26
+
27
+ class MaskToTensor(object):
28
+ def __call__(self, img):
29
+ return torch.from_numpy(np.array(img, dtype=np.int32)).long()
30
+
31
+
32
+ class FreeScale(object):
33
+ def __init__(self, size, interpolation=Image.BILINEAR):
34
+ self.size = tuple(reversed(size)) # size: (h, w)
35
+ self.interpolation = interpolation
36
+
37
+ def __call__(self, img):
38
+ return img.resize(self.size, self.interpolation)
39
+
40
+
41
+ class FlipChannels(object):
42
+ def __call__(self, img):
43
+ img = np.array(img)[:, :, ::-1]
44
+ return Image.fromarray(img.astype(np.uint8))
45
+
46
+
47
+ class RandomGaussianBlur(object):
48
+ def __call__(self, img):
49
+ sigma = 0.15 + random.random() * 1.15
50
+ blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True)
51
+ blurred_img *= 255
52
+ return Image.fromarray(blurred_img.astype(np.uint8))
53
+
54
+ # Lighting data augmentation take from here - https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py
55
+
56
+
57
+ class Lighting(object):
58
+ """Lighting noise(AlexNet - style PCA - based noise)"""
59
+
60
+ def __init__(self, alphastd,
61
+ eigval=(0.2175, 0.0188, 0.0045),
62
+ eigvec=((-0.5675, 0.7192, 0.4009),
63
+ (-0.5808, -0.0045, -0.8140),
64
+ (-0.5836, -0.6948, 0.4203))):
65
+ self.alphastd = alphastd
66
+ self.eigval = torch.Tensor(eigval)
67
+ self.eigvec = torch.Tensor(eigvec)
68
+
69
+ def __call__(self, img):
70
+ if self.alphastd == 0:
71
+ return img
72
+
73
+ alpha = img.new().resize_(3).normal_(0, self.alphastd)
74
+ rgb = self.eigvec.type_as(img).clone()\
75
+ .mul(alpha.view(1, 3).expand(3, 3))\
76
+ .mul(self.eigval.view(1, 3).expand(3, 3))\
77
+ .sum(1).squeeze()
78
+ return img.add(rgb.view(3, 1, 1).expand_as(img))
PIFu/lib/geometry.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def index(feat, uv):
5
+ '''
6
+
7
+ :param feat: [B, C, H, W] image features
8
+ :param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1]
9
+ :return: [B, C, N] image features at the uv coordinates
10
+ '''
11
+ uv = uv.transpose(1, 2) # [B, N, 2]
12
+ uv = uv.unsqueeze(2) # [B, N, 1, 2]
13
+ # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
14
+ # for old versions, simply remove the aligned_corners argument.
15
+ samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1]
16
+ return samples[:, :, :, 0] # [B, C, N]
17
+
18
+
19
+ def orthogonal(points, calibrations, transforms=None):
20
+ '''
21
+ Compute the orthogonal projections of 3D points into the image plane by given projection matrix
22
+ :param points: [B, 3, N] Tensor of 3D points
23
+ :param calibrations: [B, 4, 4] Tensor of projection matrix
24
+ :param transforms: [B, 2, 3] Tensor of image transform matrix
25
+ :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
26
+ '''
27
+ rot = calibrations[:, :3, :3]
28
+ trans = calibrations[:, :3, 3:4]
29
+ pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
30
+ if transforms is not None:
31
+ scale = transforms[:2, :2]
32
+ shift = transforms[:2, 2:3]
33
+ pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
34
+ return pts
35
+
36
+
37
+ def perspective(points, calibrations, transforms=None):
38
+ '''
39
+ Compute the perspective projections of 3D points into the image plane by given projection matrix
40
+ :param points: [Bx3xN] Tensor of 3D points
41
+ :param calibrations: [Bx4x4] Tensor of projection matrix
42
+ :param transforms: [Bx2x3] Tensor of image transform matrix
43
+ :return: xy: [Bx2xN] Tensor of xy coordinates in the image plane
44
+ '''
45
+ rot = calibrations[:, :3, :3]
46
+ trans = calibrations[:, :3, 3:4]
47
+ homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
48
+ xy = homo[:, :2, :] / homo[:, 2:3, :]
49
+ if transforms is not None:
50
+ scale = transforms[:2, :2]
51
+ shift = transforms[:2, 2:3]
52
+ xy = torch.baddbmm(shift, scale, xy)
53
+
54
+ xyz = torch.cat([xy, homo[:, 2:3, :]], 1)
55
+ return xyz
PIFu/lib/mesh_util.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage import measure
2
+ import numpy as np
3
+ import torch
4
+ from .sdf import create_grid, eval_grid_octree, eval_grid
5
+ from skimage import measure
6
+
7
+
8
+ def reconstruction(net, cuda, calib_tensor,
9
+ resolution, b_min, b_max,
10
+ use_octree=False, num_samples=10000, transform=None):
11
+ '''
12
+ Reconstruct meshes from sdf predicted by the network.
13
+ :param net: a BasePixImpNet object. call image filter beforehead.
14
+ :param cuda: cuda device
15
+ :param calib_tensor: calibration tensor
16
+ :param resolution: resolution of the grid cell
17
+ :param b_min: bounding box corner [x_min, y_min, z_min]
18
+ :param b_max: bounding box corner [x_max, y_max, z_max]
19
+ :param use_octree: whether to use octree acceleration
20
+ :param num_samples: how many points to query each gpu iteration
21
+ :return: marching cubes results.
22
+ '''
23
+ # First we create a grid by resolution
24
+ # and transforming matrix for grid coordinates to real world xyz
25
+ coords, mat = create_grid(resolution, resolution, resolution,
26
+ b_min, b_max, transform=transform)
27
+
28
+ # Then we define the lambda function for cell evaluation
29
+ def eval_func(points):
30
+ points = np.expand_dims(points, axis=0)
31
+ points = np.repeat(points, net.num_views, axis=0)
32
+ samples = torch.from_numpy(points).to(device=cuda).float()
33
+ net.query(samples, calib_tensor)
34
+ pred = net.get_preds()[0][0]
35
+ return pred.detach().cpu().numpy()
36
+
37
+ # Then we evaluate the grid
38
+ if use_octree:
39
+ sdf = eval_grid_octree(coords, eval_func, num_samples=num_samples)
40
+ else:
41
+ sdf = eval_grid(coords, eval_func, num_samples=num_samples)
42
+
43
+ # Finally we do marching cubes
44
+ try:
45
+ verts, faces, normals, values = measure.marching_cubes_lewiner(sdf, 0.5)
46
+ # transform verts into world coordinate system
47
+ verts = np.matmul(mat[:3, :3], verts.T) + mat[:3, 3:4]
48
+ verts = verts.T
49
+ return verts, faces, normals, values
50
+ except:
51
+ print('error cannot marching cubes')
52
+ return -1
53
+
54
+
55
+ def save_obj_mesh(mesh_path, verts, faces):
56
+ file = open(mesh_path, 'w')
57
+
58
+ for v in verts:
59
+ file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
60
+ for f in faces:
61
+ f_plus = f + 1
62
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
63
+ file.close()
64
+
65
+
66
+ def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
67
+ file = open(mesh_path, 'w')
68
+
69
+ for idx, v in enumerate(verts):
70
+ c = colors[idx]
71
+ file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % (v[0], v[1], v[2], c[0], c[1], c[2]))
72
+ for f in faces:
73
+ f_plus = f + 1
74
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
75
+ file.close()
76
+
77
+
78
+ def save_obj_mesh_with_uv(mesh_path, verts, faces, uvs):
79
+ file = open(mesh_path, 'w')
80
+
81
+ for idx, v in enumerate(verts):
82
+ vt = uvs[idx]
83
+ file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
84
+ file.write('vt %.4f %.4f\n' % (vt[0], vt[1]))
85
+
86
+ for f in faces:
87
+ f_plus = f + 1
88
+ file.write('f %d/%d %d/%d %d/%d\n' % (f_plus[0], f_plus[0],
89
+ f_plus[2], f_plus[2],
90
+ f_plus[1], f_plus[1]))
91
+ file.close()
PIFu/lib/model/BasePIFuNet.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from ..geometry import index, orthogonal, perspective
6
+
7
+ class BasePIFuNet(nn.Module):
8
+ def __init__(self,
9
+ projection_mode='orthogonal',
10
+ error_term=nn.MSELoss(),
11
+ ):
12
+ """
13
+ :param projection_mode:
14
+ Either orthogonal or perspective.
15
+ It will call the corresponding function for projection.
16
+ :param error_term:
17
+ nn Loss between the predicted [B, Res, N] and the label [B, Res, N]
18
+ """
19
+ super(BasePIFuNet, self).__init__()
20
+ self.name = 'base'
21
+
22
+ self.error_term = error_term
23
+
24
+ self.index = index
25
+ self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
26
+
27
+ self.preds = None
28
+ self.labels = None
29
+
30
+ def forward(self, points, images, calibs, transforms=None):
31
+ '''
32
+ :param points: [B, 3, N] world space coordinates of points
33
+ :param images: [B, C, H, W] input images
34
+ :param calibs: [B, 3, 4] calibration matrices for each image
35
+ :param transforms: Optional [B, 2, 3] image space coordinate transforms
36
+ :return: [B, Res, N] predictions for each point
37
+ '''
38
+ self.filter(images)
39
+ self.query(points, calibs, transforms)
40
+ return self.get_preds()
41
+
42
+ def filter(self, images):
43
+ '''
44
+ Filter the input images
45
+ store all intermediate features.
46
+ :param images: [B, C, H, W] input images
47
+ '''
48
+ None
49
+
50
+ def query(self, points, calibs, transforms=None, labels=None):
51
+ '''
52
+ Given 3D points, query the network predictions for each point.
53
+ Image features should be pre-computed before this call.
54
+ store all intermediate features.
55
+ query() function may behave differently during training/testing.
56
+ :param points: [B, 3, N] world space coordinates of points
57
+ :param calibs: [B, 3, 4] calibration matrices for each image
58
+ :param transforms: Optional [B, 2, 3] image space coordinate transforms
59
+ :param labels: Optional [B, Res, N] gt labeling
60
+ :return: [B, Res, N] predictions for each point
61
+ '''
62
+ None
63
+
64
+ def get_preds(self):
65
+ '''
66
+ Get the predictions from the last query
67
+ :return: [B, Res, N] network prediction for the last query
68
+ '''
69
+ return self.preds
70
+
71
+ def get_error(self):
72
+ '''
73
+ Get the network loss from the last query
74
+ :return: loss term
75
+ '''
76
+ return self.error_term(self.preds, self.labels)
PIFu/lib/model/ConvFilters.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models.resnet as resnet
5
+ import torchvision.models.vgg as vgg
6
+
7
+
8
+ class MultiConv(nn.Module):
9
+ def __init__(self, filter_channels):
10
+ super(MultiConv, self).__init__()
11
+ self.filters = []
12
+
13
+ for l in range(0, len(filter_channels) - 1):
14
+ self.filters.append(
15
+ nn.Conv2d(filter_channels[l], filter_channels[l + 1], kernel_size=4, stride=2))
16
+ self.add_module("conv%d" % l, self.filters[l])
17
+
18
+ def forward(self, image):
19
+ '''
20
+ :param image: [BxC_inxHxW] tensor of input image
21
+ :return: list of [BxC_outxHxW] tensors of output features
22
+ '''
23
+ y = image
24
+ # y = F.relu(self.bn0(self.conv0(y)), True)
25
+ feat_pyramid = [y]
26
+ for i, f in enumerate(self.filters):
27
+ y = f(y)
28
+ if i != len(self.filters) - 1:
29
+ y = F.leaky_relu(y)
30
+ # y = F.max_pool2d(y, kernel_size=2, stride=2)
31
+ feat_pyramid.append(y)
32
+ return feat_pyramid
33
+
34
+
35
+ class Vgg16(torch.nn.Module):
36
+ def __init__(self):
37
+ super(Vgg16, self).__init__()
38
+ vgg_pretrained_features = vgg.vgg16(pretrained=True).features
39
+ self.slice1 = torch.nn.Sequential()
40
+ self.slice2 = torch.nn.Sequential()
41
+ self.slice3 = torch.nn.Sequential()
42
+ self.slice4 = torch.nn.Sequential()
43
+ self.slice5 = torch.nn.Sequential()
44
+
45
+ for x in range(4):
46
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
47
+ for x in range(4, 9):
48
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
49
+ for x in range(9, 16):
50
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
51
+ for x in range(16, 23):
52
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
53
+ for x in range(23, 30):
54
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
55
+
56
+ def forward(self, X):
57
+ h = self.slice1(X)
58
+ h_relu1_2 = h
59
+ h = self.slice2(h)
60
+ h_relu2_2 = h
61
+ h = self.slice3(h)
62
+ h_relu3_3 = h
63
+ h = self.slice4(h)
64
+ h_relu4_3 = h
65
+ h = self.slice5(h)
66
+ h_relu5_3 = h
67
+
68
+ return [h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
69
+
70
+
71
+ class ResNet(nn.Module):
72
+ def __init__(self, model='resnet18'):
73
+ super(ResNet, self).__init__()
74
+
75
+ if model == 'resnet18':
76
+ net = resnet.resnet18(pretrained=True)
77
+ elif model == 'resnet34':
78
+ net = resnet.resnet34(pretrained=True)
79
+ elif model == 'resnet50':
80
+ net = resnet.resnet50(pretrained=True)
81
+ else:
82
+ raise NameError('Unknown Fan Filter setting!')
83
+
84
+ self.conv1 = net.conv1
85
+
86
+ self.pool = net.maxpool
87
+ self.layer0 = nn.Sequential(net.conv1, net.bn1, net.relu)
88
+ self.layer1 = net.layer1
89
+ self.layer2 = net.layer2
90
+ self.layer3 = net.layer3
91
+ self.layer4 = net.layer4
92
+
93
+ def forward(self, image):
94
+ '''
95
+ :param image: [BxC_inxHxW] tensor of input image
96
+ :return: list of [BxC_outxHxW] tensors of output features
97
+ '''
98
+
99
+ y = image
100
+ feat_pyramid = []
101
+ y = self.layer0(y)
102
+ feat_pyramid.append(y)
103
+ y = self.layer1(self.pool(y))
104
+ feat_pyramid.append(y)
105
+ y = self.layer2(y)
106
+ feat_pyramid.append(y)
107
+ y = self.layer3(y)
108
+ feat_pyramid.append(y)
109
+ y = self.layer4(y)
110
+ feat_pyramid.append(y)
111
+
112
+ return feat_pyramid
PIFu/lib/model/ConvPIFuNet.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .BasePIFuNet import BasePIFuNet
5
+ from .SurfaceClassifier import SurfaceClassifier
6
+ from .DepthNormalizer import DepthNormalizer
7
+ from .ConvFilters import *
8
+ from ..net_util import init_net
9
+
10
+ class ConvPIFuNet(BasePIFuNet):
11
+ '''
12
+ Conv Piximp network is the standard 3-phase network that we will use.
13
+ The image filter is a pure multi-layer convolutional network,
14
+ while during feature extraction phase all features in the pyramid at the projected location
15
+ will be aggregated.
16
+ It does the following:
17
+ 1. Compute image feature pyramids and store it in self.im_feat_list
18
+ 2. Calculate calibration and indexing on each of the feat, and append them together
19
+ 3. Classification.
20
+ '''
21
+
22
+ def __init__(self,
23
+ opt,
24
+ projection_mode='orthogonal',
25
+ error_term=nn.MSELoss(),
26
+ ):
27
+ super(ConvPIFuNet, self).__init__(
28
+ projection_mode=projection_mode,
29
+ error_term=error_term)
30
+
31
+ self.name = 'convpifu'
32
+
33
+ self.opt = opt
34
+ self.num_views = self.opt.num_views
35
+
36
+ self.image_filter = self.define_imagefilter(opt)
37
+
38
+ self.surface_classifier = SurfaceClassifier(
39
+ filter_channels=self.opt.mlp_dim,
40
+ num_views=self.opt.num_views,
41
+ no_residual=self.opt.no_residual,
42
+ last_op=nn.Sigmoid())
43
+
44
+ self.normalizer = DepthNormalizer(opt)
45
+
46
+ # This is a list of [B x Feat_i x H x W] features
47
+ self.im_feat_list = []
48
+
49
+ init_net(self)
50
+
51
+ def define_imagefilter(self, opt):
52
+ net = None
53
+ if opt.netIMF == 'multiconv':
54
+ net = MultiConv(opt.enc_dim)
55
+ elif 'resnet' in opt.netIMF:
56
+ net = ResNet(model=opt.netIMF)
57
+ elif opt.netIMF == 'vgg16':
58
+ net = Vgg16()
59
+ else:
60
+ raise NotImplementedError('model name [%s] is not recognized' % opt.imf_type)
61
+
62
+ return net
63
+
64
+ def filter(self, images):
65
+ '''
66
+ Filter the input images
67
+ store all intermediate features.
68
+ :param images: [B, C, H, W] input images
69
+ '''
70
+ self.im_feat_list = self.image_filter(images)
71
+
72
+ def query(self, points, calibs, transforms=None, labels=None):
73
+ '''
74
+ Given 3D points, query the network predictions for each point.
75
+ Image features should be pre-computed before this call.
76
+ store all intermediate features.
77
+ query() function may behave differently during training/testing.
78
+ :param points: [B, 3, N] world space coordinates of points
79
+ :param calibs: [B, 3, 4] calibration matrices for each image
80
+ :param transforms: Optional [B, 2, 3] image space coordinate transforms
81
+ :param labels: Optional [B, Res, N] gt labeling
82
+ :return: [B, Res, N] predictions for each point
83
+ '''
84
+ if labels is not None:
85
+ self.labels = labels
86
+
87
+ xyz = self.projection(points, calibs, transforms)
88
+ xy = xyz[:, :2, :]
89
+ z = xyz[:, 2:3, :]
90
+
91
+ z_feat = self.normalizer(z)
92
+
93
+ # This is a list of [B, Feat_i, N] features
94
+ point_local_feat_list = [self.index(im_feat, xy) for im_feat in self.im_feat_list]
95
+ point_local_feat_list.append(z_feat)
96
+ # [B, Feat_all, N]
97
+ point_local_feat = torch.cat(point_local_feat_list, 1)
98
+
99
+ self.preds = self.surface_classifier(point_local_feat)
PIFu/lib/model/DepthNormalizer.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DepthNormalizer(nn.Module):
7
+ def __init__(self, opt):
8
+ super(DepthNormalizer, self).__init__()
9
+ self.opt = opt
10
+
11
+ def forward(self, z, calibs=None, index_feat=None):
12
+ '''
13
+ Normalize z_feature
14
+ :param z_feat: [B, 1, N] depth value for z in the image coordinate system
15
+ :return:
16
+ '''
17
+ z_feat = z * (self.opt.loadSize // 2) / self.opt.z_size
18
+ return z_feat
PIFu/lib/model/HGFilters.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from ..net_util import *
5
+
6
+
7
+ class HourGlass(nn.Module):
8
+ def __init__(self, num_modules, depth, num_features, norm='batch'):
9
+ super(HourGlass, self).__init__()
10
+ self.num_modules = num_modules
11
+ self.depth = depth
12
+ self.features = num_features
13
+ self.norm = norm
14
+
15
+ self._generate_network(self.depth)
16
+
17
+ def _generate_network(self, level):
18
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
19
+
20
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
21
+
22
+ if level > 1:
23
+ self._generate_network(level - 1)
24
+ else:
25
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
26
+
27
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
28
+
29
+ def _forward(self, level, inp):
30
+ # Upper branch
31
+ up1 = inp
32
+ up1 = self._modules['b1_' + str(level)](up1)
33
+
34
+ # Lower branch
35
+ low1 = F.avg_pool2d(inp, 2, stride=2)
36
+ low1 = self._modules['b2_' + str(level)](low1)
37
+
38
+ if level > 1:
39
+ low2 = self._forward(level - 1, low1)
40
+ else:
41
+ low2 = low1
42
+ low2 = self._modules['b2_plus_' + str(level)](low2)
43
+
44
+ low3 = low2
45
+ low3 = self._modules['b3_' + str(level)](low3)
46
+
47
+ # NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample
48
+ # if the pretrained model behaves weirdly, switch with the commented line.
49
+ # NOTE: I also found that "bicubic" works better.
50
+ up2 = F.interpolate(low3, scale_factor=2, mode='bicubic', align_corners=True)
51
+ # up2 = F.interpolate(low3, scale_factor=2, mode='nearest)
52
+
53
+ return up1 + up2
54
+
55
+ def forward(self, x):
56
+ return self._forward(self.depth, x)
57
+
58
+
59
+ class HGFilter(nn.Module):
60
+ def __init__(self, opt):
61
+ super(HGFilter, self).__init__()
62
+ self.num_modules = opt.num_stack
63
+
64
+ self.opt = opt
65
+
66
+ # Base part
67
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
68
+
69
+ if self.opt.norm == 'batch':
70
+ self.bn1 = nn.BatchNorm2d(64)
71
+ elif self.opt.norm == 'group':
72
+ self.bn1 = nn.GroupNorm(32, 64)
73
+
74
+ if self.opt.hg_down == 'conv64':
75
+ self.conv2 = ConvBlock(64, 64, self.opt.norm)
76
+ self.down_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
77
+ elif self.opt.hg_down == 'conv128':
78
+ self.conv2 = ConvBlock(64, 128, self.opt.norm)
79
+ self.down_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
80
+ elif self.opt.hg_down == 'ave_pool':
81
+ self.conv2 = ConvBlock(64, 128, self.opt.norm)
82
+ else:
83
+ raise NameError('Unknown Fan Filter setting!')
84
+
85
+ self.conv3 = ConvBlock(128, 128, self.opt.norm)
86
+ self.conv4 = ConvBlock(128, 256, self.opt.norm)
87
+
88
+ # Stacking part
89
+ for hg_module in range(self.num_modules):
90
+ self.add_module('m' + str(hg_module), HourGlass(1, opt.num_hourglass, 256, self.opt.norm))
91
+
92
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256, self.opt.norm))
93
+ self.add_module('conv_last' + str(hg_module),
94
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
95
+ if self.opt.norm == 'batch':
96
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
97
+ elif self.opt.norm == 'group':
98
+ self.add_module('bn_end' + str(hg_module), nn.GroupNorm(32, 256))
99
+
100
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
101
+ opt.hourglass_dim, kernel_size=1, stride=1, padding=0))
102
+
103
+ if hg_module < self.num_modules - 1:
104
+ self.add_module(
105
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
106
+ self.add_module('al' + str(hg_module), nn.Conv2d(opt.hourglass_dim,
107
+ 256, kernel_size=1, stride=1, padding=0))
108
+
109
+ def forward(self, x):
110
+ x = F.relu(self.bn1(self.conv1(x)), True)
111
+ tmpx = x
112
+ if self.opt.hg_down == 'ave_pool':
113
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
114
+ elif self.opt.hg_down in ['conv64', 'conv128']:
115
+ x = self.conv2(x)
116
+ x = self.down_conv2(x)
117
+ else:
118
+ raise NameError('Unknown Fan Filter setting!')
119
+
120
+ normx = x
121
+
122
+ x = self.conv3(x)
123
+ x = self.conv4(x)
124
+
125
+ previous = x
126
+
127
+ outputs = []
128
+ for i in range(self.num_modules):
129
+ hg = self._modules['m' + str(i)](previous)
130
+
131
+ ll = hg
132
+ ll = self._modules['top_m_' + str(i)](ll)
133
+
134
+ ll = F.relu(self._modules['bn_end' + str(i)]
135
+ (self._modules['conv_last' + str(i)](ll)), True)
136
+
137
+ # Predict heatmaps
138
+ tmp_out = self._modules['l' + str(i)](ll)
139
+ outputs.append(tmp_out)
140
+
141
+ if i < self.num_modules - 1:
142
+ ll = self._modules['bl' + str(i)](ll)
143
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
144
+ previous = previous + ll + tmp_out_
145
+
146
+ return outputs, tmpx.detach(), normx
PIFu/lib/model/HGPIFuNet.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .BasePIFuNet import BasePIFuNet
5
+ from .SurfaceClassifier import SurfaceClassifier
6
+ from .DepthNormalizer import DepthNormalizer
7
+ from .HGFilters import *
8
+ from ..net_util import init_net
9
+
10
+
11
+ class HGPIFuNet(BasePIFuNet):
12
+ '''
13
+ HG PIFu network uses Hourglass stacks as the image filter.
14
+ It does the following:
15
+ 1. Compute image feature stacks and store it in self.im_feat_list
16
+ self.im_feat_list[-1] is the last stack (output stack)
17
+ 2. Calculate calibration
18
+ 3. If training, it index on every intermediate stacks,
19
+ If testing, it index on the last stack.
20
+ 4. Classification.
21
+ 5. During training, error is calculated on all stacks.
22
+ '''
23
+
24
+ def __init__(self,
25
+ opt,
26
+ projection_mode='orthogonal',
27
+ error_term=nn.MSELoss(),
28
+ ):
29
+ super(HGPIFuNet, self).__init__(
30
+ projection_mode=projection_mode,
31
+ error_term=error_term)
32
+
33
+ self.name = 'hgpifu'
34
+
35
+ self.opt = opt
36
+ self.num_views = self.opt.num_views
37
+
38
+ self.image_filter = HGFilter(opt)
39
+
40
+ self.surface_classifier = SurfaceClassifier(
41
+ filter_channels=self.opt.mlp_dim,
42
+ num_views=self.opt.num_views,
43
+ no_residual=self.opt.no_residual,
44
+ last_op=nn.Sigmoid())
45
+
46
+ self.normalizer = DepthNormalizer(opt)
47
+
48
+ # This is a list of [B x Feat_i x H x W] features
49
+ self.im_feat_list = []
50
+ self.tmpx = None
51
+ self.normx = None
52
+
53
+ self.intermediate_preds_list = []
54
+
55
+ init_net(self)
56
+
57
+ def filter(self, images):
58
+ '''
59
+ Filter the input images
60
+ store all intermediate features.
61
+ :param images: [B, C, H, W] input images
62
+ '''
63
+ self.im_feat_list, self.tmpx, self.normx = self.image_filter(images)
64
+ # If it is not in training, only produce the last im_feat
65
+ if not self.training:
66
+ self.im_feat_list = [self.im_feat_list[-1]]
67
+
68
+ def query(self, points, calibs, transforms=None, labels=None):
69
+ '''
70
+ Given 3D points, query the network predictions for each point.
71
+ Image features should be pre-computed before this call.
72
+ store all intermediate features.
73
+ query() function may behave differently during training/testing.
74
+ :param points: [B, 3, N] world space coordinates of points
75
+ :param calibs: [B, 3, 4] calibration matrices for each image
76
+ :param transforms: Optional [B, 2, 3] image space coordinate transforms
77
+ :param labels: Optional [B, Res, N] gt labeling
78
+ :return: [B, Res, N] predictions for each point
79
+ '''
80
+ if labels is not None:
81
+ self.labels = labels
82
+
83
+ xyz = self.projection(points, calibs, transforms)
84
+ xy = xyz[:, :2, :]
85
+ z = xyz[:, 2:3, :]
86
+
87
+ in_img = (xy[:, 0] >= -1.0) & (xy[:, 0] <= 1.0) & (xy[:, 1] >= -1.0) & (xy[:, 1] <= 1.0)
88
+
89
+ z_feat = self.normalizer(z, calibs=calibs)
90
+
91
+ if self.opt.skip_hourglass:
92
+ tmpx_local_feature = self.index(self.tmpx, xy)
93
+
94
+ self.intermediate_preds_list = []
95
+
96
+ for im_feat in self.im_feat_list:
97
+ # [B, Feat_i + z, N]
98
+ point_local_feat_list = [self.index(im_feat, xy), z_feat]
99
+
100
+ if self.opt.skip_hourglass:
101
+ point_local_feat_list.append(tmpx_local_feature)
102
+
103
+ point_local_feat = torch.cat(point_local_feat_list, 1)
104
+
105
+ # out of image plane is always set to 0
106
+ pred = in_img[:,None].float() * self.surface_classifier(point_local_feat)
107
+ self.intermediate_preds_list.append(pred)
108
+
109
+ self.preds = self.intermediate_preds_list[-1]
110
+
111
+ def get_im_feat(self):
112
+ '''
113
+ Get the image filter
114
+ :return: [B, C_feat, H, W] image feature after filtering
115
+ '''
116
+ return self.im_feat_list[-1]
117
+
118
+ def get_error(self):
119
+ '''
120
+ Hourglass has its own intermediate supervision scheme
121
+ '''
122
+ error = 0
123
+ for preds in self.intermediate_preds_list:
124
+ error += self.error_term(preds, self.labels)
125
+ error /= len(self.intermediate_preds_list)
126
+
127
+ return error
128
+
129
+ def forward(self, images, points, calibs, transforms=None, labels=None):
130
+ # Get image feature
131
+ self.filter(images)
132
+
133
+ # Phase 2: point query
134
+ self.query(points=points, calibs=calibs, transforms=transforms, labels=labels)
135
+
136
+ # get the prediction
137
+ res = self.get_preds()
138
+
139
+ # get the error
140
+ error = self.get_error()
141
+
142
+ return res, error
PIFu/lib/model/ResBlkPIFuNet.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .BasePIFuNet import BasePIFuNet
5
+ import functools
6
+ from .SurfaceClassifier import SurfaceClassifier
7
+ from .DepthNormalizer import DepthNormalizer
8
+ from ..net_util import *
9
+
10
+
11
+ class ResBlkPIFuNet(BasePIFuNet):
12
+ def __init__(self, opt,
13
+ projection_mode='orthogonal'):
14
+ if opt.color_loss_type == 'l1':
15
+ error_term = nn.L1Loss()
16
+ elif opt.color_loss_type == 'mse':
17
+ error_term = nn.MSELoss()
18
+
19
+ super(ResBlkPIFuNet, self).__init__(
20
+ projection_mode=projection_mode,
21
+ error_term=error_term)
22
+
23
+ self.name = 'respifu'
24
+ self.opt = opt
25
+
26
+ norm_type = get_norm_layer(norm_type=opt.norm_color)
27
+ self.image_filter = ResnetFilter(opt, norm_layer=norm_type)
28
+
29
+ self.surface_classifier = SurfaceClassifier(
30
+ filter_channels=self.opt.mlp_dim_color,
31
+ num_views=self.opt.num_views,
32
+ no_residual=self.opt.no_residual,
33
+ last_op=nn.Tanh())
34
+
35
+ self.normalizer = DepthNormalizer(opt)
36
+
37
+ init_net(self)
38
+
39
+ def filter(self, images):
40
+ '''
41
+ Filter the input images
42
+ store all intermediate features.
43
+ :param images: [B, C, H, W] input images
44
+ '''
45
+ self.im_feat = self.image_filter(images)
46
+
47
+ def attach(self, im_feat):
48
+ self.im_feat = torch.cat([im_feat, self.im_feat], 1)
49
+
50
+ def query(self, points, calibs, transforms=None, labels=None):
51
+ '''
52
+ Given 3D points, query the network predictions for each point.
53
+ Image features should be pre-computed before this call.
54
+ store all intermediate features.
55
+ query() function may behave differently during training/testing.
56
+ :param points: [B, 3, N] world space coordinates of points
57
+ :param calibs: [B, 3, 4] calibration matrices for each image
58
+ :param transforms: Optional [B, 2, 3] image space coordinate transforms
59
+ :param labels: Optional [B, Res, N] gt labeling
60
+ :return: [B, Res, N] predictions for each point
61
+ '''
62
+ if labels is not None:
63
+ self.labels = labels
64
+
65
+ xyz = self.projection(points, calibs, transforms)
66
+ xy = xyz[:, :2, :]
67
+ z = xyz[:, 2:3, :]
68
+
69
+ z_feat = self.normalizer(z)
70
+
71
+ # This is a list of [B, Feat_i, N] features
72
+ point_local_feat_list = [self.index(self.im_feat, xy), z_feat]
73
+ # [B, Feat_all, N]
74
+ point_local_feat = torch.cat(point_local_feat_list, 1)
75
+
76
+ self.preds = self.surface_classifier(point_local_feat)
77
+
78
+ def forward(self, images, im_feat, points, calibs, transforms=None, labels=None):
79
+ self.filter(images)
80
+
81
+ self.attach(im_feat)
82
+
83
+ self.query(points, calibs, transforms, labels)
84
+
85
+ res = self.get_preds()
86
+ error = self.get_error()
87
+
88
+ return res, error
89
+
90
+ class ResnetBlock(nn.Module):
91
+ """Define a Resnet block"""
92
+
93
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False):
94
+ """Initialize the Resnet block
95
+ A resnet block is a conv block with skip connections
96
+ We construct a conv block with build_conv_block function,
97
+ and implement skip connections in <forward> function.
98
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
99
+ """
100
+ super(ResnetBlock, self).__init__()
101
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, last)
102
+
103
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False):
104
+ """Construct a convolutional block.
105
+ Parameters:
106
+ dim (int) -- the number of channels in the conv layer.
107
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
108
+ norm_layer -- normalization layer
109
+ use_dropout (bool) -- if use dropout layers.
110
+ use_bias (bool) -- if the conv layer uses bias or not
111
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
112
+ """
113
+ conv_block = []
114
+ p = 0
115
+ if padding_type == 'reflect':
116
+ conv_block += [nn.ReflectionPad2d(1)]
117
+ elif padding_type == 'replicate':
118
+ conv_block += [nn.ReplicationPad2d(1)]
119
+ elif padding_type == 'zero':
120
+ p = 1
121
+ else:
122
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
123
+
124
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
125
+ if use_dropout:
126
+ conv_block += [nn.Dropout(0.5)]
127
+
128
+ p = 0
129
+ if padding_type == 'reflect':
130
+ conv_block += [nn.ReflectionPad2d(1)]
131
+ elif padding_type == 'replicate':
132
+ conv_block += [nn.ReplicationPad2d(1)]
133
+ elif padding_type == 'zero':
134
+ p = 1
135
+ else:
136
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
137
+ if last:
138
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
139
+ else:
140
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
141
+
142
+ return nn.Sequential(*conv_block)
143
+
144
+ def forward(self, x):
145
+ """Forward function (with skip connections)"""
146
+ out = x + self.conv_block(x) # add skip connections
147
+ return out
148
+
149
+
150
+ class ResnetFilter(nn.Module):
151
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
152
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
153
+ """
154
+
155
+ def __init__(self, opt, input_nc=3, output_nc=256, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
156
+ n_blocks=6, padding_type='reflect'):
157
+ """Construct a Resnet-based generator
158
+ Parameters:
159
+ input_nc (int) -- the number of channels in input images
160
+ output_nc (int) -- the number of channels in output images
161
+ ngf (int) -- the number of filters in the last conv layer
162
+ norm_layer -- normalization layer
163
+ use_dropout (bool) -- if use dropout layers
164
+ n_blocks (int) -- the number of ResNet blocks
165
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
166
+ """
167
+ assert (n_blocks >= 0)
168
+ super(ResnetFilter, self).__init__()
169
+ if type(norm_layer) == functools.partial:
170
+ use_bias = norm_layer.func == nn.InstanceNorm2d
171
+ else:
172
+ use_bias = norm_layer == nn.InstanceNorm2d
173
+
174
+ model = [nn.ReflectionPad2d(3),
175
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
176
+ norm_layer(ngf),
177
+ nn.ReLU(True)]
178
+
179
+ n_downsampling = 2
180
+ for i in range(n_downsampling): # add downsampling layers
181
+ mult = 2 ** i
182
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
183
+ norm_layer(ngf * mult * 2),
184
+ nn.ReLU(True)]
185
+
186
+ mult = 2 ** n_downsampling
187
+ for i in range(n_blocks): # add ResNet blocks
188
+ if i == n_blocks - 1:
189
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
190
+ use_dropout=use_dropout, use_bias=use_bias, last=True)]
191
+ else:
192
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
193
+ use_dropout=use_dropout, use_bias=use_bias)]
194
+
195
+ if opt.use_tanh:
196
+ model += [nn.Tanh()]
197
+ self.model = nn.Sequential(*model)
198
+
199
+ def forward(self, input):
200
+ """Standard forward"""
201
+ return self.model(input)
PIFu/lib/model/SurfaceClassifier.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SurfaceClassifier(nn.Module):
7
+ def __init__(self, filter_channels, num_views=1, no_residual=True, last_op=None):
8
+ super(SurfaceClassifier, self).__init__()
9
+
10
+ self.filters = []
11
+ self.num_views = num_views
12
+ self.no_residual = no_residual
13
+ filter_channels = filter_channels
14
+ self.last_op = last_op
15
+
16
+ if self.no_residual:
17
+ for l in range(0, len(filter_channels) - 1):
18
+ self.filters.append(nn.Conv1d(
19
+ filter_channels[l],
20
+ filter_channels[l + 1],
21
+ 1))
22
+ self.add_module("conv%d" % l, self.filters[l])
23
+ else:
24
+ for l in range(0, len(filter_channels) - 1):
25
+ if 0 != l:
26
+ self.filters.append(
27
+ nn.Conv1d(
28
+ filter_channels[l] + filter_channels[0],
29
+ filter_channels[l + 1],
30
+ 1))
31
+ else:
32
+ self.filters.append(nn.Conv1d(
33
+ filter_channels[l],
34
+ filter_channels[l + 1],
35
+ 1))
36
+
37
+ self.add_module("conv%d" % l, self.filters[l])
38
+
39
+ def forward(self, feature):
40
+ '''
41
+
42
+ :param feature: list of [BxC_inxHxW] tensors of image features
43
+ :param xy: [Bx3xN] tensor of (x,y) coodinates in the image plane
44
+ :return: [BxC_outxN] tensor of features extracted at the coordinates
45
+ '''
46
+
47
+ y = feature
48
+ tmpy = feature
49
+ for i, f in enumerate(self.filters):
50
+ if self.no_residual:
51
+ y = self._modules['conv' + str(i)](y)
52
+ else:
53
+ y = self._modules['conv' + str(i)](
54
+ y if i == 0
55
+ else torch.cat([y, tmpy], 1)
56
+ )
57
+ if i != len(self.filters) - 1:
58
+ y = F.leaky_relu(y)
59
+
60
+ if self.num_views > 1 and i == len(self.filters) // 2:
61
+ y = y.view(
62
+ -1, self.num_views, y.shape[1], y.shape[2]
63
+ ).mean(dim=1)
64
+ tmpy = feature.view(
65
+ -1, self.num_views, feature.shape[1], feature.shape[2]
66
+ ).mean(dim=1)
67
+
68
+ if self.last_op:
69
+ y = self.last_op(y)
70
+
71
+ return y
PIFu/lib/model/VhullPIFuNet.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .BasePIFuNet import BasePIFuNet
5
+
6
+
7
+ class VhullPIFuNet(BasePIFuNet):
8
+ '''
9
+ Vhull Piximp network is a minimal network demonstrating how the template works
10
+ also, it helps debugging the training/test schemes
11
+ It does the following:
12
+ 1. Compute the masks of images and stores under self.im_feats
13
+ 2. Calculate calibration and indexing
14
+ 3. Return if the points fall into the intersection of all masks
15
+ '''
16
+
17
+ def __init__(self,
18
+ num_views,
19
+ projection_mode='orthogonal',
20
+ error_term=nn.MSELoss(),
21
+ ):
22
+ super(VhullPIFuNet, self).__init__(
23
+ projection_mode=projection_mode,
24
+ error_term=error_term)
25
+ self.name = 'vhull'
26
+
27
+ self.num_views = num_views
28
+
29
+ self.im_feat = None
30
+
31
+ def filter(self, images):
32
+ '''
33
+ Filter the input images
34
+ store all intermediate features.
35
+ :param images: [B, C, H, W] input images
36
+ '''
37
+ # If the image has alpha channel, use the alpha channel
38
+ if images.shape[1] > 3:
39
+ self.im_feat = images[:, 3:4, :, :]
40
+ # Else, tell if it's not white
41
+ else:
42
+ self.im_feat = images[:, 0:1, :, :]
43
+
44
+ def query(self, points, calibs, transforms=None, labels=None):
45
+ '''
46
+ Given 3D points, query the network predictions for each point.
47
+ Image features should be pre-computed before this call.
48
+ store all intermediate features.
49
+ query() function may behave differently during training/testing.
50
+ :param points: [B, 3, N] world space coordinates of points
51
+ :param calibs: [B, 3, 4] calibration matrices for each image
52
+ :param transforms: Optional [B, 2, 3] image space coordinate transforms
53
+ :param labels: Optional [B, Res, N] gt labeling
54
+ :return: [B, Res, N] predictions for each point
55
+ '''
56
+ if labels is not None:
57
+ self.labels = labels
58
+
59
+ xyz = self.projection(points, calibs, transforms)
60
+ xy = xyz[:, :2, :]
61
+
62
+ point_local_feat = self.index(self.im_feat, xy)
63
+ local_shape = point_local_feat.shape
64
+ point_feat = point_local_feat.view(
65
+ local_shape[0] // self.num_views,
66
+ local_shape[1] * self.num_views,
67
+ -1)
68
+ pred = torch.prod(point_feat, dim=1)
69
+
70
+ self.preds = pred.unsqueeze(1)
PIFu/lib/model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .BasePIFuNet import BasePIFuNet
2
+ from .VhullPIFuNet import VhullPIFuNet
3
+ from .ConvPIFuNet import ConvPIFuNet
4
+ from .HGPIFuNet import HGPIFuNet
5
+ from .ResBlkPIFuNet import ResBlkPIFuNet
PIFu/lib/net_util.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import init
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import functools
6
+
7
+ import numpy as np
8
+ from .mesh_util import *
9
+ from .sample_util import *
10
+ from .geometry import index
11
+ import cv2
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+
15
+
16
+ def reshape_multiview_tensors(image_tensor, calib_tensor):
17
+ # Careful here! Because we put single view and multiview together,
18
+ # the returned tensor.shape is 5-dim: [B, num_views, C, W, H]
19
+ # So we need to convert it back to 4-dim [B*num_views, C, W, H]
20
+ # Don't worry classifier will handle multi-view cases
21
+ image_tensor = image_tensor.view(
22
+ image_tensor.shape[0] * image_tensor.shape[1],
23
+ image_tensor.shape[2],
24
+ image_tensor.shape[3],
25
+ image_tensor.shape[4]
26
+ )
27
+ calib_tensor = calib_tensor.view(
28
+ calib_tensor.shape[0] * calib_tensor.shape[1],
29
+ calib_tensor.shape[2],
30
+ calib_tensor.shape[3]
31
+ )
32
+
33
+ return image_tensor, calib_tensor
34
+
35
+
36
+ def reshape_sample_tensor(sample_tensor, num_views):
37
+ if num_views == 1:
38
+ return sample_tensor
39
+ # Need to repeat sample_tensor along the batch dim num_views times
40
+ sample_tensor = sample_tensor.unsqueeze(dim=1)
41
+ sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
42
+ sample_tensor = sample_tensor.view(
43
+ sample_tensor.shape[0] * sample_tensor.shape[1],
44
+ sample_tensor.shape[2],
45
+ sample_tensor.shape[3]
46
+ )
47
+ return sample_tensor
48
+
49
+
50
+ def gen_mesh(opt, net, cuda, data, save_path, use_octree=True):
51
+ image_tensor = data['img'].to(device=cuda)
52
+ calib_tensor = data['calib'].to(device=cuda)
53
+
54
+ net.filter(image_tensor)
55
+
56
+ b_min = data['b_min']
57
+ b_max = data['b_max']
58
+ try:
59
+ save_img_path = save_path[:-4] + '.png'
60
+ save_img_list = []
61
+ for v in range(image_tensor.shape[0]):
62
+ save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
63
+ save_img_list.append(save_img)
64
+ save_img = np.concatenate(save_img_list, axis=1)
65
+ Image.fromarray(np.uint8(save_img[:,:,::-1])).save(save_img_path)
66
+
67
+ verts, faces, _, _ = reconstruction(
68
+ net, cuda, calib_tensor, opt.resolution, b_min, b_max, use_octree=use_octree)
69
+ verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(device=cuda).float()
70
+ xyz_tensor = net.projection(verts_tensor, calib_tensor[:1])
71
+ uv = xyz_tensor[:, :2, :]
72
+ color = index(image_tensor[:1], uv).detach().cpu().numpy()[0].T
73
+ color = color * 0.5 + 0.5
74
+ save_obj_mesh_with_color(save_path, verts, faces, color)
75
+ except Exception as e:
76
+ print(e)
77
+ print('Can not create marching cubes at this time.')
78
+
79
+ def gen_mesh_color(opt, netG, netC, cuda, data, save_path, use_octree=True):
80
+ image_tensor = data['img'].to(device=cuda)
81
+ calib_tensor = data['calib'].to(device=cuda)
82
+
83
+ netG.filter(image_tensor)
84
+ netC.filter(image_tensor)
85
+ netC.attach(netG.get_im_feat())
86
+
87
+ b_min = data['b_min']
88
+ b_max = data['b_max']
89
+ try:
90
+ save_img_path = save_path[:-4] + '.png'
91
+ save_img_list = []
92
+ for v in range(image_tensor.shape[0]):
93
+ save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
94
+ save_img_list.append(save_img)
95
+ save_img = np.concatenate(save_img_list, axis=1)
96
+ Image.fromarray(np.uint8(save_img[:,:,::-1])).save(save_img_path)
97
+
98
+ verts, faces, _, _ = reconstruction(
99
+ netG, cuda, calib_tensor, opt.resolution, b_min, b_max, use_octree=use_octree)
100
+
101
+ # Now Getting colors
102
+ verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(device=cuda).float()
103
+ verts_tensor = reshape_sample_tensor(verts_tensor, opt.num_views)
104
+
105
+ color = np.zeros(verts.shape)
106
+ interval = opt.num_sample_color
107
+ for i in range(len(color) // interval):
108
+ left = i * interval
109
+ right = i * interval + interval
110
+ if i == len(color) // interval - 1:
111
+ right = -1
112
+ netC.query(verts_tensor[:, :, left:right], calib_tensor)
113
+ rgb = netC.get_preds()[0].detach().cpu().numpy() * 0.5 + 0.5
114
+ color[left:right] = rgb.T
115
+
116
+ save_obj_mesh_with_color(save_path, verts, faces, color)
117
+ except Exception as e:
118
+ print(e)
119
+ print('Can not create marching cubes at this time.')
120
+
121
+ def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
122
+ """Sets the learning rate to the initial LR decayed by schedule"""
123
+ if epoch in schedule:
124
+ lr *= gamma
125
+ for param_group in optimizer.param_groups:
126
+ param_group['lr'] = lr
127
+ return lr
128
+
129
+
130
+ def compute_acc(pred, gt, thresh=0.5):
131
+ '''
132
+ return:
133
+ IOU, precision, and recall
134
+ '''
135
+ with torch.no_grad():
136
+ vol_pred = pred > thresh
137
+ vol_gt = gt > thresh
138
+
139
+ union = vol_pred | vol_gt
140
+ inter = vol_pred & vol_gt
141
+
142
+ true_pos = inter.sum().float()
143
+
144
+ union = union.sum().float()
145
+ if union == 0:
146
+ union = 1
147
+ vol_pred = vol_pred.sum().float()
148
+ if vol_pred == 0:
149
+ vol_pred = 1
150
+ vol_gt = vol_gt.sum().float()
151
+ if vol_gt == 0:
152
+ vol_gt = 1
153
+ return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
154
+
155
+
156
+ def calc_error(opt, net, cuda, dataset, num_tests):
157
+ if num_tests > len(dataset):
158
+ num_tests = len(dataset)
159
+ with torch.no_grad():
160
+ erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
161
+ for idx in tqdm(range(num_tests)):
162
+ data = dataset[idx * len(dataset) // num_tests]
163
+ # retrieve the data
164
+ image_tensor = data['img'].to(device=cuda)
165
+ calib_tensor = data['calib'].to(device=cuda)
166
+ sample_tensor = data['samples'].to(device=cuda).unsqueeze(0)
167
+ if opt.num_views > 1:
168
+ sample_tensor = reshape_sample_tensor(sample_tensor, opt.num_views)
169
+ label_tensor = data['labels'].to(device=cuda).unsqueeze(0)
170
+
171
+ res, error = net.forward(image_tensor, sample_tensor, calib_tensor, labels=label_tensor)
172
+
173
+ IOU, prec, recall = compute_acc(res, label_tensor)
174
+
175
+ # print(
176
+ # '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
177
+ # .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
178
+ erorr_arr.append(error.item())
179
+ IOU_arr.append(IOU.item())
180
+ prec_arr.append(prec.item())
181
+ recall_arr.append(recall.item())
182
+
183
+ return np.average(erorr_arr), np.average(IOU_arr), np.average(prec_arr), np.average(recall_arr)
184
+
185
+ def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
186
+ if num_tests > len(dataset):
187
+ num_tests = len(dataset)
188
+ with torch.no_grad():
189
+ error_color_arr = []
190
+
191
+ for idx in tqdm(range(num_tests)):
192
+ data = dataset[idx * len(dataset) // num_tests]
193
+ # retrieve the data
194
+ image_tensor = data['img'].to(device=cuda)
195
+ calib_tensor = data['calib'].to(device=cuda)
196
+ color_sample_tensor = data['color_samples'].to(device=cuda).unsqueeze(0)
197
+
198
+ if opt.num_views > 1:
199
+ color_sample_tensor = reshape_sample_tensor(color_sample_tensor, opt.num_views)
200
+
201
+ rgb_tensor = data['rgbs'].to(device=cuda).unsqueeze(0)
202
+
203
+ netG.filter(image_tensor)
204
+ _, errorC = netC.forward(image_tensor, netG.get_im_feat(), color_sample_tensor, calib_tensor, labels=rgb_tensor)
205
+
206
+ # print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
207
+ # .format(idx, num_tests, errorG.item(), errorC.item()))
208
+ error_color_arr.append(errorC.item())
209
+
210
+ return np.average(error_color_arr)
211
+
212
+
213
+ def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
214
+ "3x3 convolution with padding"
215
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
216
+ stride=strd, padding=padding, bias=bias)
217
+
218
+ def init_weights(net, init_type='normal', init_gain=0.02):
219
+ """Initialize network weights.
220
+
221
+ Parameters:
222
+ net (network) -- network to be initialized
223
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
224
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
225
+
226
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
227
+ work better for some applications. Feel free to try yourself.
228
+ """
229
+
230
+ def init_func(m): # define the initialization function
231
+ classname = m.__class__.__name__
232
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
233
+ if init_type == 'normal':
234
+ init.normal_(m.weight.data, 0.0, init_gain)
235
+ elif init_type == 'xavier':
236
+ init.xavier_normal_(m.weight.data, gain=init_gain)
237
+ elif init_type == 'kaiming':
238
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
239
+ elif init_type == 'orthogonal':
240
+ init.orthogonal_(m.weight.data, gain=init_gain)
241
+ else:
242
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
243
+ if hasattr(m, 'bias') and m.bias is not None:
244
+ init.constant_(m.bias.data, 0.0)
245
+ elif classname.find(
246
+ 'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
247
+ init.normal_(m.weight.data, 1.0, init_gain)
248
+ init.constant_(m.bias.data, 0.0)
249
+
250
+ print('initialize network with %s' % init_type)
251
+ net.apply(init_func) # apply the initialization function <init_func>
252
+
253
+
254
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
255
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
256
+ Parameters:
257
+ net (network) -- the network to be initialized
258
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
259
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
260
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
261
+
262
+ Return an initialized network.
263
+ """
264
+ if len(gpu_ids) > 0:
265
+ assert (torch.cuda.is_available())
266
+ net.to(gpu_ids[0])
267
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
268
+ init_weights(net, init_type, init_gain=init_gain)
269
+ return net
270
+
271
+
272
+ def imageSpaceRotation(xy, rot):
273
+ '''
274
+ args:
275
+ xy: (B, 2, N) input
276
+ rot: (B, 2) x,y axis rotation angles
277
+
278
+ rotation center will be always image center (other rotation center can be represented by additional z translation)
279
+ '''
280
+ disp = rot.unsqueeze(2).sin().expand_as(xy)
281
+ return (disp * xy).sum(dim=1)
282
+
283
+
284
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
285
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
286
+
287
+ Arguments:
288
+ netD (network) -- discriminator network
289
+ real_data (tensor array) -- real images
290
+ fake_data (tensor array) -- generated images from the generator
291
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
292
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
293
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
294
+ lambda_gp (float) -- weight for this loss
295
+
296
+ Returns the gradient penalty loss
297
+ """
298
+ if lambda_gp > 0.0:
299
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
300
+ interpolatesv = real_data
301
+ elif type == 'fake':
302
+ interpolatesv = fake_data
303
+ elif type == 'mixed':
304
+ alpha = torch.rand(real_data.shape[0], 1)
305
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(
306
+ *real_data.shape)
307
+ alpha = alpha.to(device)
308
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
309
+ else:
310
+ raise NotImplementedError('{} not implemented'.format(type))
311
+ interpolatesv.requires_grad_(True)
312
+ disc_interpolates = netD(interpolatesv)
313
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
314
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
315
+ create_graph=True, retain_graph=True, only_inputs=True)
316
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
317
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
318
+ return gradient_penalty, gradients
319
+ else:
320
+ return 0.0, None
321
+
322
+ def get_norm_layer(norm_type='instance'):
323
+ """Return a normalization layer
324
+ Parameters:
325
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
326
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
327
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
328
+ """
329
+ if norm_type == 'batch':
330
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
331
+ elif norm_type == 'instance':
332
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
333
+ elif norm_type == 'group':
334
+ norm_layer = functools.partial(nn.GroupNorm, 32)
335
+ elif norm_type == 'none':
336
+ norm_layer = None
337
+ else:
338
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
339
+ return norm_layer
340
+
341
+ class Flatten(nn.Module):
342
+ def forward(self, input):
343
+ return input.view(input.size(0), -1)
344
+
345
+ class ConvBlock(nn.Module):
346
+ def __init__(self, in_planes, out_planes, norm='batch'):
347
+ super(ConvBlock, self).__init__()
348
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
349
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
350
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
351
+
352
+ if norm == 'batch':
353
+ self.bn1 = nn.BatchNorm2d(in_planes)
354
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
355
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
356
+ self.bn4 = nn.BatchNorm2d(in_planes)
357
+ elif norm == 'group':
358
+ self.bn1 = nn.GroupNorm(32, in_planes)
359
+ self.bn2 = nn.GroupNorm(32, int(out_planes / 2))
360
+ self.bn3 = nn.GroupNorm(32, int(out_planes / 4))
361
+ self.bn4 = nn.GroupNorm(32, in_planes)
362
+
363
+ if in_planes != out_planes:
364
+ self.downsample = nn.Sequential(
365
+ self.bn4,
366
+ nn.ReLU(True),
367
+ nn.Conv2d(in_planes, out_planes,
368
+ kernel_size=1, stride=1, bias=False),
369
+ )
370
+ else:
371
+ self.downsample = None
372
+
373
+ def forward(self, x):
374
+ residual = x
375
+
376
+ out1 = self.bn1(x)
377
+ out1 = F.relu(out1, True)
378
+ out1 = self.conv1(out1)
379
+
380
+ out2 = self.bn2(out1)
381
+ out2 = F.relu(out2, True)
382
+ out2 = self.conv2(out2)
383
+
384
+ out3 = self.bn3(out2)
385
+ out3 = F.relu(out3, True)
386
+ out3 = self.conv3(out3)
387
+
388
+ out3 = torch.cat((out1, out2, out3), 1)
389
+
390
+ if self.downsample is not None:
391
+ residual = self.downsample(residual)
392
+
393
+ out3 += residual
394
+
395
+ return out3
396
+
PIFu/lib/options.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+
5
+ class BaseOptions():
6
+ def __init__(self):
7
+ self.initialized = False
8
+ argparse
9
+ def initialize(self, parser):
10
+ # Datasets related
11
+ g_data = parser.add_argument_group('Data')
12
+ g_data.add_argument('--dataroot', type=str, default='./data',
13
+ help='path to images (data folder)')
14
+
15
+ g_data.add_argument('--loadSize', type=int, default=512, help='load size of input image')
16
+
17
+ # Experiment related
18
+ g_exp = parser.add_argument_group('Experiment')
19
+ g_exp.add_argument('--name', type=str, default='example',
20
+ help='name of the experiment. It decides where to store samples and models')
21
+ g_exp.add_argument('--debug', action='store_true', help='debug mode or not')
22
+
23
+ g_exp.add_argument('--num_views', type=int, default=1, help='How many views to use for multiview network.')
24
+ g_exp.add_argument('--random_multiview', action='store_true', help='Select random multiview combination.')
25
+
26
+ # Training related
27
+ g_train = parser.add_argument_group('Training')
28
+ g_train.add_argument('--gpu_id', type=int, default=0, help='gpu id for cuda')
29
+ g_train.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, -1 for CPU mode')
30
+
31
+ g_train.add_argument('--num_threads', default=1, type=int, help='# sthreads for loading data')
32
+ g_train.add_argument('--serial_batches', action='store_true',
33
+ help='if true, takes images in order to make batches, otherwise takes them randomly')
34
+ g_train.add_argument('--pin_memory', action='store_true', help='pin_memory')
35
+
36
+ g_train.add_argument('--batch_size', type=int, default=2, help='input batch size')
37
+ g_train.add_argument('--learning_rate', type=float, default=1e-3, help='adam learning rate')
38
+ g_train.add_argument('--learning_rateC', type=float, default=1e-3, help='adam learning rate')
39
+ g_train.add_argument('--num_epoch', type=int, default=100, help='num epoch to train')
40
+
41
+ g_train.add_argument('--freq_plot', type=int, default=10, help='freqency of the error plot')
42
+ g_train.add_argument('--freq_save', type=int, default=50, help='freqency of the save_checkpoints')
43
+ g_train.add_argument('--freq_save_ply', type=int, default=100, help='freqency of the save ply')
44
+
45
+ g_train.add_argument('--no_gen_mesh', action='store_true')
46
+ g_train.add_argument('--no_num_eval', action='store_true')
47
+
48
+ g_train.add_argument('--resume_epoch', type=int, default=-1, help='epoch resuming the training')
49
+ g_train.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
50
+
51
+ # Testing related
52
+ g_test = parser.add_argument_group('Testing')
53
+ g_test.add_argument('--resolution', type=int, default=256, help='# of grid in mesh reconstruction')
54
+ g_test.add_argument('--test_folder_path', type=str, default=None, help='the folder of test image')
55
+
56
+ # Sampling related
57
+ g_sample = parser.add_argument_group('Sampling')
58
+ g_sample.add_argument('--sigma', type=float, default=5.0, help='perturbation standard deviation for positions')
59
+
60
+ g_sample.add_argument('--num_sample_inout', type=int, default=5000, help='# of sampling points')
61
+ g_sample.add_argument('--num_sample_color', type=int, default=0, help='# of sampling points')
62
+
63
+ g_sample.add_argument('--z_size', type=float, default=200.0, help='z normalization factor')
64
+
65
+ # Model related
66
+ g_model = parser.add_argument_group('Model')
67
+ # General
68
+ g_model.add_argument('--norm', type=str, default='group',
69
+ help='instance normalization or batch normalization or group normalization')
70
+ g_model.add_argument('--norm_color', type=str, default='instance',
71
+ help='instance normalization or batch normalization or group normalization')
72
+
73
+ # hg filter specify
74
+ g_model.add_argument('--num_stack', type=int, default=4, help='# of hourglass')
75
+ g_model.add_argument('--num_hourglass', type=int, default=2, help='# of stacked layer of hourglass')
76
+ g_model.add_argument('--skip_hourglass', action='store_true', help='skip connection in hourglass')
77
+ g_model.add_argument('--hg_down', type=str, default='ave_pool', help='ave pool || conv64 || conv128')
78
+ g_model.add_argument('--hourglass_dim', type=int, default='256', help='256 | 512')
79
+
80
+ # Classification General
81
+ g_model.add_argument('--mlp_dim', nargs='+', default=[257, 1024, 512, 256, 128, 1], type=int,
82
+ help='# of dimensions of mlp')
83
+ g_model.add_argument('--mlp_dim_color', nargs='+', default=[513, 1024, 512, 256, 128, 3],
84
+ type=int, help='# of dimensions of color mlp')
85
+
86
+ g_model.add_argument('--use_tanh', action='store_true',
87
+ help='using tanh after last conv of image_filter network')
88
+
89
+ # for train
90
+ parser.add_argument('--random_flip', action='store_true', help='if random flip')
91
+ parser.add_argument('--random_trans', action='store_true', help='if random flip')
92
+ parser.add_argument('--random_scale', action='store_true', help='if random flip')
93
+ parser.add_argument('--no_residual', action='store_true', help='no skip connection in mlp')
94
+ parser.add_argument('--schedule', type=int, nargs='+', default=[60, 80],
95
+ help='Decrease learning rate at these epochs.')
96
+ parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
97
+ parser.add_argument('--color_loss_type', type=str, default='l1', help='mse | l1')
98
+
99
+ # for eval
100
+ parser.add_argument('--val_test_error', action='store_true', help='validate errors of test data')
101
+ parser.add_argument('--val_train_error', action='store_true', help='validate errors of train data')
102
+ parser.add_argument('--gen_test_mesh', action='store_true', help='generate test mesh')
103
+ parser.add_argument('--gen_train_mesh', action='store_true', help='generate train mesh')
104
+ parser.add_argument('--all_mesh', action='store_true', help='generate meshs from all hourglass output')
105
+ parser.add_argument('--num_gen_mesh_test', type=int, default=1,
106
+ help='how many meshes to generate during testing')
107
+
108
+ # path
109
+ parser.add_argument('--checkpoints_path', type=str, default='./checkpoints', help='path to save checkpoints')
110
+ parser.add_argument('--load_netG_checkpoint_path', type=str, default=None, help='path to save checkpoints')
111
+ parser.add_argument('--load_netC_checkpoint_path', type=str, default=None, help='path to save checkpoints')
112
+ parser.add_argument('--results_path', type=str, default='./results', help='path to save results ply')
113
+ parser.add_argument('--load_checkpoint_path', type=str, help='path to save results ply')
114
+ parser.add_argument('--single', type=str, default='', help='single data for training')
115
+ # for single image reconstruction
116
+ parser.add_argument('--mask_path', type=str, help='path for input mask')
117
+ parser.add_argument('--img_path', type=str, help='path for input image')
118
+
119
+ # aug
120
+ group_aug = parser.add_argument_group('aug')
121
+ group_aug.add_argument('--aug_alstd', type=float, default=0.0, help='augmentation pca lighting alpha std')
122
+ group_aug.add_argument('--aug_bri', type=float, default=0.0, help='augmentation brightness')
123
+ group_aug.add_argument('--aug_con', type=float, default=0.0, help='augmentation contrast')
124
+ group_aug.add_argument('--aug_sat', type=float, default=0.0, help='augmentation saturation')
125
+ group_aug.add_argument('--aug_hue', type=float, default=0.0, help='augmentation hue')
126
+ group_aug.add_argument('--aug_blur', type=float, default=0.0, help='augmentation blur')
127
+
128
+ # special tasks
129
+ self.initialized = True
130
+ return parser
131
+
132
+ def gather_options(self):
133
+ # initialize parser with basic options
134
+ if not self.initialized:
135
+ parser = argparse.ArgumentParser(
136
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
137
+ parser = self.initialize(parser)
138
+
139
+ self.parser = parser
140
+
141
+ return parser.parse_args()
142
+
143
+ def print_options(self, opt):
144
+ message = ''
145
+ message += '----------------- Options ---------------\n'
146
+ for k, v in sorted(vars(opt).items()):
147
+ comment = ''
148
+ default = self.parser.get_default(k)
149
+ if v != default:
150
+ comment = '\t[default: %s]' % str(default)
151
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
152
+ message += '----------------- End -------------------'
153
+ print(message)
154
+
155
+ def parse(self):
156
+ opt = self.gather_options()
157
+ return opt
158
+
159
+ def parse_to_dict(self):
160
+ opt = self.gather_options()
161
+ return opt.__dict__
PIFu/lib/renderer/__init__.py ADDED
File without changes
PIFu/lib/renderer/camera.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from .glm import ortho
5
+
6
+
7
+ class Camera:
8
+ def __init__(self, width=1600, height=1200):
9
+ # Focal Length
10
+ # equivalent 50mm
11
+ focal = np.sqrt(width * width + height * height)
12
+ self.focal_x = focal
13
+ self.focal_y = focal
14
+ # Principal Point Offset
15
+ self.principal_x = width / 2
16
+ self.principal_y = height / 2
17
+ # Axis Skew
18
+ self.skew = 0
19
+ # Image Size
20
+ self.width = width
21
+ self.height = height
22
+
23
+ self.near = 1
24
+ self.far = 10
25
+
26
+ # Camera Center
27
+ self.center = np.array([0, 0, 1.6])
28
+ self.direction = np.array([0, 0, -1])
29
+ self.right = np.array([1, 0, 0])
30
+ self.up = np.array([0, 1, 0])
31
+
32
+ self.ortho_ratio = None
33
+
34
+ def sanity_check(self):
35
+ self.center = self.center.reshape([-1])
36
+ self.direction = self.direction.reshape([-1])
37
+ self.right = self.right.reshape([-1])
38
+ self.up = self.up.reshape([-1])
39
+
40
+ assert len(self.center) == 3
41
+ assert len(self.direction) == 3
42
+ assert len(self.right) == 3
43
+ assert len(self.up) == 3
44
+
45
+ @staticmethod
46
+ def normalize_vector(v):
47
+ v_norm = np.linalg.norm(v)
48
+ return v if v_norm == 0 else v / v_norm
49
+
50
+ def get_real_z_value(self, z):
51
+ z_near = self.near
52
+ z_far = self.far
53
+ z_n = 2.0 * z - 1.0
54
+ z_e = 2.0 * z_near * z_far / (z_far + z_near - z_n * (z_far - z_near))
55
+ return z_e
56
+
57
+ def get_rotation_matrix(self):
58
+ rot_mat = np.eye(3)
59
+ s = self.right
60
+ s = self.normalize_vector(s)
61
+ rot_mat[0, :] = s
62
+ u = self.up
63
+ u = self.normalize_vector(u)
64
+ rot_mat[1, :] = -u
65
+ rot_mat[2, :] = self.normalize_vector(self.direction)
66
+
67
+ return rot_mat
68
+
69
+ def get_translation_vector(self):
70
+ rot_mat = self.get_rotation_matrix()
71
+ trans = -np.dot(rot_mat, self.center)
72
+ return trans
73
+
74
+ def get_intrinsic_matrix(self):
75
+ int_mat = np.eye(3)
76
+
77
+ int_mat[0, 0] = self.focal_x
78
+ int_mat[1, 1] = self.focal_y
79
+ int_mat[0, 1] = self.skew
80
+ int_mat[0, 2] = self.principal_x
81
+ int_mat[1, 2] = self.principal_y
82
+
83
+ return int_mat
84
+
85
+ def get_projection_matrix(self):
86
+ ext_mat = self.get_extrinsic_matrix()
87
+ int_mat = self.get_intrinsic_matrix()
88
+
89
+ return np.matmul(int_mat, ext_mat)
90
+
91
+ def get_extrinsic_matrix(self):
92
+ rot_mat = self.get_rotation_matrix()
93
+ int_mat = self.get_intrinsic_matrix()
94
+ trans = self.get_translation_vector()
95
+
96
+ extrinsic = np.eye(4)
97
+ extrinsic[:3, :3] = rot_mat
98
+ extrinsic[:3, 3] = trans
99
+
100
+ return extrinsic[:3, :]
101
+
102
+ def set_rotation_matrix(self, rot_mat):
103
+ self.direction = rot_mat[2, :]
104
+ self.up = -rot_mat[1, :]
105
+ self.right = rot_mat[0, :]
106
+
107
+ def set_intrinsic_matrix(self, int_mat):
108
+ self.focal_x = int_mat[0, 0]
109
+ self.focal_y = int_mat[1, 1]
110
+ self.skew = int_mat[0, 1]
111
+ self.principal_x = int_mat[0, 2]
112
+ self.principal_y = int_mat[1, 2]
113
+
114
+ def set_projection_matrix(self, proj_mat):
115
+ res = cv2.decomposeProjectionMatrix(proj_mat)
116
+ int_mat, rot_mat, camera_center_homo = res[0], res[1], res[2]
117
+ camera_center = camera_center_homo[0:3] / camera_center_homo[3]
118
+ camera_center = camera_center.reshape(-1)
119
+ int_mat = int_mat / int_mat[2][2]
120
+
121
+ self.set_intrinsic_matrix(int_mat)
122
+ self.set_rotation_matrix(rot_mat)
123
+ self.center = camera_center
124
+
125
+ self.sanity_check()
126
+
127
+ def get_gl_matrix(self):
128
+ z_near = self.near
129
+ z_far = self.far
130
+ rot_mat = self.get_rotation_matrix()
131
+ int_mat = self.get_intrinsic_matrix()
132
+ trans = self.get_translation_vector()
133
+
134
+ extrinsic = np.eye(4)
135
+ extrinsic[:3, :3] = rot_mat
136
+ extrinsic[:3, 3] = trans
137
+ axis_adj = np.eye(4)
138
+ axis_adj[2, 2] = -1
139
+ axis_adj[1, 1] = -1
140
+ model_view = np.matmul(axis_adj, extrinsic)
141
+
142
+ projective = np.zeros([4, 4])
143
+ projective[:2, :2] = int_mat[:2, :2]
144
+ projective[:2, 2:3] = -int_mat[:2, 2:3]
145
+ projective[3, 2] = -1
146
+ projective[2, 2] = (z_near + z_far)
147
+ projective[2, 3] = (z_near * z_far)
148
+
149
+ if self.ortho_ratio is None:
150
+ ndc = ortho(0, self.width, 0, self.height, z_near, z_far)
151
+ perspective = np.matmul(ndc, projective)
152
+ else:
153
+ perspective = ortho(-self.width * self.ortho_ratio / 2, self.width * self.ortho_ratio / 2,
154
+ -self.height * self.ortho_ratio / 2, self.height * self.ortho_ratio / 2,
155
+ z_near, z_far)
156
+
157
+ return perspective, model_view
158
+
159
+
160
+ def KRT_from_P(proj_mat, normalize_K=True):
161
+ res = cv2.decomposeProjectionMatrix(proj_mat)
162
+ K, Rot, camera_center_homog = res[0], res[1], res[2]
163
+ camera_center = camera_center_homog[0:3] / camera_center_homog[3]
164
+ trans = -Rot.dot(camera_center)
165
+ if normalize_K:
166
+ K = K / K[2][2]
167
+ return K, Rot, trans
168
+
169
+
170
+ def MVP_from_P(proj_mat, width, height, near=0.1, far=10000):
171
+ '''
172
+ Convert OpenCV camera calibration matrix to OpenGL projection and model view matrix
173
+ :param proj_mat: OpenCV camera projeciton matrix
174
+ :param width: Image width
175
+ :param height: Image height
176
+ :param near: Z near value
177
+ :param far: Z far value
178
+ :return: OpenGL projection matrix and model view matrix
179
+ '''
180
+ res = cv2.decomposeProjectionMatrix(proj_mat)
181
+ K, Rot, camera_center_homog = res[0], res[1], res[2]
182
+ camera_center = camera_center_homog[0:3] / camera_center_homog[3]
183
+ trans = -Rot.dot(camera_center)
184
+ K = K / K[2][2]
185
+
186
+ extrinsic = np.eye(4)
187
+ extrinsic[:3, :3] = Rot
188
+ extrinsic[:3, 3:4] = trans
189
+ axis_adj = np.eye(4)
190
+ axis_adj[2, 2] = -1
191
+ axis_adj[1, 1] = -1
192
+ model_view = np.matmul(axis_adj, extrinsic)
193
+
194
+ zFar = far
195
+ zNear = near
196
+ projective = np.zeros([4, 4])
197
+ projective[:2, :2] = K[:2, :2]
198
+ projective[:2, 2:3] = -K[:2, 2:3]
199
+ projective[3, 2] = -1
200
+ projective[2, 2] = (zNear + zFar)
201
+ projective[2, 3] = (zNear * zFar)
202
+
203
+ ndc = ortho(0, width, 0, height, zNear, zFar)
204
+
205
+ perspective = np.matmul(ndc, projective)
206
+
207
+ return perspective, model_view
PIFu/lib/renderer/gl/__init__.py ADDED
File without changes
PIFu/lib/renderer/gl/cam_render.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .render import Render
2
+
3
+ GLUT = None
4
+
5
+ class CamRender(Render):
6
+ def __init__(self, width=1600, height=1200, name='Cam Renderer',
7
+ program_files=['simple.fs', 'simple.vs'], color_size=1, ms_rate=1, egl=False):
8
+ Render.__init__(self, width, height, name, program_files, color_size, ms_rate=ms_rate, egl=egl)
9
+ self.camera = None
10
+
11
+ if not egl:
12
+ global GLUT
13
+ import OpenGL.GLUT as GLUT
14
+ GLUT.glutDisplayFunc(self.display)
15
+ GLUT.glutKeyboardFunc(self.keyboard)
16
+
17
+ def set_camera(self, camera):
18
+ self.camera = camera
19
+ self.projection_matrix, self.model_view_matrix = camera.get_gl_matrix()
20
+
21
+ def keyboard(self, key, x, y):
22
+ # up
23
+ eps = 1
24
+ # print(key)
25
+ if key == b'w':
26
+ self.camera.center += eps * self.camera.direction
27
+ elif key == b's':
28
+ self.camera.center -= eps * self.camera.direction
29
+ if key == b'a':
30
+ self.camera.center -= eps * self.camera.right
31
+ elif key == b'd':
32
+ self.camera.center += eps * self.camera.right
33
+ if key == b' ':
34
+ self.camera.center += eps * self.camera.up
35
+ elif key == b'x':
36
+ self.camera.center -= eps * self.camera.up
37
+ elif key == b'i':
38
+ self.camera.near += 0.1 * eps
39
+ self.camera.far += 0.1 * eps
40
+ elif key == b'o':
41
+ self.camera.near -= 0.1 * eps
42
+ self.camera.far -= 0.1 * eps
43
+
44
+ self.projection_matrix, self.model_view_matrix = self.camera.get_gl_matrix()
45
+
46
+ def show(self):
47
+ if GLUT is not None:
48
+ GLUT.glutMainLoop()
PIFu/lib/renderer/gl/data/prt.fs ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 330
2
+
3
+ uniform vec3 SHCoeffs[9];
4
+ uniform uint analytic;
5
+
6
+ uniform uint hasNormalMap;
7
+ uniform uint hasAlbedoMap;
8
+
9
+ uniform sampler2D AlbedoMap;
10
+ uniform sampler2D NormalMap;
11
+
12
+ in VertexData {
13
+ vec3 Position;
14
+ vec3 Depth;
15
+ vec3 ModelNormal;
16
+ vec2 Texcoord;
17
+ vec3 Tangent;
18
+ vec3 Bitangent;
19
+ vec3 PRT1;
20
+ vec3 PRT2;
21
+ vec3 PRT3;
22
+ } VertexIn;
23
+
24
+ layout (location = 0) out vec4 FragColor;
25
+ layout (location = 1) out vec4 FragNormal;
26
+ layout (location = 2) out vec4 FragPosition;
27
+ layout (location = 3) out vec4 FragAlbedo;
28
+ layout (location = 4) out vec4 FragShading;
29
+ layout (location = 5) out vec4 FragPRT1;
30
+ layout (location = 6) out vec4 FragPRT2;
31
+ layout (location = 7) out vec4 FragPRT3;
32
+
33
+ vec4 gammaCorrection(vec4 vec, float g)
34
+ {
35
+ return vec4(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g), vec.w);
36
+ }
37
+
38
+ vec3 gammaCorrection(vec3 vec, float g)
39
+ {
40
+ return vec3(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g));
41
+ }
42
+
43
+ void evaluateH(vec3 n, out float H[9])
44
+ {
45
+ float c1 = 0.429043, c2 = 0.511664,
46
+ c3 = 0.743125, c4 = 0.886227, c5 = 0.247708;
47
+
48
+ H[0] = c4;
49
+ H[1] = 2.0 * c2 * n[1];
50
+ H[2] = 2.0 * c2 * n[2];
51
+ H[3] = 2.0 * c2 * n[0];
52
+ H[4] = 2.0 * c1 * n[0] * n[1];
53
+ H[5] = 2.0 * c1 * n[1] * n[2];
54
+ H[6] = c3 * n[2] * n[2] - c5;
55
+ H[7] = 2.0 * c1 * n[2] * n[0];
56
+ H[8] = c1 * (n[0] * n[0] - n[1] * n[1]);
57
+ }
58
+
59
+ vec3 evaluateLightingModel(vec3 normal)
60
+ {
61
+ float H[9];
62
+ evaluateH(normal, H);
63
+ vec3 res = vec3(0.0);
64
+ for (int i = 0; i < 9; i++) {
65
+ res += H[i] * SHCoeffs[i];
66
+ }
67
+ return res;
68
+ }
69
+
70
+ // nC: coarse geometry normal, nH: fine normal from normal map
71
+ vec3 evaluateLightingModelHybrid(vec3 nC, vec3 nH, mat3 prt)
72
+ {
73
+ float HC[9], HH[9];
74
+ evaluateH(nC, HC);
75
+ evaluateH(nH, HH);
76
+
77
+ vec3 res = vec3(0.0);
78
+ vec3 shadow = vec3(0.0);
79
+ vec3 unshadow = vec3(0.0);
80
+ for(int i = 0; i < 3; ++i){
81
+ for(int j = 0; j < 3; ++j){
82
+ int id = i*3+j;
83
+ res += HH[id]* SHCoeffs[id];
84
+ shadow += prt[i][j] * SHCoeffs[id];
85
+ unshadow += HC[id] * SHCoeffs[id];
86
+ }
87
+ }
88
+ vec3 ratio = clamp(shadow/unshadow,0.0,1.0);
89
+ res = ratio * res;
90
+
91
+ return res;
92
+ }
93
+
94
+ vec3 evaluateLightingModelPRT(mat3 prt)
95
+ {
96
+ vec3 res = vec3(0.0);
97
+ for(int i = 0; i < 3; ++i){
98
+ for(int j = 0; j < 3; ++j){
99
+ res += prt[i][j] * SHCoeffs[i*3+j];
100
+ }
101
+ }
102
+
103
+ return res;
104
+ }
105
+
106
+ void main()
107
+ {
108
+ vec2 uv = VertexIn.Texcoord;
109
+ vec3 nC = normalize(VertexIn.ModelNormal);
110
+ vec3 nml = nC;
111
+ mat3 prt = mat3(VertexIn.PRT1, VertexIn.PRT2, VertexIn.PRT3);
112
+
113
+ if(hasAlbedoMap == uint(0))
114
+ FragAlbedo = vec4(1.0);
115
+ else
116
+ FragAlbedo = texture(AlbedoMap, uv);//gammaCorrection(texture(AlbedoMap, uv), 1.0/2.2);
117
+
118
+ if(hasNormalMap == uint(0))
119
+ {
120
+ if(analytic == uint(0))
121
+ FragShading = vec4(evaluateLightingModelPRT(prt), 1.0f);
122
+ else
123
+ FragShading = vec4(evaluateLightingModel(nC), 1.0f);
124
+ }
125
+ else
126
+ {
127
+ vec3 n_tan = normalize(texture(NormalMap, uv).rgb*2.0-vec3(1.0));
128
+
129
+ mat3 TBN = mat3(normalize(VertexIn.Tangent),normalize(VertexIn.Bitangent),nC);
130
+ vec3 nH = normalize(TBN * n_tan);
131
+
132
+ if(analytic == uint(0))
133
+ FragShading = vec4(evaluateLightingModelHybrid(nC,nH,prt),1.0f);
134
+ else
135
+ FragShading = vec4(evaluateLightingModel(nH), 1.0f);
136
+
137
+ nml = nH;
138
+ }
139
+
140
+ FragShading = gammaCorrection(FragShading, 2.2);
141
+ FragColor = clamp(FragAlbedo * FragShading, 0.0, 1.0);
142
+ FragNormal = vec4(0.5*(nml+vec3(1.0)), 1.0);
143
+ FragPosition = vec4(VertexIn.Position,VertexIn.Depth.x);
144
+ FragShading = vec4(clamp(0.5*FragShading.xyz, 0.0, 1.0),1.0);
145
+ // FragColor = gammaCorrection(clamp(FragAlbedo * FragShading, 0.0, 1.0),2.2);
146
+ // FragNormal = vec4(0.5*(nml+vec3(1.0)), 1.0);
147
+ // FragPosition = vec4(VertexIn.Position,VertexIn.Depth.x);
148
+ // FragShading = vec4(gammaCorrection(clamp(0.5*FragShading.xyz, 0.0, 1.0),2.2),1.0);
149
+ // FragAlbedo = gammaCorrection(FragAlbedo,2.2);
150
+ FragPRT1 = vec4(VertexIn.PRT1,1.0);
151
+ FragPRT2 = vec4(VertexIn.PRT2,1.0);
152
+ FragPRT3 = vec4(VertexIn.PRT3,1.0);
153
+ }
PIFu/lib/renderer/gl/data/prt.vs ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 330
2
+
3
+ layout (location = 0) in vec3 a_Position;
4
+ layout (location = 1) in vec3 a_Normal;
5
+ layout (location = 2) in vec2 a_TextureCoord;
6
+ layout (location = 3) in vec3 a_Tangent;
7
+ layout (location = 4) in vec3 a_Bitangent;
8
+ layout (location = 5) in vec3 a_PRT1;
9
+ layout (location = 6) in vec3 a_PRT2;
10
+ layout (location = 7) in vec3 a_PRT3;
11
+
12
+ out VertexData {
13
+ vec3 Position;
14
+ vec3 Depth;
15
+ vec3 ModelNormal;
16
+ vec2 Texcoord;
17
+ vec3 Tangent;
18
+ vec3 Bitangent;
19
+ vec3 PRT1;
20
+ vec3 PRT2;
21
+ vec3 PRT3;
22
+ } VertexOut;
23
+
24
+ uniform mat3 RotMat;
25
+ uniform mat4 NormMat;
26
+ uniform mat4 ModelMat;
27
+ uniform mat4 PerspMat;
28
+
29
+ float s_c3 = 0.94617469575; // (3*sqrt(5))/(4*sqrt(pi))
30
+ float s_c4 = -0.31539156525;// (-sqrt(5))/(4*sqrt(pi))
31
+ float s_c5 = 0.54627421529; // (sqrt(15))/(4*sqrt(pi))
32
+
33
+ float s_c_scale = 1.0/0.91529123286551084;
34
+ float s_c_scale_inv = 0.91529123286551084;
35
+
36
+ float s_rc2 = 1.5853309190550713*s_c_scale;
37
+ float s_c4_div_c3 = s_c4/s_c3;
38
+ float s_c4_div_c3_x2 = (s_c4/s_c3)*2.0;
39
+
40
+ float s_scale_dst2 = s_c3 * s_c_scale_inv;
41
+ float s_scale_dst4 = s_c5 * s_c_scale_inv;
42
+
43
+ void OptRotateBand0(float x[1], mat3 R, out float dst[1])
44
+ {
45
+ dst[0] = x[0];
46
+ }
47
+
48
+ // 9 multiplies
49
+ void OptRotateBand1(float x[3], mat3 R, out float dst[3])
50
+ {
51
+ // derived from SlowRotateBand1
52
+ dst[0] = ( R[1][1])*x[0] + (-R[1][2])*x[1] + ( R[1][0])*x[2];
53
+ dst[1] = (-R[2][1])*x[0] + ( R[2][2])*x[1] + (-R[2][0])*x[2];
54
+ dst[2] = ( R[0][1])*x[0] + (-R[0][2])*x[1] + ( R[0][0])*x[2];
55
+ }
56
+
57
+ // 48 multiplies
58
+ void OptRotateBand2(float x[5], mat3 R, out float dst[5])
59
+ {
60
+ // Sparse matrix multiply
61
+ float sh0 = x[3] + x[4] + x[4] - x[1];
62
+ float sh1 = x[0] + s_rc2*x[2] + x[3] + x[4];
63
+ float sh2 = x[0];
64
+ float sh3 = -x[3];
65
+ float sh4 = -x[1];
66
+
67
+ // Rotations. R0 and R1 just use the raw matrix columns
68
+ float r2x = R[0][0] + R[0][1];
69
+ float r2y = R[1][0] + R[1][1];
70
+ float r2z = R[2][0] + R[2][1];
71
+
72
+ float r3x = R[0][0] + R[0][2];
73
+ float r3y = R[1][0] + R[1][2];
74
+ float r3z = R[2][0] + R[2][2];
75
+
76
+ float r4x = R[0][1] + R[0][2];
77
+ float r4y = R[1][1] + R[1][2];
78
+ float r4z = R[2][1] + R[2][2];
79
+
80
+ // dense matrix multiplication one column at a time
81
+
82
+ // column 0
83
+ float sh0_x = sh0 * R[0][0];
84
+ float sh0_y = sh0 * R[1][0];
85
+ float d0 = sh0_x * R[1][0];
86
+ float d1 = sh0_y * R[2][0];
87
+ float d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3);
88
+ float d3 = sh0_x * R[2][0];
89
+ float d4 = sh0_x * R[0][0] - sh0_y * R[1][0];
90
+
91
+ // column 1
92
+ float sh1_x = sh1 * R[0][2];
93
+ float sh1_y = sh1 * R[1][2];
94
+ d0 += sh1_x * R[1][2];
95
+ d1 += sh1_y * R[2][2];
96
+ d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3);
97
+ d3 += sh1_x * R[2][2];
98
+ d4 += sh1_x * R[0][2] - sh1_y * R[1][2];
99
+
100
+ // column 2
101
+ float sh2_x = sh2 * r2x;
102
+ float sh2_y = sh2 * r2y;
103
+ d0 += sh2_x * r2y;
104
+ d1 += sh2_y * r2z;
105
+ d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2);
106
+ d3 += sh2_x * r2z;
107
+ d4 += sh2_x * r2x - sh2_y * r2y;
108
+
109
+ // column 3
110
+ float sh3_x = sh3 * r3x;
111
+ float sh3_y = sh3 * r3y;
112
+ d0 += sh3_x * r3y;
113
+ d1 += sh3_y * r3z;
114
+ d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2);
115
+ d3 += sh3_x * r3z;
116
+ d4 += sh3_x * r3x - sh3_y * r3y;
117
+
118
+ // column 4
119
+ float sh4_x = sh4 * r4x;
120
+ float sh4_y = sh4 * r4y;
121
+ d0 += sh4_x * r4y;
122
+ d1 += sh4_y * r4z;
123
+ d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2);
124
+ d3 += sh4_x * r4z;
125
+ d4 += sh4_x * r4x - sh4_y * r4y;
126
+
127
+ // extra multipliers
128
+ dst[0] = d0;
129
+ dst[1] = -d1;
130
+ dst[2] = d2 * s_scale_dst2;
131
+ dst[3] = -d3;
132
+ dst[4] = d4 * s_scale_dst4;
133
+ }
134
+
135
+ void main()
136
+ {
137
+ // normalization
138
+ vec3 pos = (NormMat * vec4(a_Position,1.0)).xyz;
139
+
140
+ mat3 R = mat3(ModelMat) * RotMat;
141
+ VertexOut.ModelNormal = (R * a_Normal);
142
+ VertexOut.Position = R * pos;
143
+ VertexOut.Texcoord = a_TextureCoord;
144
+ VertexOut.Tangent = (R * a_Tangent);
145
+ VertexOut.Bitangent = (R * a_Bitangent);
146
+ float PRT0, PRT1[3], PRT2[5];
147
+ PRT0 = a_PRT1[0];
148
+ PRT1[0] = a_PRT1[1];
149
+ PRT1[1] = a_PRT1[2];
150
+ PRT1[2] = a_PRT2[0];
151
+ PRT2[0] = a_PRT2[1];
152
+ PRT2[1] = a_PRT2[2];
153
+ PRT2[2] = a_PRT3[0];
154
+ PRT2[3] = a_PRT3[1];
155
+ PRT2[4] = a_PRT3[2];
156
+
157
+ OptRotateBand1(PRT1, R, PRT1);
158
+ OptRotateBand2(PRT2, R, PRT2);
159
+
160
+ VertexOut.PRT1 = vec3(PRT0,PRT1[0],PRT1[1]);
161
+ VertexOut.PRT2 = vec3(PRT1[2],PRT2[0],PRT2[1]);
162
+ VertexOut.PRT3 = vec3(PRT2[2],PRT2[3],PRT2[4]);
163
+
164
+ gl_Position = PerspMat * ModelMat * vec4(RotMat * pos, 1.0);
165
+
166
+ VertexOut.Depth = vec3(gl_Position.z / gl_Position.w);
167
+ }
PIFu/lib/renderer/gl/data/prt_uv.fs ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 330
2
+
3
+ uniform vec3 SHCoeffs[9];
4
+ uniform uint analytic;
5
+
6
+ uniform uint hasNormalMap;
7
+ uniform uint hasAlbedoMap;
8
+
9
+ uniform sampler2D AlbedoMap;
10
+ uniform sampler2D NormalMap;
11
+
12
+ in VertexData {
13
+ vec3 Position;
14
+ vec3 ModelNormal;
15
+ vec3 CameraNormal;
16
+ vec2 Texcoord;
17
+ vec3 Tangent;
18
+ vec3 Bitangent;
19
+ vec3 PRT1;
20
+ vec3 PRT2;
21
+ vec3 PRT3;
22
+ } VertexIn;
23
+
24
+ layout (location = 0) out vec4 FragColor;
25
+ layout (location = 1) out vec4 FragPosition;
26
+ layout (location = 2) out vec4 FragNormal;
27
+
28
+ vec4 gammaCorrection(vec4 vec, float g)
29
+ {
30
+ return vec4(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g), vec.w);
31
+ }
32
+
33
+ vec3 gammaCorrection(vec3 vec, float g)
34
+ {
35
+ return vec3(pow(vec.x, 1.0/g), pow(vec.y, 1.0/g), pow(vec.z, 1.0/g));
36
+ }
37
+
38
+ void evaluateH(vec3 n, out float H[9])
39
+ {
40
+ float c1 = 0.429043, c2 = 0.511664,
41
+ c3 = 0.743125, c4 = 0.886227, c5 = 0.247708;
42
+
43
+ H[0] = c4;
44
+ H[1] = 2.0 * c2 * n[1];
45
+ H[2] = 2.0 * c2 * n[2];
46
+ H[3] = 2.0 * c2 * n[0];
47
+ H[4] = 2.0 * c1 * n[0] * n[1];
48
+ H[5] = 2.0 * c1 * n[1] * n[2];
49
+ H[6] = c3 * n[2] * n[2] - c5;
50
+ H[7] = 2.0 * c1 * n[2] * n[0];
51
+ H[8] = c1 * (n[0] * n[0] - n[1] * n[1]);
52
+ }
53
+
54
+ vec3 evaluateLightingModel(vec3 normal)
55
+ {
56
+ float H[9];
57
+ evaluateH(normal, H);
58
+ vec3 res = vec3(0.0);
59
+ for (int i = 0; i < 9; i++) {
60
+ res += H[i] * SHCoeffs[i];
61
+ }
62
+ return res;
63
+ }
64
+
65
+ // nC: coarse geometry normal, nH: fine normal from normal map
66
+ vec3 evaluateLightingModelHybrid(vec3 nC, vec3 nH, mat3 prt)
67
+ {
68
+ float HC[9], HH[9];
69
+ evaluateH(nC, HC);
70
+ evaluateH(nH, HH);
71
+
72
+ vec3 res = vec3(0.0);
73
+ vec3 shadow = vec3(0.0);
74
+ vec3 unshadow = vec3(0.0);
75
+ for(int i = 0; i < 3; ++i){
76
+ for(int j = 0; j < 3; ++j){
77
+ int id = i*3+j;
78
+ res += HH[id]* SHCoeffs[id];
79
+ shadow += prt[i][j] * SHCoeffs[id];
80
+ unshadow += HC[id] * SHCoeffs[id];
81
+ }
82
+ }
83
+ vec3 ratio = clamp(shadow/unshadow,0.0,1.0);
84
+ res = ratio * res;
85
+
86
+ return res;
87
+ }
88
+
89
+ vec3 evaluateLightingModelPRT(mat3 prt)
90
+ {
91
+ vec3 res = vec3(0.0);
92
+ for(int i = 0; i < 3; ++i){
93
+ for(int j = 0; j < 3; ++j){
94
+ res += prt[i][j] * SHCoeffs[i*3+j];
95
+ }
96
+ }
97
+
98
+ return res;
99
+ }
100
+
101
+ void main()
102
+ {
103
+ vec2 uv = VertexIn.Texcoord;
104
+ vec3 nM = normalize(VertexIn.ModelNormal);
105
+ vec3 nC = normalize(VertexIn.CameraNormal);
106
+ vec3 nml = nC;
107
+ mat3 prt = mat3(VertexIn.PRT1, VertexIn.PRT2, VertexIn.PRT3);
108
+
109
+ vec4 albedo, shading;
110
+ if(hasAlbedoMap == uint(0))
111
+ albedo = vec4(1.0);
112
+ else
113
+ albedo = texture(AlbedoMap, uv);//gammaCorrection(texture(AlbedoMap, uv), 1.0/2.2);
114
+
115
+ if(hasNormalMap == uint(0))
116
+ {
117
+ if(analytic == uint(0))
118
+ shading = vec4(evaluateLightingModelPRT(prt), 1.0f);
119
+ else
120
+ shading = vec4(evaluateLightingModel(nC), 1.0f);
121
+ }
122
+ else
123
+ {
124
+ vec3 n_tan = normalize(texture(NormalMap, uv).rgb*2.0-vec3(1.0));
125
+
126
+ mat3 TBN = mat3(normalize(VertexIn.Tangent),normalize(VertexIn.Bitangent),nC);
127
+ vec3 nH = normalize(TBN * n_tan);
128
+
129
+ if(analytic == uint(0))
130
+ shading = vec4(evaluateLightingModelHybrid(nC,nH,prt),1.0f);
131
+ else
132
+ shading = vec4(evaluateLightingModel(nH), 1.0f);
133
+
134
+ nml = nH;
135
+ }
136
+
137
+ shading = gammaCorrection(shading, 2.2);
138
+ FragColor = clamp(albedo * shading, 0.0, 1.0);
139
+ FragPosition = vec4(VertexIn.Position,1.0);
140
+ FragNormal = vec4(0.5*(nM+vec3(1.0)),1.0);
141
+ }
PIFu/lib/renderer/gl/data/prt_uv.vs ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 330
2
+
3
+ layout (location = 0) in vec3 a_Position;
4
+ layout (location = 1) in vec3 a_Normal;
5
+ layout (location = 2) in vec2 a_TextureCoord;
6
+ layout (location = 3) in vec3 a_Tangent;
7
+ layout (location = 4) in vec3 a_Bitangent;
8
+ layout (location = 5) in vec3 a_PRT1;
9
+ layout (location = 6) in vec3 a_PRT2;
10
+ layout (location = 7) in vec3 a_PRT3;
11
+
12
+ out VertexData {
13
+ vec3 Position;
14
+ vec3 ModelNormal;
15
+ vec3 CameraNormal;
16
+ vec2 Texcoord;
17
+ vec3 Tangent;
18
+ vec3 Bitangent;
19
+ vec3 PRT1;
20
+ vec3 PRT2;
21
+ vec3 PRT3;
22
+ } VertexOut;
23
+
24
+ uniform mat3 RotMat;
25
+ uniform mat4 NormMat;
26
+ uniform mat4 ModelMat;
27
+ uniform mat4 PerspMat;
28
+
29
+ #define pi 3.1415926535897932384626433832795
30
+
31
+ float s_c3 = 0.94617469575; // (3*sqrt(5))/(4*sqrt(pi))
32
+ float s_c4 = -0.31539156525;// (-sqrt(5))/(4*sqrt(pi))
33
+ float s_c5 = 0.54627421529; // (sqrt(15))/(4*sqrt(pi))
34
+
35
+ float s_c_scale = 1.0/0.91529123286551084;
36
+ float s_c_scale_inv = 0.91529123286551084;
37
+
38
+ float s_rc2 = 1.5853309190550713*s_c_scale;
39
+ float s_c4_div_c3 = s_c4/s_c3;
40
+ float s_c4_div_c3_x2 = (s_c4/s_c3)*2.0;
41
+
42
+ float s_scale_dst2 = s_c3 * s_c_scale_inv;
43
+ float s_scale_dst4 = s_c5 * s_c_scale_inv;
44
+
45
+ void OptRotateBand0(float x[1], mat3 R, out float dst[1])
46
+ {
47
+ dst[0] = x[0];
48
+ }
49
+
50
+ // 9 multiplies
51
+ void OptRotateBand1(float x[3], mat3 R, out float dst[3])
52
+ {
53
+ // derived from SlowRotateBand1
54
+ dst[0] = ( R[1][1])*x[0] + (-R[1][2])*x[1] + ( R[1][0])*x[2];
55
+ dst[1] = (-R[2][1])*x[0] + ( R[2][2])*x[1] + (-R[2][0])*x[2];
56
+ dst[2] = ( R[0][1])*x[0] + (-R[0][2])*x[1] + ( R[0][0])*x[2];
57
+ }
58
+
59
+ // 48 multiplies
60
+ void OptRotateBand2(float x[5], mat3 R, out float dst[5])
61
+ {
62
+ // Sparse matrix multiply
63
+ float sh0 = x[3] + x[4] + x[4] - x[1];
64
+ float sh1 = x[0] + s_rc2*x[2] + x[3] + x[4];
65
+ float sh2 = x[0];
66
+ float sh3 = -x[3];
67
+ float sh4 = -x[1];
68
+
69
+ // Rotations. R0 and R1 just use the raw matrix columns
70
+ float r2x = R[0][0] + R[0][1];
71
+ float r2y = R[1][0] + R[1][1];
72
+ float r2z = R[2][0] + R[2][1];
73
+
74
+ float r3x = R[0][0] + R[0][2];
75
+ float r3y = R[1][0] + R[1][2];
76
+ float r3z = R[2][0] + R[2][2];
77
+
78
+ float r4x = R[0][1] + R[0][2];
79
+ float r4y = R[1][1] + R[1][2];
80
+ float r4z = R[2][1] + R[2][2];
81
+
82
+ // dense matrix multiplication one column at a time
83
+
84
+ // column 0
85
+ float sh0_x = sh0 * R[0][0];
86
+ float sh0_y = sh0 * R[1][0];
87
+ float d0 = sh0_x * R[1][0];
88
+ float d1 = sh0_y * R[2][0];
89
+ float d2 = sh0 * (R[2][0] * R[2][0] + s_c4_div_c3);
90
+ float d3 = sh0_x * R[2][0];
91
+ float d4 = sh0_x * R[0][0] - sh0_y * R[1][0];
92
+
93
+ // column 1
94
+ float sh1_x = sh1 * R[0][2];
95
+ float sh1_y = sh1 * R[1][2];
96
+ d0 += sh1_x * R[1][2];
97
+ d1 += sh1_y * R[2][2];
98
+ d2 += sh1 * (R[2][2] * R[2][2] + s_c4_div_c3);
99
+ d3 += sh1_x * R[2][2];
100
+ d4 += sh1_x * R[0][2] - sh1_y * R[1][2];
101
+
102
+ // column 2
103
+ float sh2_x = sh2 * r2x;
104
+ float sh2_y = sh2 * r2y;
105
+ d0 += sh2_x * r2y;
106
+ d1 += sh2_y * r2z;
107
+ d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2);
108
+ d3 += sh2_x * r2z;
109
+ d4 += sh2_x * r2x - sh2_y * r2y;
110
+
111
+ // column 3
112
+ float sh3_x = sh3 * r3x;
113
+ float sh3_y = sh3 * r3y;
114
+ d0 += sh3_x * r3y;
115
+ d1 += sh3_y * r3z;
116
+ d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2);
117
+ d3 += sh3_x * r3z;
118
+ d4 += sh3_x * r3x - sh3_y * r3y;
119
+
120
+ // column 4
121
+ float sh4_x = sh4 * r4x;
122
+ float sh4_y = sh4 * r4y;
123
+ d0 += sh4_x * r4y;
124
+ d1 += sh4_y * r4z;
125
+ d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2);
126
+ d3 += sh4_x * r4z;
127
+ d4 += sh4_x * r4x - sh4_y * r4y;
128
+
129
+ // extra multipliers
130
+ dst[0] = d0;
131
+ dst[1] = -d1;
132
+ dst[2] = d2 * s_scale_dst2;
133
+ dst[3] = -d3;
134
+ dst[4] = d4 * s_scale_dst4;
135
+ }
136
+
137
+ void main()
138
+ {
139
+ // normalization
140
+ mat3 R = mat3(ModelMat) * RotMat;
141
+ VertexOut.ModelNormal = a_Normal;
142
+ VertexOut.CameraNormal = (R * a_Normal);
143
+ VertexOut.Position = a_Position;
144
+ VertexOut.Texcoord = a_TextureCoord;
145
+ VertexOut.Tangent = (R * a_Tangent);
146
+ VertexOut.Bitangent = (R * a_Bitangent);
147
+ float PRT0, PRT1[3], PRT2[5];
148
+ PRT0 = a_PRT1[0];
149
+ PRT1[0] = a_PRT1[1];
150
+ PRT1[1] = a_PRT1[2];
151
+ PRT1[2] = a_PRT2[0];
152
+ PRT2[0] = a_PRT2[1];
153
+ PRT2[1] = a_PRT2[2];
154
+ PRT2[2] = a_PRT3[0];
155
+ PRT2[3] = a_PRT3[1];
156
+ PRT2[4] = a_PRT3[2];
157
+
158
+ OptRotateBand1(PRT1, R, PRT1);
159
+ OptRotateBand2(PRT2, R, PRT2);
160
+
161
+ VertexOut.PRT1 = vec3(PRT0,PRT1[0],PRT1[1]);
162
+ VertexOut.PRT2 = vec3(PRT1[2],PRT2[0],PRT2[1]);
163
+ VertexOut.PRT3 = vec3(PRT2[2],PRT2[3],PRT2[4]);
164
+
165
+ gl_Position = vec4(a_TextureCoord, 0.0, 1.0) - vec4(0.5, 0.5, 0, 0);
166
+ gl_Position[0] *= 2.0;
167
+ gl_Position[1] *= 2.0;
168
+ }
PIFu/lib/renderer/gl/data/quad.fs ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 330 core
2
+ out vec4 FragColor;
3
+
4
+ in vec2 TexCoord;
5
+
6
+ uniform sampler2D screenTexture;
7
+
8
+ void main()
9
+ {
10
+ FragColor = texture(screenTexture, TexCoord);
11
+ }
PIFu/lib/renderer/gl/data/quad.vs ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #version 330 core
2
+ layout (location = 0) in vec2 aPos;
3
+ layout (location = 1) in vec2 aTexCoord;
4
+
5
+ out vec2 TexCoord;
6
+
7
+ void main()
8
+ {
9
+ gl_Position = vec4(aPos.x, aPos.y, 0.0, 1.0);
10
+ TexCoord = aTexCoord;
11
+ }
PIFu/lib/renderer/gl/framework.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mario Rosasco, 2016
2
+ # adapted from framework.cpp, Copyright (C) 2010-2012 by Jason L. McKesson
3
+ # This file is licensed under the MIT License.
4
+ #
5
+ # NB: Unlike in the framework.cpp organization, the main loop is contained
6
+ # in the tutorial files, not in this framework file. Additionally, a copy of
7
+ # this module file must exist in the same directory as the tutorial files
8
+ # to be imported properly.
9
+
10
+ import os
11
+ from OpenGL.GL import *
12
+
13
+ # Function that creates and compiles shaders according to the given type (a GL enum value) and
14
+ # shader program (a file containing a GLSL program).
15
+ def loadShader(shaderType, shaderFile):
16
+ # check if file exists, get full path name
17
+ strFilename = findFileOrThrow(shaderFile)
18
+ shaderData = None
19
+ with open(strFilename, 'r') as f:
20
+ shaderData = f.read()
21
+
22
+ shader = glCreateShader(shaderType)
23
+ glShaderSource(shader, shaderData) # note that this is a simpler function call than in C
24
+
25
+ # This shader compilation is more explicit than the one used in
26
+ # framework.cpp, which relies on a glutil wrapper function.
27
+ # This is made explicit here mainly to decrease dependence on pyOpenGL
28
+ # utilities and wrappers, which docs caution may change in future versions.
29
+ glCompileShader(shader)
30
+
31
+ status = glGetShaderiv(shader, GL_COMPILE_STATUS)
32
+ if status == GL_FALSE:
33
+ # Note that getting the error log is much simpler in Python than in C/C++
34
+ # and does not require explicit handling of the string buffer
35
+ strInfoLog = glGetShaderInfoLog(shader)
36
+ strShaderType = ""
37
+ if shaderType is GL_VERTEX_SHADER:
38
+ strShaderType = "vertex"
39
+ elif shaderType is GL_GEOMETRY_SHADER:
40
+ strShaderType = "geometry"
41
+ elif shaderType is GL_FRAGMENT_SHADER:
42
+ strShaderType = "fragment"
43
+
44
+ print("Compilation failure for " + strShaderType + " shader:\n" + str(strInfoLog))
45
+
46
+ return shader
47
+
48
+
49
+ # Function that accepts a list of shaders, compiles them, and returns a handle to the compiled program
50
+ def createProgram(shaderList):
51
+ program = glCreateProgram()
52
+
53
+ for shader in shaderList:
54
+ glAttachShader(program, shader)
55
+
56
+ glLinkProgram(program)
57
+
58
+ status = glGetProgramiv(program, GL_LINK_STATUS)
59
+ if status == GL_FALSE:
60
+ # Note that getting the error log is much simpler in Python than in C/C++
61
+ # and does not require explicit handling of the string buffer
62
+ strInfoLog = glGetProgramInfoLog(program)
63
+ print("Linker failure: \n" + str(strInfoLog))
64
+
65
+ for shader in shaderList:
66
+ glDetachShader(program, shader)
67
+
68
+ return program
69
+
70
+
71
+ # Helper function to locate and open the target file (passed in as a string).
72
+ # Returns the full path to the file as a string.
73
+ def findFileOrThrow(strBasename):
74
+ # Keep constant names in C-style convention, for readability
75
+ # when comparing to C(/C++) code.
76
+ if os.path.isfile(strBasename):
77
+ return strBasename
78
+
79
+ LOCAL_FILE_DIR = "data" + os.sep
80
+ GLOBAL_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) + os.sep + "data" + os.sep
81
+
82
+ strFilename = LOCAL_FILE_DIR + strBasename
83
+ if os.path.isfile(strFilename):
84
+ return strFilename
85
+
86
+ strFilename = GLOBAL_FILE_DIR + strBasename
87
+ if os.path.isfile(strFilename):
88
+ return strFilename
89
+
90
+ raise IOError('Could not find target file ' + strBasename)
PIFu/lib/renderer/gl/glcontext.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Headless GPU-accelerated OpenGL context creation on Google Colaboratory.
2
+
3
+ Typical usage:
4
+
5
+ # Optional PyOpenGL configuratiopn can be done here.
6
+ # import OpenGL
7
+ # OpenGL.ERROR_CHECKING = True
8
+
9
+ # 'glcontext' must be imported before any OpenGL.* API.
10
+ from lucid.misc.gl.glcontext import create_opengl_context
11
+
12
+ # Now it's safe to import OpenGL and EGL functions
13
+ import OpenGL.GL as gl
14
+
15
+ # create_opengl_context() creates a GL context that is attached to an
16
+ # offscreen surface of the specified size. Note that rendering to buffers
17
+ # of other sizes and formats is still possible with OpenGL Framebuffers.
18
+ #
19
+ # Users are expected to directly use the EGL API in case more advanced
20
+ # context management is required.
21
+ width, height = 640, 480
22
+ create_opengl_context((width, height))
23
+
24
+ # OpenGL context is available here.
25
+
26
+ """
27
+
28
+ from __future__ import print_function
29
+
30
+ # pylint: disable=unused-import,g-import-not-at-top,g-statement-before-imports
31
+
32
+ try:
33
+ import OpenGL
34
+ except:
35
+ print('This module depends on PyOpenGL.')
36
+ print('Please run "\033[1m!pip install -q pyopengl\033[0m" '
37
+ 'prior importing this module.')
38
+ raise
39
+
40
+ import ctypes
41
+ from ctypes import pointer, util
42
+ import os
43
+
44
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
45
+
46
+ # OpenGL loading workaround.
47
+ #
48
+ # * PyOpenGL tries to load libGL, but we need libOpenGL, see [1,2].
49
+ # This could have been solved by a symlink libGL->libOpenGL, but:
50
+ #
51
+ # * Python 2.7 can't find libGL and linEGL due to a bug (see [3])
52
+ # in ctypes.util, that was only wixed in Python 3.6.
53
+ #
54
+ # So, the only solution I've found is to monkeypatch ctypes.util
55
+ # [1] https://devblogs.nvidia.com/egl-eye-opengl-visualization-without-x-server/
56
+ # [2] https://devblogs.nvidia.com/linking-opengl-server-side-rendering/
57
+ # [3] https://bugs.python.org/issue9998
58
+ _find_library_old = ctypes.util.find_library
59
+ try:
60
+
61
+ def _find_library_new(name):
62
+ return {
63
+ 'GL': 'libOpenGL.so',
64
+ 'EGL': 'libEGL.so',
65
+ }.get(name, _find_library_old(name))
66
+ util.find_library = _find_library_new
67
+ import OpenGL.GL as gl
68
+ import OpenGL.EGL as egl
69
+ from OpenGL import error
70
+ from OpenGL.EGL.EXT.device_base import egl_get_devices
71
+ from OpenGL.raw.EGL.EXT.platform_device import EGL_PLATFORM_DEVICE_EXT
72
+ except:
73
+ print('Unable to load OpenGL libraries. '
74
+ 'Make sure you use GPU-enabled backend.')
75
+ print('Press "Runtime->Change runtime type" and set '
76
+ '"Hardware accelerator" to GPU.')
77
+ raise
78
+ finally:
79
+ util.find_library = _find_library_old
80
+
81
+ def create_initialized_headless_egl_display():
82
+ """Creates an initialized EGL display directly on a device."""
83
+ for device in egl_get_devices():
84
+ display = egl.eglGetPlatformDisplayEXT(EGL_PLATFORM_DEVICE_EXT, device, None)
85
+
86
+ if display != egl.EGL_NO_DISPLAY and egl.eglGetError() == egl.EGL_SUCCESS:
87
+ # `eglInitialize` may or may not raise an exception on failure depending
88
+ # on how PyOpenGL is configured. We therefore catch a `GLError` and also
89
+ # manually check the output of `eglGetError()` here.
90
+ try:
91
+ initialized = egl.eglInitialize(display, None, None)
92
+ except error.GLError:
93
+ pass
94
+ else:
95
+ if initialized == egl.EGL_TRUE and egl.eglGetError() == egl.EGL_SUCCESS:
96
+ return display
97
+ return egl.EGL_NO_DISPLAY
98
+
99
+ def create_opengl_context(surface_size=(640, 480)):
100
+ """Create offscreen OpenGL context and make it current.
101
+
102
+ Users are expected to directly use EGL API in case more advanced
103
+ context management is required.
104
+
105
+ Args:
106
+ surface_size: (width, height), size of the offscreen rendering surface.
107
+ """
108
+ egl_display = create_initialized_headless_egl_display()
109
+ if egl_display == egl.EGL_NO_DISPLAY:
110
+ raise ImportError('Cannot initialize a headless EGL display.')
111
+
112
+ major, minor = egl.EGLint(), egl.EGLint()
113
+ egl.eglInitialize(egl_display, pointer(major), pointer(minor))
114
+
115
+ config_attribs = [
116
+ egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT, egl.EGL_BLUE_SIZE, 8,
117
+ egl.EGL_GREEN_SIZE, 8, egl.EGL_RED_SIZE, 8, egl.EGL_DEPTH_SIZE, 24,
118
+ egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT, egl.EGL_NONE
119
+ ]
120
+ config_attribs = (egl.EGLint * len(config_attribs))(*config_attribs)
121
+
122
+ num_configs = egl.EGLint()
123
+ egl_cfg = egl.EGLConfig()
124
+ egl.eglChooseConfig(egl_display, config_attribs, pointer(egl_cfg), 1,
125
+ pointer(num_configs))
126
+
127
+ width, height = surface_size
128
+ pbuffer_attribs = [
129
+ egl.EGL_WIDTH,
130
+ width,
131
+ egl.EGL_HEIGHT,
132
+ height,
133
+ egl.EGL_NONE,
134
+ ]
135
+ pbuffer_attribs = (egl.EGLint * len(pbuffer_attribs))(*pbuffer_attribs)
136
+ egl_surf = egl.eglCreatePbufferSurface(egl_display, egl_cfg, pbuffer_attribs)
137
+
138
+ egl.eglBindAPI(egl.EGL_OPENGL_API)
139
+
140
+ egl_context = egl.eglCreateContext(egl_display, egl_cfg, egl.EGL_NO_CONTEXT,
141
+ None)
142
+ egl.eglMakeCurrent(egl_display, egl_surf, egl_surf, egl_context)
PIFu/lib/renderer/gl/init_gl.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _glut_window = None
2
+ _context_inited = None
3
+
4
+ def initialize_GL_context(width=512, height=512, egl=False):
5
+ '''
6
+ default context uses GLUT
7
+ '''
8
+ if not egl:
9
+ import OpenGL.GLUT as GLUT
10
+ display_mode = GLUT.GLUT_DOUBLE | GLUT.GLUT_RGB | GLUT.GLUT_DEPTH
11
+ global _glut_window
12
+ if _glut_window is None:
13
+ GLUT.glutInit()
14
+ GLUT.glutInitDisplayMode(display_mode)
15
+ GLUT.glutInitWindowSize(width, height)
16
+ GLUT.glutInitWindowPosition(0, 0)
17
+ _glut_window = GLUT.glutCreateWindow("My Render.")
18
+ else:
19
+ from .glcontext import create_opengl_context
20
+ global _context_inited
21
+ if _context_inited is None:
22
+ create_opengl_context((width, height))
23
+ _context_inited = True
24
+
PIFu/lib/renderer/gl/prt_render.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+
4
+ from .framework import *
5
+ from .cam_render import CamRender
6
+
7
+ class PRTRender(CamRender):
8
+ def __init__(self, width=1600, height=1200, name='PRT Renderer', uv_mode=False, ms_rate=1, egl=False):
9
+ program_files = ['prt.vs', 'prt.fs'] if not uv_mode else ['prt_uv.vs', 'prt_uv.fs']
10
+ CamRender.__init__(self, width, height, name, program_files=program_files, color_size=8, ms_rate=ms_rate, egl=egl)
11
+
12
+ # WARNING: this differs from vertex_buffer and vertex_data in Render
13
+ self.vert_buffer = {}
14
+ self.vert_data = {}
15
+
16
+ self.norm_buffer = {}
17
+ self.norm_data = {}
18
+
19
+ self.tan_buffer = {}
20
+ self.tan_data = {}
21
+
22
+ self.btan_buffer = {}
23
+ self.btan_data = {}
24
+
25
+ self.prt1_buffer = {}
26
+ self.prt1_data = {}
27
+ self.prt2_buffer = {}
28
+ self.prt2_data = {}
29
+ self.prt3_buffer = {}
30
+ self.prt3_data = {}
31
+
32
+ self.uv_buffer = {}
33
+ self.uv_data = {}
34
+
35
+ self.render_texture_mat = {}
36
+
37
+ self.vertex_dim = {}
38
+ self.n_vertices = {}
39
+
40
+ self.norm_mat_unif = glGetUniformLocation(self.program, 'NormMat')
41
+ self.normalize_matrix = np.eye(4)
42
+
43
+ self.shcoeff_unif = glGetUniformLocation(self.program, 'SHCoeffs')
44
+ self.shcoeffs = np.zeros((9,3))
45
+ self.shcoeffs[0,:] = 1.0
46
+ #self.shcoeffs[1:,:] = np.random.rand(8,3)
47
+
48
+ self.hasAlbedoUnif = glGetUniformLocation(self.program, 'hasAlbedoMap')
49
+ self.hasNormalUnif = glGetUniformLocation(self.program, 'hasNormalMap')
50
+
51
+ self.analyticUnif = glGetUniformLocation(self.program, 'analytic')
52
+ self.analytic = False
53
+
54
+ self.rot_mat_unif = glGetUniformLocation(self.program, 'RotMat')
55
+ self.rot_matrix = np.eye(3)
56
+
57
+ def set_texture(self, mat_name, smplr_name, texture):
58
+ # texture_image: H x W x 3
59
+ width = texture.shape[1]
60
+ height = texture.shape[0]
61
+ texture = np.flip(texture, 0)
62
+ img_data = np.fromstring(texture.tostring(), np.uint8)
63
+
64
+ if mat_name not in self.render_texture_mat:
65
+ self.render_texture_mat[mat_name] = {}
66
+ if smplr_name in self.render_texture_mat[mat_name].keys():
67
+ glDeleteTextures([self.render_texture_mat[mat_name][smplr_name]])
68
+ del self.render_texture_mat[mat_name][smplr_name]
69
+ self.render_texture_mat[mat_name][smplr_name] = glGenTextures(1)
70
+ glActiveTexture(GL_TEXTURE0)
71
+
72
+ glPixelStorei(GL_UNPACK_ALIGNMENT, 1)
73
+ glBindTexture(GL_TEXTURE_2D, self.render_texture_mat[mat_name][smplr_name])
74
+
75
+ glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, width, height, 0, GL_RGB, GL_UNSIGNED_BYTE, img_data)
76
+
77
+ glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAX_LEVEL, 3)
78
+ glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)
79
+ glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)
80
+ glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR)
81
+ glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR_MIPMAP_LINEAR)
82
+
83
+ glGenerateMipmap(GL_TEXTURE_2D)
84
+
85
+ def set_albedo(self, texture_image, mat_name='all'):
86
+ self.set_texture(mat_name, 'AlbedoMap', texture_image)
87
+
88
+ def set_normal_map(self, texture_image, mat_name='all'):
89
+ self.set_texture(mat_name, 'NormalMap', texture_image)
90
+
91
+ def set_mesh(self, vertices, faces, norms, faces_nml, uvs, faces_uvs, prt, faces_prt, tans, bitans, mat_name='all'):
92
+ self.vert_data[mat_name] = vertices[faces.reshape([-1])]
93
+ self.n_vertices[mat_name] = self.vert_data[mat_name].shape[0]
94
+ self.vertex_dim[mat_name] = self.vert_data[mat_name].shape[1]
95
+
96
+ if mat_name not in self.vert_buffer.keys():
97
+ self.vert_buffer[mat_name] = glGenBuffers(1)
98
+ glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[mat_name])
99
+ glBufferData(GL_ARRAY_BUFFER, self.vert_data[mat_name], GL_STATIC_DRAW)
100
+
101
+ self.uv_data[mat_name] = uvs[faces_uvs.reshape([-1])]
102
+ if mat_name not in self.uv_buffer.keys():
103
+ self.uv_buffer[mat_name] = glGenBuffers(1)
104
+ glBindBuffer(GL_ARRAY_BUFFER, self.uv_buffer[mat_name])
105
+ glBufferData(GL_ARRAY_BUFFER, self.uv_data[mat_name], GL_STATIC_DRAW)
106
+
107
+ self.norm_data[mat_name] = norms[faces_nml.reshape([-1])]
108
+ if mat_name not in self.norm_buffer.keys():
109
+ self.norm_buffer[mat_name] = glGenBuffers(1)
110
+ glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[mat_name])
111
+ glBufferData(GL_ARRAY_BUFFER, self.norm_data[mat_name], GL_STATIC_DRAW)
112
+
113
+ self.tan_data[mat_name] = tans[faces_nml.reshape([-1])]
114
+ if mat_name not in self.tan_buffer.keys():
115
+ self.tan_buffer[mat_name] = glGenBuffers(1)
116
+ glBindBuffer(GL_ARRAY_BUFFER, self.tan_buffer[mat_name])
117
+ glBufferData(GL_ARRAY_BUFFER, self.tan_data[mat_name], GL_STATIC_DRAW)
118
+
119
+ self.btan_data[mat_name] = bitans[faces_nml.reshape([-1])]
120
+ if mat_name not in self.btan_buffer.keys():
121
+ self.btan_buffer[mat_name] = glGenBuffers(1)
122
+ glBindBuffer(GL_ARRAY_BUFFER, self.btan_buffer[mat_name])
123
+ glBufferData(GL_ARRAY_BUFFER, self.btan_data[mat_name], GL_STATIC_DRAW)
124
+
125
+ self.prt1_data[mat_name] = prt[faces_prt.reshape([-1])][:,:3]
126
+ self.prt2_data[mat_name] = prt[faces_prt.reshape([-1])][:,3:6]
127
+ self.prt3_data[mat_name] = prt[faces_prt.reshape([-1])][:,6:]
128
+
129
+ if mat_name not in self.prt1_buffer.keys():
130
+ self.prt1_buffer[mat_name] = glGenBuffers(1)
131
+ if mat_name not in self.prt2_buffer.keys():
132
+ self.prt2_buffer[mat_name] = glGenBuffers(1)
133
+ if mat_name not in self.prt3_buffer.keys():
134
+ self.prt3_buffer[mat_name] = glGenBuffers(1)
135
+ glBindBuffer(GL_ARRAY_BUFFER, self.prt1_buffer[mat_name])
136
+ glBufferData(GL_ARRAY_BUFFER, self.prt1_data[mat_name], GL_STATIC_DRAW)
137
+ glBindBuffer(GL_ARRAY_BUFFER, self.prt2_buffer[mat_name])
138
+ glBufferData(GL_ARRAY_BUFFER, self.prt2_data[mat_name], GL_STATIC_DRAW)
139
+ glBindBuffer(GL_ARRAY_BUFFER, self.prt3_buffer[mat_name])
140
+ glBufferData(GL_ARRAY_BUFFER, self.prt3_data[mat_name], GL_STATIC_DRAW)
141
+
142
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
143
+
144
+ def set_mesh_mtl(self, vertices, faces, norms, faces_nml, uvs, faces_uvs, tans, bitans, prt):
145
+ for key in faces:
146
+ self.vert_data[key] = vertices[faces[key].reshape([-1])]
147
+ self.n_vertices[key] = self.vert_data[key].shape[0]
148
+ self.vertex_dim[key] = self.vert_data[key].shape[1]
149
+
150
+ if key not in self.vert_buffer.keys():
151
+ self.vert_buffer[key] = glGenBuffers(1)
152
+ glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[key])
153
+ glBufferData(GL_ARRAY_BUFFER, self.vert_data[key], GL_STATIC_DRAW)
154
+
155
+ self.uv_data[key] = uvs[faces_uvs[key].reshape([-1])]
156
+ if key not in self.uv_buffer.keys():
157
+ self.uv_buffer[key] = glGenBuffers(1)
158
+ glBindBuffer(GL_ARRAY_BUFFER, self.uv_buffer[key])
159
+ glBufferData(GL_ARRAY_BUFFER, self.uv_data[key], GL_STATIC_DRAW)
160
+
161
+ self.norm_data[key] = norms[faces_nml[key].reshape([-1])]
162
+ if key not in self.norm_buffer.keys():
163
+ self.norm_buffer[key] = glGenBuffers(1)
164
+ glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[key])
165
+ glBufferData(GL_ARRAY_BUFFER, self.norm_data[key], GL_STATIC_DRAW)
166
+
167
+ self.tan_data[key] = tans[faces_nml[key].reshape([-1])]
168
+ if key not in self.tan_buffer.keys():
169
+ self.tan_buffer[key] = glGenBuffers(1)
170
+ glBindBuffer(GL_ARRAY_BUFFER, self.tan_buffer[key])
171
+ glBufferData(GL_ARRAY_BUFFER, self.tan_data[key], GL_STATIC_DRAW)
172
+
173
+ self.btan_data[key] = bitans[faces_nml[key].reshape([-1])]
174
+ if key not in self.btan_buffer.keys():
175
+ self.btan_buffer[key] = glGenBuffers(1)
176
+ glBindBuffer(GL_ARRAY_BUFFER, self.btan_buffer[key])
177
+ glBufferData(GL_ARRAY_BUFFER, self.btan_data[key], GL_STATIC_DRAW)
178
+
179
+ self.prt1_data[key] = prt[faces[key].reshape([-1])][:,:3]
180
+ self.prt2_data[key] = prt[faces[key].reshape([-1])][:,3:6]
181
+ self.prt3_data[key] = prt[faces[key].reshape([-1])][:,6:]
182
+
183
+ if key not in self.prt1_buffer.keys():
184
+ self.prt1_buffer[key] = glGenBuffers(1)
185
+ if key not in self.prt2_buffer.keys():
186
+ self.prt2_buffer[key] = glGenBuffers(1)
187
+ if key not in self.prt3_buffer.keys():
188
+ self.prt3_buffer[key] = glGenBuffers(1)
189
+ glBindBuffer(GL_ARRAY_BUFFER, self.prt1_buffer[key])
190
+ glBufferData(GL_ARRAY_BUFFER, self.prt1_data[key], GL_STATIC_DRAW)
191
+ glBindBuffer(GL_ARRAY_BUFFER, self.prt2_buffer[key])
192
+ glBufferData(GL_ARRAY_BUFFER, self.prt2_data[key], GL_STATIC_DRAW)
193
+ glBindBuffer(GL_ARRAY_BUFFER, self.prt3_buffer[key])
194
+ glBufferData(GL_ARRAY_BUFFER, self.prt3_data[key], GL_STATIC_DRAW)
195
+
196
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
197
+
198
+ def cleanup(self):
199
+
200
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
201
+ for key in self.vert_data:
202
+ glDeleteBuffers(1, [self.vert_buffer[key]])
203
+ glDeleteBuffers(1, [self.norm_buffer[key]])
204
+ glDeleteBuffers(1, [self.uv_buffer[key]])
205
+
206
+ glDeleteBuffers(1, [self.tan_buffer[key]])
207
+ glDeleteBuffers(1, [self.btan_buffer[key]])
208
+ glDeleteBuffers(1, [self.prt1_buffer[key]])
209
+ glDeleteBuffers(1, [self.prt2_buffer[key]])
210
+ glDeleteBuffers(1, [self.prt3_buffer[key]])
211
+
212
+ glDeleteBuffers(1, [])
213
+
214
+ for smplr in self.render_texture_mat[key]:
215
+ glDeleteTextures([self.render_texture_mat[key][smplr]])
216
+
217
+ self.vert_buffer = {}
218
+ self.vert_data = {}
219
+
220
+ self.norm_buffer = {}
221
+ self.norm_data = {}
222
+
223
+ self.tan_buffer = {}
224
+ self.tan_data = {}
225
+
226
+ self.btan_buffer = {}
227
+ self.btan_data = {}
228
+
229
+ self.prt1_buffer = {}
230
+ self.prt1_data = {}
231
+
232
+ self.prt2_buffer = {}
233
+ self.prt2_data = {}
234
+
235
+ self.prt3_buffer = {}
236
+ self.prt3_data = {}
237
+
238
+ self.uv_buffer = {}
239
+ self.uv_data = {}
240
+
241
+ self.render_texture_mat = {}
242
+
243
+ self.vertex_dim = {}
244
+ self.n_vertices = {}
245
+
246
+ def randomize_sh(self):
247
+ self.shcoeffs[0,:] = 0.8
248
+ self.shcoeffs[1:,:] = 1.0*np.random.rand(8,3)
249
+
250
+ def set_sh(self, sh):
251
+ self.shcoeffs = sh
252
+
253
+ def set_norm_mat(self, scale, center):
254
+ N = np.eye(4)
255
+ N[:3, :3] = scale*np.eye(3)
256
+ N[:3, 3] = -scale*center
257
+
258
+ self.normalize_matrix = N
259
+
260
+ def draw(self):
261
+ self.draw_init()
262
+
263
+ glDisable(GL_BLEND)
264
+ #glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
265
+ glEnable(GL_MULTISAMPLE)
266
+
267
+ glUseProgram(self.program)
268
+ glUniformMatrix4fv(self.norm_mat_unif, 1, GL_FALSE, self.normalize_matrix.transpose())
269
+ glUniformMatrix4fv(self.model_mat_unif, 1, GL_FALSE, self.model_view_matrix.transpose())
270
+ glUniformMatrix4fv(self.persp_mat_unif, 1, GL_FALSE, self.projection_matrix.transpose())
271
+
272
+ if 'AlbedoMap' in self.render_texture_mat['all']:
273
+ glUniform1ui(self.hasAlbedoUnif, GLuint(1))
274
+ else:
275
+ glUniform1ui(self.hasAlbedoUnif, GLuint(0))
276
+
277
+ if 'NormalMap' in self.render_texture_mat['all']:
278
+ glUniform1ui(self.hasNormalUnif, GLuint(1))
279
+ else:
280
+ glUniform1ui(self.hasNormalUnif, GLuint(0))
281
+
282
+ glUniform1ui(self.analyticUnif, GLuint(1) if self.analytic else GLuint(0))
283
+
284
+ glUniform3fv(self.shcoeff_unif, 9, self.shcoeffs)
285
+
286
+ glUniformMatrix3fv(self.rot_mat_unif, 1, GL_FALSE, self.rot_matrix.transpose())
287
+
288
+ for mat in self.vert_buffer:
289
+ # Handle vertex buffer
290
+ glBindBuffer(GL_ARRAY_BUFFER, self.vert_buffer[mat])
291
+ glEnableVertexAttribArray(0)
292
+ glVertexAttribPointer(0, self.vertex_dim[mat], GL_DOUBLE, GL_FALSE, 0, None)
293
+
294
+ # Handle normal buffer
295
+ glBindBuffer(GL_ARRAY_BUFFER, self.norm_buffer[mat])
296
+ glEnableVertexAttribArray(1)
297
+ glVertexAttribPointer(1, 3, GL_DOUBLE, GL_FALSE, 0, None)
298
+
299
+ # Handle uv buffer
300
+ glBindBuffer(GL_ARRAY_BUFFER, self.uv_buffer[mat])
301
+ glEnableVertexAttribArray(2)
302
+ glVertexAttribPointer(2, 2, GL_DOUBLE, GL_FALSE, 0, None)
303
+
304
+ # Handle tan buffer
305
+ glBindBuffer(GL_ARRAY_BUFFER, self.tan_buffer[mat])
306
+ glEnableVertexAttribArray(3)
307
+ glVertexAttribPointer(3, 3, GL_DOUBLE, GL_FALSE, 0, None)
308
+
309
+ # Handle btan buffer
310
+ glBindBuffer(GL_ARRAY_BUFFER, self.btan_buffer[mat])
311
+ glEnableVertexAttribArray(4)
312
+ glVertexAttribPointer(4, 3, GL_DOUBLE, GL_FALSE, 0, None)
313
+
314
+ # Handle PTR buffer
315
+ glBindBuffer(GL_ARRAY_BUFFER, self.prt1_buffer[mat])
316
+ glEnableVertexAttribArray(5)
317
+ glVertexAttribPointer(5, 3, GL_DOUBLE, GL_FALSE, 0, None)
318
+
319
+ glBindBuffer(GL_ARRAY_BUFFER, self.prt2_buffer[mat])
320
+ glEnableVertexAttribArray(6)
321
+ glVertexAttribPointer(6, 3, GL_DOUBLE, GL_FALSE, 0, None)
322
+
323
+ glBindBuffer(GL_ARRAY_BUFFER, self.prt3_buffer[mat])
324
+ glEnableVertexAttribArray(7)
325
+ glVertexAttribPointer(7, 3, GL_DOUBLE, GL_FALSE, 0, None)
326
+
327
+ for i, smplr in enumerate(self.render_texture_mat[mat]):
328
+ glActiveTexture(GL_TEXTURE0 + i)
329
+ glBindTexture(GL_TEXTURE_2D, self.render_texture_mat[mat][smplr])
330
+ glUniform1i(glGetUniformLocation(self.program, smplr), i)
331
+
332
+ glDrawArrays(GL_TRIANGLES, 0, self.n_vertices[mat])
333
+
334
+ glDisableVertexAttribArray(7)
335
+ glDisableVertexAttribArray(6)
336
+ glDisableVertexAttribArray(5)
337
+ glDisableVertexAttribArray(4)
338
+ glDisableVertexAttribArray(3)
339
+ glDisableVertexAttribArray(2)
340
+ glDisableVertexAttribArray(1)
341
+ glDisableVertexAttribArray(0)
342
+
343
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
344
+
345
+ glUseProgram(0)
346
+
347
+ glDisable(GL_BLEND)
348
+ glDisable(GL_MULTISAMPLE)
349
+
350
+ self.draw_end()