ameerazam08
commited on
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- LICENSE +13 -0
- README.md +225 -0
- assets/image/synctalk.png +3 -0
- data/Please place the data file. +0 -0
- data_utils/UNFaceFlow/core/__init__.py +0 -0
- data_utils/UNFaceFlow/core/corr.py +91 -0
- data_utils/UNFaceFlow/core/datasets.py +235 -0
- data_utils/UNFaceFlow/core/extractor.py +266 -0
- data_utils/UNFaceFlow/core/nnutils.py +233 -0
- data_utils/UNFaceFlow/core/raft.py +259 -0
- data_utils/UNFaceFlow/core/update.py +169 -0
- data_utils/UNFaceFlow/core/utils_core/__init__.py +0 -0
- data_utils/UNFaceFlow/core/utils_core/augmentor.py +246 -0
- data_utils/UNFaceFlow/core/utils_core/flow_viz.py +132 -0
- data_utils/UNFaceFlow/core/utils_core/frame_utils.py +137 -0
- data_utils/UNFaceFlow/core/utils_core/utils.py +86 -0
- data_utils/UNFaceFlow/core/warp_utils.py +118 -0
- data_utils/UNFaceFlow/data_test_flow/__init__.py +94 -0
- data_utils/UNFaceFlow/data_test_flow/base_dataset.py +98 -0
- data_utils/UNFaceFlow/data_test_flow/dd_dataset.py +108 -0
- data_utils/UNFaceFlow/data_test_flow/dd_dataset_bak.py +107 -0
- data_utils/UNFaceFlow/models/network_test_flow.py +88 -0
- data_utils/UNFaceFlow/options_test_flow.py +123 -0
- data_utils/UNFaceFlow/pretrain_model/raft-small.pth +3 -0
- data_utils/UNFaceFlow/sgd_NNRT_model_epoch19008_50000.pth +3 -0
- data_utils/UNFaceFlow/test_flow.py +62 -0
- data_utils/UNFaceFlow/utils.py +84 -0
- data_utils/blendshape_capture/face_landmarker.task +3 -0
- data_utils/blendshape_capture/main.py +86 -0
- data_utils/deepspeech_features/README.md +20 -0
- data_utils/deepspeech_features/deepspeech_features.py +275 -0
- data_utils/deepspeech_features/deepspeech_store.py +172 -0
- data_utils/deepspeech_features/extract_ds_features.py +132 -0
- data_utils/deepspeech_features/extract_wav.py +87 -0
- data_utils/deepspeech_features/fea_win.py +11 -0
- data_utils/face_parsing/79999_iter.pth +3 -0
- data_utils/face_parsing/logger.py +23 -0
- data_utils/face_parsing/model.py +285 -0
- data_utils/face_parsing/resnet.py +109 -0
- data_utils/face_parsing/test.py +148 -0
- data_utils/face_tracking/3DMM/exp_info.npy +3 -0
- data_utils/face_tracking/3DMM/keys_info.npy +3 -0
- data_utils/face_tracking/3DMM/lands_info.txt +403 -0
- data_utils/face_tracking/3DMM/sub_mesh.obj +0 -0
- data_utils/face_tracking/3DMM/topology_info.npy +3 -0
- data_utils/face_tracking/3DMM/tris.txt +0 -0
- data_utils/face_tracking/3DMM/vert_tris.txt +0 -0
- data_utils/face_tracking/__init__.py +0 -0
- data_utils/face_tracking/bundle_adjustment.py +63 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/image/synctalk.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data_utils/blendshape_capture/face_landmarker.task filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2024 Peng Ziqiao
|
2 |
+
|
3 |
+
This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0). To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, and distribute the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
6 |
+
|
7 |
+
1. Attribution — You must give appropriate credit, provide a link to the license, and indicate if changes were made. You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
|
8 |
+
|
9 |
+
2. NonCommercial — You may not use the material for commercial purposes.
|
10 |
+
|
11 |
+
3. No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SyncTalk: The Devil😈 is in the Synchronization for Talking Head Synthesis [CVPR 2024]
|
2 |
+
The official repository of the paper [SyncTalk: The Devil is in the Synchronization for Talking Head Synthesis](https://arxiv.org/abs/2311.17590)
|
3 |
+
|
4 |
+
<p align='center'>
|
5 |
+
<b>
|
6 |
+
<a href="https://arxiv.org/abs/2311.17590">Paper</a>
|
7 |
+
|
|
8 |
+
<a href="https://ziqiaopeng.github.io/synctalk/">Project Page</a>
|
9 |
+
|
|
10 |
+
<a href="https://github.com/ZiqiaoPeng/SyncTalk">Code</a>
|
11 |
+
</b>
|
12 |
+
</p>
|
13 |
+
|
14 |
+
Colab notebook demonstration: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Egq0_ZK5sJAAawShxC0y4JRZQuVS2X-Z?usp=sharing)
|
15 |
+
|
16 |
+
A short demo video can be found [here](./demo/short_demo.mp4).
|
17 |
+
|
18 |
+
<p align='center'>
|
19 |
+
<img src='assets/image/synctalk.png' width='1000'/>
|
20 |
+
</p>
|
21 |
+
|
22 |
+
The proposed **SyncTalk** synthesizes synchronized talking head videos, employing tri-plane hash representations to maintain subject identity. It can generate synchronized lip movements, facial expressions, and stable head poses, and restores hair details to create high-resolution videos.
|
23 |
+
|
24 |
+
## 🔥🔥🔥 News
|
25 |
+
- [2023-11-30] Update arXiv paper.
|
26 |
+
- [2024-03-04] The code and pre-trained model are released.
|
27 |
+
- [2024-03-22] The Google Colab notebook is released.
|
28 |
+
- [2024-04-14] Add Windows support.
|
29 |
+
- [2024-04-28] The preprocessing code is released.
|
30 |
+
- [2024-04-29] Fix bugs: audio encoder, blendshape capture, and face tracker.
|
31 |
+
- [2024-05-03] Try replacing NeRF with Gaussian Splatting. Code: [GS-SyncTalk](https://github.com/ZiqiaoPeng/GS-SyncTalk)
|
32 |
+
- **[2024-05-24] Introduce torso training to repair double chin.**
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
## For Windows
|
37 |
+
Thanks to [okgpt](https://github.com/okgptai), we have launched a Windows integration package, you can download `SyncTalk-Windows.zip` and unzip it, double-click `inference.bat` to run the demo.
|
38 |
+
|
39 |
+
Download link: [Hugging Face](https://huggingface.co/ZiqiaoPeng/SyncTalk/blob/main/SyncTalk-Windows.zip) || [Baidu Netdisk](https://pan.baidu.com/s/1g3312mZxx__T6rAFPHjrRg?pwd=6666)
|
40 |
+
|
41 |
+
## For Linux
|
42 |
+
|
43 |
+
### Installation
|
44 |
+
|
45 |
+
Tested on Ubuntu 18.04, Pytorch 1.12.1 and CUDA 11.3.
|
46 |
+
```bash
|
47 |
+
git clone https://github.com/ZiqiaoPeng/SyncTalk.git
|
48 |
+
cd SyncTalk
|
49 |
+
```
|
50 |
+
#### Install dependency
|
51 |
+
|
52 |
+
```bash
|
53 |
+
conda create -n synctalk python==3.8.8
|
54 |
+
conda activate synctalk
|
55 |
+
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
|
56 |
+
pip install -r requirements.txt
|
57 |
+
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1121/download.html
|
58 |
+
pip install tensorflow-gpu==2.8.1
|
59 |
+
pip install ./freqencoder
|
60 |
+
pip install ./shencoder
|
61 |
+
pip install ./gridencoder
|
62 |
+
pip install ./raymarching
|
63 |
+
```
|
64 |
+
If you encounter problems installing PyTorch3D, you can use the following command to install it:
|
65 |
+
```bash
|
66 |
+
python ./scripts/install_pytorch3d.py
|
67 |
+
```
|
68 |
+
|
69 |
+
### Data Preparation
|
70 |
+
#### Pre-trained model
|
71 |
+
Please place the [May.zip](https://drive.google.com/file/d/18Q2H612CAReFxBd9kxr-i1dD8U1AUfsV/view?usp=sharing) in the **data** folder, the [trial_may.zip](https://drive.google.com/file/d/1C2639qi9jvhRygYHwPZDGs8pun3po3W7/view?usp=sharing) in the **model** folder, and then unzip them.
|
72 |
+
#### [New] Process your video
|
73 |
+
- Prepare face-parsing model.
|
74 |
+
|
75 |
+
```bash
|
76 |
+
wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_parsing/79999_iter.pth?raw=true -O data_utils/face_parsing/79999_iter.pth
|
77 |
+
```
|
78 |
+
|
79 |
+
- Prepare the 3DMM model for head pose estimation.
|
80 |
+
|
81 |
+
```bash
|
82 |
+
wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/exp_info.npy?raw=true -O data_utils/face_tracking/3DMM/exp_info.npy
|
83 |
+
wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/keys_info.npy?raw=true -O data_utils/face_tracking/3DMM/keys_info.npy
|
84 |
+
wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/sub_mesh.obj?raw=true -O data_utils/face_tracking/3DMM/sub_mesh.obj
|
85 |
+
wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/topology_info.npy?raw=true -O data_utils/face_tracking/3DMM/topology_info.npy
|
86 |
+
```
|
87 |
+
|
88 |
+
- Download 3DMM model from [Basel Face Model 2009](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-1-0&id=details):
|
89 |
+
|
90 |
+
```
|
91 |
+
# 1. copy 01_MorphableModel.mat to data_util/face_tracking/3DMM/
|
92 |
+
# 2.
|
93 |
+
cd data_utils/face_tracking
|
94 |
+
python convert_BFM.py
|
95 |
+
```
|
96 |
+
- Put your video under `data/<ID>/<ID>.mp4`, and then run the following command to process the video.
|
97 |
+
|
98 |
+
**[Note]** The video must be 25FPS, with all frames containing the talking person. The resolution should be about 512x512, and duration about 4-5 min.
|
99 |
+
```bash
|
100 |
+
python data_utils/process.py data/<ID>/<ID>.mp4 --asr ave
|
101 |
+
```
|
102 |
+
You can choose to use AVE, DeepSpeech or Hubert. The processed video will be saved in the **data** folder.
|
103 |
+
|
104 |
+
|
105 |
+
- [Optional] Obtain AU45 for eyes blinking
|
106 |
+
|
107 |
+
Run `FeatureExtraction` in [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace), rename and move the output CSV file to `data/<ID>/au.csv`.
|
108 |
+
|
109 |
+
|
110 |
+
**[Note]** Since EmoTalk's blendshape capture is not open source, the preprocessing code here is replaced with mediapipe's blendshape capture. But according to some feedback, it doesn't work well, you can choose to replace it with AU45. If you want to compare with SyncTalk, some results from using EmoTalk capture can be obtained [here](https://drive.google.com/drive/folders/1LLFtQa2Yy2G0FaNOxwtZr0L974TXCYKh?usp=sharing) and videos from [GeneFace](https://drive.google.com/drive/folders/1vimGVNvP6d6nmmc8yAxtWuooxhJbkl68).
|
111 |
+
|
112 |
+
|
113 |
+
### Quick Start
|
114 |
+
#### Run the evaluation code
|
115 |
+
```bash
|
116 |
+
python main.py data/May --workspace model/trial_may -O --test --asr_model ave
|
117 |
+
|
118 |
+
python main.py data/May --workspace model/trial_may -O --test --asr_model ave --portrait
|
119 |
+
```
|
120 |
+
“ave” refers to our Audio Visual Encoder, “portrait” signifies pasting the generated face back onto the original image, representing higher quality.
|
121 |
+
|
122 |
+
If it runs correctly, you will get the following results.
|
123 |
+
|
124 |
+
| Setting | PSNR | LPIPS | LMD |
|
125 |
+
|--------------------------|--------|--------|-------|
|
126 |
+
| SyncTalk (w/o Portrait) | 32.201 | 0.0394 | 2.822 |
|
127 |
+
| SyncTalk (Portrait) | 37.644 | 0.0117 | 2.825 |
|
128 |
+
|
129 |
+
This is for a single subject; the paper reports the average results for multiple subjects.
|
130 |
+
|
131 |
+
#### Inference with target audio
|
132 |
+
```bash
|
133 |
+
python main.py data/May --workspace model/trial_may -O --test --test_train --asr_model ave --portrait --aud ./demo/test.wav
|
134 |
+
```
|
135 |
+
Please use files with the “.wav” extension for inference, and the inference results will be saved in “model/trial_may/results/”. If do not use Audio Visual Encoder, replace wav with the npy file path.
|
136 |
+
* DeepSpeech
|
137 |
+
|
138 |
+
```bash
|
139 |
+
python data_utils/deepspeech_features/extract_ds_features.py --input data/<name>.wav # save to data/<name>.npy
|
140 |
+
```
|
141 |
+
* HuBERT
|
142 |
+
|
143 |
+
```bash
|
144 |
+
# Borrowed from GeneFace. English pre-trained.
|
145 |
+
python data_utils/hubert.py --wav data/<name>.wav # save to data/<name>_hu.npy
|
146 |
+
```
|
147 |
+
### Train
|
148 |
+
```bash
|
149 |
+
# by default, we load data from disk on the fly.
|
150 |
+
# we can also preload all data to CPU/GPU for faster training, but this is very memory-hungry for large datasets.
|
151 |
+
# `--preload 0`: load from disk (default, slower).
|
152 |
+
# `--preload 1`: load to CPU (slightly slower)
|
153 |
+
# `--preload 2`: load to GPU (fast)
|
154 |
+
python main.py data/May --workspace model/trial_may -O --iters 60000 --asr_model ave
|
155 |
+
python main.py data/May --workspace model/trial_may -O --iters 100000 --finetune_lips --patch_size 64 --asr_model ave
|
156 |
+
|
157 |
+
# or you can use the script to train
|
158 |
+
sh ./scripts/train_may.sh
|
159 |
+
```
|
160 |
+
**[Tips]** Audio visual encoder (AVE) is suitable for characters with accurate lip sync and large lip movements such as May and Shaheen. Using AVE in the inference stage can achieve more accurate lip sync. If your training results show lip jitter, please try using deepspeech or hubert model as audio feature encoder.
|
161 |
+
|
162 |
+
```bash
|
163 |
+
# Use deepspeech model
|
164 |
+
python main.py data/May --workspace model/trial_may -O --iters 60000 --asr_model deepspeech
|
165 |
+
python main.py data/May --workspace model/trial_may -O --iters 100000 --finetune_lips --patch_size 64 --asr_model deepspeech
|
166 |
+
|
167 |
+
# Use hubert model
|
168 |
+
python main.py data/May --workspace model/trial_may -O --iters 60000 --asr_model hubert
|
169 |
+
python main.py data/May --workspace model/trial_may -O --iters 100000 --finetune_lips --patch_size 64 --asr_model hubert
|
170 |
+
```
|
171 |
+
|
172 |
+
If you want to use the OpenFace au45 as the eye parameter, please add "--au45" to the command line.
|
173 |
+
|
174 |
+
```bash
|
175 |
+
# Use OpenFace AU45
|
176 |
+
python main.py data/May --workspace model/trial_may -O --iters 60000 --asr_model ave --au45
|
177 |
+
python main.py data/May --workspace model/trial_may -O --iters 100000 --finetune_lips --patch_size 64 --asr_model ave --au45
|
178 |
+
```
|
179 |
+
|
180 |
+
### Test
|
181 |
+
```bash
|
182 |
+
python main.py data/May --workspace model/trial_may -O --test --asr_model ave --portrait
|
183 |
+
|
184 |
+
```
|
185 |
+
|
186 |
+
### Train & Test Torso [Repair Double Chin]
|
187 |
+
If your character trained only the head appeared double chin problem, you can introduce torso training. By training the torso, this problem can be solved, but **you will not be able to use the "--portrait" mode.** If you add "--portrait", the torso model will fail!
|
188 |
+
|
189 |
+
```bash
|
190 |
+
# Train
|
191 |
+
# <head>.pth should be the latest checkpoint in trial_may
|
192 |
+
python main.py data/May/ --workspace model/trial_may_torso/ -O --torso --head_ckpt <head>.pth --iters 150000 --asr_model ave
|
193 |
+
|
194 |
+
# For example
|
195 |
+
python main.py data/May/ --workspace model/trial_may_torso/ -O --torso --head_ckpt model/trial_may/ngp_ep0019.pth --iters 150000 --asr_model ave
|
196 |
+
|
197 |
+
# Test
|
198 |
+
python main.py data/May --workspace model/trial_may_torso -O --torso --test --asr_model ave # not support --portrait
|
199 |
+
|
200 |
+
# Inference with target audio
|
201 |
+
python main.py data/May --workspace model/trial_may_torso -O --torso --test --test_train --asr_model ave --aud ./demo/test.wav # not support --portrait
|
202 |
+
|
203 |
+
```
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
## Citation
|
208 |
+
|
209 |
+
```
|
210 |
+
@InProceedings{peng2023synctalk,
|
211 |
+
title = {SyncTalk: The Devil is in the Synchronization for Talking Head Synthesis},
|
212 |
+
author = {Ziqiao Peng and Wentao Hu and Yue Shi and Xiangyu Zhu and Xiaomei Zhang and Jun He and Hongyan Liu and Zhaoxin Fan},
|
213 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
214 |
+
month = {June},
|
215 |
+
year = {2024},
|
216 |
+
}
|
217 |
+
```
|
218 |
+
|
219 |
+
## Acknowledgement
|
220 |
+
This code is developed heavily relying on [ER-NeRF](https://github.com/Fictionarry/ER-NeRF), and also [RAD-NeRF](https://github.com/ashawkey/RAD-NeRF), [GeneFace](https://github.com/yerfor/GeneFace), [DFRF](https://github.com/sstzal/DFRF), [DFA-NeRF](https://github.com/ShunyuYao/DFA-NeRF/), [AD-NeRF](https://github.com/YudongGuo/AD-NeRF), and [Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch).
|
221 |
+
|
222 |
+
Thanks for these great projects. Thanks to [Tiandishihua](https://github.com/Tiandishihua) for helping us fix the bug that loss equals NaN.
|
223 |
+
|
224 |
+
## Disclaimer
|
225 |
+
By using the "SyncTalk", users agree to comply with all applicable laws and regulations, and acknowledge that misuse of the software, including the creation or distribution of harmful content, is strictly prohibited. The developers of the software disclaim all liability for any direct, indirect, or consequential damages arising from the use or misuse of the software.
|
assets/image/synctalk.png
ADDED
Git LFS Details
|
data/Please place the data file.
ADDED
File without changes
|
data_utils/UNFaceFlow/core/__init__.py
ADDED
File without changes
|
data_utils/UNFaceFlow/core/corr.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from utils_core.utils import bilinear_sampler, coords_grid
|
4 |
+
|
5 |
+
try:
|
6 |
+
import alt_cuda_corr
|
7 |
+
except:
|
8 |
+
# alt_cuda_corr is not compiled
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
class CorrBlock:
|
13 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
14 |
+
self.num_levels = num_levels
|
15 |
+
self.radius = radius
|
16 |
+
self.corr_pyramid = []
|
17 |
+
|
18 |
+
# all pairs correlation
|
19 |
+
corr = CorrBlock.corr(fmap1, fmap2)
|
20 |
+
|
21 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
22 |
+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
23 |
+
|
24 |
+
self.corr_pyramid.append(corr)
|
25 |
+
for i in range(self.num_levels-1):
|
26 |
+
corr = F.avg_pool2d(corr, 2, stride=2)
|
27 |
+
self.corr_pyramid.append(corr)
|
28 |
+
|
29 |
+
def __call__(self, coords):
|
30 |
+
r = self.radius
|
31 |
+
coords = coords.permute(0, 2, 3, 1)
|
32 |
+
batch, h1, w1, _ = coords.shape
|
33 |
+
|
34 |
+
out_pyramid = []
|
35 |
+
for i in range(self.num_levels):
|
36 |
+
corr = self.corr_pyramid[i]
|
37 |
+
dx = torch.linspace(-r, r, 2*r+1)
|
38 |
+
dy = torch.linspace(-r, r, 2*r+1)
|
39 |
+
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
|
40 |
+
|
41 |
+
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
42 |
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
43 |
+
coords_lvl = centroid_lvl + delta_lvl
|
44 |
+
|
45 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
46 |
+
corr = corr.view(batch, h1, w1, -1)
|
47 |
+
out_pyramid.append(corr)
|
48 |
+
|
49 |
+
out = torch.cat(out_pyramid, dim=-1)
|
50 |
+
return out.permute(0, 3, 1, 2).contiguous().float()
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def corr(fmap1, fmap2):
|
54 |
+
batch, dim, ht, wd = fmap1.shape
|
55 |
+
fmap1 = fmap1.view(batch, dim, ht*wd)
|
56 |
+
fmap2 = fmap2.view(batch, dim, ht*wd)
|
57 |
+
|
58 |
+
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
59 |
+
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
60 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
61 |
+
|
62 |
+
|
63 |
+
class AlternateCorrBlock:
|
64 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
65 |
+
self.num_levels = num_levels
|
66 |
+
self.radius = radius
|
67 |
+
|
68 |
+
self.pyramid = [(fmap1, fmap2)]
|
69 |
+
for i in range(self.num_levels):
|
70 |
+
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
71 |
+
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
72 |
+
self.pyramid.append((fmap1, fmap2))
|
73 |
+
|
74 |
+
def __call__(self, coords):
|
75 |
+
coords = coords.permute(0, 2, 3, 1)
|
76 |
+
B, H, W, _ = coords.shape
|
77 |
+
dim = self.pyramid[0][0].shape[1]
|
78 |
+
|
79 |
+
corr_list = []
|
80 |
+
for i in range(self.num_levels):
|
81 |
+
r = self.radius
|
82 |
+
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
83 |
+
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
84 |
+
|
85 |
+
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
86 |
+
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
87 |
+
corr_list.append(corr.squeeze(1))
|
88 |
+
|
89 |
+
corr = torch.stack(corr_list, dim=1)
|
90 |
+
corr = corr.reshape(B, -1, H, W)
|
91 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
data_utils/UNFaceFlow/core/datasets.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.utils.data as data
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
import os
|
9 |
+
import math
|
10 |
+
import random
|
11 |
+
from glob import glob
|
12 |
+
import os.path as osp
|
13 |
+
|
14 |
+
from utils import frame_utils
|
15 |
+
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
|
16 |
+
|
17 |
+
|
18 |
+
class FlowDataset(data.Dataset):
|
19 |
+
def __init__(self, aug_params=None, sparse=False):
|
20 |
+
self.augmentor = None
|
21 |
+
self.sparse = sparse
|
22 |
+
if aug_params is not None:
|
23 |
+
if sparse:
|
24 |
+
self.augmentor = SparseFlowAugmentor(**aug_params)
|
25 |
+
else:
|
26 |
+
self.augmentor = FlowAugmentor(**aug_params)
|
27 |
+
|
28 |
+
self.is_test = False
|
29 |
+
self.init_seed = False
|
30 |
+
self.flow_list = []
|
31 |
+
self.image_list = []
|
32 |
+
self.extra_info = []
|
33 |
+
|
34 |
+
def __getitem__(self, index):
|
35 |
+
|
36 |
+
if self.is_test:
|
37 |
+
img1 = frame_utils.read_gen(self.image_list[index][0])
|
38 |
+
img2 = frame_utils.read_gen(self.image_list[index][1])
|
39 |
+
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
40 |
+
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
41 |
+
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
42 |
+
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
43 |
+
return img1, img2, self.extra_info[index]
|
44 |
+
|
45 |
+
if not self.init_seed:
|
46 |
+
worker_info = torch.utils.data.get_worker_info()
|
47 |
+
if worker_info is not None:
|
48 |
+
torch.manual_seed(worker_info.id)
|
49 |
+
np.random.seed(worker_info.id)
|
50 |
+
random.seed(worker_info.id)
|
51 |
+
self.init_seed = True
|
52 |
+
|
53 |
+
index = index % len(self.image_list)
|
54 |
+
valid = None
|
55 |
+
if self.sparse:
|
56 |
+
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
|
57 |
+
else:
|
58 |
+
flow = frame_utils.read_gen(self.flow_list[index])
|
59 |
+
|
60 |
+
img1 = frame_utils.read_gen(self.image_list[index][0])
|
61 |
+
img2 = frame_utils.read_gen(self.image_list[index][1])
|
62 |
+
|
63 |
+
flow = np.array(flow).astype(np.float32)
|
64 |
+
img1 = np.array(img1).astype(np.uint8)
|
65 |
+
img2 = np.array(img2).astype(np.uint8)
|
66 |
+
|
67 |
+
# grayscale images
|
68 |
+
if len(img1.shape) == 2:
|
69 |
+
img1 = np.tile(img1[...,None], (1, 1, 3))
|
70 |
+
img2 = np.tile(img2[...,None], (1, 1, 3))
|
71 |
+
else:
|
72 |
+
img1 = img1[..., :3]
|
73 |
+
img2 = img2[..., :3]
|
74 |
+
|
75 |
+
if self.augmentor is not None:
|
76 |
+
if self.sparse:
|
77 |
+
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
|
78 |
+
else:
|
79 |
+
img1, img2, flow = self.augmentor(img1, img2, flow)
|
80 |
+
|
81 |
+
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
82 |
+
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
83 |
+
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
84 |
+
|
85 |
+
if valid is not None:
|
86 |
+
valid = torch.from_numpy(valid)
|
87 |
+
else:
|
88 |
+
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
|
89 |
+
|
90 |
+
return img1, img2, flow, valid.float()
|
91 |
+
|
92 |
+
|
93 |
+
def __rmul__(self, v):
|
94 |
+
self.flow_list = v * self.flow_list
|
95 |
+
self.image_list = v * self.image_list
|
96 |
+
return self
|
97 |
+
|
98 |
+
def __len__(self):
|
99 |
+
return len(self.image_list)
|
100 |
+
|
101 |
+
|
102 |
+
class MpiSintel(FlowDataset):
|
103 |
+
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
|
104 |
+
super(MpiSintel, self).__init__(aug_params)
|
105 |
+
flow_root = osp.join(root, split, 'flow')
|
106 |
+
image_root = osp.join(root, split, dstype)
|
107 |
+
|
108 |
+
if split == 'test':
|
109 |
+
self.is_test = True
|
110 |
+
|
111 |
+
for scene in os.listdir(image_root):
|
112 |
+
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
|
113 |
+
for i in range(len(image_list)-1):
|
114 |
+
self.image_list += [ [image_list[i], image_list[i+1]] ]
|
115 |
+
self.extra_info += [ (scene, i) ] # scene and frame_id
|
116 |
+
|
117 |
+
if split != 'test':
|
118 |
+
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
|
119 |
+
|
120 |
+
|
121 |
+
class FlyingChairs(FlowDataset):
|
122 |
+
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
|
123 |
+
super(FlyingChairs, self).__init__(aug_params)
|
124 |
+
|
125 |
+
images = sorted(glob(osp.join(root, '*.ppm')))
|
126 |
+
flows = sorted(glob(osp.join(root, '*.flo')))
|
127 |
+
assert (len(images)//2 == len(flows))
|
128 |
+
|
129 |
+
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
|
130 |
+
for i in range(len(flows)):
|
131 |
+
xid = split_list[i]
|
132 |
+
if (split=='training' and xid==1) or (split=='validation' and xid==2):
|
133 |
+
self.flow_list += [ flows[i] ]
|
134 |
+
self.image_list += [ [images[2*i], images[2*i+1]] ]
|
135 |
+
|
136 |
+
|
137 |
+
class FlyingThings3D(FlowDataset):
|
138 |
+
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
|
139 |
+
super(FlyingThings3D, self).__init__(aug_params)
|
140 |
+
|
141 |
+
for cam in ['left']:
|
142 |
+
for direction in ['into_future', 'into_past']:
|
143 |
+
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
|
144 |
+
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
|
145 |
+
|
146 |
+
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
|
147 |
+
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
|
148 |
+
|
149 |
+
for idir, fdir in zip(image_dirs, flow_dirs):
|
150 |
+
images = sorted(glob(osp.join(idir, '*.png')) )
|
151 |
+
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
|
152 |
+
for i in range(len(flows)-1):
|
153 |
+
if direction == 'into_future':
|
154 |
+
self.image_list += [ [images[i], images[i+1]] ]
|
155 |
+
self.flow_list += [ flows[i] ]
|
156 |
+
elif direction == 'into_past':
|
157 |
+
self.image_list += [ [images[i+1], images[i]] ]
|
158 |
+
self.flow_list += [ flows[i+1] ]
|
159 |
+
|
160 |
+
|
161 |
+
class KITTI(FlowDataset):
|
162 |
+
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
|
163 |
+
super(KITTI, self).__init__(aug_params, sparse=True)
|
164 |
+
if split == 'testing':
|
165 |
+
self.is_test = True
|
166 |
+
|
167 |
+
root = osp.join(root, split)
|
168 |
+
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
|
169 |
+
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
|
170 |
+
|
171 |
+
for img1, img2 in zip(images1, images2):
|
172 |
+
frame_id = img1.split('/')[-1]
|
173 |
+
self.extra_info += [ [frame_id] ]
|
174 |
+
self.image_list += [ [img1, img2] ]
|
175 |
+
|
176 |
+
if split == 'training':
|
177 |
+
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
|
178 |
+
|
179 |
+
|
180 |
+
class HD1K(FlowDataset):
|
181 |
+
def __init__(self, aug_params=None, root='datasets/HD1k'):
|
182 |
+
super(HD1K, self).__init__(aug_params, sparse=True)
|
183 |
+
|
184 |
+
seq_ix = 0
|
185 |
+
while 1:
|
186 |
+
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
|
187 |
+
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
|
188 |
+
|
189 |
+
if len(flows) == 0:
|
190 |
+
break
|
191 |
+
|
192 |
+
for i in range(len(flows)-1):
|
193 |
+
self.flow_list += [flows[i]]
|
194 |
+
self.image_list += [ [images[i], images[i+1]] ]
|
195 |
+
|
196 |
+
seq_ix += 1
|
197 |
+
|
198 |
+
|
199 |
+
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
200 |
+
""" Create the data loader for the corresponding trainign set """
|
201 |
+
|
202 |
+
if args.stage == 'chairs':
|
203 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
|
204 |
+
train_dataset = FlyingChairs(aug_params, split='training')
|
205 |
+
|
206 |
+
elif args.stage == 'things':
|
207 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
|
208 |
+
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
209 |
+
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
|
210 |
+
train_dataset = clean_dataset + final_dataset
|
211 |
+
|
212 |
+
elif args.stage == 'sintel':
|
213 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
|
214 |
+
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
215 |
+
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
|
216 |
+
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
|
217 |
+
|
218 |
+
if TRAIN_DS == 'C+T+K+S+H':
|
219 |
+
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
|
220 |
+
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
|
221 |
+
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
|
222 |
+
|
223 |
+
elif TRAIN_DS == 'C+T+K/S':
|
224 |
+
train_dataset = 100*sintel_clean + 100*sintel_final + things
|
225 |
+
|
226 |
+
elif args.stage == 'kitti':
|
227 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
|
228 |
+
train_dataset = KITTI(aug_params, split='training')
|
229 |
+
|
230 |
+
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
231 |
+
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
|
232 |
+
|
233 |
+
print('Training with %d image pairs' % len(train_dataset))
|
234 |
+
return train_loader
|
235 |
+
|
data_utils/UNFaceFlow/core/extractor.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class ResidualBlock(nn.Module):
|
7 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
8 |
+
super(ResidualBlock, self).__init__()
|
9 |
+
|
10 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
11 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
12 |
+
self.relu = nn.ReLU(inplace=True)
|
13 |
+
|
14 |
+
num_groups = planes // 8
|
15 |
+
|
16 |
+
if norm_fn == 'group':
|
17 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
18 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
19 |
+
if not stride == 1:
|
20 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
21 |
+
|
22 |
+
elif norm_fn == 'batch':
|
23 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
24 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
25 |
+
if not stride == 1:
|
26 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
27 |
+
|
28 |
+
elif norm_fn == 'instance':
|
29 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
30 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
31 |
+
if not stride == 1:
|
32 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
33 |
+
|
34 |
+
elif norm_fn == 'none':
|
35 |
+
self.norm1 = nn.Sequential()
|
36 |
+
self.norm2 = nn.Sequential()
|
37 |
+
if not stride == 1:
|
38 |
+
self.norm3 = nn.Sequential()
|
39 |
+
|
40 |
+
if stride == 1:
|
41 |
+
self.downsample = None
|
42 |
+
|
43 |
+
else:
|
44 |
+
self.downsample = nn.Sequential(
|
45 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
46 |
+
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
y = x
|
50 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
51 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
52 |
+
|
53 |
+
if self.downsample is not None:
|
54 |
+
x = self.downsample(x)
|
55 |
+
|
56 |
+
return self.relu(x+y)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
class BottleneckBlock(nn.Module):
|
61 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
62 |
+
super(BottleneckBlock, self).__init__()
|
63 |
+
|
64 |
+
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
65 |
+
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
66 |
+
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
67 |
+
self.relu = nn.ReLU(inplace=True)
|
68 |
+
|
69 |
+
num_groups = planes // 8
|
70 |
+
|
71 |
+
if norm_fn == 'group':
|
72 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
73 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
74 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
75 |
+
if not stride == 1:
|
76 |
+
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
77 |
+
|
78 |
+
elif norm_fn == 'batch':
|
79 |
+
self.norm1 = nn.BatchNorm2d(planes//4)
|
80 |
+
self.norm2 = nn.BatchNorm2d(planes//4)
|
81 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
82 |
+
if not stride == 1:
|
83 |
+
self.norm4 = nn.BatchNorm2d(planes)
|
84 |
+
|
85 |
+
elif norm_fn == 'instance':
|
86 |
+
self.norm1 = nn.InstanceNorm2d(planes//4)
|
87 |
+
self.norm2 = nn.InstanceNorm2d(planes//4)
|
88 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
89 |
+
if not stride == 1:
|
90 |
+
self.norm4 = nn.InstanceNorm2d(planes)
|
91 |
+
|
92 |
+
elif norm_fn == 'none':
|
93 |
+
self.norm1 = nn.Sequential()
|
94 |
+
self.norm2 = nn.Sequential()
|
95 |
+
self.norm3 = nn.Sequential()
|
96 |
+
if not stride == 1:
|
97 |
+
self.norm4 = nn.Sequential()
|
98 |
+
|
99 |
+
if stride == 1:
|
100 |
+
self.downsample = None
|
101 |
+
|
102 |
+
else:
|
103 |
+
self.downsample = nn.Sequential(
|
104 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
105 |
+
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
y = x
|
109 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
110 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
111 |
+
y = self.relu(self.norm3(self.conv3(y)))
|
112 |
+
|
113 |
+
if self.downsample is not None:
|
114 |
+
x = self.downsample(x)
|
115 |
+
|
116 |
+
return self.relu(x+y)
|
117 |
+
|
118 |
+
class BasicEncoder(nn.Module):
|
119 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
120 |
+
super(BasicEncoder, self).__init__()
|
121 |
+
self.norm_fn = norm_fn
|
122 |
+
|
123 |
+
if self.norm_fn == 'group':
|
124 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
125 |
+
|
126 |
+
elif self.norm_fn == 'batch':
|
127 |
+
self.norm1 = nn.BatchNorm2d(64)
|
128 |
+
|
129 |
+
elif self.norm_fn == 'instance':
|
130 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
131 |
+
|
132 |
+
elif self.norm_fn == 'none':
|
133 |
+
self.norm1 = nn.Sequential()
|
134 |
+
|
135 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
136 |
+
self.relu1 = nn.ReLU(inplace=True)
|
137 |
+
|
138 |
+
self.in_planes = 64
|
139 |
+
self.layer1 = self._make_layer(64, stride=1)
|
140 |
+
self.layer2 = self._make_layer(96, stride=2)
|
141 |
+
self.layer3 = self._make_layer(128, stride=2)
|
142 |
+
|
143 |
+
# output convolution
|
144 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
145 |
+
|
146 |
+
self.dropout = None
|
147 |
+
if dropout > 0:
|
148 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
149 |
+
|
150 |
+
for m in self.modules():
|
151 |
+
if isinstance(m, nn.Conv2d):
|
152 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
153 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
154 |
+
if m.weight is not None:
|
155 |
+
nn.init.constant_(m.weight, 1)
|
156 |
+
if m.bias is not None:
|
157 |
+
nn.init.constant_(m.bias, 0)
|
158 |
+
|
159 |
+
def _make_layer(self, dim, stride=1):
|
160 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
161 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
162 |
+
layers = (layer1, layer2)
|
163 |
+
|
164 |
+
self.in_planes = dim
|
165 |
+
return nn.Sequential(*layers)
|
166 |
+
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
|
170 |
+
# if input is list, combine batch dimension
|
171 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
172 |
+
if is_list:
|
173 |
+
batch_dim = x[0].shape[0]
|
174 |
+
x = torch.cat(x, dim=0)
|
175 |
+
|
176 |
+
x = self.conv1(x)
|
177 |
+
x = self.norm1(x)
|
178 |
+
x = self.relu1(x)
|
179 |
+
|
180 |
+
x = self.layer1(x)
|
181 |
+
x = self.layer2(x)
|
182 |
+
x = self.layer3(x)
|
183 |
+
|
184 |
+
x = self.conv2(x)
|
185 |
+
|
186 |
+
if self.training and self.dropout is not None:
|
187 |
+
x = self.dropout(x)
|
188 |
+
|
189 |
+
if is_list:
|
190 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
191 |
+
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class SmallEncoder(nn.Module):
|
196 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
197 |
+
super(SmallEncoder, self).__init__()
|
198 |
+
self.norm_fn = norm_fn
|
199 |
+
|
200 |
+
if self.norm_fn == 'group':
|
201 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
202 |
+
|
203 |
+
elif self.norm_fn == 'batch':
|
204 |
+
self.norm1 = nn.BatchNorm2d(32)
|
205 |
+
|
206 |
+
elif self.norm_fn == 'instance':
|
207 |
+
self.norm1 = nn.InstanceNorm2d(32)
|
208 |
+
|
209 |
+
elif self.norm_fn == 'none':
|
210 |
+
self.norm1 = nn.Sequential()
|
211 |
+
|
212 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
213 |
+
self.relu1 = nn.ReLU(inplace=True)
|
214 |
+
|
215 |
+
self.in_planes = 32
|
216 |
+
self.layer1 = self._make_layer(32, stride=1)
|
217 |
+
self.layer2 = self._make_layer(64, stride=2)
|
218 |
+
self.layer3 = self._make_layer(96, stride=2)
|
219 |
+
|
220 |
+
self.dropout = None
|
221 |
+
if dropout > 0:
|
222 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
223 |
+
|
224 |
+
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
225 |
+
|
226 |
+
for m in self.modules():
|
227 |
+
if isinstance(m, nn.Conv2d):
|
228 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
229 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
230 |
+
if m.weight is not None:
|
231 |
+
nn.init.constant_(m.weight, 1)
|
232 |
+
if m.bias is not None:
|
233 |
+
nn.init.constant_(m.bias, 0)
|
234 |
+
|
235 |
+
def _make_layer(self, dim, stride=1):
|
236 |
+
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
237 |
+
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
238 |
+
layers = (layer1, layer2)
|
239 |
+
|
240 |
+
self.in_planes = dim
|
241 |
+
return nn.Sequential(*layers)
|
242 |
+
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
|
246 |
+
# if input is list, combine batch dimension
|
247 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
248 |
+
if is_list:
|
249 |
+
batch_dim = x[0].shape[0]
|
250 |
+
x = torch.cat(x, dim=0)
|
251 |
+
|
252 |
+
x = self.conv1(x)
|
253 |
+
x = self.norm1(x)
|
254 |
+
x = self.relu1(x)
|
255 |
+
x = self.layer1(x)
|
256 |
+
x = self.layer2(x)
|
257 |
+
x = self.layer3(x)
|
258 |
+
x = self.conv2(x)
|
259 |
+
|
260 |
+
if self.training and self.dropout is not None:
|
261 |
+
x = self.dropout(x)
|
262 |
+
|
263 |
+
if is_list:
|
264 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
265 |
+
|
266 |
+
return x
|
data_utils/UNFaceFlow/core/nnutils.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def make_conv(n_in, n_out, n_blocks, kernel=3, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
|
7 |
+
blocks = []
|
8 |
+
for i in range(n_blocks):
|
9 |
+
in1 = n_in if i == 0 else n_out
|
10 |
+
blocks.append(torch.nn.Sequential(
|
11 |
+
torch.nn.Conv3d(in1, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
|
12 |
+
normalization(n_out),
|
13 |
+
activation(inplace=True)
|
14 |
+
))
|
15 |
+
return torch.nn.Sequential(*blocks)
|
16 |
+
|
17 |
+
|
18 |
+
def make_conv_2d(n_in, n_out, n_blocks, kernel=3, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
|
19 |
+
blocks = []
|
20 |
+
for i in range(n_blocks):
|
21 |
+
in1 = n_in if i == 0 else n_out
|
22 |
+
blocks.append(torch.nn.Sequential(
|
23 |
+
torch.nn.Conv2d(in1, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
|
24 |
+
normalization(n_out),
|
25 |
+
activation(inplace=True)
|
26 |
+
))
|
27 |
+
return torch.nn.Sequential(*blocks)
|
28 |
+
|
29 |
+
|
30 |
+
def make_downscale(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
|
31 |
+
block = torch.nn.Sequential(
|
32 |
+
torch.nn.Conv3d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
|
33 |
+
normalization(n_out),
|
34 |
+
activation(inplace=True)
|
35 |
+
)
|
36 |
+
return block
|
37 |
+
|
38 |
+
|
39 |
+
def make_downscale_2d(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
|
40 |
+
block = torch.nn.Sequential(
|
41 |
+
torch.nn.Conv2d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
|
42 |
+
normalization(n_out),
|
43 |
+
activation(inplace=True)
|
44 |
+
)
|
45 |
+
return block
|
46 |
+
|
47 |
+
|
48 |
+
def make_upscale(n_in, n_out, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
|
49 |
+
block = torch.nn.Sequential(
|
50 |
+
torch.nn.ConvTranspose3d(n_in, n_out, kernel_size=6, stride=2, padding=2),
|
51 |
+
normalization(n_out),
|
52 |
+
activation(inplace=True)
|
53 |
+
)
|
54 |
+
return block
|
55 |
+
|
56 |
+
|
57 |
+
def make_upscale_2d(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
|
58 |
+
block = torch.nn.Sequential(
|
59 |
+
torch.nn.ConvTranspose2d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
|
60 |
+
normalization(n_out),
|
61 |
+
activation(inplace=True)
|
62 |
+
)
|
63 |
+
return block
|
64 |
+
|
65 |
+
|
66 |
+
class ResBlock(torch.nn.Module):
|
67 |
+
def __init__(self, n_out, kernel=3, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
|
68 |
+
super().__init__()
|
69 |
+
self.block0 = torch.nn.Sequential(
|
70 |
+
torch.nn.Conv3d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
|
71 |
+
normalization(n_out),
|
72 |
+
activation(inplace=True)
|
73 |
+
)
|
74 |
+
|
75 |
+
self.block1 = torch.nn.Sequential(
|
76 |
+
torch.nn.Conv3d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
|
77 |
+
normalization(n_out),
|
78 |
+
)
|
79 |
+
|
80 |
+
self.block2 = torch.nn.ReLU()
|
81 |
+
|
82 |
+
def forward(self, x0):
|
83 |
+
x = self.block0(x0)
|
84 |
+
|
85 |
+
x = self.block1(x)
|
86 |
+
|
87 |
+
x = self.block2(x + x0)
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
class ResBlock2d(torch.nn.Module):
|
92 |
+
def __init__(self, n_out, kernel=3, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
|
93 |
+
super().__init__()
|
94 |
+
self.block0 = torch.nn.Sequential(
|
95 |
+
torch.nn.Conv2d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
|
96 |
+
normalization(n_out),
|
97 |
+
activation(inplace=True)
|
98 |
+
)
|
99 |
+
|
100 |
+
self.block1 = torch.nn.Sequential(
|
101 |
+
torch.nn.Conv2d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
|
102 |
+
normalization(n_out),
|
103 |
+
)
|
104 |
+
|
105 |
+
self.block2 = torch.nn.ReLU()
|
106 |
+
|
107 |
+
def forward(self, x0):
|
108 |
+
x = self.block0(x0)
|
109 |
+
|
110 |
+
x = self.block1(x)
|
111 |
+
|
112 |
+
x = self.block2(x + x0)
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
class Identity(torch.nn.Module):
|
117 |
+
def __init__(self, *args, **kwargs):
|
118 |
+
super().__init__()
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
def downscale_gt_flow(flow_gt, flow_mask, image_height, image_width):
|
125 |
+
flow_gt_copy = flow_gt.clone()
|
126 |
+
flow_mask_copy = flow_mask.clone()
|
127 |
+
|
128 |
+
flow_gt_copy = flow_gt_copy / 20.0
|
129 |
+
flow_mask_copy = flow_mask_copy.float()
|
130 |
+
|
131 |
+
assert image_height % 64 == 0 and image_width % 64 == 0
|
132 |
+
|
133 |
+
flow_gt2 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//4, image_width//4), mode='nearest')
|
134 |
+
flow_mask2 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//4, image_width//4), mode='nearest').bool()
|
135 |
+
|
136 |
+
flow_gt3 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//8, image_width//8), mode='nearest')
|
137 |
+
flow_mask3 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//8, image_width//8), mode='nearest').bool()
|
138 |
+
|
139 |
+
flow_gt4 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//16, image_width//16), mode='nearest')
|
140 |
+
flow_mask4 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//16, image_width//16), mode='nearest').bool()
|
141 |
+
|
142 |
+
flow_gt5 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//32, image_width//32), mode='nearest')
|
143 |
+
flow_mask5 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//32, image_width//32), mode='nearest').bool()
|
144 |
+
|
145 |
+
flow_gt6 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//64, image_width//64), mode='nearest')
|
146 |
+
flow_mask6 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//64, image_width//64), mode='nearest').bool()
|
147 |
+
|
148 |
+
return [flow_gt2, flow_gt3, flow_gt4, flow_gt5, flow_gt6], [flow_mask2, flow_mask3, flow_mask4, flow_mask5, flow_mask6]
|
149 |
+
|
150 |
+
|
151 |
+
def compute_baseline_mask_gt(
|
152 |
+
xy_coords_warped,
|
153 |
+
target_matches, valid_target_matches,
|
154 |
+
source_points, valid_source_points,
|
155 |
+
scene_flow_gt, scene_flow_mask, target_boundary_mask,
|
156 |
+
max_pos_flowed_source_to_target_dist, min_neg_flowed_source_to_target_dist
|
157 |
+
):
|
158 |
+
# Scene flow mask
|
159 |
+
scene_flow_mask_0 = scene_flow_mask[:, 0].type(torch.bool)
|
160 |
+
|
161 |
+
# Boundary correspondences mask
|
162 |
+
# We use the nearest neighbor interpolation, since the boundary computations
|
163 |
+
# already marks any of 4 pixels as boundary.
|
164 |
+
target_nonboundary_mask = (~target_boundary_mask).type(torch.float32)
|
165 |
+
target_matches_nonboundary_mask = torch.nn.functional.grid_sample(target_nonboundary_mask, xy_coords_warped, padding_mode='zeros', mode='nearest', align_corners=False)
|
166 |
+
target_matches_nonboundary_mask = target_matches_nonboundary_mask[:, 0, :, :] >= 0.999
|
167 |
+
|
168 |
+
# Compute groundtruth mask (oracle)
|
169 |
+
flowed_source_points = source_points + scene_flow_gt
|
170 |
+
dist = torch.norm(flowed_source_points - target_matches, p=2, dim=1)
|
171 |
+
|
172 |
+
# Combine all masks
|
173 |
+
# We mark a correspondence as positive if;
|
174 |
+
# - it is close enough to groundtruth flow
|
175 |
+
# AND
|
176 |
+
# - there exists groundtruth flow
|
177 |
+
# AND
|
178 |
+
# - the target match is valid
|
179 |
+
# AND
|
180 |
+
# - the source point is valid
|
181 |
+
# AND
|
182 |
+
# - the target match is not on the boundary
|
183 |
+
mask_pos_gt = (dist <= max_pos_flowed_source_to_target_dist) & scene_flow_mask_0 & valid_target_matches & valid_source_points & target_matches_nonboundary_mask
|
184 |
+
|
185 |
+
# We mark a correspondence as negative if;
|
186 |
+
# - there exists groundtruth flow AND it is far away enough from the groundtruth flow AND source/target points are valid
|
187 |
+
# OR
|
188 |
+
# - the target match is on the boundary AND there exists groundtruth flow AND source/target points are valid
|
189 |
+
mask_neg_gt = ((dist > min_neg_flowed_source_to_target_dist) & scene_flow_mask_0 & valid_source_points & valid_target_matches) \
|
190 |
+
| (~target_matches_nonboundary_mask & scene_flow_mask_0 & valid_source_points & valid_target_matches)
|
191 |
+
|
192 |
+
# What remains is left undecided (masked out at loss).
|
193 |
+
# For groundtruth mask we set it to zero.
|
194 |
+
valid_mask_pixels = mask_pos_gt | mask_neg_gt
|
195 |
+
mask_gt = mask_pos_gt
|
196 |
+
|
197 |
+
mask_gt = mask_gt.type(torch.float32)
|
198 |
+
|
199 |
+
return mask_gt, valid_mask_pixels
|
200 |
+
|
201 |
+
|
202 |
+
def compute_deformed_points_gt(
|
203 |
+
source_points, scene_flow_gt,
|
204 |
+
valid_solve, valid_correspondences,
|
205 |
+
deformed_points_idxs, deformed_points_subsampled
|
206 |
+
):
|
207 |
+
batch_size = source_points.shape[0]
|
208 |
+
max_warped_points = deformed_points_idxs.shape[1]
|
209 |
+
|
210 |
+
deformed_points_gt = torch.zeros((batch_size, max_warped_points, 3), dtype=source_points.dtype, device=source_points.device)
|
211 |
+
deformed_points_mask = torch.zeros((batch_size, max_warped_points, 3), dtype=source_points.dtype, device=source_points.device)
|
212 |
+
|
213 |
+
for i in range(batch_size):
|
214 |
+
if valid_solve[i]:
|
215 |
+
valid_correspondences_idxs = torch.where(valid_correspondences[i])
|
216 |
+
|
217 |
+
# Compute deformed point groundtruth.
|
218 |
+
deformed_points_i_gt = source_points[i] + scene_flow_gt[i]
|
219 |
+
deformed_points_i_gt = deformed_points_i_gt.permute(1, 2, 0)
|
220 |
+
deformed_points_i_gt = deformed_points_i_gt[valid_correspondences_idxs[0], valid_correspondences_idxs[1], :].view(-1, 3, 1)
|
221 |
+
|
222 |
+
# Filter out points randomly, if too many are still left.
|
223 |
+
if deformed_points_subsampled[i]:
|
224 |
+
sampled_idxs_i = deformed_points_idxs[i]
|
225 |
+
deformed_points_i_gt = deformed_points_i_gt[sampled_idxs_i]
|
226 |
+
|
227 |
+
num_points = deformed_points_i_gt.shape[0]
|
228 |
+
|
229 |
+
# Store the results.
|
230 |
+
deformed_points_gt[i, :num_points, :] = deformed_points_i_gt.view(1, num_points, 3)
|
231 |
+
deformed_points_mask[i, :num_points, :] = 1
|
232 |
+
|
233 |
+
return deformed_points_gt, deformed_points_mask
|
data_utils/UNFaceFlow/core/raft.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from update import BasicUpdateBlock, SmallUpdateBlock
|
7 |
+
from extractor import BasicEncoder, SmallEncoder
|
8 |
+
from corr import CorrBlock, AlternateCorrBlock
|
9 |
+
from utils_core.utils import bilinear_sampler, coords_grid, upflow8
|
10 |
+
|
11 |
+
try:
|
12 |
+
autocast = torch.cuda.amp.autocast
|
13 |
+
except:
|
14 |
+
# dummy autocast for PyTorch < 1.6
|
15 |
+
class autocast:
|
16 |
+
def __init__(self, enabled):
|
17 |
+
pass
|
18 |
+
def __enter__(self):
|
19 |
+
pass
|
20 |
+
def __exit__(self, *args):
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class RAFT(nn.Module):
|
25 |
+
def __init__(self, args):
|
26 |
+
super(RAFT, self).__init__()
|
27 |
+
self.args = args
|
28 |
+
|
29 |
+
if args.small:
|
30 |
+
self.hidden_dim = hdim = 96
|
31 |
+
self.context_dim = cdim = 64
|
32 |
+
args.corr_levels = 4
|
33 |
+
args.corr_radius = 3
|
34 |
+
|
35 |
+
else:
|
36 |
+
self.hidden_dim = hdim = 128
|
37 |
+
self.context_dim = cdim = 128
|
38 |
+
args.corr_levels = 4
|
39 |
+
args.corr_radius = 4
|
40 |
+
|
41 |
+
if 'dropout' not in self.args:
|
42 |
+
self.args.dropout = 0
|
43 |
+
|
44 |
+
if 'alternate_corr' not in self.args:
|
45 |
+
self.args.alternate_corr = False
|
46 |
+
|
47 |
+
# feature network, context network, and update block
|
48 |
+
if args.small:
|
49 |
+
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
50 |
+
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
51 |
+
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
52 |
+
|
53 |
+
else:
|
54 |
+
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
|
55 |
+
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
56 |
+
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
57 |
+
|
58 |
+
def freeze_bn(self):
|
59 |
+
for m in self.modules():
|
60 |
+
if isinstance(m, nn.BatchNorm2d):
|
61 |
+
m.eval()
|
62 |
+
|
63 |
+
def initialize_flow(self, img):
|
64 |
+
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
65 |
+
N, C, H, W = img.shape
|
66 |
+
coords0 = coords_grid(N, H//8, W//8).to(img.device)
|
67 |
+
coords1 = coords_grid(N, H//8, W//8).to(img.device)
|
68 |
+
|
69 |
+
# optical flow computed as difference: flow = coords1 - coords0
|
70 |
+
return coords0, coords1
|
71 |
+
|
72 |
+
def upsample_flow(self, flow, mask):
|
73 |
+
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
74 |
+
N, _, H, W = flow.shape
|
75 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
76 |
+
mask = torch.softmax(mask, dim=2)
|
77 |
+
|
78 |
+
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
79 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
80 |
+
|
81 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
82 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
83 |
+
return up_flow.reshape(N, 2, 8*H, 8*W)
|
84 |
+
|
85 |
+
|
86 |
+
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
87 |
+
""" Estimate optical flow between pair of frames """
|
88 |
+
|
89 |
+
image1 = 2 * (image1 / 255.0) - 1.0
|
90 |
+
image2 = 2 * (image2 / 255.0) - 1.0
|
91 |
+
|
92 |
+
image1 = image1.contiguous()
|
93 |
+
image2 = image2.contiguous()
|
94 |
+
|
95 |
+
hdim = self.hidden_dim
|
96 |
+
cdim = self.context_dim
|
97 |
+
|
98 |
+
# run the feature network
|
99 |
+
with autocast(enabled=self.args.mixed_precision):
|
100 |
+
fmap1, fmap2 = self.fnet([image1, image2])
|
101 |
+
|
102 |
+
fmap1 = fmap1.float()
|
103 |
+
fmap2 = fmap2.float()
|
104 |
+
# print("fmap mean: ", fmap1.mean(), fmap2.mean())
|
105 |
+
if self.args.alternate_corr:
|
106 |
+
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
107 |
+
else:
|
108 |
+
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
109 |
+
|
110 |
+
# run the context network
|
111 |
+
with autocast(enabled=self.args.mixed_precision):
|
112 |
+
cnet = self.cnet(image1)
|
113 |
+
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
114 |
+
net = torch.tanh(net)
|
115 |
+
inp = torch.relu(inp)
|
116 |
+
|
117 |
+
coords0, coords1 = self.initialize_flow(image1)
|
118 |
+
|
119 |
+
if flow_init is not None:
|
120 |
+
coords1 = coords1 + flow_init
|
121 |
+
|
122 |
+
flow_predictions = []
|
123 |
+
for itr in range(iters):
|
124 |
+
coords1 = coords1.detach()
|
125 |
+
corr = corr_fn(coords1) # index correlation volume
|
126 |
+
|
127 |
+
flow = coords1 - coords0
|
128 |
+
with autocast(enabled=self.args.mixed_precision):
|
129 |
+
net, up_mask, delta_flow, feature = self.update_block(net, inp, corr, flow)
|
130 |
+
# print("delta flow mean: ", delta_flow.mean())
|
131 |
+
# F(t+1) = F(t) + \Delta(t)
|
132 |
+
coords1 = coords1 + delta_flow
|
133 |
+
|
134 |
+
# upsample predictions
|
135 |
+
if up_mask is None:
|
136 |
+
flow_up = upflow8(coords1 - coords0)
|
137 |
+
else:
|
138 |
+
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
139 |
+
|
140 |
+
return flow_up, feature
|
141 |
+
|
142 |
+
class RAFT_ALL(nn.Module):
|
143 |
+
def __init__(self, args):
|
144 |
+
super(RAFT_ALL, self).__init__()
|
145 |
+
self.args = args
|
146 |
+
|
147 |
+
if args.small:
|
148 |
+
self.hidden_dim = hdim = 96
|
149 |
+
self.context_dim = cdim = 64
|
150 |
+
args.corr_levels = 4
|
151 |
+
args.corr_radius = 3
|
152 |
+
|
153 |
+
else:
|
154 |
+
self.hidden_dim = hdim = 128
|
155 |
+
self.context_dim = cdim = 128
|
156 |
+
args.corr_levels = 4
|
157 |
+
args.corr_radius = 4
|
158 |
+
|
159 |
+
if 'dropout' not in self.args:
|
160 |
+
self.args.dropout = 0
|
161 |
+
|
162 |
+
if 'alternate_corr' not in self.args:
|
163 |
+
self.args.alternate_corr = False
|
164 |
+
|
165 |
+
# feature network, context network, and update block
|
166 |
+
if args.small:
|
167 |
+
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
168 |
+
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
169 |
+
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
170 |
+
|
171 |
+
else:
|
172 |
+
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
|
173 |
+
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
174 |
+
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
175 |
+
|
176 |
+
def freeze_bn(self):
|
177 |
+
for m in self.modules():
|
178 |
+
if isinstance(m, nn.BatchNorm2d):
|
179 |
+
m.eval()
|
180 |
+
|
181 |
+
def initialize_flow(self, img):
|
182 |
+
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
183 |
+
N, C, H, W = img.shape
|
184 |
+
coords0 = coords_grid(N, H//8, W//8).to(img.device)
|
185 |
+
coords1 = coords_grid(N, H//8, W//8).to(img.device)
|
186 |
+
|
187 |
+
# optical flow computed as difference: flow = coords1 - coords0
|
188 |
+
return coords0, coords1
|
189 |
+
|
190 |
+
def upsample_flow(self, flow, mask):
|
191 |
+
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
192 |
+
N, _, H, W = flow.shape
|
193 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
194 |
+
mask = torch.softmax(mask, dim=2)
|
195 |
+
|
196 |
+
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
197 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
198 |
+
|
199 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
200 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
201 |
+
return up_flow.reshape(N, 2, 8*H, 8*W)
|
202 |
+
|
203 |
+
|
204 |
+
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
|
205 |
+
""" Estimate optical flow between pair of frames """
|
206 |
+
|
207 |
+
image1 = 2 * (image1 / 255.0) - 1.0
|
208 |
+
image2 = 2 * (image2 / 255.0) - 1.0
|
209 |
+
|
210 |
+
image1 = image1.contiguous()
|
211 |
+
image2 = image2.contiguous()
|
212 |
+
|
213 |
+
hdim = self.hidden_dim
|
214 |
+
cdim = self.context_dim
|
215 |
+
|
216 |
+
# run the feature network
|
217 |
+
with autocast(enabled=self.args.mixed_precision):
|
218 |
+
fmap1, fmap2 = self.fnet([image1, image2])
|
219 |
+
|
220 |
+
fmap1 = fmap1.float()
|
221 |
+
fmap2 = fmap2.float()
|
222 |
+
# print("fmap mean: ", fmap1.mean(), fmap2.mean())
|
223 |
+
if self.args.alternate_corr:
|
224 |
+
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
225 |
+
else:
|
226 |
+
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
227 |
+
|
228 |
+
# run the context network
|
229 |
+
with autocast(enabled=self.args.mixed_precision):
|
230 |
+
cnet = self.cnet(image1)
|
231 |
+
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
232 |
+
net = torch.tanh(net)
|
233 |
+
inp = torch.relu(inp)
|
234 |
+
|
235 |
+
coords0, coords1 = self.initialize_flow(image1)
|
236 |
+
|
237 |
+
if flow_init is not None:
|
238 |
+
coords1 = coords1 + flow_init
|
239 |
+
|
240 |
+
flow_predictions = []
|
241 |
+
for itr in range(iters):
|
242 |
+
coords1 = coords1.detach()
|
243 |
+
corr = corr_fn(coords1) # index correlation volume
|
244 |
+
|
245 |
+
flow = coords1 - coords0
|
246 |
+
with autocast(enabled=self.args.mixed_precision):
|
247 |
+
net, up_mask, delta_flow, feature = self.update_block(net, inp, corr, flow)
|
248 |
+
# print("delta flow mean: ", delta_flow.mean())
|
249 |
+
# F(t+1) = F(t) + \Delta(t)
|
250 |
+
coords1 = coords1 + delta_flow
|
251 |
+
|
252 |
+
# upsample predictions
|
253 |
+
if up_mask is None:
|
254 |
+
flow_up = upflow8(coords1 - coords0)
|
255 |
+
else:
|
256 |
+
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
257 |
+
flow_predictions.append(flow_up)
|
258 |
+
|
259 |
+
return flow_predictions, feature
|
data_utils/UNFaceFlow/core/update.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from nnutils import make_conv_2d, make_upscale_2d, make_downscale_2d, ResBlock2d, Identity
|
5 |
+
|
6 |
+
class FlowHead(nn.Module):
|
7 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
8 |
+
super(FlowHead, self).__init__()
|
9 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
10 |
+
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
11 |
+
self.relu = nn.ReLU(inplace=True)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
x = self.relu(self.conv1(x))
|
15 |
+
return self.conv2(x), x
|
16 |
+
|
17 |
+
class ConvGRU(nn.Module):
|
18 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
19 |
+
super(ConvGRU, self).__init__()
|
20 |
+
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
21 |
+
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
22 |
+
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
23 |
+
|
24 |
+
def forward(self, h, x):
|
25 |
+
hx = torch.cat([h, x], dim=1)
|
26 |
+
|
27 |
+
z = torch.sigmoid(self.convz(hx))
|
28 |
+
r = torch.sigmoid(self.convr(hx))
|
29 |
+
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
30 |
+
|
31 |
+
h = (1-z) * h + z * q
|
32 |
+
return h
|
33 |
+
|
34 |
+
class SepConvGRU(nn.Module):
|
35 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
36 |
+
super(SepConvGRU, self).__init__()
|
37 |
+
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
38 |
+
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
39 |
+
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
40 |
+
|
41 |
+
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
42 |
+
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
43 |
+
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, h, x):
|
47 |
+
# horizontal
|
48 |
+
hx = torch.cat([h, x], dim=1)
|
49 |
+
z = torch.sigmoid(self.convz1(hx))
|
50 |
+
r = torch.sigmoid(self.convr1(hx))
|
51 |
+
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
52 |
+
h = (1-z) * h + z * q
|
53 |
+
|
54 |
+
# vertical
|
55 |
+
hx = torch.cat([h, x], dim=1)
|
56 |
+
z = torch.sigmoid(self.convz2(hx))
|
57 |
+
r = torch.sigmoid(self.convr2(hx))
|
58 |
+
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
59 |
+
h = (1-z) * h + z * q
|
60 |
+
|
61 |
+
return h
|
62 |
+
|
63 |
+
class SmallMotionEncoder(nn.Module):
|
64 |
+
def __init__(self, args):
|
65 |
+
super(SmallMotionEncoder, self).__init__()
|
66 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
67 |
+
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
68 |
+
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
69 |
+
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
70 |
+
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
71 |
+
|
72 |
+
def forward(self, flow, corr):
|
73 |
+
cor = F.relu(self.convc1(corr))
|
74 |
+
flo = F.relu(self.convf1(flow))
|
75 |
+
flo = F.relu(self.convf2(flo))
|
76 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
77 |
+
out = F.relu(self.conv(cor_flo))
|
78 |
+
return torch.cat([out, flow], dim=1)
|
79 |
+
|
80 |
+
class BasicMotionEncoder(nn.Module):
|
81 |
+
def __init__(self, args):
|
82 |
+
super(BasicMotionEncoder, self).__init__()
|
83 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
84 |
+
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
85 |
+
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
86 |
+
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
87 |
+
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
88 |
+
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
|
89 |
+
|
90 |
+
def forward(self, flow, corr):
|
91 |
+
cor = F.relu(self.convc1(corr))
|
92 |
+
cor = F.relu(self.convc2(cor))
|
93 |
+
flo = F.relu(self.convf1(flow))
|
94 |
+
flo = F.relu(self.convf2(flo))
|
95 |
+
|
96 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
97 |
+
out = F.relu(self.conv(cor_flo))
|
98 |
+
return torch.cat([out, flow], dim=1)
|
99 |
+
|
100 |
+
class SmallUpdateBlock(nn.Module):
|
101 |
+
def __init__(self, args, hidden_dim=96):
|
102 |
+
super(SmallUpdateBlock, self).__init__()
|
103 |
+
self.encoder = SmallMotionEncoder(args)
|
104 |
+
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
105 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
106 |
+
|
107 |
+
def forward(self, net, inp, corr, flow):
|
108 |
+
motion_features = self.encoder(flow, corr)
|
109 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
110 |
+
net = self.gru(net, inp)
|
111 |
+
delta_flow, feature = self.flow_head(net)
|
112 |
+
|
113 |
+
return net, None, delta_flow, feature
|
114 |
+
|
115 |
+
class BasicUpdateBlock(nn.Module):
|
116 |
+
def __init__(self, args, hidden_dim=128, input_dim=128):
|
117 |
+
super(BasicUpdateBlock, self).__init__()
|
118 |
+
self.args = args
|
119 |
+
self.encoder = BasicMotionEncoder(args)
|
120 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
121 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
122 |
+
|
123 |
+
self.mask = nn.Sequential(
|
124 |
+
nn.Conv2d(128, 256, 3, padding=1),
|
125 |
+
nn.ReLU(inplace=True),
|
126 |
+
nn.Conv2d(256, 64*9, 1, padding=0))
|
127 |
+
|
128 |
+
def forward(self, net, inp, corr, flow, upsample=True):
|
129 |
+
motion_features = self.encoder(flow, corr)
|
130 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
131 |
+
|
132 |
+
net = self.gru(net, inp)
|
133 |
+
delta_flow, feature = self.flow_head(net)
|
134 |
+
|
135 |
+
# scale mask to balence gradients
|
136 |
+
mask = .25 * self.mask(net)
|
137 |
+
return net, mask, delta_flow, feature
|
138 |
+
|
139 |
+
class BasicWeightsNet(nn.Module):
|
140 |
+
def __init__(self, opt):
|
141 |
+
super(BasicUpdateBlock, self).__init__()
|
142 |
+
if opt.small:
|
143 |
+
in_dim = 128
|
144 |
+
else:
|
145 |
+
in_dim = 256
|
146 |
+
fn_0 = 16
|
147 |
+
self.input_fn = fn_0 + 2
|
148 |
+
fn_1 = 16
|
149 |
+
self.conv1 = torch.nn.Conv2d(in_channels=in_dim, out_channels=fn_0, kernel_size=3, stride=1, padding=1)
|
150 |
+
if opt.use_batch_norm:
|
151 |
+
custom_batch_norm = torch.nn.BatchNorm2d
|
152 |
+
else:
|
153 |
+
custom_batch_norm = Identity
|
154 |
+
self.model = nn.Sequential(
|
155 |
+
make_conv_2d(self.input_fn, fn_1, n_blocks=1, normalization=custom_batch_norm),
|
156 |
+
ResBlock2d(fn_1, normalization=custom_batch_norm),
|
157 |
+
ResBlock2d(fn_1, normalization=custom_batch_norm),
|
158 |
+
ResBlock2d(fn_1, normalization=custom_batch_norm),
|
159 |
+
nn.Conv2d(fn_1, 1, kernel_size=3, padding=1),
|
160 |
+
torch.nn.Sigmoid()
|
161 |
+
)
|
162 |
+
|
163 |
+
def forward(self, flow, feature):
|
164 |
+
features = self.conv1(features)
|
165 |
+
x = torch.cat([features, flow], 1)
|
166 |
+
return self.model(x)
|
167 |
+
|
168 |
+
|
169 |
+
|
data_utils/UNFaceFlow/core/utils_core/__init__.py
ADDED
File without changes
|
data_utils/UNFaceFlow/core/utils_core/augmentor.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import math
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
cv2.setNumThreads(0)
|
8 |
+
cv2.ocl.setUseOpenCL(False)
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torchvision.transforms import ColorJitter
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class FlowAugmentor:
|
16 |
+
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
|
17 |
+
|
18 |
+
# spatial augmentation params
|
19 |
+
self.crop_size = crop_size
|
20 |
+
self.min_scale = min_scale
|
21 |
+
self.max_scale = max_scale
|
22 |
+
self.spatial_aug_prob = 0.8
|
23 |
+
self.stretch_prob = 0.8
|
24 |
+
self.max_stretch = 0.2
|
25 |
+
|
26 |
+
# flip augmentation params
|
27 |
+
self.do_flip = do_flip
|
28 |
+
self.h_flip_prob = 0.5
|
29 |
+
self.v_flip_prob = 0.1
|
30 |
+
|
31 |
+
# photometric augmentation params
|
32 |
+
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
|
33 |
+
self.asymmetric_color_aug_prob = 0.2
|
34 |
+
self.eraser_aug_prob = 0.5
|
35 |
+
|
36 |
+
def color_transform(self, img1, img2):
|
37 |
+
""" Photometric augmentation """
|
38 |
+
|
39 |
+
# asymmetric
|
40 |
+
if np.random.rand() < self.asymmetric_color_aug_prob:
|
41 |
+
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
|
42 |
+
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
|
43 |
+
|
44 |
+
# symmetric
|
45 |
+
else:
|
46 |
+
image_stack = np.concatenate([img1, img2], axis=0)
|
47 |
+
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
48 |
+
img1, img2 = np.split(image_stack, 2, axis=0)
|
49 |
+
|
50 |
+
return img1, img2
|
51 |
+
|
52 |
+
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
53 |
+
""" Occlusion augmentation """
|
54 |
+
|
55 |
+
ht, wd = img1.shape[:2]
|
56 |
+
if np.random.rand() < self.eraser_aug_prob:
|
57 |
+
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
58 |
+
for _ in range(np.random.randint(1, 3)):
|
59 |
+
x0 = np.random.randint(0, wd)
|
60 |
+
y0 = np.random.randint(0, ht)
|
61 |
+
dx = np.random.randint(bounds[0], bounds[1])
|
62 |
+
dy = np.random.randint(bounds[0], bounds[1])
|
63 |
+
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
64 |
+
|
65 |
+
return img1, img2
|
66 |
+
|
67 |
+
def spatial_transform(self, img1, img2, flow):
|
68 |
+
# randomly sample scale
|
69 |
+
ht, wd = img1.shape[:2]
|
70 |
+
min_scale = np.maximum(
|
71 |
+
(self.crop_size[0] + 8) / float(ht),
|
72 |
+
(self.crop_size[1] + 8) / float(wd))
|
73 |
+
|
74 |
+
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
75 |
+
scale_x = scale
|
76 |
+
scale_y = scale
|
77 |
+
if np.random.rand() < self.stretch_prob:
|
78 |
+
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
79 |
+
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
80 |
+
|
81 |
+
scale_x = np.clip(scale_x, min_scale, None)
|
82 |
+
scale_y = np.clip(scale_y, min_scale, None)
|
83 |
+
|
84 |
+
if np.random.rand() < self.spatial_aug_prob:
|
85 |
+
# rescale the images
|
86 |
+
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
87 |
+
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
88 |
+
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
89 |
+
flow = flow * [scale_x, scale_y]
|
90 |
+
|
91 |
+
if self.do_flip:
|
92 |
+
if np.random.rand() < self.h_flip_prob: # h-flip
|
93 |
+
img1 = img1[:, ::-1]
|
94 |
+
img2 = img2[:, ::-1]
|
95 |
+
flow = flow[:, ::-1] * [-1.0, 1.0]
|
96 |
+
|
97 |
+
if np.random.rand() < self.v_flip_prob: # v-flip
|
98 |
+
img1 = img1[::-1, :]
|
99 |
+
img2 = img2[::-1, :]
|
100 |
+
flow = flow[::-1, :] * [1.0, -1.0]
|
101 |
+
|
102 |
+
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
|
103 |
+
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
|
104 |
+
|
105 |
+
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
106 |
+
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
107 |
+
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
108 |
+
|
109 |
+
return img1, img2, flow
|
110 |
+
|
111 |
+
def __call__(self, img1, img2, flow):
|
112 |
+
img1, img2 = self.color_transform(img1, img2)
|
113 |
+
img1, img2 = self.eraser_transform(img1, img2)
|
114 |
+
img1, img2, flow = self.spatial_transform(img1, img2, flow)
|
115 |
+
|
116 |
+
img1 = np.ascontiguousarray(img1)
|
117 |
+
img2 = np.ascontiguousarray(img2)
|
118 |
+
flow = np.ascontiguousarray(flow)
|
119 |
+
|
120 |
+
return img1, img2, flow
|
121 |
+
|
122 |
+
class SparseFlowAugmentor:
|
123 |
+
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
|
124 |
+
# spatial augmentation params
|
125 |
+
self.crop_size = crop_size
|
126 |
+
self.min_scale = min_scale
|
127 |
+
self.max_scale = max_scale
|
128 |
+
self.spatial_aug_prob = 0.8
|
129 |
+
self.stretch_prob = 0.8
|
130 |
+
self.max_stretch = 0.2
|
131 |
+
|
132 |
+
# flip augmentation params
|
133 |
+
self.do_flip = do_flip
|
134 |
+
self.h_flip_prob = 0.5
|
135 |
+
self.v_flip_prob = 0.1
|
136 |
+
|
137 |
+
# photometric augmentation params
|
138 |
+
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
139 |
+
self.asymmetric_color_aug_prob = 0.2
|
140 |
+
self.eraser_aug_prob = 0.5
|
141 |
+
|
142 |
+
def color_transform(self, img1, img2):
|
143 |
+
image_stack = np.concatenate([img1, img2], axis=0)
|
144 |
+
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
145 |
+
img1, img2 = np.split(image_stack, 2, axis=0)
|
146 |
+
return img1, img2
|
147 |
+
|
148 |
+
def eraser_transform(self, img1, img2):
|
149 |
+
ht, wd = img1.shape[:2]
|
150 |
+
if np.random.rand() < self.eraser_aug_prob:
|
151 |
+
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
152 |
+
for _ in range(np.random.randint(1, 3)):
|
153 |
+
x0 = np.random.randint(0, wd)
|
154 |
+
y0 = np.random.randint(0, ht)
|
155 |
+
dx = np.random.randint(50, 100)
|
156 |
+
dy = np.random.randint(50, 100)
|
157 |
+
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
158 |
+
|
159 |
+
return img1, img2
|
160 |
+
|
161 |
+
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
|
162 |
+
ht, wd = flow.shape[:2]
|
163 |
+
coords = np.meshgrid(np.arange(wd), np.arange(ht))
|
164 |
+
coords = np.stack(coords, axis=-1)
|
165 |
+
|
166 |
+
coords = coords.reshape(-1, 2).astype(np.float32)
|
167 |
+
flow = flow.reshape(-1, 2).astype(np.float32)
|
168 |
+
valid = valid.reshape(-1).astype(np.float32)
|
169 |
+
|
170 |
+
coords0 = coords[valid>=1]
|
171 |
+
flow0 = flow[valid>=1]
|
172 |
+
|
173 |
+
ht1 = int(round(ht * fy))
|
174 |
+
wd1 = int(round(wd * fx))
|
175 |
+
|
176 |
+
coords1 = coords0 * [fx, fy]
|
177 |
+
flow1 = flow0 * [fx, fy]
|
178 |
+
|
179 |
+
xx = np.round(coords1[:,0]).astype(np.int32)
|
180 |
+
yy = np.round(coords1[:,1]).astype(np.int32)
|
181 |
+
|
182 |
+
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
|
183 |
+
xx = xx[v]
|
184 |
+
yy = yy[v]
|
185 |
+
flow1 = flow1[v]
|
186 |
+
|
187 |
+
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
|
188 |
+
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
|
189 |
+
|
190 |
+
flow_img[yy, xx] = flow1
|
191 |
+
valid_img[yy, xx] = 1
|
192 |
+
|
193 |
+
return flow_img, valid_img
|
194 |
+
|
195 |
+
def spatial_transform(self, img1, img2, flow, valid):
|
196 |
+
# randomly sample scale
|
197 |
+
|
198 |
+
ht, wd = img1.shape[:2]
|
199 |
+
min_scale = np.maximum(
|
200 |
+
(self.crop_size[0] + 1) / float(ht),
|
201 |
+
(self.crop_size[1] + 1) / float(wd))
|
202 |
+
|
203 |
+
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
204 |
+
scale_x = np.clip(scale, min_scale, None)
|
205 |
+
scale_y = np.clip(scale, min_scale, None)
|
206 |
+
|
207 |
+
if np.random.rand() < self.spatial_aug_prob:
|
208 |
+
# rescale the images
|
209 |
+
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
210 |
+
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
211 |
+
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
212 |
+
|
213 |
+
if self.do_flip:
|
214 |
+
if np.random.rand() < 0.5: # h-flip
|
215 |
+
img1 = img1[:, ::-1]
|
216 |
+
img2 = img2[:, ::-1]
|
217 |
+
flow = flow[:, ::-1] * [-1.0, 1.0]
|
218 |
+
valid = valid[:, ::-1]
|
219 |
+
|
220 |
+
margin_y = 20
|
221 |
+
margin_x = 50
|
222 |
+
|
223 |
+
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
|
224 |
+
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
|
225 |
+
|
226 |
+
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
|
227 |
+
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
|
228 |
+
|
229 |
+
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
230 |
+
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
231 |
+
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
232 |
+
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
233 |
+
return img1, img2, flow, valid
|
234 |
+
|
235 |
+
|
236 |
+
def __call__(self, img1, img2, flow, valid):
|
237 |
+
img1, img2 = self.color_transform(img1, img2)
|
238 |
+
img1, img2 = self.eraser_transform(img1, img2)
|
239 |
+
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
|
240 |
+
|
241 |
+
img1 = np.ascontiguousarray(img1)
|
242 |
+
img2 = np.ascontiguousarray(img2)
|
243 |
+
flow = np.ascontiguousarray(flow)
|
244 |
+
valid = np.ascontiguousarray(valid)
|
245 |
+
|
246 |
+
return img1, img2, flow, valid
|
data_utils/UNFaceFlow/core/utils_core/flow_viz.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
2 |
+
|
3 |
+
|
4 |
+
# MIT License
|
5 |
+
#
|
6 |
+
# Copyright (c) 2018 Tom Runia
|
7 |
+
#
|
8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
10 |
+
# in the Software without restriction, including without limitation the rights
|
11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
13 |
+
# furnished to do so, subject to conditions.
|
14 |
+
#
|
15 |
+
# Author: Tom Runia
|
16 |
+
# Date Created: 2018-08-03
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
def make_colorwheel():
|
21 |
+
"""
|
22 |
+
Generates a color wheel for optical flow visualization as presented in:
|
23 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
24 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
25 |
+
|
26 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
27 |
+
Code follows the the Matlab source code of Deqing Sun.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
np.ndarray: Color wheel
|
31 |
+
"""
|
32 |
+
|
33 |
+
RY = 15
|
34 |
+
YG = 6
|
35 |
+
GC = 4
|
36 |
+
CB = 11
|
37 |
+
BM = 13
|
38 |
+
MR = 6
|
39 |
+
|
40 |
+
ncols = RY + YG + GC + CB + BM + MR
|
41 |
+
colorwheel = np.zeros((ncols, 3))
|
42 |
+
col = 0
|
43 |
+
|
44 |
+
# RY
|
45 |
+
colorwheel[0:RY, 0] = 255
|
46 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
47 |
+
col = col+RY
|
48 |
+
# YG
|
49 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
50 |
+
colorwheel[col:col+YG, 1] = 255
|
51 |
+
col = col+YG
|
52 |
+
# GC
|
53 |
+
colorwheel[col:col+GC, 1] = 255
|
54 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
55 |
+
col = col+GC
|
56 |
+
# CB
|
57 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
58 |
+
colorwheel[col:col+CB, 2] = 255
|
59 |
+
col = col+CB
|
60 |
+
# BM
|
61 |
+
colorwheel[col:col+BM, 2] = 255
|
62 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
63 |
+
col = col+BM
|
64 |
+
# MR
|
65 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
66 |
+
colorwheel[col:col+MR, 0] = 255
|
67 |
+
return colorwheel
|
68 |
+
|
69 |
+
|
70 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
71 |
+
"""
|
72 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
73 |
+
|
74 |
+
According to the C++ source code of Daniel Scharstein
|
75 |
+
According to the Matlab source code of Deqing Sun
|
76 |
+
|
77 |
+
Args:
|
78 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
79 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
80 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
84 |
+
"""
|
85 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
86 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
87 |
+
ncols = colorwheel.shape[0]
|
88 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
89 |
+
a = np.arctan2(-v, -u)/np.pi
|
90 |
+
fk = (a+1) / 2*(ncols-1)
|
91 |
+
k0 = np.floor(fk).astype(np.int32)
|
92 |
+
k1 = k0 + 1
|
93 |
+
k1[k1 == ncols] = 0
|
94 |
+
f = fk - k0
|
95 |
+
for i in range(colorwheel.shape[1]):
|
96 |
+
tmp = colorwheel[:,i]
|
97 |
+
col0 = tmp[k0] / 255.0
|
98 |
+
col1 = tmp[k1] / 255.0
|
99 |
+
col = (1-f)*col0 + f*col1
|
100 |
+
idx = (rad <= 1)
|
101 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
102 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
103 |
+
# Note the 2-i => BGR instead of RGB
|
104 |
+
ch_idx = 2-i if convert_to_bgr else i
|
105 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
106 |
+
return flow_image
|
107 |
+
|
108 |
+
|
109 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
110 |
+
"""
|
111 |
+
Expects a two dimensional flow image of shape.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
115 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
116 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
120 |
+
"""
|
121 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
122 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
123 |
+
if clip_flow is not None:
|
124 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
125 |
+
u = flow_uv[:,:,0]
|
126 |
+
v = flow_uv[:,:,1]
|
127 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
128 |
+
rad_max = np.max(rad)
|
129 |
+
epsilon = 1e-5
|
130 |
+
u = u / (rad_max + epsilon)
|
131 |
+
v = v / (rad_max + epsilon)
|
132 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
data_utils/UNFaceFlow/core/utils_core/frame_utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
from os.path import *
|
4 |
+
import re
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
cv2.setNumThreads(0)
|
8 |
+
cv2.ocl.setUseOpenCL(False)
|
9 |
+
|
10 |
+
TAG_CHAR = np.array([202021.25], np.float32)
|
11 |
+
|
12 |
+
def readFlow(fn):
|
13 |
+
""" Read .flo file in Middlebury format"""
|
14 |
+
# Code adapted from:
|
15 |
+
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
16 |
+
|
17 |
+
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
|
18 |
+
# print 'fn = %s'%(fn)
|
19 |
+
with open(fn, 'rb') as f:
|
20 |
+
magic = np.fromfile(f, np.float32, count=1)
|
21 |
+
if 202021.25 != magic:
|
22 |
+
print('Magic number incorrect. Invalid .flo file')
|
23 |
+
return None
|
24 |
+
else:
|
25 |
+
w = np.fromfile(f, np.int32, count=1)
|
26 |
+
h = np.fromfile(f, np.int32, count=1)
|
27 |
+
# print 'Reading %d x %d flo file\n' % (w, h)
|
28 |
+
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
|
29 |
+
# Reshape data into 3D array (columns, rows, bands)
|
30 |
+
# The reshape here is for visualization, the original code is (w,h,2)
|
31 |
+
return np.resize(data, (int(h), int(w), 2))
|
32 |
+
|
33 |
+
def readPFM(file):
|
34 |
+
file = open(file, 'rb')
|
35 |
+
|
36 |
+
color = None
|
37 |
+
width = None
|
38 |
+
height = None
|
39 |
+
scale = None
|
40 |
+
endian = None
|
41 |
+
|
42 |
+
header = file.readline().rstrip()
|
43 |
+
if header == b'PF':
|
44 |
+
color = True
|
45 |
+
elif header == b'Pf':
|
46 |
+
color = False
|
47 |
+
else:
|
48 |
+
raise Exception('Not a PFM file.')
|
49 |
+
|
50 |
+
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
51 |
+
if dim_match:
|
52 |
+
width, height = map(int, dim_match.groups())
|
53 |
+
else:
|
54 |
+
raise Exception('Malformed PFM header.')
|
55 |
+
|
56 |
+
scale = float(file.readline().rstrip())
|
57 |
+
if scale < 0: # little-endian
|
58 |
+
endian = '<'
|
59 |
+
scale = -scale
|
60 |
+
else:
|
61 |
+
endian = '>' # big-endian
|
62 |
+
|
63 |
+
data = np.fromfile(file, endian + 'f')
|
64 |
+
shape = (height, width, 3) if color else (height, width)
|
65 |
+
|
66 |
+
data = np.reshape(data, shape)
|
67 |
+
data = np.flipud(data)
|
68 |
+
return data
|
69 |
+
|
70 |
+
def writeFlow(filename,uv,v=None):
|
71 |
+
""" Write optical flow to file.
|
72 |
+
|
73 |
+
If v is None, uv is assumed to contain both u and v channels,
|
74 |
+
stacked in depth.
|
75 |
+
Original code by Deqing Sun, adapted from Daniel Scharstein.
|
76 |
+
"""
|
77 |
+
nBands = 2
|
78 |
+
|
79 |
+
if v is None:
|
80 |
+
assert(uv.ndim == 3)
|
81 |
+
assert(uv.shape[2] == 2)
|
82 |
+
u = uv[:,:,0]
|
83 |
+
v = uv[:,:,1]
|
84 |
+
else:
|
85 |
+
u = uv
|
86 |
+
|
87 |
+
assert(u.shape == v.shape)
|
88 |
+
height,width = u.shape
|
89 |
+
f = open(filename,'wb')
|
90 |
+
# write the header
|
91 |
+
f.write(TAG_CHAR)
|
92 |
+
np.array(width).astype(np.int32).tofile(f)
|
93 |
+
np.array(height).astype(np.int32).tofile(f)
|
94 |
+
# arrange into matrix form
|
95 |
+
tmp = np.zeros((height, width*nBands))
|
96 |
+
tmp[:,np.arange(width)*2] = u
|
97 |
+
tmp[:,np.arange(width)*2 + 1] = v
|
98 |
+
tmp.astype(np.float32).tofile(f)
|
99 |
+
f.close()
|
100 |
+
|
101 |
+
|
102 |
+
def readFlowKITTI(filename):
|
103 |
+
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
|
104 |
+
flow = flow[:,:,::-1].astype(np.float32)
|
105 |
+
flow, valid = flow[:, :, :2], flow[:, :, 2]
|
106 |
+
flow = (flow - 2**15) / 64.0
|
107 |
+
return flow, valid
|
108 |
+
|
109 |
+
def readDispKITTI(filename):
|
110 |
+
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
|
111 |
+
valid = disp > 0.0
|
112 |
+
flow = np.stack([-disp, np.zeros_like(disp)], -1)
|
113 |
+
return flow, valid
|
114 |
+
|
115 |
+
|
116 |
+
def writeFlowKITTI(filename, uv):
|
117 |
+
uv = 64.0 * uv + 2**15
|
118 |
+
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
119 |
+
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
|
120 |
+
cv2.imwrite(filename, uv[..., ::-1])
|
121 |
+
|
122 |
+
|
123 |
+
def read_gen(file_name, pil=False):
|
124 |
+
ext = splitext(file_name)[-1]
|
125 |
+
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
|
126 |
+
return Image.open(file_name)
|
127 |
+
elif ext == '.bin' or ext == '.raw':
|
128 |
+
return np.load(file_name)
|
129 |
+
elif ext == '.flo':
|
130 |
+
return readFlow(file_name).astype(np.float32)
|
131 |
+
elif ext == '.pfm':
|
132 |
+
flow = readPFM(file_name).astype(np.float32)
|
133 |
+
if len(flow.shape) == 2:
|
134 |
+
return flow
|
135 |
+
else:
|
136 |
+
return flow[:, :, :-1]
|
137 |
+
return []
|
data_utils/UNFaceFlow/core/utils_core/utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from scipy import interpolate
|
5 |
+
|
6 |
+
|
7 |
+
class InputPadder:
|
8 |
+
""" Pads images such that dimensions are divisible by 8 """
|
9 |
+
def __init__(self, dims, mode='sintel'):
|
10 |
+
self.ht, self.wd = dims[-2:]
|
11 |
+
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
12 |
+
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
13 |
+
if mode == 'sintel':
|
14 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
15 |
+
else:
|
16 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
17 |
+
|
18 |
+
def pad(self, *inputs):
|
19 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
20 |
+
|
21 |
+
def unpad(self,x):
|
22 |
+
ht, wd = x.shape[-2:]
|
23 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
24 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
25 |
+
|
26 |
+
def forward_interpolate(flow):
|
27 |
+
flow = flow.detach().cpu().numpy()
|
28 |
+
dx, dy = flow[0], flow[1]
|
29 |
+
|
30 |
+
ht, wd = dx.shape
|
31 |
+
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
32 |
+
|
33 |
+
x1 = x0 + dx
|
34 |
+
y1 = y0 + dy
|
35 |
+
|
36 |
+
x1 = x1.reshape(-1)
|
37 |
+
y1 = y1.reshape(-1)
|
38 |
+
dx = dx.reshape(-1)
|
39 |
+
dy = dy.reshape(-1)
|
40 |
+
|
41 |
+
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
42 |
+
x1 = x1[valid]
|
43 |
+
y1 = y1[valid]
|
44 |
+
dx = dx[valid]
|
45 |
+
dy = dy[valid]
|
46 |
+
|
47 |
+
flow_x = interpolate.griddata(
|
48 |
+
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
49 |
+
|
50 |
+
flow_y = interpolate.griddata(
|
51 |
+
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
52 |
+
|
53 |
+
flow = np.stack([flow_x, flow_y], axis=0)
|
54 |
+
return torch.from_numpy(flow).float()
|
55 |
+
|
56 |
+
|
57 |
+
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
58 |
+
""" Wrapper for grid_sample, uses pixel coordinates """
|
59 |
+
H, W = img.shape[-2:]
|
60 |
+
xgrid, ygrid = coords.split([1,1], dim=-1)
|
61 |
+
xgrid = 2*xgrid/(W-1) - 1
|
62 |
+
ygrid = 2*ygrid/(H-1) - 1
|
63 |
+
|
64 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
65 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
66 |
+
|
67 |
+
if mask:
|
68 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
69 |
+
return img, mask.float()
|
70 |
+
|
71 |
+
return img
|
72 |
+
|
73 |
+
|
74 |
+
def coords_grid(batch, ht, wd):
|
75 |
+
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
76 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
77 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
78 |
+
|
79 |
+
|
80 |
+
def upflow8(flow, mode='bilinear'):
|
81 |
+
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
82 |
+
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
83 |
+
|
84 |
+
def upweights8(weights, mode='bilinear'):
|
85 |
+
new_size = (8 * weights.shape[2], 8 * weights.shape[3])
|
86 |
+
return F.interpolate(weights, size=new_size, mode=mode, align_corners=True)
|
data_utils/UNFaceFlow/core/warp_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def mesh_grid(B, H, W):
|
7 |
+
# mesh grid
|
8 |
+
x_base = torch.arange(0, W).repeat(B, H, 1) # BHW
|
9 |
+
y_base = torch.arange(0, H).repeat(B, W, 1).transpose(1, 2) # BHW
|
10 |
+
|
11 |
+
base_grid = torch.stack([x_base, y_base], 1) # B2HW
|
12 |
+
return base_grid
|
13 |
+
|
14 |
+
|
15 |
+
def norm_grid(v_grid):
|
16 |
+
_, _, H, W = v_grid.size()
|
17 |
+
|
18 |
+
# scale grid to [-1,1]
|
19 |
+
v_grid_norm = torch.zeros_like(v_grid)
|
20 |
+
v_grid_norm[:, 0, :, :] = 2.0 * v_grid[:, 0, :, :] / (W - 1) - 1.0
|
21 |
+
v_grid_norm[:, 1, :, :] = 2.0 * v_grid[:, 1, :, :] / (H - 1) - 1.0
|
22 |
+
return v_grid_norm.permute(0, 2, 3, 1) # BHW2
|
23 |
+
|
24 |
+
|
25 |
+
def get_corresponding_map(data):
|
26 |
+
"""
|
27 |
+
|
28 |
+
:param data: unnormalized coordinates Bx2xHxW
|
29 |
+
:return: Bx1xHxW
|
30 |
+
"""
|
31 |
+
B, _, H, W = data.size()
|
32 |
+
|
33 |
+
# x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W)
|
34 |
+
# y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1)
|
35 |
+
|
36 |
+
x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W)
|
37 |
+
y = data[:, 1, :, :].view(B, -1)
|
38 |
+
|
39 |
+
# invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN
|
40 |
+
# invalid = invalid.repeat([1, 4])
|
41 |
+
|
42 |
+
x1 = torch.floor(x)
|
43 |
+
x_floor = x1.clamp(0, W - 1)
|
44 |
+
y1 = torch.floor(y)
|
45 |
+
y_floor = y1.clamp(0, H - 1)
|
46 |
+
x0 = x1 + 1
|
47 |
+
x_ceil = x0.clamp(0, W - 1)
|
48 |
+
y0 = y1 + 1
|
49 |
+
y_ceil = y0.clamp(0, H - 1)
|
50 |
+
|
51 |
+
x_ceil_out = x0 != x_ceil
|
52 |
+
y_ceil_out = y0 != y_ceil
|
53 |
+
x_floor_out = x1 != x_floor
|
54 |
+
y_floor_out = y1 != y_floor
|
55 |
+
invalid = torch.cat([x_ceil_out | y_ceil_out,
|
56 |
+
x_ceil_out | y_floor_out,
|
57 |
+
x_floor_out | y_ceil_out,
|
58 |
+
x_floor_out | y_floor_out], dim=1)
|
59 |
+
|
60 |
+
# encode coordinates, since the scatter function can only index along one axis
|
61 |
+
corresponding_map = torch.zeros(B, H * W).type_as(data)
|
62 |
+
indices = torch.cat([x_ceil + y_ceil * W,
|
63 |
+
x_ceil + y_floor * W,
|
64 |
+
x_floor + y_ceil * W,
|
65 |
+
x_floor + y_floor * W], 1).long() # BxN (N=4*H*W)
|
66 |
+
values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)),
|
67 |
+
(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)),
|
68 |
+
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)),
|
69 |
+
(1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))],
|
70 |
+
1)
|
71 |
+
# values = torch.ones_like(values)
|
72 |
+
|
73 |
+
values[invalid] = 0
|
74 |
+
|
75 |
+
corresponding_map.scatter_add_(1, indices, values)
|
76 |
+
# decode coordinates
|
77 |
+
corresponding_map = corresponding_map.view(B, H, W)
|
78 |
+
|
79 |
+
return corresponding_map.unsqueeze(1)
|
80 |
+
|
81 |
+
|
82 |
+
def flow_warp(x, flow12, pad='border', mode='bilinear'):
|
83 |
+
B, _, H, W = x.size()
|
84 |
+
|
85 |
+
base_grid = mesh_grid(B, H, W).type_as(x) # B2HW
|
86 |
+
|
87 |
+
v_grid = norm_grid(base_grid + flow12) # BHW2
|
88 |
+
im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad)
|
89 |
+
return im1_recons
|
90 |
+
|
91 |
+
|
92 |
+
def get_occu_mask_bidirection(flow12, flow21, mask, scale=1, bias=0.5):
|
93 |
+
flow21_warped = flow_warp(flow21, flow12, pad='zeros')
|
94 |
+
flow12_diff = flow12 + flow21_warped
|
95 |
+
mag = (flow12 * flow12).sum(1, keepdim=True) + \
|
96 |
+
(flow21_warped * flow21_warped).sum(1, keepdim=True)
|
97 |
+
occ_thresh = scale * mag + bias
|
98 |
+
occu = (flow12_diff * flow12_diff).sum(1, keepdim=True) > occ_thresh
|
99 |
+
# soft_occu = 1.0 / (1 + torch.exp(diff) / 5.0)
|
100 |
+
# print("forward:", diff.max(), diff.min())
|
101 |
+
return occu
|
102 |
+
|
103 |
+
|
104 |
+
def get_occu_mask_backward(flow21, th=0.2):
|
105 |
+
B, _, H, W = flow21.size()
|
106 |
+
base_grid = mesh_grid(B, H, W).type_as(flow21) # B2HW
|
107 |
+
|
108 |
+
corr_map = get_corresponding_map(base_grid + flow21) # BHW
|
109 |
+
occu_mask = corr_map.clamp(min=0., max=1.) < th
|
110 |
+
return occu_mask.float()
|
111 |
+
|
112 |
+
def get_ssv_weights(cycle_corres, input, mask, scale_value):
|
113 |
+
vgrid = (cycle_corres.mul(scale_value) - 1.0).permute(0,2,3,1)
|
114 |
+
new_input = nn.functional.grid_sample(input, vgrid, align_corners=True, padding_mode='border')
|
115 |
+
color_diff = (((input[:, :3, :, :] - new_input[:, :3, :, :]) / 255.0) ** 2).sum(1, keepdim=True)
|
116 |
+
depth_diff = (((input[:, 3:, :, :] - new_input[:, 3:, :, :])) ** 2).sum(1, keepdim=True)
|
117 |
+
diff = torch.mul(mask.float(), color_diff + depth_diff) #(N, 1, H, W)
|
118 |
+
return torch.exp(-diff)
|
data_utils/UNFaceFlow/data_test_flow/__init__.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import torch.utils.data
|
3 |
+
from data_test_flow.dd_dataset import DDDataset
|
4 |
+
|
5 |
+
def CreateDataLoader(opt):
|
6 |
+
data_loader = CustomDatasetDataLoader()
|
7 |
+
data_loader.initialize(opt)
|
8 |
+
return data_loader
|
9 |
+
|
10 |
+
# def CreateTestDataLoader(opt):
|
11 |
+
# data_loader = CustomTestDatasetDataLoader()
|
12 |
+
# data_loader.initialize(opt)
|
13 |
+
# return data_loader
|
14 |
+
|
15 |
+
class BaseDataLoader():
|
16 |
+
def __init__(self):
|
17 |
+
pass
|
18 |
+
|
19 |
+
def initialize(self, opt):
|
20 |
+
self.opt = opt
|
21 |
+
pass
|
22 |
+
|
23 |
+
def load_data(self):
|
24 |
+
return None
|
25 |
+
|
26 |
+
class CustomDatasetDataLoader(BaseDataLoader):
|
27 |
+
def name(self):
|
28 |
+
return 'CustomDatasetDataLoader'
|
29 |
+
|
30 |
+
def initialize(self, opt):
|
31 |
+
BaseDataLoader.initialize(self, opt)
|
32 |
+
self.dataset = DDDataset()
|
33 |
+
self.dataset.initialize(opt)
|
34 |
+
'''
|
35 |
+
sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
|
36 |
+
self.dataloader = torch.utils.data.DataLoader(
|
37 |
+
self.dataset,
|
38 |
+
batch_size=opt.batch_size,
|
39 |
+
shuffle=False,
|
40 |
+
sampler=sampler)
|
41 |
+
'''
|
42 |
+
self.dataloader = torch.utils.data.DataLoader(
|
43 |
+
self.dataset,
|
44 |
+
batch_size=opt.batch_size,
|
45 |
+
shuffle=opt.shuffle,
|
46 |
+
drop_last=True,
|
47 |
+
num_workers=int(opt.num_threads))
|
48 |
+
|
49 |
+
def load_data(self):
|
50 |
+
return self
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
54 |
+
|
55 |
+
def __iter__(self):
|
56 |
+
for i, data in enumerate(self.dataloader):
|
57 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
58 |
+
break
|
59 |
+
yield data
|
60 |
+
|
61 |
+
# class CustomTestDatasetDataLoader(BaseDataLoader):
|
62 |
+
# def name(self):
|
63 |
+
# return 'CustomDatasetDataLoader'
|
64 |
+
|
65 |
+
# def initialize(self, opt):
|
66 |
+
# BaseDataLoader.initialize(self, opt)
|
67 |
+
# self.dataset = DDDatasetTest()
|
68 |
+
# self.dataset.initialize(opt)
|
69 |
+
# '''
|
70 |
+
# sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
|
71 |
+
# self.dataloader = torch.utils.data.DataLoader(
|
72 |
+
# self.dataset,
|
73 |
+
# batch_size=opt.batch_size,
|
74 |
+
# shuffle=False,
|
75 |
+
# sampler=sampler)
|
76 |
+
# '''
|
77 |
+
# self.dataloader = torch.utils.data.DataLoader(
|
78 |
+
# self.dataset,
|
79 |
+
# batch_size=opt.batch_size,
|
80 |
+
# shuffle=opt.shuffle,
|
81 |
+
# drop_last=True,
|
82 |
+
# num_workers=int(opt.num_threads))
|
83 |
+
|
84 |
+
# def load_data(self):
|
85 |
+
# return self
|
86 |
+
|
87 |
+
# def __len__(self):
|
88 |
+
# return min(len(self.dataset), self.opt.max_dataset_size)
|
89 |
+
|
90 |
+
# def __iter__(self):
|
91 |
+
# for i, data in enumerate(self.dataloader):
|
92 |
+
# if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
93 |
+
# break
|
94 |
+
# yield data
|
data_utils/UNFaceFlow/data_test_flow/base_dataset.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
from PIL import Image
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
|
5 |
+
|
6 |
+
class BaseDataset(data.Dataset):
|
7 |
+
def __init__(self):
|
8 |
+
super(BaseDataset, self).__init__()
|
9 |
+
|
10 |
+
def name(self):
|
11 |
+
return 'BaseDataset'
|
12 |
+
|
13 |
+
def initialize(self, opt):
|
14 |
+
pass
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return 0
|
18 |
+
|
19 |
+
|
20 |
+
def get_transform(opt):
|
21 |
+
transform_list = []
|
22 |
+
if opt.resize_or_crop == 'resize_and_crop':
|
23 |
+
osize = [opt.loadSize, opt.loadSize]
|
24 |
+
transform_list.append(transforms.Resize(osize, Image.BICUBIC))
|
25 |
+
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
26 |
+
elif opt.resize_or_crop == 'crop':
|
27 |
+
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
28 |
+
elif opt.resize_or_crop == 'scale_width':
|
29 |
+
transform_list.append(transforms.Lambda(
|
30 |
+
lambda img: __scale_width(img, opt.fineSize)))
|
31 |
+
elif opt.resize_or_crop == 'scale_width_and_crop':
|
32 |
+
transform_list.append(transforms.Lambda(
|
33 |
+
lambda img: __scale_width(img, opt.loadSize)))
|
34 |
+
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
35 |
+
elif opt.resize_or_crop == 'none':
|
36 |
+
transform_list.append(transforms.Lambda(
|
37 |
+
lambda img: __adjust(img)))
|
38 |
+
else:
|
39 |
+
raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)
|
40 |
+
|
41 |
+
if opt.isTrain and not opt.no_flip:
|
42 |
+
transform_list.append(transforms.RandomHorizontalFlip())
|
43 |
+
|
44 |
+
transform_list += [transforms.ToTensor(),
|
45 |
+
transforms.Normalize((0.5, 0.5, 0.5),
|
46 |
+
(0.5, 0.5, 0.5))]
|
47 |
+
return transforms.Compose(transform_list)
|
48 |
+
|
49 |
+
|
50 |
+
# just modify the width and height to be multiple of 4
|
51 |
+
def __adjust(img):
|
52 |
+
ow, oh = img.size
|
53 |
+
|
54 |
+
# the size needs to be a multiple of this number,
|
55 |
+
# because going through generator network may change img size
|
56 |
+
# and eventually cause size mismatch error
|
57 |
+
mult = 4
|
58 |
+
if ow % mult == 0 and oh % mult == 0:
|
59 |
+
return img
|
60 |
+
w = (ow - 1) // mult
|
61 |
+
w = (w + 1) * mult
|
62 |
+
h = (oh - 1) // mult
|
63 |
+
h = (h + 1) * mult
|
64 |
+
|
65 |
+
if ow != w or oh != h:
|
66 |
+
__print_size_warning(ow, oh, w, h)
|
67 |
+
|
68 |
+
return img.resize((w, h), Image.BICUBIC)
|
69 |
+
|
70 |
+
|
71 |
+
def __scale_width(img, target_width):
|
72 |
+
ow, oh = img.size
|
73 |
+
|
74 |
+
# the size needs to be a multiple of this number,
|
75 |
+
# because going through generator network may change img size
|
76 |
+
# and eventually cause size mismatch error
|
77 |
+
mult = 4
|
78 |
+
assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
|
79 |
+
if (ow == target_width and oh % mult == 0):
|
80 |
+
return img
|
81 |
+
w = target_width
|
82 |
+
target_height = int(target_width * oh / ow)
|
83 |
+
m = (target_height - 1) // mult
|
84 |
+
h = (m + 1) * mult
|
85 |
+
|
86 |
+
if target_height != h:
|
87 |
+
__print_size_warning(target_width, target_height, w, h)
|
88 |
+
|
89 |
+
return img.resize((w, h), Image.BICUBIC)
|
90 |
+
|
91 |
+
|
92 |
+
def __print_size_warning(ow, oh, w, h):
|
93 |
+
if not hasattr(__print_size_warning, 'has_printed'):
|
94 |
+
print("The image size needs to be a multiple of 4. "
|
95 |
+
"The loaded image size was (%d, %d), so it was adjusted to "
|
96 |
+
"(%d, %d). This adjustment will be done to all images "
|
97 |
+
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
98 |
+
__print_size_warning.has_printed = True
|
data_utils/UNFaceFlow/data_test_flow/dd_dataset.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import torch
|
3 |
+
import torch.utils.data as data
|
4 |
+
from PIL import Image
|
5 |
+
import random
|
6 |
+
import utils
|
7 |
+
import numpy as np
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from utils_core import flow_viz
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
class DDDataset(data.Dataset):
|
13 |
+
def __init__(self):
|
14 |
+
super(DDDataset, self).__init__()
|
15 |
+
def initialize(self, opt):
|
16 |
+
self.opt = opt
|
17 |
+
self.dir_txt = opt.datapath
|
18 |
+
self.paths = []
|
19 |
+
in_file = open(self.dir_txt, "r")
|
20 |
+
k = 0
|
21 |
+
list_paths = in_file.readlines()
|
22 |
+
for line in list_paths:
|
23 |
+
#if k>=20: break
|
24 |
+
flag = False
|
25 |
+
line = line.strip()
|
26 |
+
line = line.split()
|
27 |
+
|
28 |
+
#source data
|
29 |
+
if (not os.path.exists(line[0])):
|
30 |
+
print(line[0]+" not exists")
|
31 |
+
continue
|
32 |
+
if (not os.path.exists(line[1])):
|
33 |
+
print(line[1]+" not exists")
|
34 |
+
continue
|
35 |
+
if (not os.path.exists(line[2])):
|
36 |
+
print(line[2]+" not exists")
|
37 |
+
continue
|
38 |
+
if (not os.path.exists(line[3])):
|
39 |
+
print(line[3]+" not exists")
|
40 |
+
continue
|
41 |
+
# if (not os.path.exists(line[2])):
|
42 |
+
# print(line[2]+" not exists")
|
43 |
+
# continue
|
44 |
+
|
45 |
+
# path_list = [line[0], line[1], line[2]]
|
46 |
+
path_list = [line[0], line[1], line[2], line[3]]
|
47 |
+
self.paths.append(path_list)
|
48 |
+
k += 1
|
49 |
+
in_file.close()
|
50 |
+
self.data_size = len(self.paths)
|
51 |
+
print("num data: ", len(self.paths))
|
52 |
+
|
53 |
+
def process_data(self, color, mask):
|
54 |
+
non_zero = mask.nonzero()
|
55 |
+
bound = 10
|
56 |
+
min_x = max(0, non_zero[1].min()-bound)
|
57 |
+
max_x = min(self.opt.width-1, non_zero[1].max()+bound)
|
58 |
+
min_y = max(0, non_zero[0].min()-bound)
|
59 |
+
max_y = min(self.opt.height-1, non_zero[0].max()+bound)
|
60 |
+
color = color * (mask!=0).astype(float)[:, :, None]
|
61 |
+
crop_color = color[min_y:max_y, min_x:max_x, :]
|
62 |
+
crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR)
|
63 |
+
crop_params = [[min_x], [max_x], [min_y], [max_y]]
|
64 |
+
|
65 |
+
return crop_color, crop_params
|
66 |
+
|
67 |
+
def __getitem__(self, index):
|
68 |
+
paths = self.paths[index % self.data_size]
|
69 |
+
src_color = np.array(Image.open(paths[0]))
|
70 |
+
src_color = src_color.astype(np.uint8)
|
71 |
+
raw_src_color = src_color.copy()
|
72 |
+
src_mask = np.array(Image.open(paths[1]))[:, :, 0]
|
73 |
+
cv2.imwrite("test_mask.png", src_mask)
|
74 |
+
src_mask_copy = src_mask.copy()
|
75 |
+
src_crop_color, src_crop_params = self.process_data(src_color, src_mask)
|
76 |
+
#self.write_mesh(src_X, src_Y, src_Z, "./tmp/src.obj")
|
77 |
+
#HWC --> CHW,
|
78 |
+
raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0
|
79 |
+
src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0
|
80 |
+
|
81 |
+
src_mask_copy = (src_mask_copy!=0)
|
82 |
+
src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :])
|
83 |
+
|
84 |
+
tar_color = np.array(Image.open(paths[2]))
|
85 |
+
tar_color = tar_color.astype(np.uint8)
|
86 |
+
raw_tar_color = tar_color.copy()
|
87 |
+
tar_mask = np.array(Image.open(paths[3]))[:, :, 0]
|
88 |
+
tar_mask_copy = tar_mask.copy()
|
89 |
+
tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask)
|
90 |
+
|
91 |
+
raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0
|
92 |
+
tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0
|
93 |
+
|
94 |
+
tar_mask_copy = (tar_mask_copy!=0)
|
95 |
+
tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :])
|
96 |
+
|
97 |
+
Crop_param = torch.tensor(src_crop_params+tar_crop_params)
|
98 |
+
|
99 |
+
split_ = paths[0].split("/")
|
100 |
+
path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow"
|
101 |
+
|
102 |
+
return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param}
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
return self.data_size
|
106 |
+
|
107 |
+
def name(self):
|
108 |
+
return 'DDDataset'
|
data_utils/UNFaceFlow/data_test_flow/dd_dataset_bak.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import torch
|
3 |
+
import torch.utils.data as data
|
4 |
+
from PIL import Image
|
5 |
+
import random
|
6 |
+
import utils
|
7 |
+
import numpy as np
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from utils_core import flow_viz
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
class DDDataset(data.Dataset):
|
13 |
+
def __init__(self):
|
14 |
+
super(DDDataset, self).__init__()
|
15 |
+
def initialize(self, opt):
|
16 |
+
self.opt = opt
|
17 |
+
self.dir_txt = opt.datapath
|
18 |
+
self.paths = []
|
19 |
+
in_file = open(self.dir_txt, "r")
|
20 |
+
k = 0
|
21 |
+
list_paths = in_file.readlines()
|
22 |
+
for line in list_paths:
|
23 |
+
#if k>=20: break
|
24 |
+
flag = False
|
25 |
+
line = line.strip()
|
26 |
+
line = line.split()
|
27 |
+
|
28 |
+
#source data
|
29 |
+
if (not os.path.exists(line[0])):
|
30 |
+
print(line[0]+" not exists")
|
31 |
+
continue
|
32 |
+
if (not os.path.exists(line[1])):
|
33 |
+
print(line[1]+" not exists")
|
34 |
+
continue
|
35 |
+
if (not os.path.exists(line[2])):
|
36 |
+
print(line[2]+" not exists")
|
37 |
+
continue
|
38 |
+
if (not os.path.exists(line[3])):
|
39 |
+
print(line[3]+" not exists")
|
40 |
+
continue
|
41 |
+
# if (not os.path.exists(line[2])):
|
42 |
+
# print(line[2]+" not exists")
|
43 |
+
# continue
|
44 |
+
|
45 |
+
# path_list = [line[0], line[1], line[2]]
|
46 |
+
path_list = [line[0], line[1], line[2], line[3]]
|
47 |
+
self.paths.append(path_list)
|
48 |
+
k += 1
|
49 |
+
in_file.close()
|
50 |
+
self.data_size = len(self.paths)
|
51 |
+
print("num data: ", len(self.paths))
|
52 |
+
|
53 |
+
def process_data(self, color, mask):
|
54 |
+
non_zero = mask.nonzero()
|
55 |
+
bound = 10
|
56 |
+
min_x = max(0, non_zero[1].min()-bound)
|
57 |
+
max_x = min(self.opt.width-1, non_zero[1].max()+bound)
|
58 |
+
min_y = max(0, non_zero[0].min()-bound)
|
59 |
+
max_y = min(self.opt.height-1, non_zero[0].max()+bound)
|
60 |
+
color = color * (mask!=0).astype(float)[:, :, None]
|
61 |
+
crop_color = color[min_y:max_y, min_x:max_x, :]
|
62 |
+
crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR)
|
63 |
+
crop_params = [[min_x], [max_x], [min_y], [max_y]]
|
64 |
+
|
65 |
+
return crop_color, crop_params
|
66 |
+
|
67 |
+
def __getitem__(self, index):
|
68 |
+
paths = self.paths[index % self.data_size]
|
69 |
+
src_color = np.array(Image.open(paths[0]))
|
70 |
+
src_color = src_color.astype(np.uint8)
|
71 |
+
raw_src_color = src_color.copy()
|
72 |
+
src_mask = np.array(Image.open(paths[1]))
|
73 |
+
src_mask_copy = src_mask.copy()
|
74 |
+
src_crop_color, src_crop_params = self.process_data(src_color, src_mask)
|
75 |
+
#self.write_mesh(src_X, src_Y, src_Z, "./tmp/src.obj")
|
76 |
+
#HWC --> CHW,
|
77 |
+
raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0
|
78 |
+
src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0
|
79 |
+
|
80 |
+
src_mask_copy = (src_mask_copy!=0)
|
81 |
+
src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :])
|
82 |
+
|
83 |
+
tar_color = np.array(Image.open(paths[2]))
|
84 |
+
tar_color = tar_color.astype(np.uint8)
|
85 |
+
raw_tar_color = tar_color.copy()
|
86 |
+
tar_mask = np.array(Image.open(paths[3]))
|
87 |
+
tar_mask_copy = tar_mask.copy()
|
88 |
+
tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask)
|
89 |
+
|
90 |
+
raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0
|
91 |
+
tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0
|
92 |
+
|
93 |
+
tar_mask_copy = (tar_mask_copy!=0)
|
94 |
+
tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :])
|
95 |
+
|
96 |
+
Crop_param = torch.tensor(src_crop_params+tar_crop_params)
|
97 |
+
|
98 |
+
split_ = paths[0].split("/")
|
99 |
+
path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow"
|
100 |
+
|
101 |
+
return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param}
|
102 |
+
|
103 |
+
def __len__(self):
|
104 |
+
return self.data_size
|
105 |
+
|
106 |
+
def name(self):
|
107 |
+
return 'DDDataset'
|
data_utils/UNFaceFlow/models/network_test_flow.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from raft import RAFT
|
5 |
+
from nnutils import make_conv_2d, make_upscale_2d, make_downscale_2d, ResBlock2d, Identity
|
6 |
+
|
7 |
+
|
8 |
+
class ImportanceWeights(torch.nn.Module):
|
9 |
+
def __init__(self, opt):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
if opt.small:
|
13 |
+
in_dim = 128
|
14 |
+
else:
|
15 |
+
in_dim = 256
|
16 |
+
fn_0 = 16
|
17 |
+
self.input_fn = fn_0 + 3 * 2
|
18 |
+
fn_1 = 16
|
19 |
+
self.conv1 = torch.nn.Conv2d(in_channels=in_dim, out_channels=fn_0, kernel_size=3, stride=1, padding=1)
|
20 |
+
|
21 |
+
if opt.use_batch_norm:
|
22 |
+
custom_batch_norm = torch.nn.BatchNorm2d
|
23 |
+
else:
|
24 |
+
custom_batch_norm = Identity
|
25 |
+
|
26 |
+
self.model = nn.Sequential(
|
27 |
+
make_conv_2d(self.input_fn, fn_1, n_blocks=1, normalization=custom_batch_norm),
|
28 |
+
ResBlock2d(fn_1, normalization=custom_batch_norm),
|
29 |
+
ResBlock2d(fn_1, normalization=custom_batch_norm),
|
30 |
+
ResBlock2d(fn_1, normalization=custom_batch_norm),
|
31 |
+
nn.Conv2d(fn_1, 1, kernel_size=3, padding=1)
|
32 |
+
# torch.nn.Sigmoid()
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, x, features):
|
36 |
+
# Reduce number of channels and upscale to highest resolution
|
37 |
+
features = self.conv1(features)
|
38 |
+
x = torch.cat([features, x], 1)
|
39 |
+
assert x.shape[1] == self.input_fn
|
40 |
+
x = self.model(x)
|
41 |
+
print(x)
|
42 |
+
print(x.max(), x.min(), x.mean())
|
43 |
+
|
44 |
+
return torch.nn.Sigmoid()(x)
|
45 |
+
|
46 |
+
class NeuralNRT(nn.Module):
|
47 |
+
def __init__(self, opt, path=None, device="cuda:0"):
|
48 |
+
super(NeuralNRT, self).__init__()
|
49 |
+
self.opt = opt
|
50 |
+
self.CorresPred = RAFT(opt)
|
51 |
+
self.ImportanceW = ImportanceWeights(opt)
|
52 |
+
if path is not None:
|
53 |
+
data = torch.load(path,map_location='cpu')
|
54 |
+
if 'state_dict' in data.keys():
|
55 |
+
self.CorresPred.load_state_dict(data['state_dict'])
|
56 |
+
print("load done")
|
57 |
+
else:
|
58 |
+
self.CorresPred.load_state_dict({k.replace('module.', ''):v for k,v in data.items()})
|
59 |
+
print("load done")
|
60 |
+
def forward(self, src_im,tar_im, src_im_raw, tar_im_raw, Crop_param):
|
61 |
+
N=src_im.shape[0]
|
62 |
+
src_im = src_im*255.0
|
63 |
+
tar_im = tar_im*255.0
|
64 |
+
flow_fw_crop, feature_fw_crop = self.CorresPred(src_im, tar_im, iters=self.opt.iters)
|
65 |
+
|
66 |
+
xx = torch.arange(0, self.opt.width).view(1,-1).repeat(self.opt.height,1)
|
67 |
+
yy = torch.arange(0, self.opt.height).view(-1,1).repeat(1,self.opt.width)
|
68 |
+
xx = xx.view(1,1,self.opt.height,self.opt.width).repeat(N,1,1,1)
|
69 |
+
yy = yy.view(1,1,self.opt.height,self.opt.width).repeat(N,1,1,1)
|
70 |
+
grid = torch.cat((xx,yy),1).float()
|
71 |
+
grid = grid.to(src_im.device)
|
72 |
+
|
73 |
+
grid_crop = grid[:, :, :self.opt.crop_height, :self.opt.crop_width]
|
74 |
+
|
75 |
+
flow_fw = torch.zeros((N, 2, self.opt.height, self.opt.width), device=src_im.device)
|
76 |
+
|
77 |
+
leftup1 = torch.cat((Crop_param[:, 0:1, 0], Crop_param[:, 2:3, 0]), 1)[:, :, None, None]
|
78 |
+
leftup2 = torch.cat((Crop_param[:, 4:5, 0], Crop_param[:, 6:7, 0]), 1)[:, :, None, None]
|
79 |
+
|
80 |
+
scale1 = torch.cat(((Crop_param[:, 1:2, 0]-Crop_param[:, 0:1, 0]).float() / self.opt.crop_width, (Crop_param[:, 3:4, 0]-Crop_param[:, 2:3, 0]).float() / self.opt.crop_height), 1)[:, :, None, None]
|
81 |
+
scale2 = torch.cat(((Crop_param[:, 5:6, 0]-Crop_param[:, 4:5, 0]).float() / self.opt.crop_width, (Crop_param[:, 7:8, 0]-Crop_param[:, 6:7, 0]).float() / self.opt.crop_height), 1)[:, :, None, None]
|
82 |
+
|
83 |
+
flow_fw_crop = (scale2 - scale1) * grid_crop + scale2 * flow_fw_crop
|
84 |
+
for i in range(N):
|
85 |
+
flow_fw_cropi = F.interpolate(flow_fw_crop[i:(i+1)], ((Crop_param[i, 3, 0]-Crop_param[i, 2, 0]).item(), (Crop_param[i, 1, 0]-Crop_param[i, 0, 0]).item()), mode='bilinear', align_corners=True)
|
86 |
+
flow_fw_cropi =flow_fw_cropi + (leftup2 - leftup1)[i:(i+1), :, :, :]
|
87 |
+
flow_fw[i, :, Crop_param[i, 2, 0]:Crop_param[i, 3, 0], Crop_param[i, 0, 0]:Crop_param[i, 1, 0]] = flow_fw_cropi[0]
|
88 |
+
return flow_fw
|
data_utils/UNFaceFlow/options_test_flow.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ref:https://github.com/ShunyuYao/DFA-NeRF
|
2 |
+
import argparse
|
3 |
+
class BaseOptions():
|
4 |
+
def __init__(self):
|
5 |
+
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
6 |
+
self.initialized = False
|
7 |
+
|
8 |
+
def initialize(self):
|
9 |
+
# self.parser.add_argument('--model_save_path', type=str, default='snapshot/small_filter_wo_ct_wi_bn/real_data/combine/', help='path')
|
10 |
+
self.parser.add_argument('--model_save_path', type=str, default='snapshot/version1/', help='path')
|
11 |
+
self.parser.add_argument('--num_threads', type=int, default=2, help='number of threads')
|
12 |
+
self.parser.add_argument('--max_dataset_size', type=int, default=150000, help='max dataset size')
|
13 |
+
|
14 |
+
self.parser.add_argument('--n_epochs', type=int, default=40000, help='number of iterations')
|
15 |
+
self.parser.add_argument('--dropout', type=float, default=0.0, help='dropout')
|
16 |
+
self.parser.add_argument('--init_type', type=str, default='uniform', help='[uniform | xavier]')
|
17 |
+
self.parser.add_argument('--frequency_print_batch', type=int, default=1000, help='print messages every set iter')
|
18 |
+
self.parser.add_argument('--frequency_save_model', type=int, default=2000, help='save model every set iter')
|
19 |
+
self.parser.add_argument('--small', type=bool, default=True, help='use small model')
|
20 |
+
self.parser.add_argument('--use_batch_norm', action='store_true', help='')
|
21 |
+
self.parser.add_argument('--smooth_2nd', type=bool, default=True, help='')
|
22 |
+
|
23 |
+
|
24 |
+
#loss weight for Gauss-Newton optimization
|
25 |
+
self.parser.add_argument('--lambda_2d', type=float, default=0.001, help='weight of 2D projection loss')
|
26 |
+
self.parser.add_argument('--lambda_depth', type=float, default=1.0, help='weight of depth loss')
|
27 |
+
self.parser.add_argument('--lambda_reg', type=float, default=1.0, help='weight of regularization loss')
|
28 |
+
|
29 |
+
self.parser.add_argument('--num_adja', type=int, default=6, help='number of nodes who affect a point')
|
30 |
+
self.parser.add_argument('--num_corres', type=int, default=20000, help='number of corres')
|
31 |
+
self.parser.add_argument('--iter_num', type=int, default=3, help='GN iter num')
|
32 |
+
self.parser.add_argument('--width', type=int, default=512, help='image width')#480
|
33 |
+
self.parser.add_argument('--height', type=int, default=512, help='image height')#640
|
34 |
+
self.parser.add_argument('--crop_width', type=int, default=240, help='image width')
|
35 |
+
self.parser.add_argument('--crop_height', type=int, default=320, help='image height')
|
36 |
+
self.parser.add_argument('--max_num_edges', type=int, default=30000, help='number of edges')
|
37 |
+
self.parser.add_argument('--max_num_nodes', type=int, default=1500, help='number of edges')
|
38 |
+
self.parser.add_argument('--fdim', type=int, default=128)
|
39 |
+
|
40 |
+
#loss weight for training
|
41 |
+
self.parser.add_argument('--lambda_weights', type=float, default=0.0, help='weight of weights loss')#75
|
42 |
+
self.parser.add_argument('--lambda_corres', type=float, default=1.0, help='weight of corres loss')#0, 1
|
43 |
+
self.parser.add_argument('--lambda_graph', type=float, default=10.0, help='weight of graph loss')#1000, 5
|
44 |
+
self.parser.add_argument('--lambda_warp', type=float, default=10.0, help='weight of warp loss')#1000, 5
|
45 |
+
|
46 |
+
|
47 |
+
def parse(self):
|
48 |
+
if not self.initialized:
|
49 |
+
self.initialize()
|
50 |
+
|
51 |
+
self.opt = self.parser.parse_args()
|
52 |
+
self.opt.isTrain = self.isTrain
|
53 |
+
self.opt.isTest = self.isTest
|
54 |
+
args = vars(self.opt)
|
55 |
+
|
56 |
+
return self.opt
|
57 |
+
|
58 |
+
class TrainOptions(BaseOptions):
|
59 |
+
# Override
|
60 |
+
def initialize(self):
|
61 |
+
BaseOptions.initialize(self)
|
62 |
+
#syn_datasets/syn_new_train_data.txt
|
63 |
+
self.parser.add_argument('--datapath', type=str, default='./data/train_data.txt', help='path')
|
64 |
+
self.parser.add_argument('--pretrain_model_path', type=str, default='./pretrain_model/raft-small.pth', help='path')#
|
65 |
+
self.parser.add_argument('--lr_C', type=float, default=0.00001, help='initial learning rate')#0.01
|
66 |
+
self.parser.add_argument('--optimizer_C', type=str, default='sgd', help='[sgd | adam]')
|
67 |
+
self.parser.add_argument('--lr_W', type=float, default=0.00001, help='initial learning rate')
|
68 |
+
self.parser.add_argument('--lr_BSW', type=float, default=0.00001, help='initial learning rate')
|
69 |
+
self.parser.add_argument('--optimizer_W', type=str, default='sgd', help='[sgd | adam]')
|
70 |
+
self.parser.add_argument('--optimizer_BSW', type=str, default='sgd', help='[sgd | adam]')
|
71 |
+
self.parser.add_argument('--lr_decay_epoch', type=int, default=8000, help='multiply by a gamma every set iter')
|
72 |
+
self.parser.add_argument('--lr_decay', type=float, default=0.1, help='coefficient of lr decay')
|
73 |
+
self.parser.add_argument('--weight_decay', type=float, default=1e-4, help='0.0005coefficient of weight decay')
|
74 |
+
self.parser.add_argument('--batch_size', type=int, default=4, help='batch size')
|
75 |
+
self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
|
76 |
+
|
77 |
+
self.parser.add_argument('--validation', type=str, nargs='+')
|
78 |
+
#self.parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
|
79 |
+
self.parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
|
80 |
+
self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
81 |
+
self.parser.add_argument('--iters', type=int, default=12)
|
82 |
+
|
83 |
+
self.parser.add_argument('--clip', type=float, default=1.0)
|
84 |
+
self.parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
|
85 |
+
self.parser.add_argument('--add_noise', action='store_true')
|
86 |
+
|
87 |
+
self.parser.add_argument('--train_bsw', type=bool, default=True, help='whether to train bsw network')
|
88 |
+
self.parser.add_argument('--train_weight', type=bool, default=True, help='whether to train weight network')
|
89 |
+
self.parser.add_argument('--train_corres', type=bool, default=True, help='whether to train corresPred network')
|
90 |
+
|
91 |
+
self.isTrain = True
|
92 |
+
self.isTest = False
|
93 |
+
|
94 |
+
class ValOptions(BaseOptions):
|
95 |
+
def initialize(self):
|
96 |
+
BaseOptions.initialize(self)
|
97 |
+
self.parser.add_argument('--batch_size', type=int, default=4, help='batch size')
|
98 |
+
self.parser.add_argument('--datapath', type=str, default='./data/val_data.txt', help='path')
|
99 |
+
self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
|
100 |
+
self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
101 |
+
self.parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
102 |
+
self.parser.add_argument('--iters', type=int, default=12)
|
103 |
+
self.isTrain = True
|
104 |
+
self.isTest = False
|
105 |
+
|
106 |
+
class TestOptions(BaseOptions):
|
107 |
+
def initialize(self):
|
108 |
+
BaseOptions.initialize(self)
|
109 |
+
self.parser.add_argument('--batch_size', type=int, default=1, help='batch size')
|
110 |
+
self.parser.add_argument('--pretrain_model_path', type=str, default='./pretrain_model/raft-small.pth', help='path')#
|
111 |
+
|
112 |
+
# self.parser.add_argument('--datapath', type=str, default='./data/real_train_data_1128_1.txt', help='path')
|
113 |
+
# self.parser.add_argument('--datapath', type=str, default='./data_test_flow/test_data.txt', help='path')
|
114 |
+
self.parser.add_argument('--savepath', type=str, default='flow_result',
|
115 |
+
help='save path')
|
116 |
+
self.parser.add_argument('--datapath', type=str, default='/data_b/yudong/paper_code/TalkingHead-NeRF/data_guancha/guancha_flow.txt',
|
117 |
+
help='path')
|
118 |
+
self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
119 |
+
self.parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
120 |
+
self.parser.add_argument('--iters', type=int, default=12)
|
121 |
+
self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
|
122 |
+
self.isTrain = False
|
123 |
+
self.isTest = True
|
data_utils/UNFaceFlow/pretrain_model/raft-small.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7d41b9cc88442bb8aa911dbb33086dac55a226394b142937ff22d5578717332
|
3 |
+
size 3984814
|
data_utils/UNFaceFlow/sgd_NNRT_model_epoch19008_50000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8156ef276732a4cbd9a9d85d9b5653abf500976372daf7892b122971a7b8f37
|
3 |
+
size 8808087
|
data_utils/UNFaceFlow/test_flow.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ref:https://github.com/ShunyuYao/DFA-NeRF
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
from tqdm import tqdm
|
5 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
6 |
+
sys.path.append(os.path.join(dir_path, 'core'))
|
7 |
+
from pathlib import Path
|
8 |
+
from data_test_flow import *
|
9 |
+
from models.network_test_flow import NeuralNRT
|
10 |
+
from options_test_flow import TestOptions
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def save_flow_numpy(filename, flow_input):
|
17 |
+
np.save(filename, flow_input)
|
18 |
+
|
19 |
+
|
20 |
+
def predict(data):
|
21 |
+
with torch.no_grad():
|
22 |
+
model.eval()
|
23 |
+
path_flow = data["path_flow"]
|
24 |
+
src_crop_im = data["src_crop_color"].cuda()
|
25 |
+
tar_crop_im = data["tar_crop_color"].cuda()
|
26 |
+
src_im = data["src_color"].cuda()
|
27 |
+
tar_im = data["tar_color"].cuda()
|
28 |
+
src_mask = data["src_mask"].cuda()
|
29 |
+
crop_param = data["Crop_param"].cuda()
|
30 |
+
B = src_mask.shape[0]
|
31 |
+
flow = model(src_crop_im, tar_crop_im, src_im, tar_im, crop_param)
|
32 |
+
for i in range(B):
|
33 |
+
flow_tmp = flow[i].cpu().numpy() * src_mask[i].cpu().numpy()
|
34 |
+
save_flow_numpy(os.path.join(save_path, os.path.basename(
|
35 |
+
path_flow[i])[:-6]+".npy"), flow_tmp)
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
width = 272
|
40 |
+
height = 480
|
41 |
+
|
42 |
+
test_opts = TestOptions().parse()
|
43 |
+
test_opts.pretrain_model_path = os.path.join(
|
44 |
+
dir_path, 'pretrain_model/raft-small.pth')
|
45 |
+
data_loader = CreateDataLoader(test_opts)
|
46 |
+
testloader = data_loader.load_data()
|
47 |
+
model_path = os.path.join(dir_path, 'sgd_NNRT_model_epoch19008_50000.pth')
|
48 |
+
model = NeuralNRT(test_opts, os.path.join(
|
49 |
+
dir_path, 'pretrain_model/raft-small.pth'))
|
50 |
+
state_dict = torch.load(model_path)
|
51 |
+
|
52 |
+
model.CorresPred.load_state_dict(state_dict["net_C"])
|
53 |
+
model.ImportanceW.load_state_dict(state_dict["net_W"])
|
54 |
+
|
55 |
+
model = model.cuda()
|
56 |
+
|
57 |
+
save_path = test_opts.savepath
|
58 |
+
Path(save_path).mkdir(parents=True, exist_ok=True)
|
59 |
+
total_length = len(testloader)
|
60 |
+
|
61 |
+
for batch_idx, data in tqdm(enumerate(testloader), total=total_length):
|
62 |
+
predict(data)
|
data_utils/UNFaceFlow/utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import numpy as np
|
4 |
+
import struct
|
5 |
+
import pickle
|
6 |
+
from scipy.sparse import coo_matrix
|
7 |
+
|
8 |
+
def load_flow(filename):
|
9 |
+
# Flow is stored row-wise in order [channels, height, width].
|
10 |
+
assert os.path.isfile(filename), "File not found: {}".format(filename)
|
11 |
+
|
12 |
+
flow = None
|
13 |
+
with open(filename, 'rb') as fin:
|
14 |
+
width = struct.unpack('I', fin.read(4))[0]
|
15 |
+
height = struct.unpack('I', fin.read(4))[0]
|
16 |
+
channels = struct.unpack('I', fin.read(4))[0]
|
17 |
+
n_elems = height * width * channels
|
18 |
+
|
19 |
+
flow = struct.unpack('f' * n_elems, fin.read(n_elems * 4))
|
20 |
+
flow = np.asarray(flow, dtype=np.float32).reshape([channels, height, width])
|
21 |
+
|
22 |
+
return flow
|
23 |
+
|
24 |
+
def load_graph_info(filename, max_edges, max_nodes):
|
25 |
+
|
26 |
+
assert os.path.isfile(filename), "File not found: {}".format(filename)
|
27 |
+
|
28 |
+
with open(filename, 'rb') as fin:
|
29 |
+
edge_total_size = struct.unpack('I', fin.read(4))[0]
|
30 |
+
edges = struct.unpack('I' * (int(edge_total_size / 4)), fin.read(edge_total_size))
|
31 |
+
edges = np.asarray(edges, dtype=np.int16).reshape(-1, 2).transpose()
|
32 |
+
nodes_total_size = struct.unpack('I', fin.read(4))[0]
|
33 |
+
nodes_ids = struct.unpack('I' * (int(nodes_total_size / 4)), fin.read(nodes_total_size))
|
34 |
+
nodes_ids = np.asarray(nodes_ids, dtype=np.int32).reshape(-1)
|
35 |
+
nodes_ids = np.sort(nodes_ids)
|
36 |
+
|
37 |
+
edges_extent = np.zeros((2, max_edges), dtype=np.int16)
|
38 |
+
edges_mask = np.zeros((max_edges), dtype=np.bool)
|
39 |
+
edges_mask[:edges.shape[1]] = 1
|
40 |
+
edges_extent[:, :edges.shape[1]] = edges
|
41 |
+
|
42 |
+
nodes_extent = np.zeros((max_nodes), dtype=np.int32)
|
43 |
+
nodes_mask = np.zeros((max_nodes), dtype=np.bool)
|
44 |
+
nodes_mask[:nodes_ids.shape[0]] = 1
|
45 |
+
nodes_extent[:nodes_ids.shape[0]] = nodes_ids
|
46 |
+
|
47 |
+
fx = struct.unpack('f', fin.read(4))[0]
|
48 |
+
fy = struct.unpack('f', fin.read(4))[0]
|
49 |
+
ox = struct.unpack('f', fin.read(4))[0]
|
50 |
+
oy = struct.unpack('f', fin.read(4))[0]
|
51 |
+
|
52 |
+
return edges_extent, edges_mask, nodes_extent, nodes_mask, fx, fy, ox, oy
|
53 |
+
|
54 |
+
def load_adja_id_info(filename, src_mask, H, W, num_adja, num_neigb):
|
55 |
+
|
56 |
+
assert os.path.isfile(filename), "File not found: {}".format(filename)
|
57 |
+
assert num_adja<=8, "Num of adja is larger than 8"
|
58 |
+
assert num_neigb<=8, "Num of neighb is larger than 8"
|
59 |
+
src_v_id = np.zeros((H*W, num_adja), dtype=np.int16)
|
60 |
+
src_neigb_id = np.zeros((H*W, num_neigb), dtype=np.int32)
|
61 |
+
with open(filename, 'rb') as fin:
|
62 |
+
neigb_id, value_id = pickle.load(fin)
|
63 |
+
assert((src_mask.sum())==value_id.shape[0])
|
64 |
+
|
65 |
+
for i in range(num_adja):
|
66 |
+
src_v_id[src_mask.reshape(-1), i] = value_id[:, i]
|
67 |
+
for i in range(num_neigb):
|
68 |
+
src_neigb_id[src_mask.reshape(-1), i] = neigb_id[:, i]
|
69 |
+
src_v_id = src_v_id.transpose().reshape(num_adja, H, W)
|
70 |
+
src_neigb_id = src_neigb_id.transpose().reshape(num_neigb, H, W)
|
71 |
+
|
72 |
+
return src_v_id, src_neigb_id
|
73 |
+
|
74 |
+
def save_flow(filename, flow_input):
|
75 |
+
flow = np.copy(flow_input)
|
76 |
+
|
77 |
+
# Flow is stored row-wise in order [channels, height, width].
|
78 |
+
assert len(flow.shape) == 3
|
79 |
+
|
80 |
+
with open(filename, 'wb') as fout:
|
81 |
+
fout.write(struct.pack('I', flow.shape[2]))
|
82 |
+
fout.write(struct.pack('I', flow.shape[1]))
|
83 |
+
fout.write(struct.pack('I', flow.shape[0]))
|
84 |
+
fout.write(struct.pack('={}f'.format(flow.size), *flow.flatten("C")))
|
data_utils/blendshape_capture/face_landmarker.task
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
|
3 |
+
size 3758596
|
data_utils/blendshape_capture/main.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8-*-
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
import glob
|
8 |
+
from mediapipe import solutions
|
9 |
+
from mediapipe.framework.formats import landmark_pb2
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from scipy.signal import savgol_filter
|
15 |
+
import onnxruntime as ort
|
16 |
+
from collections import OrderedDict
|
17 |
+
import mediapipe as mp
|
18 |
+
from mediapipe.tasks import python
|
19 |
+
from mediapipe.tasks.python import vision
|
20 |
+
|
21 |
+
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
|
25 |
+
def infer_bs(root_path):
|
26 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
27 |
+
base_options = python.BaseOptions(model_asset_path="./data_utils/blendshape_capture/face_landmarker.task")
|
28 |
+
options = vision.FaceLandmarkerOptions(base_options=base_options,
|
29 |
+
output_face_blendshapes=True,
|
30 |
+
output_facial_transformation_matrixes=True,
|
31 |
+
num_faces=1)
|
32 |
+
detector = vision.FaceLandmarker.create_from_options(options)
|
33 |
+
|
34 |
+
for i in os.listdir(root_path):
|
35 |
+
if i.endswith(".mp4"):
|
36 |
+
mp4_path = os.path.join(root_path, i)
|
37 |
+
npy_path = os.path.join(root_path, "bs.npy")
|
38 |
+
if os.path.exists(npy_path):
|
39 |
+
print("npy file exists:", i.split(".")[0])
|
40 |
+
continue
|
41 |
+
else:
|
42 |
+
print("npy file not exists:", i.split(".")[0])
|
43 |
+
image_path = os.path.join(root_path, "img/temp.png")
|
44 |
+
os.makedirs(os.path.join(root_path, 'img/'), exist_ok=True)
|
45 |
+
cap = cv2.VideoCapture(mp4_path)
|
46 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
47 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
48 |
+
print("fps:", fps)
|
49 |
+
print("frame_count:", frame_count)
|
50 |
+
k = 0
|
51 |
+
total = frame_count
|
52 |
+
bs = np.zeros((int(total), 52), dtype=np.float32)
|
53 |
+
print("total:", total)
|
54 |
+
print("videoPath:{} fps:{},k".format(mp4_path.split('/')[-1], fps))
|
55 |
+
pbar = tqdm(total=int(total))
|
56 |
+
while (cap.isOpened()):
|
57 |
+
ret, frame = cap.read()
|
58 |
+
if ret:
|
59 |
+
cv2.imwrite(image_path, frame)
|
60 |
+
image = mp.Image.create_from_file(image_path)
|
61 |
+
result = detector.detect(image)
|
62 |
+
face_blendshapes_scores = [face_blendshapes_category.score for face_blendshapes_category in
|
63 |
+
result.face_blendshapes[0]]
|
64 |
+
blendshape_coef = np.array(face_blendshapes_scores)[1:]
|
65 |
+
blendshape_coef = np.append(blendshape_coef, 0)
|
66 |
+
bs[k] = blendshape_coef
|
67 |
+
pbar.update(1)
|
68 |
+
k += 1
|
69 |
+
else:
|
70 |
+
break
|
71 |
+
cap.release()
|
72 |
+
pbar.close()
|
73 |
+
# np.save(npy_path, bs)
|
74 |
+
# print(np.shape(bs))
|
75 |
+
output = np.zeros((bs.shape[0], bs.shape[1]))
|
76 |
+
for j in range(bs.shape[1]):
|
77 |
+
output[:, j] = savgol_filter(bs[:, j], 5, 3)
|
78 |
+
np.save(npy_path, output)
|
79 |
+
print(np.shape(output))
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
parser = argparse.ArgumentParser()
|
84 |
+
parser.add_argument("--path", type=str, help="idname of target person")
|
85 |
+
args = parser.parse_args()
|
86 |
+
infer_bs(args.path)
|
data_utils/deepspeech_features/README.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Routines for DeepSpeech features processing
|
2 |
+
Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model.
|
3 |
+
|
4 |
+
## Installation
|
5 |
+
|
6 |
+
```
|
7 |
+
pip3 install -r requirements.txt
|
8 |
+
```
|
9 |
+
|
10 |
+
## Usage
|
11 |
+
|
12 |
+
Generate wav files:
|
13 |
+
```
|
14 |
+
python3 extract_wav.py --in-video=<you_data_dir>
|
15 |
+
```
|
16 |
+
|
17 |
+
Generate files with DeepSpeech features:
|
18 |
+
```
|
19 |
+
python3 extract_ds_features.py --input=<you_data_dir>
|
20 |
+
```
|
data_utils/deepspeech_features/deepspeech_features.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DeepSpeech features processing routines.
|
3 |
+
NB: Based on VOCA code. See the corresponding license restrictions.
|
4 |
+
"""
|
5 |
+
|
6 |
+
__all__ = ['conv_audios_to_deepspeech']
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import warnings
|
10 |
+
import resampy
|
11 |
+
from scipy.io import wavfile
|
12 |
+
from python_speech_features import mfcc
|
13 |
+
import tensorflow.compat.v1 as tf
|
14 |
+
tf.disable_v2_behavior()
|
15 |
+
|
16 |
+
def conv_audios_to_deepspeech(audios,
|
17 |
+
out_files,
|
18 |
+
num_frames_info,
|
19 |
+
deepspeech_pb_path,
|
20 |
+
audio_window_size=1,
|
21 |
+
audio_window_stride=1):
|
22 |
+
"""
|
23 |
+
Convert list of audio files into files with DeepSpeech features.
|
24 |
+
|
25 |
+
Parameters
|
26 |
+
----------
|
27 |
+
audios : list of str or list of None
|
28 |
+
Paths to input audio files.
|
29 |
+
out_files : list of str
|
30 |
+
Paths to output files with DeepSpeech features.
|
31 |
+
num_frames_info : list of int
|
32 |
+
List of numbers of frames.
|
33 |
+
deepspeech_pb_path : str
|
34 |
+
Path to DeepSpeech 0.1.0 frozen model.
|
35 |
+
audio_window_size : int, default 16
|
36 |
+
Audio window size.
|
37 |
+
audio_window_stride : int, default 1
|
38 |
+
Audio window stride.
|
39 |
+
"""
|
40 |
+
# deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
|
41 |
+
graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(
|
42 |
+
deepspeech_pb_path)
|
43 |
+
|
44 |
+
with tf.compat.v1.Session(graph=graph) as sess:
|
45 |
+
for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):
|
46 |
+
print(audio_file_path)
|
47 |
+
print(out_file_path)
|
48 |
+
audio_sample_rate, audio = wavfile.read(audio_file_path)
|
49 |
+
if audio.ndim != 1:
|
50 |
+
warnings.warn(
|
51 |
+
"Audio has multiple channels, the first channel is used")
|
52 |
+
audio = audio[:, 0]
|
53 |
+
ds_features = pure_conv_audio_to_deepspeech(
|
54 |
+
audio=audio,
|
55 |
+
audio_sample_rate=audio_sample_rate,
|
56 |
+
audio_window_size=audio_window_size,
|
57 |
+
audio_window_stride=audio_window_stride,
|
58 |
+
num_frames=num_frames,
|
59 |
+
net_fn=lambda x: sess.run(
|
60 |
+
logits_ph,
|
61 |
+
feed_dict={
|
62 |
+
input_node_ph: x[np.newaxis, ...],
|
63 |
+
input_lengths_ph: [x.shape[0]]}))
|
64 |
+
|
65 |
+
net_output = ds_features.reshape(-1, 29)
|
66 |
+
win_size = 16
|
67 |
+
zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
|
68 |
+
net_output = np.concatenate(
|
69 |
+
(zero_pad, net_output, zero_pad), axis=0)
|
70 |
+
windows = []
|
71 |
+
for window_index in range(0, net_output.shape[0] - win_size, 2):
|
72 |
+
windows.append(
|
73 |
+
net_output[window_index:window_index + win_size])
|
74 |
+
print(np.array(windows).shape)
|
75 |
+
np.save(out_file_path, np.array(windows))
|
76 |
+
|
77 |
+
|
78 |
+
def prepare_deepspeech_net(deepspeech_pb_path):
|
79 |
+
"""
|
80 |
+
Load and prepare DeepSpeech network.
|
81 |
+
|
82 |
+
Parameters
|
83 |
+
----------
|
84 |
+
deepspeech_pb_path : str
|
85 |
+
Path to DeepSpeech 0.1.0 frozen model.
|
86 |
+
|
87 |
+
Returns
|
88 |
+
-------
|
89 |
+
graph : obj
|
90 |
+
ThensorFlow graph.
|
91 |
+
logits_ph : obj
|
92 |
+
ThensorFlow placeholder for `logits`.
|
93 |
+
input_node_ph : obj
|
94 |
+
ThensorFlow placeholder for `input_node`.
|
95 |
+
input_lengths_ph : obj
|
96 |
+
ThensorFlow placeholder for `input_lengths`.
|
97 |
+
"""
|
98 |
+
# Load graph and place_holders:
|
99 |
+
with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f:
|
100 |
+
graph_def = tf.compat.v1.GraphDef()
|
101 |
+
graph_def.ParseFromString(f.read())
|
102 |
+
|
103 |
+
graph = tf.compat.v1.get_default_graph()
|
104 |
+
tf.import_graph_def(graph_def, name="deepspeech")
|
105 |
+
logits_ph = graph.get_tensor_by_name("deepspeech/logits:0")
|
106 |
+
input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0")
|
107 |
+
input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0")
|
108 |
+
|
109 |
+
return graph, logits_ph, input_node_ph, input_lengths_ph
|
110 |
+
|
111 |
+
|
112 |
+
def pure_conv_audio_to_deepspeech(audio,
|
113 |
+
audio_sample_rate,
|
114 |
+
audio_window_size,
|
115 |
+
audio_window_stride,
|
116 |
+
num_frames,
|
117 |
+
net_fn):
|
118 |
+
"""
|
119 |
+
Core routine for converting audion into DeepSpeech features.
|
120 |
+
|
121 |
+
Parameters
|
122 |
+
----------
|
123 |
+
audio : np.array
|
124 |
+
Audio data.
|
125 |
+
audio_sample_rate : int
|
126 |
+
Audio sample rate.
|
127 |
+
audio_window_size : int
|
128 |
+
Audio window size.
|
129 |
+
audio_window_stride : int
|
130 |
+
Audio window stride.
|
131 |
+
num_frames : int or None
|
132 |
+
Numbers of frames.
|
133 |
+
net_fn : func
|
134 |
+
Function for DeepSpeech model call.
|
135 |
+
|
136 |
+
Returns
|
137 |
+
-------
|
138 |
+
np.array
|
139 |
+
DeepSpeech features.
|
140 |
+
"""
|
141 |
+
target_sample_rate = 16000
|
142 |
+
if audio_sample_rate != target_sample_rate:
|
143 |
+
resampled_audio = resampy.resample(
|
144 |
+
x=audio.astype(np.float),
|
145 |
+
sr_orig=audio_sample_rate,
|
146 |
+
sr_new=target_sample_rate)
|
147 |
+
else:
|
148 |
+
resampled_audio = audio.astype(np.float32)
|
149 |
+
input_vector = conv_audio_to_deepspeech_input_vector(
|
150 |
+
audio=resampled_audio.astype(np.int16),
|
151 |
+
sample_rate=target_sample_rate,
|
152 |
+
num_cepstrum=26,
|
153 |
+
num_context=9)
|
154 |
+
|
155 |
+
network_output = net_fn(input_vector)
|
156 |
+
# print(network_output.shape)
|
157 |
+
|
158 |
+
deepspeech_fps = 50
|
159 |
+
video_fps = 50 # Change this option if video fps is different
|
160 |
+
audio_len_s = float(audio.shape[0]) / audio_sample_rate
|
161 |
+
if num_frames is None:
|
162 |
+
num_frames = int(round(audio_len_s * video_fps))
|
163 |
+
else:
|
164 |
+
video_fps = num_frames / audio_len_s
|
165 |
+
network_output = interpolate_features(
|
166 |
+
features=network_output[:, 0],
|
167 |
+
input_rate=deepspeech_fps,
|
168 |
+
output_rate=video_fps,
|
169 |
+
output_len=num_frames)
|
170 |
+
|
171 |
+
# Make windows:
|
172 |
+
zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))
|
173 |
+
network_output = np.concatenate(
|
174 |
+
(zero_pad, network_output, zero_pad), axis=0)
|
175 |
+
windows = []
|
176 |
+
for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):
|
177 |
+
windows.append(
|
178 |
+
network_output[window_index:window_index + audio_window_size])
|
179 |
+
|
180 |
+
return np.array(windows)
|
181 |
+
|
182 |
+
|
183 |
+
def conv_audio_to_deepspeech_input_vector(audio,
|
184 |
+
sample_rate,
|
185 |
+
num_cepstrum,
|
186 |
+
num_context):
|
187 |
+
"""
|
188 |
+
Convert audio raw data into DeepSpeech input vector.
|
189 |
+
|
190 |
+
Parameters
|
191 |
+
----------
|
192 |
+
audio : np.array
|
193 |
+
Audio data.
|
194 |
+
audio_sample_rate : int
|
195 |
+
Audio sample rate.
|
196 |
+
num_cepstrum : int
|
197 |
+
Number of cepstrum.
|
198 |
+
num_context : int
|
199 |
+
Number of context.
|
200 |
+
|
201 |
+
Returns
|
202 |
+
-------
|
203 |
+
np.array
|
204 |
+
DeepSpeech input vector.
|
205 |
+
"""
|
206 |
+
# Get mfcc coefficients:
|
207 |
+
features = mfcc(
|
208 |
+
signal=audio,
|
209 |
+
samplerate=sample_rate,
|
210 |
+
numcep=num_cepstrum)
|
211 |
+
|
212 |
+
# We only keep every second feature (BiRNN stride = 2):
|
213 |
+
features = features[::2]
|
214 |
+
|
215 |
+
# One stride per time step in the input:
|
216 |
+
num_strides = len(features)
|
217 |
+
|
218 |
+
# Add empty initial and final contexts:
|
219 |
+
empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)
|
220 |
+
features = np.concatenate((empty_context, features, empty_context))
|
221 |
+
|
222 |
+
# Create a view into the array with overlapping strides of size
|
223 |
+
# numcontext (past) + 1 (present) + numcontext (future):
|
224 |
+
window_size = 2 * num_context + 1
|
225 |
+
train_inputs = np.lib.stride_tricks.as_strided(
|
226 |
+
features,
|
227 |
+
shape=(num_strides, window_size, num_cepstrum),
|
228 |
+
strides=(features.strides[0],
|
229 |
+
features.strides[0], features.strides[1]),
|
230 |
+
writeable=False)
|
231 |
+
|
232 |
+
# Flatten the second and third dimensions:
|
233 |
+
train_inputs = np.reshape(train_inputs, [num_strides, -1])
|
234 |
+
|
235 |
+
train_inputs = np.copy(train_inputs)
|
236 |
+
train_inputs = (train_inputs - np.mean(train_inputs)) / \
|
237 |
+
np.std(train_inputs)
|
238 |
+
|
239 |
+
return train_inputs
|
240 |
+
|
241 |
+
|
242 |
+
def interpolate_features(features,
|
243 |
+
input_rate,
|
244 |
+
output_rate,
|
245 |
+
output_len):
|
246 |
+
"""
|
247 |
+
Interpolate DeepSpeech features.
|
248 |
+
|
249 |
+
Parameters
|
250 |
+
----------
|
251 |
+
features : np.array
|
252 |
+
DeepSpeech features.
|
253 |
+
input_rate : int
|
254 |
+
input rate (FPS).
|
255 |
+
output_rate : int
|
256 |
+
Output rate (FPS).
|
257 |
+
output_len : int
|
258 |
+
Output data length.
|
259 |
+
|
260 |
+
Returns
|
261 |
+
-------
|
262 |
+
np.array
|
263 |
+
Interpolated data.
|
264 |
+
"""
|
265 |
+
input_len = features.shape[0]
|
266 |
+
num_features = features.shape[1]
|
267 |
+
input_timestamps = np.arange(input_len) / float(input_rate)
|
268 |
+
output_timestamps = np.arange(output_len) / float(output_rate)
|
269 |
+
output_features = np.zeros((output_len, num_features))
|
270 |
+
for feature_idx in range(num_features):
|
271 |
+
output_features[:, feature_idx] = np.interp(
|
272 |
+
x=output_timestamps,
|
273 |
+
xp=input_timestamps,
|
274 |
+
fp=features[:, feature_idx])
|
275 |
+
return output_features
|
data_utils/deepspeech_features/deepspeech_store.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Routines for loading DeepSpeech model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
__all__ = ['get_deepspeech_model_file']
|
6 |
+
|
7 |
+
import os
|
8 |
+
import zipfile
|
9 |
+
import logging
|
10 |
+
import hashlib
|
11 |
+
|
12 |
+
|
13 |
+
deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'
|
14 |
+
|
15 |
+
|
16 |
+
def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")):
|
17 |
+
"""
|
18 |
+
Return location for the pretrained on local file system. This function will download from online model zoo when
|
19 |
+
model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
----------
|
23 |
+
local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
|
24 |
+
Location for keeping the model parameters.
|
25 |
+
|
26 |
+
Returns
|
27 |
+
-------
|
28 |
+
file_path
|
29 |
+
Path to the requested pretrained model file.
|
30 |
+
"""
|
31 |
+
sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e"
|
32 |
+
file_name = "deepspeech-0_1_0-b90017e8.pb"
|
33 |
+
local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
|
34 |
+
file_path = os.path.join(local_model_store_dir_path, file_name)
|
35 |
+
if os.path.exists(file_path):
|
36 |
+
if _check_sha1(file_path, sha1_hash):
|
37 |
+
return file_path
|
38 |
+
else:
|
39 |
+
logging.warning("Mismatch in the content of model file detected. Downloading again.")
|
40 |
+
else:
|
41 |
+
logging.info("Model file not found. Downloading to {}.".format(file_path))
|
42 |
+
|
43 |
+
if not os.path.exists(local_model_store_dir_path):
|
44 |
+
os.makedirs(local_model_store_dir_path)
|
45 |
+
|
46 |
+
zip_file_path = file_path + ".zip"
|
47 |
+
_download(
|
48 |
+
url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
|
49 |
+
repo_url=deepspeech_features_repo_url,
|
50 |
+
repo_release_tag="v0.0.1",
|
51 |
+
file_name=file_name),
|
52 |
+
path=zip_file_path,
|
53 |
+
overwrite=True)
|
54 |
+
with zipfile.ZipFile(zip_file_path) as zf:
|
55 |
+
zf.extractall(local_model_store_dir_path)
|
56 |
+
os.remove(zip_file_path)
|
57 |
+
|
58 |
+
if _check_sha1(file_path, sha1_hash):
|
59 |
+
return file_path
|
60 |
+
else:
|
61 |
+
raise ValueError("Downloaded file has different hash. Please try again.")
|
62 |
+
|
63 |
+
|
64 |
+
def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
|
65 |
+
"""
|
66 |
+
Download an given URL
|
67 |
+
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
url : str
|
71 |
+
URL to download
|
72 |
+
path : str, optional
|
73 |
+
Destination path to store downloaded file. By default stores to the
|
74 |
+
current directory with same name as in url.
|
75 |
+
overwrite : bool, optional
|
76 |
+
Whether to overwrite destination file if already exists.
|
77 |
+
sha1_hash : str, optional
|
78 |
+
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
|
79 |
+
but doesn't match.
|
80 |
+
retries : integer, default 5
|
81 |
+
The number of times to attempt the download in case of failure or non 200 return codes
|
82 |
+
verify_ssl : bool, default True
|
83 |
+
Verify SSL certificates.
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
str
|
88 |
+
The file path of the downloaded file.
|
89 |
+
"""
|
90 |
+
import warnings
|
91 |
+
try:
|
92 |
+
import requests
|
93 |
+
except ImportError:
|
94 |
+
class requests_failed_to_import(object):
|
95 |
+
pass
|
96 |
+
requests = requests_failed_to_import
|
97 |
+
|
98 |
+
if path is None:
|
99 |
+
fname = url.split("/")[-1]
|
100 |
+
# Empty filenames are invalid
|
101 |
+
assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
|
102 |
+
else:
|
103 |
+
path = os.path.expanduser(path)
|
104 |
+
if os.path.isdir(path):
|
105 |
+
fname = os.path.join(path, url.split("/")[-1])
|
106 |
+
else:
|
107 |
+
fname = path
|
108 |
+
assert retries >= 0, "Number of retries should be at least 0"
|
109 |
+
|
110 |
+
if not verify_ssl:
|
111 |
+
warnings.warn(
|
112 |
+
"Unverified HTTPS request is being made (verify_ssl=False). "
|
113 |
+
"Adding certificate verification is strongly advised.")
|
114 |
+
|
115 |
+
if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
|
116 |
+
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
|
117 |
+
if not os.path.exists(dirname):
|
118 |
+
os.makedirs(dirname)
|
119 |
+
while retries + 1 > 0:
|
120 |
+
# Disable pyling too broad Exception
|
121 |
+
# pylint: disable=W0703
|
122 |
+
try:
|
123 |
+
print("Downloading {} from {}...".format(fname, url))
|
124 |
+
r = requests.get(url, stream=True, verify=verify_ssl)
|
125 |
+
if r.status_code != 200:
|
126 |
+
raise RuntimeError("Failed downloading url {}".format(url))
|
127 |
+
with open(fname, "wb") as f:
|
128 |
+
for chunk in r.iter_content(chunk_size=1024):
|
129 |
+
if chunk: # filter out keep-alive new chunks
|
130 |
+
f.write(chunk)
|
131 |
+
if sha1_hash and not _check_sha1(fname, sha1_hash):
|
132 |
+
raise UserWarning("File {} is downloaded but the content hash does not match."
|
133 |
+
" The repo may be outdated or download may be incomplete. "
|
134 |
+
"If the `repo_url` is overridden, consider switching to "
|
135 |
+
"the default repo.".format(fname))
|
136 |
+
break
|
137 |
+
except Exception as e:
|
138 |
+
retries -= 1
|
139 |
+
if retries <= 0:
|
140 |
+
raise e
|
141 |
+
else:
|
142 |
+
print("download failed, retrying, {} attempt{} left"
|
143 |
+
.format(retries, "s" if retries > 1 else ""))
|
144 |
+
|
145 |
+
return fname
|
146 |
+
|
147 |
+
|
148 |
+
def _check_sha1(filename, sha1_hash):
|
149 |
+
"""
|
150 |
+
Check whether the sha1 hash of the file content matches the expected hash.
|
151 |
+
|
152 |
+
Parameters
|
153 |
+
----------
|
154 |
+
filename : str
|
155 |
+
Path to the file.
|
156 |
+
sha1_hash : str
|
157 |
+
Expected sha1 hash in hexadecimal digits.
|
158 |
+
|
159 |
+
Returns
|
160 |
+
-------
|
161 |
+
bool
|
162 |
+
Whether the file content matches the expected hash.
|
163 |
+
"""
|
164 |
+
sha1 = hashlib.sha1()
|
165 |
+
with open(filename, "rb") as f:
|
166 |
+
while True:
|
167 |
+
data = f.read(1048576)
|
168 |
+
if not data:
|
169 |
+
break
|
170 |
+
sha1.update(data)
|
171 |
+
|
172 |
+
return sha1.hexdigest() == sha1_hash
|
data_utils/deepspeech_features/extract_ds_features.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Script for extracting DeepSpeech features from audio file.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from deepspeech_store import get_deepspeech_model_file
|
10 |
+
from deepspeech_features import conv_audios_to_deepspeech
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
"""
|
15 |
+
Create python script parameters.
|
16 |
+
Returns
|
17 |
+
-------
|
18 |
+
ArgumentParser
|
19 |
+
Resulted args.
|
20 |
+
"""
|
21 |
+
parser = argparse.ArgumentParser(
|
22 |
+
description="Extract DeepSpeech features from audio file",
|
23 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
24 |
+
parser.add_argument(
|
25 |
+
"--input",
|
26 |
+
type=str,
|
27 |
+
required=True,
|
28 |
+
help="path to input audio file or directory")
|
29 |
+
parser.add_argument(
|
30 |
+
"--output",
|
31 |
+
type=str,
|
32 |
+
help="path to output file with DeepSpeech features")
|
33 |
+
parser.add_argument(
|
34 |
+
"--deepspeech",
|
35 |
+
type=str,
|
36 |
+
help="path to DeepSpeech 0.1.0 frozen model")
|
37 |
+
parser.add_argument(
|
38 |
+
"--metainfo",
|
39 |
+
type=str,
|
40 |
+
help="path to file with meta-information")
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
return args
|
44 |
+
|
45 |
+
|
46 |
+
def extract_features(in_audios,
|
47 |
+
out_files,
|
48 |
+
deepspeech_pb_path,
|
49 |
+
metainfo_file_path=None):
|
50 |
+
"""
|
51 |
+
Real extract audio from video file.
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
in_audios : list of str
|
55 |
+
Paths to input audio files.
|
56 |
+
out_files : list of str
|
57 |
+
Paths to output files with DeepSpeech features.
|
58 |
+
deepspeech_pb_path : str
|
59 |
+
Path to DeepSpeech 0.1.0 frozen model.
|
60 |
+
metainfo_file_path : str, default None
|
61 |
+
Path to file with meta-information.
|
62 |
+
"""
|
63 |
+
#deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
|
64 |
+
if metainfo_file_path is None:
|
65 |
+
num_frames_info = [None] * len(in_audios)
|
66 |
+
else:
|
67 |
+
train_df = pd.read_csv(
|
68 |
+
metainfo_file_path,
|
69 |
+
sep="\t",
|
70 |
+
index_col=False,
|
71 |
+
dtype={"Id": np.int, "File": np.unicode, "Count": np.int})
|
72 |
+
num_frames_info = train_df["Count"].values
|
73 |
+
assert (len(num_frames_info) == len(in_audios))
|
74 |
+
|
75 |
+
for i, in_audio in enumerate(in_audios):
|
76 |
+
if not out_files[i]:
|
77 |
+
file_stem, _ = os.path.splitext(in_audio)
|
78 |
+
out_files[i] = file_stem + "_ds.npy"
|
79 |
+
#print(out_files[i])
|
80 |
+
conv_audios_to_deepspeech(
|
81 |
+
audios=in_audios,
|
82 |
+
out_files=out_files,
|
83 |
+
num_frames_info=num_frames_info,
|
84 |
+
deepspeech_pb_path=deepspeech_pb_path)
|
85 |
+
|
86 |
+
|
87 |
+
def main():
|
88 |
+
"""
|
89 |
+
Main body of script.
|
90 |
+
"""
|
91 |
+
args = parse_args()
|
92 |
+
in_audio = os.path.expanduser(args.input)
|
93 |
+
if not os.path.exists(in_audio):
|
94 |
+
raise Exception("Input file/directory doesn't exist: {}".format(in_audio))
|
95 |
+
deepspeech_pb_path = args.deepspeech
|
96 |
+
#add
|
97 |
+
deepspeech_pb_path = True
|
98 |
+
args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb'
|
99 |
+
#deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
|
100 |
+
if deepspeech_pb_path is None:
|
101 |
+
deepspeech_pb_path = ""
|
102 |
+
if deepspeech_pb_path:
|
103 |
+
deepspeech_pb_path = os.path.expanduser(args.deepspeech)
|
104 |
+
if not os.path.exists(deepspeech_pb_path):
|
105 |
+
deepspeech_pb_path = get_deepspeech_model_file()
|
106 |
+
if os.path.isfile(in_audio):
|
107 |
+
extract_features(
|
108 |
+
in_audios=[in_audio],
|
109 |
+
out_files=[args.output],
|
110 |
+
deepspeech_pb_path=deepspeech_pb_path,
|
111 |
+
metainfo_file_path=args.metainfo)
|
112 |
+
else:
|
113 |
+
audio_file_paths = []
|
114 |
+
for file_name in os.listdir(in_audio):
|
115 |
+
if not os.path.isfile(os.path.join(in_audio, file_name)):
|
116 |
+
continue
|
117 |
+
_, file_ext = os.path.splitext(file_name)
|
118 |
+
if file_ext.lower() == ".wav":
|
119 |
+
audio_file_path = os.path.join(in_audio, file_name)
|
120 |
+
audio_file_paths.append(audio_file_path)
|
121 |
+
audio_file_paths = sorted(audio_file_paths)
|
122 |
+
out_file_paths = [""] * len(audio_file_paths)
|
123 |
+
extract_features(
|
124 |
+
in_audios=audio_file_paths,
|
125 |
+
out_files=out_file_paths,
|
126 |
+
deepspeech_pb_path=deepspeech_pb_path,
|
127 |
+
metainfo_file_path=args.metainfo)
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
main()
|
132 |
+
|
data_utils/deepspeech_features/extract_wav.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Script for extracting audio (16-bit, mono, 22000 Hz) from video file.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import argparse
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
|
10 |
+
def parse_args():
|
11 |
+
"""
|
12 |
+
Create python script parameters.
|
13 |
+
|
14 |
+
Returns
|
15 |
+
-------
|
16 |
+
ArgumentParser
|
17 |
+
Resulted args.
|
18 |
+
"""
|
19 |
+
parser = argparse.ArgumentParser(
|
20 |
+
description="Extract audio from video file",
|
21 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
22 |
+
parser.add_argument(
|
23 |
+
"--in-video",
|
24 |
+
type=str,
|
25 |
+
required=True,
|
26 |
+
help="path to input video file or directory")
|
27 |
+
parser.add_argument(
|
28 |
+
"--out-audio",
|
29 |
+
type=str,
|
30 |
+
help="path to output audio file")
|
31 |
+
|
32 |
+
args = parser.parse_args()
|
33 |
+
return args
|
34 |
+
|
35 |
+
|
36 |
+
def extract_audio(in_video,
|
37 |
+
out_audio):
|
38 |
+
"""
|
39 |
+
Real extract audio from video file.
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
in_video : str
|
44 |
+
Path to input video file.
|
45 |
+
out_audio : str
|
46 |
+
Path to output audio file.
|
47 |
+
"""
|
48 |
+
if not out_audio:
|
49 |
+
file_stem, _ = os.path.splitext(in_video)
|
50 |
+
out_audio = file_stem + ".wav"
|
51 |
+
# command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}"
|
52 |
+
# command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
|
53 |
+
# command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
|
54 |
+
command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}"
|
55 |
+
subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True)
|
56 |
+
|
57 |
+
|
58 |
+
def main():
|
59 |
+
"""
|
60 |
+
Main body of script.
|
61 |
+
"""
|
62 |
+
args = parse_args()
|
63 |
+
in_video = os.path.expanduser(args.in_video)
|
64 |
+
if not os.path.exists(in_video):
|
65 |
+
raise Exception("Input file/directory doesn't exist: {}".format(in_video))
|
66 |
+
if os.path.isfile(in_video):
|
67 |
+
extract_audio(
|
68 |
+
in_video=in_video,
|
69 |
+
out_audio=args.out_audio)
|
70 |
+
else:
|
71 |
+
video_file_paths = []
|
72 |
+
for file_name in os.listdir(in_video):
|
73 |
+
if not os.path.isfile(os.path.join(in_video, file_name)):
|
74 |
+
continue
|
75 |
+
_, file_ext = os.path.splitext(file_name)
|
76 |
+
if file_ext.lower() in (".mp4", ".mkv", ".avi"):
|
77 |
+
video_file_path = os.path.join(in_video, file_name)
|
78 |
+
video_file_paths.append(video_file_path)
|
79 |
+
video_file_paths = sorted(video_file_paths)
|
80 |
+
for video_file_path in video_file_paths:
|
81 |
+
extract_audio(
|
82 |
+
in_video=video_file_path,
|
83 |
+
out_audio="")
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
main()
|
data_utils/deepspeech_features/fea_win.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
net_output = np.load('french.ds.npy').reshape(-1, 29)
|
4 |
+
win_size = 16
|
5 |
+
zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
|
6 |
+
net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0)
|
7 |
+
windows = []
|
8 |
+
for window_index in range(0, net_output.shape[0] - win_size, 2):
|
9 |
+
windows.append(net_output[window_index:window_index + win_size])
|
10 |
+
print(np.array(windows).shape)
|
11 |
+
np.save('aud_french.npy', np.array(windows))
|
data_utils/face_parsing/79999_iter.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
|
3 |
+
size 53289463
|
data_utils/face_parsing/logger.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import os.path as osp
|
6 |
+
import time
|
7 |
+
import sys
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def setup_logger(logpth):
|
14 |
+
logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
|
15 |
+
logfile = osp.join(logpth, logfile)
|
16 |
+
FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
|
17 |
+
log_level = logging.INFO
|
18 |
+
if dist.is_initialized() and not dist.get_rank()==0:
|
19 |
+
log_level = logging.ERROR
|
20 |
+
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
|
21 |
+
logging.root.addHandler(logging.StreamHandler())
|
22 |
+
|
23 |
+
|
data_utils/face_parsing/model.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from resnet import Resnet18
|
11 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
12 |
+
|
13 |
+
|
14 |
+
class ConvBNReLU(nn.Module):
|
15 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
16 |
+
super(ConvBNReLU, self).__init__()
|
17 |
+
self.conv = nn.Conv2d(in_chan,
|
18 |
+
out_chan,
|
19 |
+
kernel_size = ks,
|
20 |
+
stride = stride,
|
21 |
+
padding = padding,
|
22 |
+
bias = False)
|
23 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
24 |
+
self.init_weight()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.conv(x)
|
28 |
+
x = F.relu(self.bn(x))
|
29 |
+
return x
|
30 |
+
|
31 |
+
def init_weight(self):
|
32 |
+
for ly in self.children():
|
33 |
+
if isinstance(ly, nn.Conv2d):
|
34 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
35 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
36 |
+
|
37 |
+
class BiSeNetOutput(nn.Module):
|
38 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
39 |
+
super(BiSeNetOutput, self).__init__()
|
40 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
41 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
42 |
+
self.init_weight()
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.conv(x)
|
46 |
+
x = self.conv_out(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
def init_weight(self):
|
50 |
+
for ly in self.children():
|
51 |
+
if isinstance(ly, nn.Conv2d):
|
52 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
53 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
54 |
+
|
55 |
+
def get_params(self):
|
56 |
+
wd_params, nowd_params = [], []
|
57 |
+
for name, module in self.named_modules():
|
58 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
59 |
+
wd_params.append(module.weight)
|
60 |
+
if not module.bias is None:
|
61 |
+
nowd_params.append(module.bias)
|
62 |
+
elif isinstance(module, nn.BatchNorm2d):
|
63 |
+
nowd_params += list(module.parameters())
|
64 |
+
return wd_params, nowd_params
|
65 |
+
|
66 |
+
|
67 |
+
class AttentionRefinementModule(nn.Module):
|
68 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
69 |
+
super(AttentionRefinementModule, self).__init__()
|
70 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
71 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
72 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
73 |
+
self.sigmoid_atten = nn.Sigmoid()
|
74 |
+
self.init_weight()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
feat = self.conv(x)
|
78 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
79 |
+
atten = self.conv_atten(atten)
|
80 |
+
atten = self.bn_atten(atten)
|
81 |
+
atten = self.sigmoid_atten(atten)
|
82 |
+
out = torch.mul(feat, atten)
|
83 |
+
return out
|
84 |
+
|
85 |
+
def init_weight(self):
|
86 |
+
for ly in self.children():
|
87 |
+
if isinstance(ly, nn.Conv2d):
|
88 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
89 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
90 |
+
|
91 |
+
|
92 |
+
class ContextPath(nn.Module):
|
93 |
+
def __init__(self, *args, **kwargs):
|
94 |
+
super(ContextPath, self).__init__()
|
95 |
+
self.resnet = Resnet18()
|
96 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
97 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
98 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
99 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
100 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
101 |
+
|
102 |
+
self.init_weight()
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
H0, W0 = x.size()[2:]
|
106 |
+
feat8, feat16, feat32 = self.resnet(x)
|
107 |
+
H8, W8 = feat8.size()[2:]
|
108 |
+
H16, W16 = feat16.size()[2:]
|
109 |
+
H32, W32 = feat32.size()[2:]
|
110 |
+
|
111 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
112 |
+
avg = self.conv_avg(avg)
|
113 |
+
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
114 |
+
|
115 |
+
feat32_arm = self.arm32(feat32)
|
116 |
+
feat32_sum = feat32_arm + avg_up
|
117 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
118 |
+
feat32_up = self.conv_head32(feat32_up)
|
119 |
+
|
120 |
+
feat16_arm = self.arm16(feat16)
|
121 |
+
feat16_sum = feat16_arm + feat32_up
|
122 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
123 |
+
feat16_up = self.conv_head16(feat16_up)
|
124 |
+
|
125 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
126 |
+
|
127 |
+
def init_weight(self):
|
128 |
+
for ly in self.children():
|
129 |
+
if isinstance(ly, nn.Conv2d):
|
130 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
131 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
132 |
+
|
133 |
+
def get_params(self):
|
134 |
+
wd_params, nowd_params = [], []
|
135 |
+
for name, module in self.named_modules():
|
136 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
137 |
+
wd_params.append(module.weight)
|
138 |
+
if not module.bias is None:
|
139 |
+
nowd_params.append(module.bias)
|
140 |
+
elif isinstance(module, nn.BatchNorm2d):
|
141 |
+
nowd_params += list(module.parameters())
|
142 |
+
return wd_params, nowd_params
|
143 |
+
|
144 |
+
|
145 |
+
### This is not used, since I replace this with the resnet feature with the same size
|
146 |
+
class SpatialPath(nn.Module):
|
147 |
+
def __init__(self, *args, **kwargs):
|
148 |
+
super(SpatialPath, self).__init__()
|
149 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
150 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
151 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
152 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
153 |
+
self.init_weight()
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
feat = self.conv1(x)
|
157 |
+
feat = self.conv2(feat)
|
158 |
+
feat = self.conv3(feat)
|
159 |
+
feat = self.conv_out(feat)
|
160 |
+
return feat
|
161 |
+
|
162 |
+
def init_weight(self):
|
163 |
+
for ly in self.children():
|
164 |
+
if isinstance(ly, nn.Conv2d):
|
165 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
166 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
167 |
+
|
168 |
+
def get_params(self):
|
169 |
+
wd_params, nowd_params = [], []
|
170 |
+
for name, module in self.named_modules():
|
171 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
172 |
+
wd_params.append(module.weight)
|
173 |
+
if not module.bias is None:
|
174 |
+
nowd_params.append(module.bias)
|
175 |
+
elif isinstance(module, nn.BatchNorm2d):
|
176 |
+
nowd_params += list(module.parameters())
|
177 |
+
return wd_params, nowd_params
|
178 |
+
|
179 |
+
|
180 |
+
class FeatureFusionModule(nn.Module):
|
181 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
182 |
+
super(FeatureFusionModule, self).__init__()
|
183 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
184 |
+
self.conv1 = nn.Conv2d(out_chan,
|
185 |
+
out_chan//4,
|
186 |
+
kernel_size = 1,
|
187 |
+
stride = 1,
|
188 |
+
padding = 0,
|
189 |
+
bias = False)
|
190 |
+
self.conv2 = nn.Conv2d(out_chan//4,
|
191 |
+
out_chan,
|
192 |
+
kernel_size = 1,
|
193 |
+
stride = 1,
|
194 |
+
padding = 0,
|
195 |
+
bias = False)
|
196 |
+
self.relu = nn.ReLU(inplace=True)
|
197 |
+
self.sigmoid = nn.Sigmoid()
|
198 |
+
self.init_weight()
|
199 |
+
|
200 |
+
def forward(self, fsp, fcp):
|
201 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
202 |
+
feat = self.convblk(fcat)
|
203 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
204 |
+
atten = self.conv1(atten)
|
205 |
+
atten = self.relu(atten)
|
206 |
+
atten = self.conv2(atten)
|
207 |
+
atten = self.sigmoid(atten)
|
208 |
+
feat_atten = torch.mul(feat, atten)
|
209 |
+
feat_out = feat_atten + feat
|
210 |
+
return feat_out
|
211 |
+
|
212 |
+
def init_weight(self):
|
213 |
+
for ly in self.children():
|
214 |
+
if isinstance(ly, nn.Conv2d):
|
215 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
216 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
217 |
+
|
218 |
+
def get_params(self):
|
219 |
+
wd_params, nowd_params = [], []
|
220 |
+
for name, module in self.named_modules():
|
221 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
222 |
+
wd_params.append(module.weight)
|
223 |
+
if not module.bias is None:
|
224 |
+
nowd_params.append(module.bias)
|
225 |
+
elif isinstance(module, nn.BatchNorm2d):
|
226 |
+
nowd_params += list(module.parameters())
|
227 |
+
return wd_params, nowd_params
|
228 |
+
|
229 |
+
|
230 |
+
class BiSeNet(nn.Module):
|
231 |
+
def __init__(self, n_classes, *args, **kwargs):
|
232 |
+
super(BiSeNet, self).__init__()
|
233 |
+
self.cp = ContextPath()
|
234 |
+
## here self.sp is deleted
|
235 |
+
self.ffm = FeatureFusionModule(256, 256)
|
236 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
237 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
238 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
239 |
+
self.init_weight()
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
H, W = x.size()[2:]
|
243 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
244 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
245 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
246 |
+
|
247 |
+
feat_out = self.conv_out(feat_fuse)
|
248 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
249 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
250 |
+
|
251 |
+
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
252 |
+
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
253 |
+
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
254 |
+
|
255 |
+
# return feat_out, feat_out16, feat_out32
|
256 |
+
return feat_out
|
257 |
+
|
258 |
+
def init_weight(self):
|
259 |
+
for ly in self.children():
|
260 |
+
if isinstance(ly, nn.Conv2d):
|
261 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
262 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
263 |
+
|
264 |
+
def get_params(self):
|
265 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
266 |
+
for name, child in self.named_children():
|
267 |
+
child_wd_params, child_nowd_params = child.get_params()
|
268 |
+
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
269 |
+
lr_mul_wd_params += child_wd_params
|
270 |
+
lr_mul_nowd_params += child_nowd_params
|
271 |
+
else:
|
272 |
+
wd_params += child_wd_params
|
273 |
+
nowd_params += child_nowd_params
|
274 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
275 |
+
|
276 |
+
|
277 |
+
if __name__ == "__main__":
|
278 |
+
net = BiSeNet(19)
|
279 |
+
net.cuda()
|
280 |
+
net.eval()
|
281 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
282 |
+
out, out16, out32 = net(in_ten)
|
283 |
+
print(out.shape)
|
284 |
+
|
285 |
+
net.get_params()
|
data_utils/face_parsing/resnet.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.model_zoo as modelzoo
|
8 |
+
|
9 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
10 |
+
|
11 |
+
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
12 |
+
|
13 |
+
|
14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
15 |
+
"""3x3 convolution with padding"""
|
16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
17 |
+
padding=1, bias=False)
|
18 |
+
|
19 |
+
|
20 |
+
class BasicBlock(nn.Module):
|
21 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
22 |
+
super(BasicBlock, self).__init__()
|
23 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
24 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
25 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
26 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = None
|
29 |
+
if in_chan != out_chan or stride != 1:
|
30 |
+
self.downsample = nn.Sequential(
|
31 |
+
nn.Conv2d(in_chan, out_chan,
|
32 |
+
kernel_size=1, stride=stride, bias=False),
|
33 |
+
nn.BatchNorm2d(out_chan),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = self.conv1(x)
|
38 |
+
residual = F.relu(self.bn1(residual))
|
39 |
+
residual = self.conv2(residual)
|
40 |
+
residual = self.bn2(residual)
|
41 |
+
|
42 |
+
shortcut = x
|
43 |
+
if self.downsample is not None:
|
44 |
+
shortcut = self.downsample(x)
|
45 |
+
|
46 |
+
out = shortcut + residual
|
47 |
+
out = self.relu(out)
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
52 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
53 |
+
for i in range(bnum-1):
|
54 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
55 |
+
return nn.Sequential(*layers)
|
56 |
+
|
57 |
+
|
58 |
+
class Resnet18(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(Resnet18, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
62 |
+
bias=False)
|
63 |
+
self.bn1 = nn.BatchNorm2d(64)
|
64 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
65 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
66 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
67 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
68 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
69 |
+
self.init_weight()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.conv1(x)
|
73 |
+
x = F.relu(self.bn1(x))
|
74 |
+
x = self.maxpool(x)
|
75 |
+
|
76 |
+
x = self.layer1(x)
|
77 |
+
feat8 = self.layer2(x) # 1/8
|
78 |
+
feat16 = self.layer3(feat8) # 1/16
|
79 |
+
feat32 = self.layer4(feat16) # 1/32
|
80 |
+
return feat8, feat16, feat32
|
81 |
+
|
82 |
+
def init_weight(self):
|
83 |
+
state_dict = modelzoo.load_url(resnet18_url)
|
84 |
+
self_state_dict = self.state_dict()
|
85 |
+
for k, v in state_dict.items():
|
86 |
+
if 'fc' in k: continue
|
87 |
+
self_state_dict.update({k: v})
|
88 |
+
self.load_state_dict(self_state_dict)
|
89 |
+
|
90 |
+
def get_params(self):
|
91 |
+
wd_params, nowd_params = [], []
|
92 |
+
for name, module in self.named_modules():
|
93 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
94 |
+
wd_params.append(module.weight)
|
95 |
+
if not module.bias is None:
|
96 |
+
nowd_params.append(module.bias)
|
97 |
+
elif isinstance(module, nn.BatchNorm2d):
|
98 |
+
nowd_params += list(module.parameters())
|
99 |
+
return wd_params, nowd_params
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
net = Resnet18()
|
104 |
+
x = torch.randn(16, 3, 224, 224)
|
105 |
+
out = net(x)
|
106 |
+
print(out[0].size())
|
107 |
+
print(out[1].size())
|
108 |
+
print(out[2].size())
|
109 |
+
net.get_params()
|
data_utils/face_parsing/test.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
import numpy as np
|
4 |
+
from model import BiSeNet
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import os
|
9 |
+
import os.path as osp
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
import torchvision.transforms as transforms
|
13 |
+
import cv2
|
14 |
+
from pathlib import Path
|
15 |
+
import configargparse
|
16 |
+
import tqdm
|
17 |
+
|
18 |
+
# import ttach as tta
|
19 |
+
|
20 |
+
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg',
|
21 |
+
img_size=(512, 512)):
|
22 |
+
im = np.array(im)
|
23 |
+
vis_im = im.copy().astype(np.uint8)
|
24 |
+
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
25 |
+
vis_parsing_anno = cv2.resize(
|
26 |
+
vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
|
27 |
+
vis_parsing_anno_color = np.zeros(
|
28 |
+
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
|
29 |
+
vis_parsing_anno_color_face = np.zeros(
|
30 |
+
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
|
31 |
+
|
32 |
+
num_of_class = np.max(vis_parsing_anno)
|
33 |
+
# print(num_of_class)
|
34 |
+
for pi in range(1, 14):
|
35 |
+
index = np.where(vis_parsing_anno == pi)
|
36 |
+
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
|
37 |
+
for pi in range(14, 16):
|
38 |
+
index = np.where(vis_parsing_anno == pi)
|
39 |
+
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0])
|
40 |
+
for pi in range(16, 17):
|
41 |
+
index = np.where(vis_parsing_anno == pi)
|
42 |
+
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255])
|
43 |
+
for pi in range(17, num_of_class+1):
|
44 |
+
index = np.where(vis_parsing_anno == pi)
|
45 |
+
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
|
46 |
+
|
47 |
+
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
48 |
+
index = np.where(vis_parsing_anno == num_of_class-1)
|
49 |
+
vis_im = cv2.resize(vis_parsing_anno_color, img_size,
|
50 |
+
interpolation=cv2.INTER_NEAREST)
|
51 |
+
if save_im:
|
52 |
+
cv2.imwrite(save_path, vis_im)
|
53 |
+
|
54 |
+
for pi in range(1, 7):
|
55 |
+
index = np.where(vis_parsing_anno == pi)
|
56 |
+
vis_parsing_anno_color_face[index[0], index[1], :] = np.array([255, 0, 0])
|
57 |
+
for pi in range(10, 14):
|
58 |
+
index = np.where(vis_parsing_anno == pi)
|
59 |
+
vis_parsing_anno_color_face[index[0], index[1], :] = np.array([255, 0, 0])
|
60 |
+
pad = 5
|
61 |
+
vis_parsing_anno_color_face = vis_parsing_anno_color_face.astype(np.uint8)
|
62 |
+
face_part = (vis_parsing_anno_color_face[..., 0] == 255) & (vis_parsing_anno_color_face[..., 1] == 0) & (vis_parsing_anno_color_face[..., 2] == 0)
|
63 |
+
face_coords = np.stack(np.nonzero(face_part), axis=-1)
|
64 |
+
sorted_inds = np.lexsort((-face_coords[:, 0], face_coords[:, 1]))
|
65 |
+
sorted_face_coords = face_coords[sorted_inds]
|
66 |
+
u, uid, ucnt = np.unique(sorted_face_coords[:, 1], return_index=True, return_counts=True)
|
67 |
+
bottom_face_coords = sorted_face_coords[uid] + np.array([pad, 0])
|
68 |
+
rows, cols, _ = vis_parsing_anno_color_face.shape
|
69 |
+
|
70 |
+
# 为了保证新的坐标在图片范围内
|
71 |
+
bottom_face_coords[:, 0] = np.clip(bottom_face_coords[:, 0], 0, rows - 1)
|
72 |
+
|
73 |
+
y_min = np.min(bottom_face_coords[:, 1])
|
74 |
+
y_max = np.max(bottom_face_coords[:, 1])
|
75 |
+
|
76 |
+
# 计算1和2部分的开始和结束位置
|
77 |
+
y_range = y_max - y_min
|
78 |
+
height_per_part = y_range // 4
|
79 |
+
|
80 |
+
start_y_part1 = y_min + height_per_part
|
81 |
+
end_y_part1 = start_y_part1 + height_per_part
|
82 |
+
|
83 |
+
start_y_part2 = end_y_part1
|
84 |
+
end_y_part2 = start_y_part2 + height_per_part
|
85 |
+
|
86 |
+
for coord in bottom_face_coords:
|
87 |
+
x, y = coord
|
88 |
+
start_x = max(x - pad, 0)
|
89 |
+
end_x = min(x + pad, rows)
|
90 |
+
if start_y_part1 <= y <= end_y_part1 or start_y_part2 <= y <= end_y_part2:
|
91 |
+
vis_parsing_anno_color_face[start_x:end_x, y] = [255, 0, 0]
|
92 |
+
# else:
|
93 |
+
# start_x = max(x - 2*pad, 0)
|
94 |
+
# end_x = max(x - pad, 0)
|
95 |
+
# vis_parsing_anno_color_face[start_x:end_x+1, y] = [255, 255, 255]
|
96 |
+
|
97 |
+
vis_im = cv2.GaussianBlur(vis_parsing_anno_color_face, (9, 9), cv2.BORDER_DEFAULT)
|
98 |
+
|
99 |
+
vis_im = cv2.resize(vis_im, img_size,
|
100 |
+
interpolation=cv2.INTER_NEAREST)
|
101 |
+
|
102 |
+
cv2.imwrite(save_path.replace('.png', '_face.png'), vis_im)
|
103 |
+
|
104 |
+
|
105 |
+
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
|
106 |
+
|
107 |
+
Path(respth).mkdir(parents=True, exist_ok=True)
|
108 |
+
|
109 |
+
print(f'[INFO] loading model...')
|
110 |
+
n_classes = 19
|
111 |
+
net = BiSeNet(n_classes=n_classes)
|
112 |
+
net.cuda()
|
113 |
+
net.load_state_dict(torch.load(cp))
|
114 |
+
net.eval()
|
115 |
+
|
116 |
+
to_tensor = transforms.Compose([
|
117 |
+
transforms.ToTensor(),
|
118 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
119 |
+
])
|
120 |
+
|
121 |
+
image_paths = os.listdir(dspth)
|
122 |
+
|
123 |
+
with torch.no_grad():
|
124 |
+
for image_path in tqdm.tqdm(image_paths):
|
125 |
+
if image_path.endswith('.jpg') or image_path.endswith('.png'):
|
126 |
+
img = Image.open(osp.join(dspth, image_path))
|
127 |
+
ori_size = img.size
|
128 |
+
image = img.resize((512, 512), Image.BILINEAR)
|
129 |
+
image = image.convert("RGB")
|
130 |
+
img = to_tensor(image)
|
131 |
+
|
132 |
+
# test-time augmentation.
|
133 |
+
inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512]
|
134 |
+
outputs = net(inputs.cuda())
|
135 |
+
parsing = outputs.mean(0).cpu().numpy().argmax(0)
|
136 |
+
image_path = int(image_path[:-4])
|
137 |
+
image_path = str(image_path) + '.png'
|
138 |
+
|
139 |
+
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size)
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
parser = configargparse.ArgumentParser()
|
144 |
+
parser.add_argument('--respath', type=str, default='./result/', help='result path for label')
|
145 |
+
parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images')
|
146 |
+
parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth')
|
147 |
+
args = parser.parse_args()
|
148 |
+
evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath)
|
data_utils/face_tracking/3DMM/exp_info.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3196029ed038eb9a461df6c782125fc4d1ec1545f2e5f361891471136b6cbb6
|
3 |
+
size 33264853
|
data_utils/face_tracking/3DMM/keys_info.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:028d3c383bae129c4bdcac880e22c551a5d2436ec7db7a26f5c57148d12469e6
|
3 |
+
size 7375
|
data_utils/face_tracking/3DMM/lands_info.txt
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
136
|
2 |
+
19
|
3 |
+
155
|
4 |
+
22
|
5 |
+
177
|
6 |
+
19
|
7 |
+
196
|
8 |
+
15
|
9 |
+
211
|
10 |
+
12
|
11 |
+
223
|
12 |
+
8
|
13 |
+
231
|
14 |
+
6
|
15 |
+
237
|
16 |
+
4
|
17 |
+
241
|
18 |
+
6
|
19 |
+
247
|
20 |
+
4
|
21 |
+
251
|
22 |
+
6
|
23 |
+
257
|
24 |
+
8
|
25 |
+
265
|
26 |
+
12
|
27 |
+
277
|
28 |
+
15
|
29 |
+
292
|
30 |
+
19
|
31 |
+
311
|
32 |
+
22
|
33 |
+
333
|
34 |
+
19
|
35 |
+
352
|
36 |
+
1
|
37 |
+
353
|
38 |
+
1
|
39 |
+
354
|
40 |
+
1
|
41 |
+
355
|
42 |
+
1
|
43 |
+
356
|
44 |
+
1
|
45 |
+
357
|
46 |
+
1
|
47 |
+
358
|
48 |
+
1
|
49 |
+
359
|
50 |
+
1
|
51 |
+
360
|
52 |
+
1
|
53 |
+
361
|
54 |
+
1
|
55 |
+
362
|
56 |
+
1
|
57 |
+
363
|
58 |
+
1
|
59 |
+
364
|
60 |
+
1
|
61 |
+
365
|
62 |
+
1
|
63 |
+
366
|
64 |
+
1
|
65 |
+
367
|
66 |
+
1
|
67 |
+
368
|
68 |
+
1
|
69 |
+
369
|
70 |
+
1
|
71 |
+
370
|
72 |
+
1
|
73 |
+
371
|
74 |
+
1
|
75 |
+
372
|
76 |
+
1
|
77 |
+
373
|
78 |
+
1
|
79 |
+
374
|
80 |
+
1
|
81 |
+
375
|
82 |
+
1
|
83 |
+
376
|
84 |
+
1
|
85 |
+
377
|
86 |
+
1
|
87 |
+
378
|
88 |
+
1
|
89 |
+
379
|
90 |
+
1
|
91 |
+
380
|
92 |
+
1
|
93 |
+
381
|
94 |
+
1
|
95 |
+
382
|
96 |
+
1
|
97 |
+
383
|
98 |
+
1
|
99 |
+
384
|
100 |
+
1
|
101 |
+
385
|
102 |
+
1
|
103 |
+
386
|
104 |
+
1
|
105 |
+
387
|
106 |
+
1
|
107 |
+
388
|
108 |
+
1
|
109 |
+
389
|
110 |
+
1
|
111 |
+
390
|
112 |
+
1
|
113 |
+
391
|
114 |
+
1
|
115 |
+
392
|
116 |
+
1
|
117 |
+
393
|
118 |
+
1
|
119 |
+
394
|
120 |
+
1
|
121 |
+
395
|
122 |
+
1
|
123 |
+
396
|
124 |
+
1
|
125 |
+
397
|
126 |
+
1
|
127 |
+
398
|
128 |
+
1
|
129 |
+
399
|
130 |
+
1
|
131 |
+
400
|
132 |
+
1
|
133 |
+
401
|
134 |
+
1
|
135 |
+
402
|
136 |
+
1
|
137 |
+
16655
|
138 |
+
16901
|
139 |
+
17155
|
140 |
+
17412
|
141 |
+
17669
|
142 |
+
17926
|
143 |
+
18183
|
144 |
+
18440
|
145 |
+
18826
|
146 |
+
19083
|
147 |
+
19340
|
148 |
+
19726
|
149 |
+
19983
|
150 |
+
20240
|
151 |
+
20625
|
152 |
+
21010
|
153 |
+
21396
|
154 |
+
157
|
155 |
+
671
|
156 |
+
16922
|
157 |
+
17177
|
158 |
+
17435
|
159 |
+
17821
|
160 |
+
18208
|
161 |
+
18594
|
162 |
+
18980
|
163 |
+
19366
|
164 |
+
19752
|
165 |
+
20139
|
166 |
+
20525
|
167 |
+
20911
|
168 |
+
21168
|
169 |
+
21555
|
170 |
+
188
|
171 |
+
575
|
172 |
+
961
|
173 |
+
1477
|
174 |
+
1863
|
175 |
+
2249
|
176 |
+
2636
|
177 |
+
3280
|
178 |
+
16411
|
179 |
+
16948
|
180 |
+
17589
|
181 |
+
18232
|
182 |
+
18876
|
183 |
+
19262
|
184 |
+
19648
|
185 |
+
20163
|
186 |
+
20678
|
187 |
+
21192
|
188 |
+
21707
|
189 |
+
340
|
190 |
+
855
|
191 |
+
1370
|
192 |
+
1756
|
193 |
+
2142
|
194 |
+
2657
|
195 |
+
3043
|
196 |
+
3429
|
197 |
+
16363
|
198 |
+
16973
|
199 |
+
17871
|
200 |
+
18644
|
201 |
+
19416
|
202 |
+
20189
|
203 |
+
20833
|
204 |
+
21733
|
205 |
+
752
|
206 |
+
1523
|
207 |
+
2037
|
208 |
+
2681
|
209 |
+
3323
|
210 |
+
3708
|
211 |
+
4222
|
212 |
+
31497
|
213 |
+
31491
|
214 |
+
31484
|
215 |
+
31555
|
216 |
+
31626
|
217 |
+
31730
|
218 |
+
31865
|
219 |
+
3224
|
220 |
+
3737
|
221 |
+
4250
|
222 |
+
4764
|
223 |
+
5150
|
224 |
+
32139
|
225 |
+
32192
|
226 |
+
32271
|
227 |
+
32368
|
228 |
+
32436
|
229 |
+
32521
|
230 |
+
32600
|
231 |
+
32655
|
232 |
+
32445
|
233 |
+
32465
|
234 |
+
32506
|
235 |
+
32546
|
236 |
+
32585
|
237 |
+
32640
|
238 |
+
32716
|
239 |
+
32733
|
240 |
+
32750
|
241 |
+
32785
|
242 |
+
32914
|
243 |
+
32913
|
244 |
+
32912
|
245 |
+
32911
|
246 |
+
32910
|
247 |
+
32909
|
248 |
+
33076
|
249 |
+
33057
|
250 |
+
33038
|
251 |
+
33001
|
252 |
+
33357
|
253 |
+
33333
|
254 |
+
33287
|
255 |
+
33243
|
256 |
+
33202
|
257 |
+
33144
|
258 |
+
33675
|
259 |
+
33612
|
260 |
+
33524
|
261 |
+
33420
|
262 |
+
33348
|
263 |
+
33260
|
264 |
+
33179
|
265 |
+
33123
|
266 |
+
34322
|
267 |
+
34316
|
268 |
+
34309
|
269 |
+
34227
|
270 |
+
34147
|
271 |
+
34034
|
272 |
+
33897
|
273 |
+
13269
|
274 |
+
12750
|
275 |
+
12231
|
276 |
+
11713
|
277 |
+
11325
|
278 |
+
27304
|
279 |
+
26767
|
280 |
+
25869
|
281 |
+
25094
|
282 |
+
24318
|
283 |
+
23543
|
284 |
+
22897
|
285 |
+
21991
|
286 |
+
15699
|
287 |
+
14922
|
288 |
+
14404
|
289 |
+
13758
|
290 |
+
13110
|
291 |
+
12721
|
292 |
+
12203
|
293 |
+
27231
|
294 |
+
26742
|
295 |
+
26103
|
296 |
+
25456
|
297 |
+
24810
|
298 |
+
24422
|
299 |
+
24034
|
300 |
+
23517
|
301 |
+
23000
|
302 |
+
22482
|
303 |
+
21965
|
304 |
+
16061
|
305 |
+
15544
|
306 |
+
15027
|
307 |
+
14639
|
308 |
+
14251
|
309 |
+
13734
|
310 |
+
13346
|
311 |
+
12958
|
312 |
+
26716
|
313 |
+
26465
|
314 |
+
26207
|
315 |
+
25819
|
316 |
+
25432
|
317 |
+
25044
|
318 |
+
24656
|
319 |
+
24268
|
320 |
+
23880
|
321 |
+
23493
|
322 |
+
23105
|
323 |
+
22717
|
324 |
+
22458
|
325 |
+
22071
|
326 |
+
16167
|
327 |
+
15780
|
328 |
+
15392
|
329 |
+
14876
|
330 |
+
14488
|
331 |
+
14100
|
332 |
+
13713
|
333 |
+
13067
|
334 |
+
26939
|
335 |
+
26695
|
336 |
+
26443
|
337 |
+
26184
|
338 |
+
25925
|
339 |
+
25666
|
340 |
+
25407
|
341 |
+
25148
|
342 |
+
24760
|
343 |
+
24501
|
344 |
+
24242
|
345 |
+
23854
|
346 |
+
23595
|
347 |
+
23336
|
348 |
+
22947
|
349 |
+
22558
|
350 |
+
22170
|
351 |
+
16136
|
352 |
+
15618
|
353 |
+
27932
|
354 |
+
28270
|
355 |
+
28552
|
356 |
+
28771
|
357 |
+
28990
|
358 |
+
29567
|
359 |
+
29780
|
360 |
+
30000
|
361 |
+
30316
|
362 |
+
30627
|
363 |
+
8155
|
364 |
+
8173
|
365 |
+
8184
|
366 |
+
8190
|
367 |
+
6516
|
368 |
+
7363
|
369 |
+
8203
|
370 |
+
9043
|
371 |
+
9884
|
372 |
+
1828
|
373 |
+
4016
|
374 |
+
5177
|
375 |
+
6341
|
376 |
+
4804
|
377 |
+
3771
|
378 |
+
9955
|
379 |
+
11094
|
380 |
+
12255
|
381 |
+
14323
|
382 |
+
12526
|
383 |
+
11495
|
384 |
+
5262
|
385 |
+
6024
|
386 |
+
7375
|
387 |
+
8215
|
388 |
+
9055
|
389 |
+
10394
|
390 |
+
11179
|
391 |
+
9674
|
392 |
+
8835
|
393 |
+
8235
|
394 |
+
7635
|
395 |
+
6793
|
396 |
+
5779
|
397 |
+
7384
|
398 |
+
8225
|
399 |
+
9064
|
400 |
+
10536
|
401 |
+
8828
|
402 |
+
8228
|
403 |
+
7628
|
data_utils/face_tracking/3DMM/sub_mesh.obj
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data_utils/face_tracking/3DMM/topology_info.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2edff6a6ad574d2dddf0d0815e0beabfea9369d7c2d6e53e0ba81f809b81e963
|
3 |
+
size 4145201
|
data_utils/face_tracking/3DMM/tris.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data_utils/face_tracking/3DMM/vert_tris.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data_utils/face_tracking/__init__.py
ADDED
File without changes
|
data_utils/face_tracking/bundle_adjustment.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
from util import *
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
|
7 |
+
def set_requires_grad(tensor_list):
|
8 |
+
for tensor in tensor_list:
|
9 |
+
tensor.requires_grad = True
|
10 |
+
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
|
14 |
+
parser.add_argument(
|
15 |
+
"--path", type=str, default="", help="idname of target person")
|
16 |
+
parser.add_argument('--img_h', type=int, default=512, help='height if image')
|
17 |
+
parser.add_argument('--img_w', type=int, default=512, help='width of image')
|
18 |
+
args = parser.parse_args()
|
19 |
+
id_dir = args.path
|
20 |
+
|
21 |
+
params_dict = torch.load(os.path.join(id_dir, 'track_params.pt'))
|
22 |
+
euler_angle = params_dict['euler'].cuda()
|
23 |
+
trans = params_dict['trans'].cuda() / 1000.0
|
24 |
+
focal_len = params_dict['focal'].cuda()
|
25 |
+
|
26 |
+
track_xys = torch.as_tensor(
|
27 |
+
np.load(os.path.join(id_dir, 'track_xys.npy'))).float().cuda()
|
28 |
+
num_frames = track_xys.shape[0]
|
29 |
+
point_num = track_xys.shape[1]
|
30 |
+
|
31 |
+
pts = torch.zeros((point_num, 3), dtype=torch.float32).cuda()
|
32 |
+
set_requires_grad([euler_angle, trans, pts])
|
33 |
+
|
34 |
+
cxy = torch.Tensor((args.img_w/2.0, args.img_h/2.0)).float().cuda()
|
35 |
+
|
36 |
+
optimizer_pts = torch.optim.Adam([pts], lr=1e-2)
|
37 |
+
iter_num = 500
|
38 |
+
for iter in range(iter_num):
|
39 |
+
proj_pts = forward_transform(pts.unsqueeze(0).expand(
|
40 |
+
num_frames, -1, -1), euler_angle, trans, focal_len, cxy)
|
41 |
+
loss = cal_lan_loss(proj_pts[..., :2], track_xys)
|
42 |
+
optimizer_pts.zero_grad()
|
43 |
+
loss.backward()
|
44 |
+
optimizer_pts.step()
|
45 |
+
|
46 |
+
|
47 |
+
optimizer_ba = torch.optim.Adam([pts, euler_angle, trans], lr=1e-4)
|
48 |
+
|
49 |
+
|
50 |
+
iter_num = 8000
|
51 |
+
for iter in range(iter_num):
|
52 |
+
proj_pts = forward_transform(pts.unsqueeze(0).expand(
|
53 |
+
num_frames, -1, -1), euler_angle, trans, focal_len, cxy)
|
54 |
+
loss_lan = cal_lan_loss(proj_pts[..., :2], track_xys)
|
55 |
+
loss = loss_lan
|
56 |
+
optimizer_ba.zero_grad()
|
57 |
+
loss.backward()
|
58 |
+
optimizer_ba.step()
|
59 |
+
|
60 |
+
torch.save({'euler': euler_angle.detach().cpu(),
|
61 |
+
'trans': trans.detach().cpu(),
|
62 |
+
'focal': focal_len.detach().cpu()}, os.path.join(id_dir, 'bundle_adjustment.pt'))
|
63 |
+
print('bundle adjustment params saved')
|