ameerazam08 commited on
Commit
1f99b24
·
verified ·
1 Parent(s): b8bf677

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. LICENSE +13 -0
  3. README.md +225 -0
  4. assets/image/synctalk.png +3 -0
  5. data/Please place the data file. +0 -0
  6. data_utils/UNFaceFlow/core/__init__.py +0 -0
  7. data_utils/UNFaceFlow/core/corr.py +91 -0
  8. data_utils/UNFaceFlow/core/datasets.py +235 -0
  9. data_utils/UNFaceFlow/core/extractor.py +266 -0
  10. data_utils/UNFaceFlow/core/nnutils.py +233 -0
  11. data_utils/UNFaceFlow/core/raft.py +259 -0
  12. data_utils/UNFaceFlow/core/update.py +169 -0
  13. data_utils/UNFaceFlow/core/utils_core/__init__.py +0 -0
  14. data_utils/UNFaceFlow/core/utils_core/augmentor.py +246 -0
  15. data_utils/UNFaceFlow/core/utils_core/flow_viz.py +132 -0
  16. data_utils/UNFaceFlow/core/utils_core/frame_utils.py +137 -0
  17. data_utils/UNFaceFlow/core/utils_core/utils.py +86 -0
  18. data_utils/UNFaceFlow/core/warp_utils.py +118 -0
  19. data_utils/UNFaceFlow/data_test_flow/__init__.py +94 -0
  20. data_utils/UNFaceFlow/data_test_flow/base_dataset.py +98 -0
  21. data_utils/UNFaceFlow/data_test_flow/dd_dataset.py +108 -0
  22. data_utils/UNFaceFlow/data_test_flow/dd_dataset_bak.py +107 -0
  23. data_utils/UNFaceFlow/models/network_test_flow.py +88 -0
  24. data_utils/UNFaceFlow/options_test_flow.py +123 -0
  25. data_utils/UNFaceFlow/pretrain_model/raft-small.pth +3 -0
  26. data_utils/UNFaceFlow/sgd_NNRT_model_epoch19008_50000.pth +3 -0
  27. data_utils/UNFaceFlow/test_flow.py +62 -0
  28. data_utils/UNFaceFlow/utils.py +84 -0
  29. data_utils/blendshape_capture/face_landmarker.task +3 -0
  30. data_utils/blendshape_capture/main.py +86 -0
  31. data_utils/deepspeech_features/README.md +20 -0
  32. data_utils/deepspeech_features/deepspeech_features.py +275 -0
  33. data_utils/deepspeech_features/deepspeech_store.py +172 -0
  34. data_utils/deepspeech_features/extract_ds_features.py +132 -0
  35. data_utils/deepspeech_features/extract_wav.py +87 -0
  36. data_utils/deepspeech_features/fea_win.py +11 -0
  37. data_utils/face_parsing/79999_iter.pth +3 -0
  38. data_utils/face_parsing/logger.py +23 -0
  39. data_utils/face_parsing/model.py +285 -0
  40. data_utils/face_parsing/resnet.py +109 -0
  41. data_utils/face_parsing/test.py +148 -0
  42. data_utils/face_tracking/3DMM/exp_info.npy +3 -0
  43. data_utils/face_tracking/3DMM/keys_info.npy +3 -0
  44. data_utils/face_tracking/3DMM/lands_info.txt +403 -0
  45. data_utils/face_tracking/3DMM/sub_mesh.obj +0 -0
  46. data_utils/face_tracking/3DMM/topology_info.npy +3 -0
  47. data_utils/face_tracking/3DMM/tris.txt +0 -0
  48. data_utils/face_tracking/3DMM/vert_tris.txt +0 -0
  49. data_utils/face_tracking/__init__.py +0 -0
  50. data_utils/face_tracking/bundle_adjustment.py +63 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/image/synctalk.png filter=lfs diff=lfs merge=lfs -text
37
+ data_utils/blendshape_capture/face_landmarker.task filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2024 Peng Ziqiao
2
+
3
+ This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0). To view a copy of this license, visit http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, and distribute the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ 1. Attribution — You must give appropriate credit, provide a link to the license, and indicate if changes were made. You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
8
+
9
+ 2. NonCommercial — You may not use the material for commercial purposes.
10
+
11
+ 3. No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
README.md ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SyncTalk: The Devil😈 is in the Synchronization for Talking Head Synthesis [CVPR 2024]
2
+ The official repository of the paper [SyncTalk: The Devil is in the Synchronization for Talking Head Synthesis](https://arxiv.org/abs/2311.17590)
3
+
4
+ <p align='center'>
5
+ <b>
6
+ <a href="https://arxiv.org/abs/2311.17590">Paper</a>
7
+ |
8
+ <a href="https://ziqiaopeng.github.io/synctalk/">Project Page</a>
9
+ |
10
+ <a href="https://github.com/ZiqiaoPeng/SyncTalk">Code</a>
11
+ </b>
12
+ </p>
13
+
14
+ Colab notebook demonstration: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Egq0_ZK5sJAAawShxC0y4JRZQuVS2X-Z?usp=sharing)
15
+
16
+ A short demo video can be found [here](./demo/short_demo.mp4).
17
+
18
+ <p align='center'>
19
+ <img src='assets/image/synctalk.png' width='1000'/>
20
+ </p>
21
+
22
+ The proposed **SyncTalk** synthesizes synchronized talking head videos, employing tri-plane hash representations to maintain subject identity. It can generate synchronized lip movements, facial expressions, and stable head poses, and restores hair details to create high-resolution videos.
23
+
24
+ ## 🔥🔥🔥 News
25
+ - [2023-11-30] Update arXiv paper.
26
+ - [2024-03-04] The code and pre-trained model are released.
27
+ - [2024-03-22] The Google Colab notebook is released.
28
+ - [2024-04-14] Add Windows support.
29
+ - [2024-04-28] The preprocessing code is released.
30
+ - [2024-04-29] Fix bugs: audio encoder, blendshape capture, and face tracker.
31
+ - [2024-05-03] Try replacing NeRF with Gaussian Splatting. Code: [GS-SyncTalk](https://github.com/ZiqiaoPeng/GS-SyncTalk)
32
+ - **[2024-05-24] Introduce torso training to repair double chin.**
33
+
34
+
35
+
36
+ ## For Windows
37
+ Thanks to [okgpt](https://github.com/okgptai), we have launched a Windows integration package, you can download `SyncTalk-Windows.zip` and unzip it, double-click `inference.bat` to run the demo.
38
+
39
+ Download link: [Hugging Face](https://huggingface.co/ZiqiaoPeng/SyncTalk/blob/main/SyncTalk-Windows.zip) || [Baidu Netdisk](https://pan.baidu.com/s/1g3312mZxx__T6rAFPHjrRg?pwd=6666)
40
+
41
+ ## For Linux
42
+
43
+ ### Installation
44
+
45
+ Tested on Ubuntu 18.04, Pytorch 1.12.1 and CUDA 11.3.
46
+ ```bash
47
+ git clone https://github.com/ZiqiaoPeng/SyncTalk.git
48
+ cd SyncTalk
49
+ ```
50
+ #### Install dependency
51
+
52
+ ```bash
53
+ conda create -n synctalk python==3.8.8
54
+ conda activate synctalk
55
+ pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
56
+ pip install -r requirements.txt
57
+ pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1121/download.html
58
+ pip install tensorflow-gpu==2.8.1
59
+ pip install ./freqencoder
60
+ pip install ./shencoder
61
+ pip install ./gridencoder
62
+ pip install ./raymarching
63
+ ```
64
+ If you encounter problems installing PyTorch3D, you can use the following command to install it:
65
+ ```bash
66
+ python ./scripts/install_pytorch3d.py
67
+ ```
68
+
69
+ ### Data Preparation
70
+ #### Pre-trained model
71
+ Please place the [May.zip](https://drive.google.com/file/d/18Q2H612CAReFxBd9kxr-i1dD8U1AUfsV/view?usp=sharing) in the **data** folder, the [trial_may.zip](https://drive.google.com/file/d/1C2639qi9jvhRygYHwPZDGs8pun3po3W7/view?usp=sharing) in the **model** folder, and then unzip them.
72
+ #### [New] Process your video
73
+ - Prepare face-parsing model.
74
+
75
+ ```bash
76
+ wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_parsing/79999_iter.pth?raw=true -O data_utils/face_parsing/79999_iter.pth
77
+ ```
78
+
79
+ - Prepare the 3DMM model for head pose estimation.
80
+
81
+ ```bash
82
+ wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/exp_info.npy?raw=true -O data_utils/face_tracking/3DMM/exp_info.npy
83
+ wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/keys_info.npy?raw=true -O data_utils/face_tracking/3DMM/keys_info.npy
84
+ wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/sub_mesh.obj?raw=true -O data_utils/face_tracking/3DMM/sub_mesh.obj
85
+ wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/topology_info.npy?raw=true -O data_utils/face_tracking/3DMM/topology_info.npy
86
+ ```
87
+
88
+ - Download 3DMM model from [Basel Face Model 2009](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-1-0&id=details):
89
+
90
+ ```
91
+ # 1. copy 01_MorphableModel.mat to data_util/face_tracking/3DMM/
92
+ # 2.
93
+ cd data_utils/face_tracking
94
+ python convert_BFM.py
95
+ ```
96
+ - Put your video under `data/<ID>/<ID>.mp4`, and then run the following command to process the video.
97
+
98
+ **[Note]** The video must be 25FPS, with all frames containing the talking person. The resolution should be about 512x512, and duration about 4-5 min.
99
+ ```bash
100
+ python data_utils/process.py data/<ID>/<ID>.mp4 --asr ave
101
+ ```
102
+ You can choose to use AVE, DeepSpeech or Hubert. The processed video will be saved in the **data** folder.
103
+
104
+
105
+ - [Optional] Obtain AU45 for eyes blinking
106
+
107
+ Run `FeatureExtraction` in [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace), rename and move the output CSV file to `data/<ID>/au.csv`.
108
+
109
+
110
+ **[Note]** Since EmoTalk's blendshape capture is not open source, the preprocessing code here is replaced with mediapipe's blendshape capture. But according to some feedback, it doesn't work well, you can choose to replace it with AU45. If you want to compare with SyncTalk, some results from using EmoTalk capture can be obtained [here](https://drive.google.com/drive/folders/1LLFtQa2Yy2G0FaNOxwtZr0L974TXCYKh?usp=sharing) and videos from [GeneFace](https://drive.google.com/drive/folders/1vimGVNvP6d6nmmc8yAxtWuooxhJbkl68).
111
+
112
+
113
+ ### Quick Start
114
+ #### Run the evaluation code
115
+ ```bash
116
+ python main.py data/May --workspace model/trial_may -O --test --asr_model ave
117
+
118
+ python main.py data/May --workspace model/trial_may -O --test --asr_model ave --portrait
119
+ ```
120
+ “ave” refers to our Audio Visual Encoder, “portrait” signifies pasting the generated face back onto the original image, representing higher quality.
121
+
122
+ If it runs correctly, you will get the following results.
123
+
124
+ | Setting | PSNR | LPIPS | LMD |
125
+ |--------------------------|--------|--------|-------|
126
+ | SyncTalk (w/o Portrait) | 32.201 | 0.0394 | 2.822 |
127
+ | SyncTalk (Portrait) | 37.644 | 0.0117 | 2.825 |
128
+
129
+ This is for a single subject; the paper reports the average results for multiple subjects.
130
+
131
+ #### Inference with target audio
132
+ ```bash
133
+ python main.py data/May --workspace model/trial_may -O --test --test_train --asr_model ave --portrait --aud ./demo/test.wav
134
+ ```
135
+ Please use files with the “.wav” extension for inference, and the inference results will be saved in “model/trial_may/results/”. If do not use Audio Visual Encoder, replace wav with the npy file path.
136
+ * DeepSpeech
137
+
138
+ ```bash
139
+ python data_utils/deepspeech_features/extract_ds_features.py --input data/<name>.wav # save to data/<name>.npy
140
+ ```
141
+ * HuBERT
142
+
143
+ ```bash
144
+ # Borrowed from GeneFace. English pre-trained.
145
+ python data_utils/hubert.py --wav data/<name>.wav # save to data/<name>_hu.npy
146
+ ```
147
+ ### Train
148
+ ```bash
149
+ # by default, we load data from disk on the fly.
150
+ # we can also preload all data to CPU/GPU for faster training, but this is very memory-hungry for large datasets.
151
+ # `--preload 0`: load from disk (default, slower).
152
+ # `--preload 1`: load to CPU (slightly slower)
153
+ # `--preload 2`: load to GPU (fast)
154
+ python main.py data/May --workspace model/trial_may -O --iters 60000 --asr_model ave
155
+ python main.py data/May --workspace model/trial_may -O --iters 100000 --finetune_lips --patch_size 64 --asr_model ave
156
+
157
+ # or you can use the script to train
158
+ sh ./scripts/train_may.sh
159
+ ```
160
+ **[Tips]** Audio visual encoder (AVE) is suitable for characters with accurate lip sync and large lip movements such as May and Shaheen. Using AVE in the inference stage can achieve more accurate lip sync. If your training results show lip jitter, please try using deepspeech or hubert model as audio feature encoder.
161
+
162
+ ```bash
163
+ # Use deepspeech model
164
+ python main.py data/May --workspace model/trial_may -O --iters 60000 --asr_model deepspeech
165
+ python main.py data/May --workspace model/trial_may -O --iters 100000 --finetune_lips --patch_size 64 --asr_model deepspeech
166
+
167
+ # Use hubert model
168
+ python main.py data/May --workspace model/trial_may -O --iters 60000 --asr_model hubert
169
+ python main.py data/May --workspace model/trial_may -O --iters 100000 --finetune_lips --patch_size 64 --asr_model hubert
170
+ ```
171
+
172
+ If you want to use the OpenFace au45 as the eye parameter, please add "--au45" to the command line.
173
+
174
+ ```bash
175
+ # Use OpenFace AU45
176
+ python main.py data/May --workspace model/trial_may -O --iters 60000 --asr_model ave --au45
177
+ python main.py data/May --workspace model/trial_may -O --iters 100000 --finetune_lips --patch_size 64 --asr_model ave --au45
178
+ ```
179
+
180
+ ### Test
181
+ ```bash
182
+ python main.py data/May --workspace model/trial_may -O --test --asr_model ave --portrait
183
+
184
+ ```
185
+
186
+ ### Train & Test Torso [Repair Double Chin]
187
+ If your character trained only the head appeared double chin problem, you can introduce torso training. By training the torso, this problem can be solved, but **you will not be able to use the "--portrait" mode.** If you add "--portrait", the torso model will fail!
188
+
189
+ ```bash
190
+ # Train
191
+ # <head>.pth should be the latest checkpoint in trial_may
192
+ python main.py data/May/ --workspace model/trial_may_torso/ -O --torso --head_ckpt <head>.pth --iters 150000 --asr_model ave
193
+
194
+ # For example
195
+ python main.py data/May/ --workspace model/trial_may_torso/ -O --torso --head_ckpt model/trial_may/ngp_ep0019.pth --iters 150000 --asr_model ave
196
+
197
+ # Test
198
+ python main.py data/May --workspace model/trial_may_torso -O --torso --test --asr_model ave # not support --portrait
199
+
200
+ # Inference with target audio
201
+ python main.py data/May --workspace model/trial_may_torso -O --torso --test --test_train --asr_model ave --aud ./demo/test.wav # not support --portrait
202
+
203
+ ```
204
+
205
+
206
+
207
+ ## Citation
208
+
209
+ ```
210
+ @InProceedings{peng2023synctalk,
211
+ title = {SyncTalk: The Devil is in the Synchronization for Talking Head Synthesis},
212
+ author = {Ziqiao Peng and Wentao Hu and Yue Shi and Xiangyu Zhu and Xiaomei Zhang and Jun He and Hongyan Liu and Zhaoxin Fan},
213
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
214
+ month = {June},
215
+ year = {2024},
216
+ }
217
+ ```
218
+
219
+ ## Acknowledgement
220
+ This code is developed heavily relying on [ER-NeRF](https://github.com/Fictionarry/ER-NeRF), and also [RAD-NeRF](https://github.com/ashawkey/RAD-NeRF), [GeneFace](https://github.com/yerfor/GeneFace), [DFRF](https://github.com/sstzal/DFRF), [DFA-NeRF](https://github.com/ShunyuYao/DFA-NeRF/), [AD-NeRF](https://github.com/YudongGuo/AD-NeRF), and [Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch).
221
+
222
+ Thanks for these great projects. Thanks to [Tiandishihua](https://github.com/Tiandishihua) for helping us fix the bug that loss equals NaN.
223
+
224
+ ## Disclaimer
225
+ By using the "SyncTalk", users agree to comply with all applicable laws and regulations, and acknowledge that misuse of the software, including the creation or distribution of harmful content, is strictly prohibited. The developers of the software disclaim all liability for any direct, indirect, or consequential damages arising from the use or misuse of the software.
assets/image/synctalk.png ADDED

Git LFS Details

  • SHA256: 1c6f87ed6137d6aeb639aa3d13ec28ce8df15b3f8d926d3f5662d8de3bab7300
  • Pointer size: 132 Bytes
  • Size of remote file: 3.49 MB
data/Please place the data file. ADDED
File without changes
data_utils/UNFaceFlow/core/__init__.py ADDED
File without changes
data_utils/UNFaceFlow/core/corr.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from utils_core.utils import bilinear_sampler, coords_grid
4
+
5
+ try:
6
+ import alt_cuda_corr
7
+ except:
8
+ # alt_cuda_corr is not compiled
9
+ pass
10
+
11
+
12
+ class CorrBlock:
13
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14
+ self.num_levels = num_levels
15
+ self.radius = radius
16
+ self.corr_pyramid = []
17
+
18
+ # all pairs correlation
19
+ corr = CorrBlock.corr(fmap1, fmap2)
20
+
21
+ batch, h1, w1, dim, h2, w2 = corr.shape
22
+ corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23
+
24
+ self.corr_pyramid.append(corr)
25
+ for i in range(self.num_levels-1):
26
+ corr = F.avg_pool2d(corr, 2, stride=2)
27
+ self.corr_pyramid.append(corr)
28
+
29
+ def __call__(self, coords):
30
+ r = self.radius
31
+ coords = coords.permute(0, 2, 3, 1)
32
+ batch, h1, w1, _ = coords.shape
33
+
34
+ out_pyramid = []
35
+ for i in range(self.num_levels):
36
+ corr = self.corr_pyramid[i]
37
+ dx = torch.linspace(-r, r, 2*r+1)
38
+ dy = torch.linspace(-r, r, 2*r+1)
39
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
40
+
41
+ centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43
+ coords_lvl = centroid_lvl + delta_lvl
44
+
45
+ corr = bilinear_sampler(corr, coords_lvl)
46
+ corr = corr.view(batch, h1, w1, -1)
47
+ out_pyramid.append(corr)
48
+
49
+ out = torch.cat(out_pyramid, dim=-1)
50
+ return out.permute(0, 3, 1, 2).contiguous().float()
51
+
52
+ @staticmethod
53
+ def corr(fmap1, fmap2):
54
+ batch, dim, ht, wd = fmap1.shape
55
+ fmap1 = fmap1.view(batch, dim, ht*wd)
56
+ fmap2 = fmap2.view(batch, dim, ht*wd)
57
+
58
+ corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
60
+ return corr / torch.sqrt(torch.tensor(dim).float())
61
+
62
+
63
+ class AlternateCorrBlock:
64
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65
+ self.num_levels = num_levels
66
+ self.radius = radius
67
+
68
+ self.pyramid = [(fmap1, fmap2)]
69
+ for i in range(self.num_levels):
70
+ fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71
+ fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72
+ self.pyramid.append((fmap1, fmap2))
73
+
74
+ def __call__(self, coords):
75
+ coords = coords.permute(0, 2, 3, 1)
76
+ B, H, W, _ = coords.shape
77
+ dim = self.pyramid[0][0].shape[1]
78
+
79
+ corr_list = []
80
+ for i in range(self.num_levels):
81
+ r = self.radius
82
+ fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83
+ fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84
+
85
+ coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86
+ corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87
+ corr_list.append(corr.squeeze(1))
88
+
89
+ corr = torch.stack(corr_list, dim=1)
90
+ corr = corr.reshape(B, -1, H, W)
91
+ return corr / torch.sqrt(torch.tensor(dim).float())
data_utils/UNFaceFlow/core/datasets.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.utils.data as data
6
+ import torch.nn.functional as F
7
+
8
+ import os
9
+ import math
10
+ import random
11
+ from glob import glob
12
+ import os.path as osp
13
+
14
+ from utils import frame_utils
15
+ from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16
+
17
+
18
+ class FlowDataset(data.Dataset):
19
+ def __init__(self, aug_params=None, sparse=False):
20
+ self.augmentor = None
21
+ self.sparse = sparse
22
+ if aug_params is not None:
23
+ if sparse:
24
+ self.augmentor = SparseFlowAugmentor(**aug_params)
25
+ else:
26
+ self.augmentor = FlowAugmentor(**aug_params)
27
+
28
+ self.is_test = False
29
+ self.init_seed = False
30
+ self.flow_list = []
31
+ self.image_list = []
32
+ self.extra_info = []
33
+
34
+ def __getitem__(self, index):
35
+
36
+ if self.is_test:
37
+ img1 = frame_utils.read_gen(self.image_list[index][0])
38
+ img2 = frame_utils.read_gen(self.image_list[index][1])
39
+ img1 = np.array(img1).astype(np.uint8)[..., :3]
40
+ img2 = np.array(img2).astype(np.uint8)[..., :3]
41
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43
+ return img1, img2, self.extra_info[index]
44
+
45
+ if not self.init_seed:
46
+ worker_info = torch.utils.data.get_worker_info()
47
+ if worker_info is not None:
48
+ torch.manual_seed(worker_info.id)
49
+ np.random.seed(worker_info.id)
50
+ random.seed(worker_info.id)
51
+ self.init_seed = True
52
+
53
+ index = index % len(self.image_list)
54
+ valid = None
55
+ if self.sparse:
56
+ flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57
+ else:
58
+ flow = frame_utils.read_gen(self.flow_list[index])
59
+
60
+ img1 = frame_utils.read_gen(self.image_list[index][0])
61
+ img2 = frame_utils.read_gen(self.image_list[index][1])
62
+
63
+ flow = np.array(flow).astype(np.float32)
64
+ img1 = np.array(img1).astype(np.uint8)
65
+ img2 = np.array(img2).astype(np.uint8)
66
+
67
+ # grayscale images
68
+ if len(img1.shape) == 2:
69
+ img1 = np.tile(img1[...,None], (1, 1, 3))
70
+ img2 = np.tile(img2[...,None], (1, 1, 3))
71
+ else:
72
+ img1 = img1[..., :3]
73
+ img2 = img2[..., :3]
74
+
75
+ if self.augmentor is not None:
76
+ if self.sparse:
77
+ img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78
+ else:
79
+ img1, img2, flow = self.augmentor(img1, img2, flow)
80
+
81
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83
+ flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84
+
85
+ if valid is not None:
86
+ valid = torch.from_numpy(valid)
87
+ else:
88
+ valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89
+
90
+ return img1, img2, flow, valid.float()
91
+
92
+
93
+ def __rmul__(self, v):
94
+ self.flow_list = v * self.flow_list
95
+ self.image_list = v * self.image_list
96
+ return self
97
+
98
+ def __len__(self):
99
+ return len(self.image_list)
100
+
101
+
102
+ class MpiSintel(FlowDataset):
103
+ def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104
+ super(MpiSintel, self).__init__(aug_params)
105
+ flow_root = osp.join(root, split, 'flow')
106
+ image_root = osp.join(root, split, dstype)
107
+
108
+ if split == 'test':
109
+ self.is_test = True
110
+
111
+ for scene in os.listdir(image_root):
112
+ image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113
+ for i in range(len(image_list)-1):
114
+ self.image_list += [ [image_list[i], image_list[i+1]] ]
115
+ self.extra_info += [ (scene, i) ] # scene and frame_id
116
+
117
+ if split != 'test':
118
+ self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119
+
120
+
121
+ class FlyingChairs(FlowDataset):
122
+ def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123
+ super(FlyingChairs, self).__init__(aug_params)
124
+
125
+ images = sorted(glob(osp.join(root, '*.ppm')))
126
+ flows = sorted(glob(osp.join(root, '*.flo')))
127
+ assert (len(images)//2 == len(flows))
128
+
129
+ split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130
+ for i in range(len(flows)):
131
+ xid = split_list[i]
132
+ if (split=='training' and xid==1) or (split=='validation' and xid==2):
133
+ self.flow_list += [ flows[i] ]
134
+ self.image_list += [ [images[2*i], images[2*i+1]] ]
135
+
136
+
137
+ class FlyingThings3D(FlowDataset):
138
+ def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139
+ super(FlyingThings3D, self).__init__(aug_params)
140
+
141
+ for cam in ['left']:
142
+ for direction in ['into_future', 'into_past']:
143
+ image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144
+ image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145
+
146
+ flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147
+ flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148
+
149
+ for idir, fdir in zip(image_dirs, flow_dirs):
150
+ images = sorted(glob(osp.join(idir, '*.png')) )
151
+ flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152
+ for i in range(len(flows)-1):
153
+ if direction == 'into_future':
154
+ self.image_list += [ [images[i], images[i+1]] ]
155
+ self.flow_list += [ flows[i] ]
156
+ elif direction == 'into_past':
157
+ self.image_list += [ [images[i+1], images[i]] ]
158
+ self.flow_list += [ flows[i+1] ]
159
+
160
+
161
+ class KITTI(FlowDataset):
162
+ def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163
+ super(KITTI, self).__init__(aug_params, sparse=True)
164
+ if split == 'testing':
165
+ self.is_test = True
166
+
167
+ root = osp.join(root, split)
168
+ images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169
+ images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170
+
171
+ for img1, img2 in zip(images1, images2):
172
+ frame_id = img1.split('/')[-1]
173
+ self.extra_info += [ [frame_id] ]
174
+ self.image_list += [ [img1, img2] ]
175
+
176
+ if split == 'training':
177
+ self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178
+
179
+
180
+ class HD1K(FlowDataset):
181
+ def __init__(self, aug_params=None, root='datasets/HD1k'):
182
+ super(HD1K, self).__init__(aug_params, sparse=True)
183
+
184
+ seq_ix = 0
185
+ while 1:
186
+ flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187
+ images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188
+
189
+ if len(flows) == 0:
190
+ break
191
+
192
+ for i in range(len(flows)-1):
193
+ self.flow_list += [flows[i]]
194
+ self.image_list += [ [images[i], images[i+1]] ]
195
+
196
+ seq_ix += 1
197
+
198
+
199
+ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200
+ """ Create the data loader for the corresponding trainign set """
201
+
202
+ if args.stage == 'chairs':
203
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204
+ train_dataset = FlyingChairs(aug_params, split='training')
205
+
206
+ elif args.stage == 'things':
207
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208
+ clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209
+ final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210
+ train_dataset = clean_dataset + final_dataset
211
+
212
+ elif args.stage == 'sintel':
213
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214
+ things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215
+ sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216
+ sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217
+
218
+ if TRAIN_DS == 'C+T+K+S+H':
219
+ kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220
+ hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221
+ train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222
+
223
+ elif TRAIN_DS == 'C+T+K/S':
224
+ train_dataset = 100*sintel_clean + 100*sintel_final + things
225
+
226
+ elif args.stage == 'kitti':
227
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228
+ train_dataset = KITTI(aug_params, split='training')
229
+
230
+ train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231
+ pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232
+
233
+ print('Training with %d image pairs' % len(train_dataset))
234
+ return train_loader
235
+
data_utils/UNFaceFlow/core/extractor.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == 'group':
17
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19
+ if not stride == 1:
20
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
+
22
+ elif norm_fn == 'batch':
23
+ self.norm1 = nn.BatchNorm2d(planes)
24
+ self.norm2 = nn.BatchNorm2d(planes)
25
+ if not stride == 1:
26
+ self.norm3 = nn.BatchNorm2d(planes)
27
+
28
+ elif norm_fn == 'instance':
29
+ self.norm1 = nn.InstanceNorm2d(planes)
30
+ self.norm2 = nn.InstanceNorm2d(planes)
31
+ if not stride == 1:
32
+ self.norm3 = nn.InstanceNorm2d(planes)
33
+
34
+ elif norm_fn == 'none':
35
+ self.norm1 = nn.Sequential()
36
+ self.norm2 = nn.Sequential()
37
+ if not stride == 1:
38
+ self.norm3 = nn.Sequential()
39
+
40
+ if stride == 1:
41
+ self.downsample = None
42
+
43
+ else:
44
+ self.downsample = nn.Sequential(
45
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46
+
47
+
48
+ def forward(self, x):
49
+ y = x
50
+ y = self.relu(self.norm1(self.conv1(y)))
51
+ y = self.relu(self.norm2(self.conv2(y)))
52
+
53
+ if self.downsample is not None:
54
+ x = self.downsample(x)
55
+
56
+ return self.relu(x+y)
57
+
58
+
59
+
60
+ class BottleneckBlock(nn.Module):
61
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62
+ super(BottleneckBlock, self).__init__()
63
+
64
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67
+ self.relu = nn.ReLU(inplace=True)
68
+
69
+ num_groups = planes // 8
70
+
71
+ if norm_fn == 'group':
72
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75
+ if not stride == 1:
76
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77
+
78
+ elif norm_fn == 'batch':
79
+ self.norm1 = nn.BatchNorm2d(planes//4)
80
+ self.norm2 = nn.BatchNorm2d(planes//4)
81
+ self.norm3 = nn.BatchNorm2d(planes)
82
+ if not stride == 1:
83
+ self.norm4 = nn.BatchNorm2d(planes)
84
+
85
+ elif norm_fn == 'instance':
86
+ self.norm1 = nn.InstanceNorm2d(planes//4)
87
+ self.norm2 = nn.InstanceNorm2d(planes//4)
88
+ self.norm3 = nn.InstanceNorm2d(planes)
89
+ if not stride == 1:
90
+ self.norm4 = nn.InstanceNorm2d(planes)
91
+
92
+ elif norm_fn == 'none':
93
+ self.norm1 = nn.Sequential()
94
+ self.norm2 = nn.Sequential()
95
+ self.norm3 = nn.Sequential()
96
+ if not stride == 1:
97
+ self.norm4 = nn.Sequential()
98
+
99
+ if stride == 1:
100
+ self.downsample = None
101
+
102
+ else:
103
+ self.downsample = nn.Sequential(
104
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105
+
106
+
107
+ def forward(self, x):
108
+ y = x
109
+ y = self.relu(self.norm1(self.conv1(y)))
110
+ y = self.relu(self.norm2(self.conv2(y)))
111
+ y = self.relu(self.norm3(self.conv3(y)))
112
+
113
+ if self.downsample is not None:
114
+ x = self.downsample(x)
115
+
116
+ return self.relu(x+y)
117
+
118
+ class BasicEncoder(nn.Module):
119
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120
+ super(BasicEncoder, self).__init__()
121
+ self.norm_fn = norm_fn
122
+
123
+ if self.norm_fn == 'group':
124
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125
+
126
+ elif self.norm_fn == 'batch':
127
+ self.norm1 = nn.BatchNorm2d(64)
128
+
129
+ elif self.norm_fn == 'instance':
130
+ self.norm1 = nn.InstanceNorm2d(64)
131
+
132
+ elif self.norm_fn == 'none':
133
+ self.norm1 = nn.Sequential()
134
+
135
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136
+ self.relu1 = nn.ReLU(inplace=True)
137
+
138
+ self.in_planes = 64
139
+ self.layer1 = self._make_layer(64, stride=1)
140
+ self.layer2 = self._make_layer(96, stride=2)
141
+ self.layer3 = self._make_layer(128, stride=2)
142
+
143
+ # output convolution
144
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145
+
146
+ self.dropout = None
147
+ if dropout > 0:
148
+ self.dropout = nn.Dropout2d(p=dropout)
149
+
150
+ for m in self.modules():
151
+ if isinstance(m, nn.Conv2d):
152
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154
+ if m.weight is not None:
155
+ nn.init.constant_(m.weight, 1)
156
+ if m.bias is not None:
157
+ nn.init.constant_(m.bias, 0)
158
+
159
+ def _make_layer(self, dim, stride=1):
160
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162
+ layers = (layer1, layer2)
163
+
164
+ self.in_planes = dim
165
+ return nn.Sequential(*layers)
166
+
167
+
168
+ def forward(self, x):
169
+
170
+ # if input is list, combine batch dimension
171
+ is_list = isinstance(x, tuple) or isinstance(x, list)
172
+ if is_list:
173
+ batch_dim = x[0].shape[0]
174
+ x = torch.cat(x, dim=0)
175
+
176
+ x = self.conv1(x)
177
+ x = self.norm1(x)
178
+ x = self.relu1(x)
179
+
180
+ x = self.layer1(x)
181
+ x = self.layer2(x)
182
+ x = self.layer3(x)
183
+
184
+ x = self.conv2(x)
185
+
186
+ if self.training and self.dropout is not None:
187
+ x = self.dropout(x)
188
+
189
+ if is_list:
190
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
191
+
192
+ return x
193
+
194
+
195
+ class SmallEncoder(nn.Module):
196
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197
+ super(SmallEncoder, self).__init__()
198
+ self.norm_fn = norm_fn
199
+
200
+ if self.norm_fn == 'group':
201
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202
+
203
+ elif self.norm_fn == 'batch':
204
+ self.norm1 = nn.BatchNorm2d(32)
205
+
206
+ elif self.norm_fn == 'instance':
207
+ self.norm1 = nn.InstanceNorm2d(32)
208
+
209
+ elif self.norm_fn == 'none':
210
+ self.norm1 = nn.Sequential()
211
+
212
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213
+ self.relu1 = nn.ReLU(inplace=True)
214
+
215
+ self.in_planes = 32
216
+ self.layer1 = self._make_layer(32, stride=1)
217
+ self.layer2 = self._make_layer(64, stride=2)
218
+ self.layer3 = self._make_layer(96, stride=2)
219
+
220
+ self.dropout = None
221
+ if dropout > 0:
222
+ self.dropout = nn.Dropout2d(p=dropout)
223
+
224
+ self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225
+
226
+ for m in self.modules():
227
+ if isinstance(m, nn.Conv2d):
228
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230
+ if m.weight is not None:
231
+ nn.init.constant_(m.weight, 1)
232
+ if m.bias is not None:
233
+ nn.init.constant_(m.bias, 0)
234
+
235
+ def _make_layer(self, dim, stride=1):
236
+ layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237
+ layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238
+ layers = (layer1, layer2)
239
+
240
+ self.in_planes = dim
241
+ return nn.Sequential(*layers)
242
+
243
+
244
+ def forward(self, x):
245
+
246
+ # if input is list, combine batch dimension
247
+ is_list = isinstance(x, tuple) or isinstance(x, list)
248
+ if is_list:
249
+ batch_dim = x[0].shape[0]
250
+ x = torch.cat(x, dim=0)
251
+
252
+ x = self.conv1(x)
253
+ x = self.norm1(x)
254
+ x = self.relu1(x)
255
+ x = self.layer1(x)
256
+ x = self.layer2(x)
257
+ x = self.layer3(x)
258
+ x = self.conv2(x)
259
+
260
+ if self.training and self.dropout is not None:
261
+ x = self.dropout(x)
262
+
263
+ if is_list:
264
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
265
+
266
+ return x
data_utils/UNFaceFlow/core/nnutils.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def make_conv(n_in, n_out, n_blocks, kernel=3, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
7
+ blocks = []
8
+ for i in range(n_blocks):
9
+ in1 = n_in if i == 0 else n_out
10
+ blocks.append(torch.nn.Sequential(
11
+ torch.nn.Conv3d(in1, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
12
+ normalization(n_out),
13
+ activation(inplace=True)
14
+ ))
15
+ return torch.nn.Sequential(*blocks)
16
+
17
+
18
+ def make_conv_2d(n_in, n_out, n_blocks, kernel=3, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
19
+ blocks = []
20
+ for i in range(n_blocks):
21
+ in1 = n_in if i == 0 else n_out
22
+ blocks.append(torch.nn.Sequential(
23
+ torch.nn.Conv2d(in1, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
24
+ normalization(n_out),
25
+ activation(inplace=True)
26
+ ))
27
+ return torch.nn.Sequential(*blocks)
28
+
29
+
30
+ def make_downscale(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
31
+ block = torch.nn.Sequential(
32
+ torch.nn.Conv3d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
33
+ normalization(n_out),
34
+ activation(inplace=True)
35
+ )
36
+ return block
37
+
38
+
39
+ def make_downscale_2d(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
40
+ block = torch.nn.Sequential(
41
+ torch.nn.Conv2d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
42
+ normalization(n_out),
43
+ activation(inplace=True)
44
+ )
45
+ return block
46
+
47
+
48
+ def make_upscale(n_in, n_out, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
49
+ block = torch.nn.Sequential(
50
+ torch.nn.ConvTranspose3d(n_in, n_out, kernel_size=6, stride=2, padding=2),
51
+ normalization(n_out),
52
+ activation(inplace=True)
53
+ )
54
+ return block
55
+
56
+
57
+ def make_upscale_2d(n_in, n_out, kernel=4, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
58
+ block = torch.nn.Sequential(
59
+ torch.nn.ConvTranspose2d(n_in, n_out, kernel_size=kernel, stride=2, padding=(kernel-2)//2),
60
+ normalization(n_out),
61
+ activation(inplace=True)
62
+ )
63
+ return block
64
+
65
+
66
+ class ResBlock(torch.nn.Module):
67
+ def __init__(self, n_out, kernel=3, normalization=torch.nn.BatchNorm3d, activation=torch.nn.ReLU):
68
+ super().__init__()
69
+ self.block0 = torch.nn.Sequential(
70
+ torch.nn.Conv3d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
71
+ normalization(n_out),
72
+ activation(inplace=True)
73
+ )
74
+
75
+ self.block1 = torch.nn.Sequential(
76
+ torch.nn.Conv3d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
77
+ normalization(n_out),
78
+ )
79
+
80
+ self.block2 = torch.nn.ReLU()
81
+
82
+ def forward(self, x0):
83
+ x = self.block0(x0)
84
+
85
+ x = self.block1(x)
86
+
87
+ x = self.block2(x + x0)
88
+ return x
89
+
90
+
91
+ class ResBlock2d(torch.nn.Module):
92
+ def __init__(self, n_out, kernel=3, normalization=torch.nn.BatchNorm2d, activation=torch.nn.ReLU):
93
+ super().__init__()
94
+ self.block0 = torch.nn.Sequential(
95
+ torch.nn.Conv2d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
96
+ normalization(n_out),
97
+ activation(inplace=True)
98
+ )
99
+
100
+ self.block1 = torch.nn.Sequential(
101
+ torch.nn.Conv2d(n_out, n_out, kernel_size=kernel, stride=1, padding=(kernel//2)),
102
+ normalization(n_out),
103
+ )
104
+
105
+ self.block2 = torch.nn.ReLU()
106
+
107
+ def forward(self, x0):
108
+ x = self.block0(x0)
109
+
110
+ x = self.block1(x)
111
+
112
+ x = self.block2(x + x0)
113
+ return x
114
+
115
+
116
+ class Identity(torch.nn.Module):
117
+ def __init__(self, *args, **kwargs):
118
+ super().__init__()
119
+
120
+ def forward(self, x):
121
+ return x
122
+
123
+
124
+ def downscale_gt_flow(flow_gt, flow_mask, image_height, image_width):
125
+ flow_gt_copy = flow_gt.clone()
126
+ flow_mask_copy = flow_mask.clone()
127
+
128
+ flow_gt_copy = flow_gt_copy / 20.0
129
+ flow_mask_copy = flow_mask_copy.float()
130
+
131
+ assert image_height % 64 == 0 and image_width % 64 == 0
132
+
133
+ flow_gt2 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//4, image_width//4), mode='nearest')
134
+ flow_mask2 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//4, image_width//4), mode='nearest').bool()
135
+
136
+ flow_gt3 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//8, image_width//8), mode='nearest')
137
+ flow_mask3 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//8, image_width//8), mode='nearest').bool()
138
+
139
+ flow_gt4 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//16, image_width//16), mode='nearest')
140
+ flow_mask4 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//16, image_width//16), mode='nearest').bool()
141
+
142
+ flow_gt5 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//32, image_width//32), mode='nearest')
143
+ flow_mask5 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//32, image_width//32), mode='nearest').bool()
144
+
145
+ flow_gt6 = torch.nn.functional.interpolate(input=flow_gt_copy, size=(image_height//64, image_width//64), mode='nearest')
146
+ flow_mask6 = torch.nn.functional.interpolate(input=flow_mask_copy, size=(image_height//64, image_width//64), mode='nearest').bool()
147
+
148
+ return [flow_gt2, flow_gt3, flow_gt4, flow_gt5, flow_gt6], [flow_mask2, flow_mask3, flow_mask4, flow_mask5, flow_mask6]
149
+
150
+
151
+ def compute_baseline_mask_gt(
152
+ xy_coords_warped,
153
+ target_matches, valid_target_matches,
154
+ source_points, valid_source_points,
155
+ scene_flow_gt, scene_flow_mask, target_boundary_mask,
156
+ max_pos_flowed_source_to_target_dist, min_neg_flowed_source_to_target_dist
157
+ ):
158
+ # Scene flow mask
159
+ scene_flow_mask_0 = scene_flow_mask[:, 0].type(torch.bool)
160
+
161
+ # Boundary correspondences mask
162
+ # We use the nearest neighbor interpolation, since the boundary computations
163
+ # already marks any of 4 pixels as boundary.
164
+ target_nonboundary_mask = (~target_boundary_mask).type(torch.float32)
165
+ target_matches_nonboundary_mask = torch.nn.functional.grid_sample(target_nonboundary_mask, xy_coords_warped, padding_mode='zeros', mode='nearest', align_corners=False)
166
+ target_matches_nonboundary_mask = target_matches_nonboundary_mask[:, 0, :, :] >= 0.999
167
+
168
+ # Compute groundtruth mask (oracle)
169
+ flowed_source_points = source_points + scene_flow_gt
170
+ dist = torch.norm(flowed_source_points - target_matches, p=2, dim=1)
171
+
172
+ # Combine all masks
173
+ # We mark a correspondence as positive if;
174
+ # - it is close enough to groundtruth flow
175
+ # AND
176
+ # - there exists groundtruth flow
177
+ # AND
178
+ # - the target match is valid
179
+ # AND
180
+ # - the source point is valid
181
+ # AND
182
+ # - the target match is not on the boundary
183
+ mask_pos_gt = (dist <= max_pos_flowed_source_to_target_dist) & scene_flow_mask_0 & valid_target_matches & valid_source_points & target_matches_nonboundary_mask
184
+
185
+ # We mark a correspondence as negative if;
186
+ # - there exists groundtruth flow AND it is far away enough from the groundtruth flow AND source/target points are valid
187
+ # OR
188
+ # - the target match is on the boundary AND there exists groundtruth flow AND source/target points are valid
189
+ mask_neg_gt = ((dist > min_neg_flowed_source_to_target_dist) & scene_flow_mask_0 & valid_source_points & valid_target_matches) \
190
+ | (~target_matches_nonboundary_mask & scene_flow_mask_0 & valid_source_points & valid_target_matches)
191
+
192
+ # What remains is left undecided (masked out at loss).
193
+ # For groundtruth mask we set it to zero.
194
+ valid_mask_pixels = mask_pos_gt | mask_neg_gt
195
+ mask_gt = mask_pos_gt
196
+
197
+ mask_gt = mask_gt.type(torch.float32)
198
+
199
+ return mask_gt, valid_mask_pixels
200
+
201
+
202
+ def compute_deformed_points_gt(
203
+ source_points, scene_flow_gt,
204
+ valid_solve, valid_correspondences,
205
+ deformed_points_idxs, deformed_points_subsampled
206
+ ):
207
+ batch_size = source_points.shape[0]
208
+ max_warped_points = deformed_points_idxs.shape[1]
209
+
210
+ deformed_points_gt = torch.zeros((batch_size, max_warped_points, 3), dtype=source_points.dtype, device=source_points.device)
211
+ deformed_points_mask = torch.zeros((batch_size, max_warped_points, 3), dtype=source_points.dtype, device=source_points.device)
212
+
213
+ for i in range(batch_size):
214
+ if valid_solve[i]:
215
+ valid_correspondences_idxs = torch.where(valid_correspondences[i])
216
+
217
+ # Compute deformed point groundtruth.
218
+ deformed_points_i_gt = source_points[i] + scene_flow_gt[i]
219
+ deformed_points_i_gt = deformed_points_i_gt.permute(1, 2, 0)
220
+ deformed_points_i_gt = deformed_points_i_gt[valid_correspondences_idxs[0], valid_correspondences_idxs[1], :].view(-1, 3, 1)
221
+
222
+ # Filter out points randomly, if too many are still left.
223
+ if deformed_points_subsampled[i]:
224
+ sampled_idxs_i = deformed_points_idxs[i]
225
+ deformed_points_i_gt = deformed_points_i_gt[sampled_idxs_i]
226
+
227
+ num_points = deformed_points_i_gt.shape[0]
228
+
229
+ # Store the results.
230
+ deformed_points_gt[i, :num_points, :] = deformed_points_i_gt.view(1, num_points, 3)
231
+ deformed_points_mask[i, :num_points, :] = 1
232
+
233
+ return deformed_points_gt, deformed_points_mask
data_utils/UNFaceFlow/core/raft.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from update import BasicUpdateBlock, SmallUpdateBlock
7
+ from extractor import BasicEncoder, SmallEncoder
8
+ from corr import CorrBlock, AlternateCorrBlock
9
+ from utils_core.utils import bilinear_sampler, coords_grid, upflow8
10
+
11
+ try:
12
+ autocast = torch.cuda.amp.autocast
13
+ except:
14
+ # dummy autocast for PyTorch < 1.6
15
+ class autocast:
16
+ def __init__(self, enabled):
17
+ pass
18
+ def __enter__(self):
19
+ pass
20
+ def __exit__(self, *args):
21
+ pass
22
+
23
+
24
+ class RAFT(nn.Module):
25
+ def __init__(self, args):
26
+ super(RAFT, self).__init__()
27
+ self.args = args
28
+
29
+ if args.small:
30
+ self.hidden_dim = hdim = 96
31
+ self.context_dim = cdim = 64
32
+ args.corr_levels = 4
33
+ args.corr_radius = 3
34
+
35
+ else:
36
+ self.hidden_dim = hdim = 128
37
+ self.context_dim = cdim = 128
38
+ args.corr_levels = 4
39
+ args.corr_radius = 4
40
+
41
+ if 'dropout' not in self.args:
42
+ self.args.dropout = 0
43
+
44
+ if 'alternate_corr' not in self.args:
45
+ self.args.alternate_corr = False
46
+
47
+ # feature network, context network, and update block
48
+ if args.small:
49
+ self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50
+ self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51
+ self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52
+
53
+ else:
54
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55
+ self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56
+ self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57
+
58
+ def freeze_bn(self):
59
+ for m in self.modules():
60
+ if isinstance(m, nn.BatchNorm2d):
61
+ m.eval()
62
+
63
+ def initialize_flow(self, img):
64
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65
+ N, C, H, W = img.shape
66
+ coords0 = coords_grid(N, H//8, W//8).to(img.device)
67
+ coords1 = coords_grid(N, H//8, W//8).to(img.device)
68
+
69
+ # optical flow computed as difference: flow = coords1 - coords0
70
+ return coords0, coords1
71
+
72
+ def upsample_flow(self, flow, mask):
73
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74
+ N, _, H, W = flow.shape
75
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
76
+ mask = torch.softmax(mask, dim=2)
77
+
78
+ up_flow = F.unfold(8 * flow, [3,3], padding=1)
79
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80
+
81
+ up_flow = torch.sum(mask * up_flow, dim=2)
82
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83
+ return up_flow.reshape(N, 2, 8*H, 8*W)
84
+
85
+
86
+ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87
+ """ Estimate optical flow between pair of frames """
88
+
89
+ image1 = 2 * (image1 / 255.0) - 1.0
90
+ image2 = 2 * (image2 / 255.0) - 1.0
91
+
92
+ image1 = image1.contiguous()
93
+ image2 = image2.contiguous()
94
+
95
+ hdim = self.hidden_dim
96
+ cdim = self.context_dim
97
+
98
+ # run the feature network
99
+ with autocast(enabled=self.args.mixed_precision):
100
+ fmap1, fmap2 = self.fnet([image1, image2])
101
+
102
+ fmap1 = fmap1.float()
103
+ fmap2 = fmap2.float()
104
+ # print("fmap mean: ", fmap1.mean(), fmap2.mean())
105
+ if self.args.alternate_corr:
106
+ corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
107
+ else:
108
+ corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
109
+
110
+ # run the context network
111
+ with autocast(enabled=self.args.mixed_precision):
112
+ cnet = self.cnet(image1)
113
+ net, inp = torch.split(cnet, [hdim, cdim], dim=1)
114
+ net = torch.tanh(net)
115
+ inp = torch.relu(inp)
116
+
117
+ coords0, coords1 = self.initialize_flow(image1)
118
+
119
+ if flow_init is not None:
120
+ coords1 = coords1 + flow_init
121
+
122
+ flow_predictions = []
123
+ for itr in range(iters):
124
+ coords1 = coords1.detach()
125
+ corr = corr_fn(coords1) # index correlation volume
126
+
127
+ flow = coords1 - coords0
128
+ with autocast(enabled=self.args.mixed_precision):
129
+ net, up_mask, delta_flow, feature = self.update_block(net, inp, corr, flow)
130
+ # print("delta flow mean: ", delta_flow.mean())
131
+ # F(t+1) = F(t) + \Delta(t)
132
+ coords1 = coords1 + delta_flow
133
+
134
+ # upsample predictions
135
+ if up_mask is None:
136
+ flow_up = upflow8(coords1 - coords0)
137
+ else:
138
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
139
+
140
+ return flow_up, feature
141
+
142
+ class RAFT_ALL(nn.Module):
143
+ def __init__(self, args):
144
+ super(RAFT_ALL, self).__init__()
145
+ self.args = args
146
+
147
+ if args.small:
148
+ self.hidden_dim = hdim = 96
149
+ self.context_dim = cdim = 64
150
+ args.corr_levels = 4
151
+ args.corr_radius = 3
152
+
153
+ else:
154
+ self.hidden_dim = hdim = 128
155
+ self.context_dim = cdim = 128
156
+ args.corr_levels = 4
157
+ args.corr_radius = 4
158
+
159
+ if 'dropout' not in self.args:
160
+ self.args.dropout = 0
161
+
162
+ if 'alternate_corr' not in self.args:
163
+ self.args.alternate_corr = False
164
+
165
+ # feature network, context network, and update block
166
+ if args.small:
167
+ self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
168
+ self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
169
+ self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
170
+
171
+ else:
172
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
173
+ self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
174
+ self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
175
+
176
+ def freeze_bn(self):
177
+ for m in self.modules():
178
+ if isinstance(m, nn.BatchNorm2d):
179
+ m.eval()
180
+
181
+ def initialize_flow(self, img):
182
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
183
+ N, C, H, W = img.shape
184
+ coords0 = coords_grid(N, H//8, W//8).to(img.device)
185
+ coords1 = coords_grid(N, H//8, W//8).to(img.device)
186
+
187
+ # optical flow computed as difference: flow = coords1 - coords0
188
+ return coords0, coords1
189
+
190
+ def upsample_flow(self, flow, mask):
191
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
192
+ N, _, H, W = flow.shape
193
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
194
+ mask = torch.softmax(mask, dim=2)
195
+
196
+ up_flow = F.unfold(8 * flow, [3,3], padding=1)
197
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
198
+
199
+ up_flow = torch.sum(mask * up_flow, dim=2)
200
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
201
+ return up_flow.reshape(N, 2, 8*H, 8*W)
202
+
203
+
204
+ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
205
+ """ Estimate optical flow between pair of frames """
206
+
207
+ image1 = 2 * (image1 / 255.0) - 1.0
208
+ image2 = 2 * (image2 / 255.0) - 1.0
209
+
210
+ image1 = image1.contiguous()
211
+ image2 = image2.contiguous()
212
+
213
+ hdim = self.hidden_dim
214
+ cdim = self.context_dim
215
+
216
+ # run the feature network
217
+ with autocast(enabled=self.args.mixed_precision):
218
+ fmap1, fmap2 = self.fnet([image1, image2])
219
+
220
+ fmap1 = fmap1.float()
221
+ fmap2 = fmap2.float()
222
+ # print("fmap mean: ", fmap1.mean(), fmap2.mean())
223
+ if self.args.alternate_corr:
224
+ corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
225
+ else:
226
+ corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
227
+
228
+ # run the context network
229
+ with autocast(enabled=self.args.mixed_precision):
230
+ cnet = self.cnet(image1)
231
+ net, inp = torch.split(cnet, [hdim, cdim], dim=1)
232
+ net = torch.tanh(net)
233
+ inp = torch.relu(inp)
234
+
235
+ coords0, coords1 = self.initialize_flow(image1)
236
+
237
+ if flow_init is not None:
238
+ coords1 = coords1 + flow_init
239
+
240
+ flow_predictions = []
241
+ for itr in range(iters):
242
+ coords1 = coords1.detach()
243
+ corr = corr_fn(coords1) # index correlation volume
244
+
245
+ flow = coords1 - coords0
246
+ with autocast(enabled=self.args.mixed_precision):
247
+ net, up_mask, delta_flow, feature = self.update_block(net, inp, corr, flow)
248
+ # print("delta flow mean: ", delta_flow.mean())
249
+ # F(t+1) = F(t) + \Delta(t)
250
+ coords1 = coords1 + delta_flow
251
+
252
+ # upsample predictions
253
+ if up_mask is None:
254
+ flow_up = upflow8(coords1 - coords0)
255
+ else:
256
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
257
+ flow_predictions.append(flow_up)
258
+
259
+ return flow_predictions, feature
data_utils/UNFaceFlow/core/update.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from nnutils import make_conv_2d, make_upscale_2d, make_downscale_2d, ResBlock2d, Identity
5
+
6
+ class FlowHead(nn.Module):
7
+ def __init__(self, input_dim=128, hidden_dim=256):
8
+ super(FlowHead, self).__init__()
9
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11
+ self.relu = nn.ReLU(inplace=True)
12
+
13
+ def forward(self, x):
14
+ x = self.relu(self.conv1(x))
15
+ return self.conv2(x), x
16
+
17
+ class ConvGRU(nn.Module):
18
+ def __init__(self, hidden_dim=128, input_dim=192+128):
19
+ super(ConvGRU, self).__init__()
20
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
23
+
24
+ def forward(self, h, x):
25
+ hx = torch.cat([h, x], dim=1)
26
+
27
+ z = torch.sigmoid(self.convz(hx))
28
+ r = torch.sigmoid(self.convr(hx))
29
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
30
+
31
+ h = (1-z) * h + z * q
32
+ return h
33
+
34
+ class SepConvGRU(nn.Module):
35
+ def __init__(self, hidden_dim=128, input_dim=192+128):
36
+ super(SepConvGRU, self).__init__()
37
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
40
+
41
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
44
+
45
+
46
+ def forward(self, h, x):
47
+ # horizontal
48
+ hx = torch.cat([h, x], dim=1)
49
+ z = torch.sigmoid(self.convz1(hx))
50
+ r = torch.sigmoid(self.convr1(hx))
51
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
52
+ h = (1-z) * h + z * q
53
+
54
+ # vertical
55
+ hx = torch.cat([h, x], dim=1)
56
+ z = torch.sigmoid(self.convz2(hx))
57
+ r = torch.sigmoid(self.convr2(hx))
58
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
59
+ h = (1-z) * h + z * q
60
+
61
+ return h
62
+
63
+ class SmallMotionEncoder(nn.Module):
64
+ def __init__(self, args):
65
+ super(SmallMotionEncoder, self).__init__()
66
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
67
+ self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
68
+ self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
69
+ self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
70
+ self.conv = nn.Conv2d(128, 80, 3, padding=1)
71
+
72
+ def forward(self, flow, corr):
73
+ cor = F.relu(self.convc1(corr))
74
+ flo = F.relu(self.convf1(flow))
75
+ flo = F.relu(self.convf2(flo))
76
+ cor_flo = torch.cat([cor, flo], dim=1)
77
+ out = F.relu(self.conv(cor_flo))
78
+ return torch.cat([out, flow], dim=1)
79
+
80
+ class BasicMotionEncoder(nn.Module):
81
+ def __init__(self, args):
82
+ super(BasicMotionEncoder, self).__init__()
83
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
84
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
85
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
86
+ self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
87
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
88
+ self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
89
+
90
+ def forward(self, flow, corr):
91
+ cor = F.relu(self.convc1(corr))
92
+ cor = F.relu(self.convc2(cor))
93
+ flo = F.relu(self.convf1(flow))
94
+ flo = F.relu(self.convf2(flo))
95
+
96
+ cor_flo = torch.cat([cor, flo], dim=1)
97
+ out = F.relu(self.conv(cor_flo))
98
+ return torch.cat([out, flow], dim=1)
99
+
100
+ class SmallUpdateBlock(nn.Module):
101
+ def __init__(self, args, hidden_dim=96):
102
+ super(SmallUpdateBlock, self).__init__()
103
+ self.encoder = SmallMotionEncoder(args)
104
+ self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
105
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
106
+
107
+ def forward(self, net, inp, corr, flow):
108
+ motion_features = self.encoder(flow, corr)
109
+ inp = torch.cat([inp, motion_features], dim=1)
110
+ net = self.gru(net, inp)
111
+ delta_flow, feature = self.flow_head(net)
112
+
113
+ return net, None, delta_flow, feature
114
+
115
+ class BasicUpdateBlock(nn.Module):
116
+ def __init__(self, args, hidden_dim=128, input_dim=128):
117
+ super(BasicUpdateBlock, self).__init__()
118
+ self.args = args
119
+ self.encoder = BasicMotionEncoder(args)
120
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
121
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
122
+
123
+ self.mask = nn.Sequential(
124
+ nn.Conv2d(128, 256, 3, padding=1),
125
+ nn.ReLU(inplace=True),
126
+ nn.Conv2d(256, 64*9, 1, padding=0))
127
+
128
+ def forward(self, net, inp, corr, flow, upsample=True):
129
+ motion_features = self.encoder(flow, corr)
130
+ inp = torch.cat([inp, motion_features], dim=1)
131
+
132
+ net = self.gru(net, inp)
133
+ delta_flow, feature = self.flow_head(net)
134
+
135
+ # scale mask to balence gradients
136
+ mask = .25 * self.mask(net)
137
+ return net, mask, delta_flow, feature
138
+
139
+ class BasicWeightsNet(nn.Module):
140
+ def __init__(self, opt):
141
+ super(BasicUpdateBlock, self).__init__()
142
+ if opt.small:
143
+ in_dim = 128
144
+ else:
145
+ in_dim = 256
146
+ fn_0 = 16
147
+ self.input_fn = fn_0 + 2
148
+ fn_1 = 16
149
+ self.conv1 = torch.nn.Conv2d(in_channels=in_dim, out_channels=fn_0, kernel_size=3, stride=1, padding=1)
150
+ if opt.use_batch_norm:
151
+ custom_batch_norm = torch.nn.BatchNorm2d
152
+ else:
153
+ custom_batch_norm = Identity
154
+ self.model = nn.Sequential(
155
+ make_conv_2d(self.input_fn, fn_1, n_blocks=1, normalization=custom_batch_norm),
156
+ ResBlock2d(fn_1, normalization=custom_batch_norm),
157
+ ResBlock2d(fn_1, normalization=custom_batch_norm),
158
+ ResBlock2d(fn_1, normalization=custom_batch_norm),
159
+ nn.Conv2d(fn_1, 1, kernel_size=3, padding=1),
160
+ torch.nn.Sigmoid()
161
+ )
162
+
163
+ def forward(self, flow, feature):
164
+ features = self.conv1(features)
165
+ x = torch.cat([features, flow], 1)
166
+ return self.model(x)
167
+
168
+
169
+
data_utils/UNFaceFlow/core/utils_core/__init__.py ADDED
File without changes
data_utils/UNFaceFlow/core/utils_core/augmentor.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import math
4
+ from PIL import Image
5
+
6
+ import cv2
7
+ cv2.setNumThreads(0)
8
+ cv2.ocl.setUseOpenCL(False)
9
+
10
+ import torch
11
+ from torchvision.transforms import ColorJitter
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class FlowAugmentor:
16
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17
+
18
+ # spatial augmentation params
19
+ self.crop_size = crop_size
20
+ self.min_scale = min_scale
21
+ self.max_scale = max_scale
22
+ self.spatial_aug_prob = 0.8
23
+ self.stretch_prob = 0.8
24
+ self.max_stretch = 0.2
25
+
26
+ # flip augmentation params
27
+ self.do_flip = do_flip
28
+ self.h_flip_prob = 0.5
29
+ self.v_flip_prob = 0.1
30
+
31
+ # photometric augmentation params
32
+ self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33
+ self.asymmetric_color_aug_prob = 0.2
34
+ self.eraser_aug_prob = 0.5
35
+
36
+ def color_transform(self, img1, img2):
37
+ """ Photometric augmentation """
38
+
39
+ # asymmetric
40
+ if np.random.rand() < self.asymmetric_color_aug_prob:
41
+ img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42
+ img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43
+
44
+ # symmetric
45
+ else:
46
+ image_stack = np.concatenate([img1, img2], axis=0)
47
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48
+ img1, img2 = np.split(image_stack, 2, axis=0)
49
+
50
+ return img1, img2
51
+
52
+ def eraser_transform(self, img1, img2, bounds=[50, 100]):
53
+ """ Occlusion augmentation """
54
+
55
+ ht, wd = img1.shape[:2]
56
+ if np.random.rand() < self.eraser_aug_prob:
57
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58
+ for _ in range(np.random.randint(1, 3)):
59
+ x0 = np.random.randint(0, wd)
60
+ y0 = np.random.randint(0, ht)
61
+ dx = np.random.randint(bounds[0], bounds[1])
62
+ dy = np.random.randint(bounds[0], bounds[1])
63
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64
+
65
+ return img1, img2
66
+
67
+ def spatial_transform(self, img1, img2, flow):
68
+ # randomly sample scale
69
+ ht, wd = img1.shape[:2]
70
+ min_scale = np.maximum(
71
+ (self.crop_size[0] + 8) / float(ht),
72
+ (self.crop_size[1] + 8) / float(wd))
73
+
74
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75
+ scale_x = scale
76
+ scale_y = scale
77
+ if np.random.rand() < self.stretch_prob:
78
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80
+
81
+ scale_x = np.clip(scale_x, min_scale, None)
82
+ scale_y = np.clip(scale_y, min_scale, None)
83
+
84
+ if np.random.rand() < self.spatial_aug_prob:
85
+ # rescale the images
86
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88
+ flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89
+ flow = flow * [scale_x, scale_y]
90
+
91
+ if self.do_flip:
92
+ if np.random.rand() < self.h_flip_prob: # h-flip
93
+ img1 = img1[:, ::-1]
94
+ img2 = img2[:, ::-1]
95
+ flow = flow[:, ::-1] * [-1.0, 1.0]
96
+
97
+ if np.random.rand() < self.v_flip_prob: # v-flip
98
+ img1 = img1[::-1, :]
99
+ img2 = img2[::-1, :]
100
+ flow = flow[::-1, :] * [1.0, -1.0]
101
+
102
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103
+ x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104
+
105
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108
+
109
+ return img1, img2, flow
110
+
111
+ def __call__(self, img1, img2, flow):
112
+ img1, img2 = self.color_transform(img1, img2)
113
+ img1, img2 = self.eraser_transform(img1, img2)
114
+ img1, img2, flow = self.spatial_transform(img1, img2, flow)
115
+
116
+ img1 = np.ascontiguousarray(img1)
117
+ img2 = np.ascontiguousarray(img2)
118
+ flow = np.ascontiguousarray(flow)
119
+
120
+ return img1, img2, flow
121
+
122
+ class SparseFlowAugmentor:
123
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124
+ # spatial augmentation params
125
+ self.crop_size = crop_size
126
+ self.min_scale = min_scale
127
+ self.max_scale = max_scale
128
+ self.spatial_aug_prob = 0.8
129
+ self.stretch_prob = 0.8
130
+ self.max_stretch = 0.2
131
+
132
+ # flip augmentation params
133
+ self.do_flip = do_flip
134
+ self.h_flip_prob = 0.5
135
+ self.v_flip_prob = 0.1
136
+
137
+ # photometric augmentation params
138
+ self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139
+ self.asymmetric_color_aug_prob = 0.2
140
+ self.eraser_aug_prob = 0.5
141
+
142
+ def color_transform(self, img1, img2):
143
+ image_stack = np.concatenate([img1, img2], axis=0)
144
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145
+ img1, img2 = np.split(image_stack, 2, axis=0)
146
+ return img1, img2
147
+
148
+ def eraser_transform(self, img1, img2):
149
+ ht, wd = img1.shape[:2]
150
+ if np.random.rand() < self.eraser_aug_prob:
151
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152
+ for _ in range(np.random.randint(1, 3)):
153
+ x0 = np.random.randint(0, wd)
154
+ y0 = np.random.randint(0, ht)
155
+ dx = np.random.randint(50, 100)
156
+ dy = np.random.randint(50, 100)
157
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158
+
159
+ return img1, img2
160
+
161
+ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162
+ ht, wd = flow.shape[:2]
163
+ coords = np.meshgrid(np.arange(wd), np.arange(ht))
164
+ coords = np.stack(coords, axis=-1)
165
+
166
+ coords = coords.reshape(-1, 2).astype(np.float32)
167
+ flow = flow.reshape(-1, 2).astype(np.float32)
168
+ valid = valid.reshape(-1).astype(np.float32)
169
+
170
+ coords0 = coords[valid>=1]
171
+ flow0 = flow[valid>=1]
172
+
173
+ ht1 = int(round(ht * fy))
174
+ wd1 = int(round(wd * fx))
175
+
176
+ coords1 = coords0 * [fx, fy]
177
+ flow1 = flow0 * [fx, fy]
178
+
179
+ xx = np.round(coords1[:,0]).astype(np.int32)
180
+ yy = np.round(coords1[:,1]).astype(np.int32)
181
+
182
+ v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183
+ xx = xx[v]
184
+ yy = yy[v]
185
+ flow1 = flow1[v]
186
+
187
+ flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188
+ valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189
+
190
+ flow_img[yy, xx] = flow1
191
+ valid_img[yy, xx] = 1
192
+
193
+ return flow_img, valid_img
194
+
195
+ def spatial_transform(self, img1, img2, flow, valid):
196
+ # randomly sample scale
197
+
198
+ ht, wd = img1.shape[:2]
199
+ min_scale = np.maximum(
200
+ (self.crop_size[0] + 1) / float(ht),
201
+ (self.crop_size[1] + 1) / float(wd))
202
+
203
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204
+ scale_x = np.clip(scale, min_scale, None)
205
+ scale_y = np.clip(scale, min_scale, None)
206
+
207
+ if np.random.rand() < self.spatial_aug_prob:
208
+ # rescale the images
209
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211
+ flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212
+
213
+ if self.do_flip:
214
+ if np.random.rand() < 0.5: # h-flip
215
+ img1 = img1[:, ::-1]
216
+ img2 = img2[:, ::-1]
217
+ flow = flow[:, ::-1] * [-1.0, 1.0]
218
+ valid = valid[:, ::-1]
219
+
220
+ margin_y = 20
221
+ margin_x = 50
222
+
223
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224
+ x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225
+
226
+ y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227
+ x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228
+
229
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232
+ valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233
+ return img1, img2, flow, valid
234
+
235
+
236
+ def __call__(self, img1, img2, flow, valid):
237
+ img1, img2 = self.color_transform(img1, img2)
238
+ img1, img2 = self.eraser_transform(img1, img2)
239
+ img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240
+
241
+ img1 = np.ascontiguousarray(img1)
242
+ img2 = np.ascontiguousarray(img2)
243
+ flow = np.ascontiguousarray(flow)
244
+ valid = np.ascontiguousarray(valid)
245
+
246
+ return img1, img2, flow, valid
data_utils/UNFaceFlow/core/utils_core/flow_viz.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2
+
3
+
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2018 Tom Runia
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to conditions.
14
+ #
15
+ # Author: Tom Runia
16
+ # Date Created: 2018-08-03
17
+
18
+ import numpy as np
19
+
20
+ def make_colorwheel():
21
+ """
22
+ Generates a color wheel for optical flow visualization as presented in:
23
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25
+
26
+ Code follows the original C++ source code of Daniel Scharstein.
27
+ Code follows the the Matlab source code of Deqing Sun.
28
+
29
+ Returns:
30
+ np.ndarray: Color wheel
31
+ """
32
+
33
+ RY = 15
34
+ YG = 6
35
+ GC = 4
36
+ CB = 11
37
+ BM = 13
38
+ MR = 6
39
+
40
+ ncols = RY + YG + GC + CB + BM + MR
41
+ colorwheel = np.zeros((ncols, 3))
42
+ col = 0
43
+
44
+ # RY
45
+ colorwheel[0:RY, 0] = 255
46
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47
+ col = col+RY
48
+ # YG
49
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50
+ colorwheel[col:col+YG, 1] = 255
51
+ col = col+YG
52
+ # GC
53
+ colorwheel[col:col+GC, 1] = 255
54
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55
+ col = col+GC
56
+ # CB
57
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58
+ colorwheel[col:col+CB, 2] = 255
59
+ col = col+CB
60
+ # BM
61
+ colorwheel[col:col+BM, 2] = 255
62
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63
+ col = col+BM
64
+ # MR
65
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66
+ colorwheel[col:col+MR, 0] = 255
67
+ return colorwheel
68
+
69
+
70
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
71
+ """
72
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
73
+
74
+ According to the C++ source code of Daniel Scharstein
75
+ According to the Matlab source code of Deqing Sun
76
+
77
+ Args:
78
+ u (np.ndarray): Input horizontal flow of shape [H,W]
79
+ v (np.ndarray): Input vertical flow of shape [H,W]
80
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81
+
82
+ Returns:
83
+ np.ndarray: Flow visualization image of shape [H,W,3]
84
+ """
85
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86
+ colorwheel = make_colorwheel() # shape [55x3]
87
+ ncols = colorwheel.shape[0]
88
+ rad = np.sqrt(np.square(u) + np.square(v))
89
+ a = np.arctan2(-v, -u)/np.pi
90
+ fk = (a+1) / 2*(ncols-1)
91
+ k0 = np.floor(fk).astype(np.int32)
92
+ k1 = k0 + 1
93
+ k1[k1 == ncols] = 0
94
+ f = fk - k0
95
+ for i in range(colorwheel.shape[1]):
96
+ tmp = colorwheel[:,i]
97
+ col0 = tmp[k0] / 255.0
98
+ col1 = tmp[k1] / 255.0
99
+ col = (1-f)*col0 + f*col1
100
+ idx = (rad <= 1)
101
+ col[idx] = 1 - rad[idx] * (1-col[idx])
102
+ col[~idx] = col[~idx] * 0.75 # out of range
103
+ # Note the 2-i => BGR instead of RGB
104
+ ch_idx = 2-i if convert_to_bgr else i
105
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
106
+ return flow_image
107
+
108
+
109
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110
+ """
111
+ Expects a two dimensional flow image of shape.
112
+
113
+ Args:
114
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117
+
118
+ Returns:
119
+ np.ndarray: Flow visualization image of shape [H,W,3]
120
+ """
121
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123
+ if clip_flow is not None:
124
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
125
+ u = flow_uv[:,:,0]
126
+ v = flow_uv[:,:,1]
127
+ rad = np.sqrt(np.square(u) + np.square(v))
128
+ rad_max = np.max(rad)
129
+ epsilon = 1e-5
130
+ u = u / (rad_max + epsilon)
131
+ v = v / (rad_max + epsilon)
132
+ return flow_uv_to_colors(u, v, convert_to_bgr)
data_utils/UNFaceFlow/core/utils_core/frame_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from os.path import *
4
+ import re
5
+
6
+ import cv2
7
+ cv2.setNumThreads(0)
8
+ cv2.ocl.setUseOpenCL(False)
9
+
10
+ TAG_CHAR = np.array([202021.25], np.float32)
11
+
12
+ def readFlow(fn):
13
+ """ Read .flo file in Middlebury format"""
14
+ # Code adapted from:
15
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16
+
17
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18
+ # print 'fn = %s'%(fn)
19
+ with open(fn, 'rb') as f:
20
+ magic = np.fromfile(f, np.float32, count=1)
21
+ if 202021.25 != magic:
22
+ print('Magic number incorrect. Invalid .flo file')
23
+ return None
24
+ else:
25
+ w = np.fromfile(f, np.int32, count=1)
26
+ h = np.fromfile(f, np.int32, count=1)
27
+ # print 'Reading %d x %d flo file\n' % (w, h)
28
+ data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29
+ # Reshape data into 3D array (columns, rows, bands)
30
+ # The reshape here is for visualization, the original code is (w,h,2)
31
+ return np.resize(data, (int(h), int(w), 2))
32
+
33
+ def readPFM(file):
34
+ file = open(file, 'rb')
35
+
36
+ color = None
37
+ width = None
38
+ height = None
39
+ scale = None
40
+ endian = None
41
+
42
+ header = file.readline().rstrip()
43
+ if header == b'PF':
44
+ color = True
45
+ elif header == b'Pf':
46
+ color = False
47
+ else:
48
+ raise Exception('Not a PFM file.')
49
+
50
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51
+ if dim_match:
52
+ width, height = map(int, dim_match.groups())
53
+ else:
54
+ raise Exception('Malformed PFM header.')
55
+
56
+ scale = float(file.readline().rstrip())
57
+ if scale < 0: # little-endian
58
+ endian = '<'
59
+ scale = -scale
60
+ else:
61
+ endian = '>' # big-endian
62
+
63
+ data = np.fromfile(file, endian + 'f')
64
+ shape = (height, width, 3) if color else (height, width)
65
+
66
+ data = np.reshape(data, shape)
67
+ data = np.flipud(data)
68
+ return data
69
+
70
+ def writeFlow(filename,uv,v=None):
71
+ """ Write optical flow to file.
72
+
73
+ If v is None, uv is assumed to contain both u and v channels,
74
+ stacked in depth.
75
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
76
+ """
77
+ nBands = 2
78
+
79
+ if v is None:
80
+ assert(uv.ndim == 3)
81
+ assert(uv.shape[2] == 2)
82
+ u = uv[:,:,0]
83
+ v = uv[:,:,1]
84
+ else:
85
+ u = uv
86
+
87
+ assert(u.shape == v.shape)
88
+ height,width = u.shape
89
+ f = open(filename,'wb')
90
+ # write the header
91
+ f.write(TAG_CHAR)
92
+ np.array(width).astype(np.int32).tofile(f)
93
+ np.array(height).astype(np.int32).tofile(f)
94
+ # arrange into matrix form
95
+ tmp = np.zeros((height, width*nBands))
96
+ tmp[:,np.arange(width)*2] = u
97
+ tmp[:,np.arange(width)*2 + 1] = v
98
+ tmp.astype(np.float32).tofile(f)
99
+ f.close()
100
+
101
+
102
+ def readFlowKITTI(filename):
103
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104
+ flow = flow[:,:,::-1].astype(np.float32)
105
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
106
+ flow = (flow - 2**15) / 64.0
107
+ return flow, valid
108
+
109
+ def readDispKITTI(filename):
110
+ disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111
+ valid = disp > 0.0
112
+ flow = np.stack([-disp, np.zeros_like(disp)], -1)
113
+ return flow, valid
114
+
115
+
116
+ def writeFlowKITTI(filename, uv):
117
+ uv = 64.0 * uv + 2**15
118
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
119
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120
+ cv2.imwrite(filename, uv[..., ::-1])
121
+
122
+
123
+ def read_gen(file_name, pil=False):
124
+ ext = splitext(file_name)[-1]
125
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126
+ return Image.open(file_name)
127
+ elif ext == '.bin' or ext == '.raw':
128
+ return np.load(file_name)
129
+ elif ext == '.flo':
130
+ return readFlow(file_name).astype(np.float32)
131
+ elif ext == '.pfm':
132
+ flow = readPFM(file_name).astype(np.float32)
133
+ if len(flow.shape) == 2:
134
+ return flow
135
+ else:
136
+ return flow[:, :, :-1]
137
+ return []
data_utils/UNFaceFlow/core/utils_core/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy import interpolate
5
+
6
+
7
+ class InputPadder:
8
+ """ Pads images such that dimensions are divisible by 8 """
9
+ def __init__(self, dims, mode='sintel'):
10
+ self.ht, self.wd = dims[-2:]
11
+ pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12
+ pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13
+ if mode == 'sintel':
14
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15
+ else:
16
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17
+
18
+ def pad(self, *inputs):
19
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20
+
21
+ def unpad(self,x):
22
+ ht, wd = x.shape[-2:]
23
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24
+ return x[..., c[0]:c[1], c[2]:c[3]]
25
+
26
+ def forward_interpolate(flow):
27
+ flow = flow.detach().cpu().numpy()
28
+ dx, dy = flow[0], flow[1]
29
+
30
+ ht, wd = dx.shape
31
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32
+
33
+ x1 = x0 + dx
34
+ y1 = y0 + dy
35
+
36
+ x1 = x1.reshape(-1)
37
+ y1 = y1.reshape(-1)
38
+ dx = dx.reshape(-1)
39
+ dy = dy.reshape(-1)
40
+
41
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42
+ x1 = x1[valid]
43
+ y1 = y1[valid]
44
+ dx = dx[valid]
45
+ dy = dy[valid]
46
+
47
+ flow_x = interpolate.griddata(
48
+ (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49
+
50
+ flow_y = interpolate.griddata(
51
+ (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52
+
53
+ flow = np.stack([flow_x, flow_y], axis=0)
54
+ return torch.from_numpy(flow).float()
55
+
56
+
57
+ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58
+ """ Wrapper for grid_sample, uses pixel coordinates """
59
+ H, W = img.shape[-2:]
60
+ xgrid, ygrid = coords.split([1,1], dim=-1)
61
+ xgrid = 2*xgrid/(W-1) - 1
62
+ ygrid = 2*ygrid/(H-1) - 1
63
+
64
+ grid = torch.cat([xgrid, ygrid], dim=-1)
65
+ img = F.grid_sample(img, grid, align_corners=True)
66
+
67
+ if mask:
68
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69
+ return img, mask.float()
70
+
71
+ return img
72
+
73
+
74
+ def coords_grid(batch, ht, wd):
75
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
76
+ coords = torch.stack(coords[::-1], dim=0).float()
77
+ return coords[None].repeat(batch, 1, 1, 1)
78
+
79
+
80
+ def upflow8(flow, mode='bilinear'):
81
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83
+
84
+ def upweights8(weights, mode='bilinear'):
85
+ new_size = (8 * weights.shape[2], 8 * weights.shape[3])
86
+ return F.interpolate(weights, size=new_size, mode=mode, align_corners=True)
data_utils/UNFaceFlow/core/warp_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def mesh_grid(B, H, W):
7
+ # mesh grid
8
+ x_base = torch.arange(0, W).repeat(B, H, 1) # BHW
9
+ y_base = torch.arange(0, H).repeat(B, W, 1).transpose(1, 2) # BHW
10
+
11
+ base_grid = torch.stack([x_base, y_base], 1) # B2HW
12
+ return base_grid
13
+
14
+
15
+ def norm_grid(v_grid):
16
+ _, _, H, W = v_grid.size()
17
+
18
+ # scale grid to [-1,1]
19
+ v_grid_norm = torch.zeros_like(v_grid)
20
+ v_grid_norm[:, 0, :, :] = 2.0 * v_grid[:, 0, :, :] / (W - 1) - 1.0
21
+ v_grid_norm[:, 1, :, :] = 2.0 * v_grid[:, 1, :, :] / (H - 1) - 1.0
22
+ return v_grid_norm.permute(0, 2, 3, 1) # BHW2
23
+
24
+
25
+ def get_corresponding_map(data):
26
+ """
27
+
28
+ :param data: unnormalized coordinates Bx2xHxW
29
+ :return: Bx1xHxW
30
+ """
31
+ B, _, H, W = data.size()
32
+
33
+ # x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W)
34
+ # y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1)
35
+
36
+ x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W)
37
+ y = data[:, 1, :, :].view(B, -1)
38
+
39
+ # invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN
40
+ # invalid = invalid.repeat([1, 4])
41
+
42
+ x1 = torch.floor(x)
43
+ x_floor = x1.clamp(0, W - 1)
44
+ y1 = torch.floor(y)
45
+ y_floor = y1.clamp(0, H - 1)
46
+ x0 = x1 + 1
47
+ x_ceil = x0.clamp(0, W - 1)
48
+ y0 = y1 + 1
49
+ y_ceil = y0.clamp(0, H - 1)
50
+
51
+ x_ceil_out = x0 != x_ceil
52
+ y_ceil_out = y0 != y_ceil
53
+ x_floor_out = x1 != x_floor
54
+ y_floor_out = y1 != y_floor
55
+ invalid = torch.cat([x_ceil_out | y_ceil_out,
56
+ x_ceil_out | y_floor_out,
57
+ x_floor_out | y_ceil_out,
58
+ x_floor_out | y_floor_out], dim=1)
59
+
60
+ # encode coordinates, since the scatter function can only index along one axis
61
+ corresponding_map = torch.zeros(B, H * W).type_as(data)
62
+ indices = torch.cat([x_ceil + y_ceil * W,
63
+ x_ceil + y_floor * W,
64
+ x_floor + y_ceil * W,
65
+ x_floor + y_floor * W], 1).long() # BxN (N=4*H*W)
66
+ values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)),
67
+ (1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)),
68
+ (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)),
69
+ (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))],
70
+ 1)
71
+ # values = torch.ones_like(values)
72
+
73
+ values[invalid] = 0
74
+
75
+ corresponding_map.scatter_add_(1, indices, values)
76
+ # decode coordinates
77
+ corresponding_map = corresponding_map.view(B, H, W)
78
+
79
+ return corresponding_map.unsqueeze(1)
80
+
81
+
82
+ def flow_warp(x, flow12, pad='border', mode='bilinear'):
83
+ B, _, H, W = x.size()
84
+
85
+ base_grid = mesh_grid(B, H, W).type_as(x) # B2HW
86
+
87
+ v_grid = norm_grid(base_grid + flow12) # BHW2
88
+ im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad)
89
+ return im1_recons
90
+
91
+
92
+ def get_occu_mask_bidirection(flow12, flow21, mask, scale=1, bias=0.5):
93
+ flow21_warped = flow_warp(flow21, flow12, pad='zeros')
94
+ flow12_diff = flow12 + flow21_warped
95
+ mag = (flow12 * flow12).sum(1, keepdim=True) + \
96
+ (flow21_warped * flow21_warped).sum(1, keepdim=True)
97
+ occ_thresh = scale * mag + bias
98
+ occu = (flow12_diff * flow12_diff).sum(1, keepdim=True) > occ_thresh
99
+ # soft_occu = 1.0 / (1 + torch.exp(diff) / 5.0)
100
+ # print("forward:", diff.max(), diff.min())
101
+ return occu
102
+
103
+
104
+ def get_occu_mask_backward(flow21, th=0.2):
105
+ B, _, H, W = flow21.size()
106
+ base_grid = mesh_grid(B, H, W).type_as(flow21) # B2HW
107
+
108
+ corr_map = get_corresponding_map(base_grid + flow21) # BHW
109
+ occu_mask = corr_map.clamp(min=0., max=1.) < th
110
+ return occu_mask.float()
111
+
112
+ def get_ssv_weights(cycle_corres, input, mask, scale_value):
113
+ vgrid = (cycle_corres.mul(scale_value) - 1.0).permute(0,2,3,1)
114
+ new_input = nn.functional.grid_sample(input, vgrid, align_corners=True, padding_mode='border')
115
+ color_diff = (((input[:, :3, :, :] - new_input[:, :3, :, :]) / 255.0) ** 2).sum(1, keepdim=True)
116
+ depth_diff = (((input[:, 3:, :, :] - new_input[:, 3:, :, :])) ** 2).sum(1, keepdim=True)
117
+ diff = torch.mul(mask.float(), color_diff + depth_diff) #(N, 1, H, W)
118
+ return torch.exp(-diff)
data_utils/UNFaceFlow/data_test_flow/__init__.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import torch.utils.data
3
+ from data_test_flow.dd_dataset import DDDataset
4
+
5
+ def CreateDataLoader(opt):
6
+ data_loader = CustomDatasetDataLoader()
7
+ data_loader.initialize(opt)
8
+ return data_loader
9
+
10
+ # def CreateTestDataLoader(opt):
11
+ # data_loader = CustomTestDatasetDataLoader()
12
+ # data_loader.initialize(opt)
13
+ # return data_loader
14
+
15
+ class BaseDataLoader():
16
+ def __init__(self):
17
+ pass
18
+
19
+ def initialize(self, opt):
20
+ self.opt = opt
21
+ pass
22
+
23
+ def load_data(self):
24
+ return None
25
+
26
+ class CustomDatasetDataLoader(BaseDataLoader):
27
+ def name(self):
28
+ return 'CustomDatasetDataLoader'
29
+
30
+ def initialize(self, opt):
31
+ BaseDataLoader.initialize(self, opt)
32
+ self.dataset = DDDataset()
33
+ self.dataset.initialize(opt)
34
+ '''
35
+ sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
36
+ self.dataloader = torch.utils.data.DataLoader(
37
+ self.dataset,
38
+ batch_size=opt.batch_size,
39
+ shuffle=False,
40
+ sampler=sampler)
41
+ '''
42
+ self.dataloader = torch.utils.data.DataLoader(
43
+ self.dataset,
44
+ batch_size=opt.batch_size,
45
+ shuffle=opt.shuffle,
46
+ drop_last=True,
47
+ num_workers=int(opt.num_threads))
48
+
49
+ def load_data(self):
50
+ return self
51
+
52
+ def __len__(self):
53
+ return min(len(self.dataset), self.opt.max_dataset_size)
54
+
55
+ def __iter__(self):
56
+ for i, data in enumerate(self.dataloader):
57
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
58
+ break
59
+ yield data
60
+
61
+ # class CustomTestDatasetDataLoader(BaseDataLoader):
62
+ # def name(self):
63
+ # return 'CustomDatasetDataLoader'
64
+
65
+ # def initialize(self, opt):
66
+ # BaseDataLoader.initialize(self, opt)
67
+ # self.dataset = DDDatasetTest()
68
+ # self.dataset.initialize(opt)
69
+ # '''
70
+ # sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
71
+ # self.dataloader = torch.utils.data.DataLoader(
72
+ # self.dataset,
73
+ # batch_size=opt.batch_size,
74
+ # shuffle=False,
75
+ # sampler=sampler)
76
+ # '''
77
+ # self.dataloader = torch.utils.data.DataLoader(
78
+ # self.dataset,
79
+ # batch_size=opt.batch_size,
80
+ # shuffle=opt.shuffle,
81
+ # drop_last=True,
82
+ # num_workers=int(opt.num_threads))
83
+
84
+ # def load_data(self):
85
+ # return self
86
+
87
+ # def __len__(self):
88
+ # return min(len(self.dataset), self.opt.max_dataset_size)
89
+
90
+ # def __iter__(self):
91
+ # for i, data in enumerate(self.dataloader):
92
+ # if i * self.opt.batch_size >= self.opt.max_dataset_size:
93
+ # break
94
+ # yield data
data_utils/UNFaceFlow/data_test_flow/base_dataset.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+
5
+
6
+ class BaseDataset(data.Dataset):
7
+ def __init__(self):
8
+ super(BaseDataset, self).__init__()
9
+
10
+ def name(self):
11
+ return 'BaseDataset'
12
+
13
+ def initialize(self, opt):
14
+ pass
15
+
16
+ def __len__(self):
17
+ return 0
18
+
19
+
20
+ def get_transform(opt):
21
+ transform_list = []
22
+ if opt.resize_or_crop == 'resize_and_crop':
23
+ osize = [opt.loadSize, opt.loadSize]
24
+ transform_list.append(transforms.Resize(osize, Image.BICUBIC))
25
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
26
+ elif opt.resize_or_crop == 'crop':
27
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
28
+ elif opt.resize_or_crop == 'scale_width':
29
+ transform_list.append(transforms.Lambda(
30
+ lambda img: __scale_width(img, opt.fineSize)))
31
+ elif opt.resize_or_crop == 'scale_width_and_crop':
32
+ transform_list.append(transforms.Lambda(
33
+ lambda img: __scale_width(img, opt.loadSize)))
34
+ transform_list.append(transforms.RandomCrop(opt.fineSize))
35
+ elif opt.resize_or_crop == 'none':
36
+ transform_list.append(transforms.Lambda(
37
+ lambda img: __adjust(img)))
38
+ else:
39
+ raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)
40
+
41
+ if opt.isTrain and not opt.no_flip:
42
+ transform_list.append(transforms.RandomHorizontalFlip())
43
+
44
+ transform_list += [transforms.ToTensor(),
45
+ transforms.Normalize((0.5, 0.5, 0.5),
46
+ (0.5, 0.5, 0.5))]
47
+ return transforms.Compose(transform_list)
48
+
49
+
50
+ # just modify the width and height to be multiple of 4
51
+ def __adjust(img):
52
+ ow, oh = img.size
53
+
54
+ # the size needs to be a multiple of this number,
55
+ # because going through generator network may change img size
56
+ # and eventually cause size mismatch error
57
+ mult = 4
58
+ if ow % mult == 0 and oh % mult == 0:
59
+ return img
60
+ w = (ow - 1) // mult
61
+ w = (w + 1) * mult
62
+ h = (oh - 1) // mult
63
+ h = (h + 1) * mult
64
+
65
+ if ow != w or oh != h:
66
+ __print_size_warning(ow, oh, w, h)
67
+
68
+ return img.resize((w, h), Image.BICUBIC)
69
+
70
+
71
+ def __scale_width(img, target_width):
72
+ ow, oh = img.size
73
+
74
+ # the size needs to be a multiple of this number,
75
+ # because going through generator network may change img size
76
+ # and eventually cause size mismatch error
77
+ mult = 4
78
+ assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
79
+ if (ow == target_width and oh % mult == 0):
80
+ return img
81
+ w = target_width
82
+ target_height = int(target_width * oh / ow)
83
+ m = (target_height - 1) // mult
84
+ h = (m + 1) * mult
85
+
86
+ if target_height != h:
87
+ __print_size_warning(target_width, target_height, w, h)
88
+
89
+ return img.resize((w, h), Image.BICUBIC)
90
+
91
+
92
+ def __print_size_warning(ow, oh, w, h):
93
+ if not hasattr(__print_size_warning, 'has_printed'):
94
+ print("The image size needs to be a multiple of 4. "
95
+ "The loaded image size was (%d, %d), so it was adjusted to "
96
+ "(%d, %d). This adjustment will be done to all images "
97
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
98
+ __print_size_warning.has_printed = True
data_utils/UNFaceFlow/data_test_flow/dd_dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import torch
3
+ import torch.utils.data as data
4
+ from PIL import Image
5
+ import random
6
+ import utils
7
+ import numpy as np
8
+ import torchvision.transforms as transforms
9
+ from utils_core import flow_viz
10
+ import cv2
11
+
12
+ class DDDataset(data.Dataset):
13
+ def __init__(self):
14
+ super(DDDataset, self).__init__()
15
+ def initialize(self, opt):
16
+ self.opt = opt
17
+ self.dir_txt = opt.datapath
18
+ self.paths = []
19
+ in_file = open(self.dir_txt, "r")
20
+ k = 0
21
+ list_paths = in_file.readlines()
22
+ for line in list_paths:
23
+ #if k>=20: break
24
+ flag = False
25
+ line = line.strip()
26
+ line = line.split()
27
+
28
+ #source data
29
+ if (not os.path.exists(line[0])):
30
+ print(line[0]+" not exists")
31
+ continue
32
+ if (not os.path.exists(line[1])):
33
+ print(line[1]+" not exists")
34
+ continue
35
+ if (not os.path.exists(line[2])):
36
+ print(line[2]+" not exists")
37
+ continue
38
+ if (not os.path.exists(line[3])):
39
+ print(line[3]+" not exists")
40
+ continue
41
+ # if (not os.path.exists(line[2])):
42
+ # print(line[2]+" not exists")
43
+ # continue
44
+
45
+ # path_list = [line[0], line[1], line[2]]
46
+ path_list = [line[0], line[1], line[2], line[3]]
47
+ self.paths.append(path_list)
48
+ k += 1
49
+ in_file.close()
50
+ self.data_size = len(self.paths)
51
+ print("num data: ", len(self.paths))
52
+
53
+ def process_data(self, color, mask):
54
+ non_zero = mask.nonzero()
55
+ bound = 10
56
+ min_x = max(0, non_zero[1].min()-bound)
57
+ max_x = min(self.opt.width-1, non_zero[1].max()+bound)
58
+ min_y = max(0, non_zero[0].min()-bound)
59
+ max_y = min(self.opt.height-1, non_zero[0].max()+bound)
60
+ color = color * (mask!=0).astype(float)[:, :, None]
61
+ crop_color = color[min_y:max_y, min_x:max_x, :]
62
+ crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR)
63
+ crop_params = [[min_x], [max_x], [min_y], [max_y]]
64
+
65
+ return crop_color, crop_params
66
+
67
+ def __getitem__(self, index):
68
+ paths = self.paths[index % self.data_size]
69
+ src_color = np.array(Image.open(paths[0]))
70
+ src_color = src_color.astype(np.uint8)
71
+ raw_src_color = src_color.copy()
72
+ src_mask = np.array(Image.open(paths[1]))[:, :, 0]
73
+ cv2.imwrite("test_mask.png", src_mask)
74
+ src_mask_copy = src_mask.copy()
75
+ src_crop_color, src_crop_params = self.process_data(src_color, src_mask)
76
+ #self.write_mesh(src_X, src_Y, src_Z, "./tmp/src.obj")
77
+ #HWC --> CHW,
78
+ raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0
79
+ src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0
80
+
81
+ src_mask_copy = (src_mask_copy!=0)
82
+ src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :])
83
+
84
+ tar_color = np.array(Image.open(paths[2]))
85
+ tar_color = tar_color.astype(np.uint8)
86
+ raw_tar_color = tar_color.copy()
87
+ tar_mask = np.array(Image.open(paths[3]))[:, :, 0]
88
+ tar_mask_copy = tar_mask.copy()
89
+ tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask)
90
+
91
+ raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0
92
+ tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0
93
+
94
+ tar_mask_copy = (tar_mask_copy!=0)
95
+ tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :])
96
+
97
+ Crop_param = torch.tensor(src_crop_params+tar_crop_params)
98
+
99
+ split_ = paths[0].split("/")
100
+ path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow"
101
+
102
+ return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param}
103
+
104
+ def __len__(self):
105
+ return self.data_size
106
+
107
+ def name(self):
108
+ return 'DDDataset'
data_utils/UNFaceFlow/data_test_flow/dd_dataset_bak.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import torch
3
+ import torch.utils.data as data
4
+ from PIL import Image
5
+ import random
6
+ import utils
7
+ import numpy as np
8
+ import torchvision.transforms as transforms
9
+ from utils_core import flow_viz
10
+ import cv2
11
+
12
+ class DDDataset(data.Dataset):
13
+ def __init__(self):
14
+ super(DDDataset, self).__init__()
15
+ def initialize(self, opt):
16
+ self.opt = opt
17
+ self.dir_txt = opt.datapath
18
+ self.paths = []
19
+ in_file = open(self.dir_txt, "r")
20
+ k = 0
21
+ list_paths = in_file.readlines()
22
+ for line in list_paths:
23
+ #if k>=20: break
24
+ flag = False
25
+ line = line.strip()
26
+ line = line.split()
27
+
28
+ #source data
29
+ if (not os.path.exists(line[0])):
30
+ print(line[0]+" not exists")
31
+ continue
32
+ if (not os.path.exists(line[1])):
33
+ print(line[1]+" not exists")
34
+ continue
35
+ if (not os.path.exists(line[2])):
36
+ print(line[2]+" not exists")
37
+ continue
38
+ if (not os.path.exists(line[3])):
39
+ print(line[3]+" not exists")
40
+ continue
41
+ # if (not os.path.exists(line[2])):
42
+ # print(line[2]+" not exists")
43
+ # continue
44
+
45
+ # path_list = [line[0], line[1], line[2]]
46
+ path_list = [line[0], line[1], line[2], line[3]]
47
+ self.paths.append(path_list)
48
+ k += 1
49
+ in_file.close()
50
+ self.data_size = len(self.paths)
51
+ print("num data: ", len(self.paths))
52
+
53
+ def process_data(self, color, mask):
54
+ non_zero = mask.nonzero()
55
+ bound = 10
56
+ min_x = max(0, non_zero[1].min()-bound)
57
+ max_x = min(self.opt.width-1, non_zero[1].max()+bound)
58
+ min_y = max(0, non_zero[0].min()-bound)
59
+ max_y = min(self.opt.height-1, non_zero[0].max()+bound)
60
+ color = color * (mask!=0).astype(float)[:, :, None]
61
+ crop_color = color[min_y:max_y, min_x:max_x, :]
62
+ crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR)
63
+ crop_params = [[min_x], [max_x], [min_y], [max_y]]
64
+
65
+ return crop_color, crop_params
66
+
67
+ def __getitem__(self, index):
68
+ paths = self.paths[index % self.data_size]
69
+ src_color = np.array(Image.open(paths[0]))
70
+ src_color = src_color.astype(np.uint8)
71
+ raw_src_color = src_color.copy()
72
+ src_mask = np.array(Image.open(paths[1]))
73
+ src_mask_copy = src_mask.copy()
74
+ src_crop_color, src_crop_params = self.process_data(src_color, src_mask)
75
+ #self.write_mesh(src_X, src_Y, src_Z, "./tmp/src.obj")
76
+ #HWC --> CHW,
77
+ raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0
78
+ src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0
79
+
80
+ src_mask_copy = (src_mask_copy!=0)
81
+ src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :])
82
+
83
+ tar_color = np.array(Image.open(paths[2]))
84
+ tar_color = tar_color.astype(np.uint8)
85
+ raw_tar_color = tar_color.copy()
86
+ tar_mask = np.array(Image.open(paths[3]))
87
+ tar_mask_copy = tar_mask.copy()
88
+ tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask)
89
+
90
+ raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0
91
+ tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0
92
+
93
+ tar_mask_copy = (tar_mask_copy!=0)
94
+ tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :])
95
+
96
+ Crop_param = torch.tensor(src_crop_params+tar_crop_params)
97
+
98
+ split_ = paths[0].split("/")
99
+ path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow"
100
+
101
+ return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param}
102
+
103
+ def __len__(self):
104
+ return self.data_size
105
+
106
+ def name(self):
107
+ return 'DDDataset'
data_utils/UNFaceFlow/models/network_test_flow.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from raft import RAFT
5
+ from nnutils import make_conv_2d, make_upscale_2d, make_downscale_2d, ResBlock2d, Identity
6
+
7
+
8
+ class ImportanceWeights(torch.nn.Module):
9
+ def __init__(self, opt):
10
+ super().__init__()
11
+
12
+ if opt.small:
13
+ in_dim = 128
14
+ else:
15
+ in_dim = 256
16
+ fn_0 = 16
17
+ self.input_fn = fn_0 + 3 * 2
18
+ fn_1 = 16
19
+ self.conv1 = torch.nn.Conv2d(in_channels=in_dim, out_channels=fn_0, kernel_size=3, stride=1, padding=1)
20
+
21
+ if opt.use_batch_norm:
22
+ custom_batch_norm = torch.nn.BatchNorm2d
23
+ else:
24
+ custom_batch_norm = Identity
25
+
26
+ self.model = nn.Sequential(
27
+ make_conv_2d(self.input_fn, fn_1, n_blocks=1, normalization=custom_batch_norm),
28
+ ResBlock2d(fn_1, normalization=custom_batch_norm),
29
+ ResBlock2d(fn_1, normalization=custom_batch_norm),
30
+ ResBlock2d(fn_1, normalization=custom_batch_norm),
31
+ nn.Conv2d(fn_1, 1, kernel_size=3, padding=1)
32
+ # torch.nn.Sigmoid()
33
+ )
34
+
35
+ def forward(self, x, features):
36
+ # Reduce number of channels and upscale to highest resolution
37
+ features = self.conv1(features)
38
+ x = torch.cat([features, x], 1)
39
+ assert x.shape[1] == self.input_fn
40
+ x = self.model(x)
41
+ print(x)
42
+ print(x.max(), x.min(), x.mean())
43
+
44
+ return torch.nn.Sigmoid()(x)
45
+
46
+ class NeuralNRT(nn.Module):
47
+ def __init__(self, opt, path=None, device="cuda:0"):
48
+ super(NeuralNRT, self).__init__()
49
+ self.opt = opt
50
+ self.CorresPred = RAFT(opt)
51
+ self.ImportanceW = ImportanceWeights(opt)
52
+ if path is not None:
53
+ data = torch.load(path,map_location='cpu')
54
+ if 'state_dict' in data.keys():
55
+ self.CorresPred.load_state_dict(data['state_dict'])
56
+ print("load done")
57
+ else:
58
+ self.CorresPred.load_state_dict({k.replace('module.', ''):v for k,v in data.items()})
59
+ print("load done")
60
+ def forward(self, src_im,tar_im, src_im_raw, tar_im_raw, Crop_param):
61
+ N=src_im.shape[0]
62
+ src_im = src_im*255.0
63
+ tar_im = tar_im*255.0
64
+ flow_fw_crop, feature_fw_crop = self.CorresPred(src_im, tar_im, iters=self.opt.iters)
65
+
66
+ xx = torch.arange(0, self.opt.width).view(1,-1).repeat(self.opt.height,1)
67
+ yy = torch.arange(0, self.opt.height).view(-1,1).repeat(1,self.opt.width)
68
+ xx = xx.view(1,1,self.opt.height,self.opt.width).repeat(N,1,1,1)
69
+ yy = yy.view(1,1,self.opt.height,self.opt.width).repeat(N,1,1,1)
70
+ grid = torch.cat((xx,yy),1).float()
71
+ grid = grid.to(src_im.device)
72
+
73
+ grid_crop = grid[:, :, :self.opt.crop_height, :self.opt.crop_width]
74
+
75
+ flow_fw = torch.zeros((N, 2, self.opt.height, self.opt.width), device=src_im.device)
76
+
77
+ leftup1 = torch.cat((Crop_param[:, 0:1, 0], Crop_param[:, 2:3, 0]), 1)[:, :, None, None]
78
+ leftup2 = torch.cat((Crop_param[:, 4:5, 0], Crop_param[:, 6:7, 0]), 1)[:, :, None, None]
79
+
80
+ scale1 = torch.cat(((Crop_param[:, 1:2, 0]-Crop_param[:, 0:1, 0]).float() / self.opt.crop_width, (Crop_param[:, 3:4, 0]-Crop_param[:, 2:3, 0]).float() / self.opt.crop_height), 1)[:, :, None, None]
81
+ scale2 = torch.cat(((Crop_param[:, 5:6, 0]-Crop_param[:, 4:5, 0]).float() / self.opt.crop_width, (Crop_param[:, 7:8, 0]-Crop_param[:, 6:7, 0]).float() / self.opt.crop_height), 1)[:, :, None, None]
82
+
83
+ flow_fw_crop = (scale2 - scale1) * grid_crop + scale2 * flow_fw_crop
84
+ for i in range(N):
85
+ flow_fw_cropi = F.interpolate(flow_fw_crop[i:(i+1)], ((Crop_param[i, 3, 0]-Crop_param[i, 2, 0]).item(), (Crop_param[i, 1, 0]-Crop_param[i, 0, 0]).item()), mode='bilinear', align_corners=True)
86
+ flow_fw_cropi =flow_fw_cropi + (leftup2 - leftup1)[i:(i+1), :, :, :]
87
+ flow_fw[i, :, Crop_param[i, 2, 0]:Crop_param[i, 3, 0], Crop_param[i, 0, 0]:Crop_param[i, 1, 0]] = flow_fw_cropi[0]
88
+ return flow_fw
data_utils/UNFaceFlow/options_test_flow.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:https://github.com/ShunyuYao/DFA-NeRF
2
+ import argparse
3
+ class BaseOptions():
4
+ def __init__(self):
5
+ self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
6
+ self.initialized = False
7
+
8
+ def initialize(self):
9
+ # self.parser.add_argument('--model_save_path', type=str, default='snapshot/small_filter_wo_ct_wi_bn/real_data/combine/', help='path')
10
+ self.parser.add_argument('--model_save_path', type=str, default='snapshot/version1/', help='path')
11
+ self.parser.add_argument('--num_threads', type=int, default=2, help='number of threads')
12
+ self.parser.add_argument('--max_dataset_size', type=int, default=150000, help='max dataset size')
13
+
14
+ self.parser.add_argument('--n_epochs', type=int, default=40000, help='number of iterations')
15
+ self.parser.add_argument('--dropout', type=float, default=0.0, help='dropout')
16
+ self.parser.add_argument('--init_type', type=str, default='uniform', help='[uniform | xavier]')
17
+ self.parser.add_argument('--frequency_print_batch', type=int, default=1000, help='print messages every set iter')
18
+ self.parser.add_argument('--frequency_save_model', type=int, default=2000, help='save model every set iter')
19
+ self.parser.add_argument('--small', type=bool, default=True, help='use small model')
20
+ self.parser.add_argument('--use_batch_norm', action='store_true', help='')
21
+ self.parser.add_argument('--smooth_2nd', type=bool, default=True, help='')
22
+
23
+
24
+ #loss weight for Gauss-Newton optimization
25
+ self.parser.add_argument('--lambda_2d', type=float, default=0.001, help='weight of 2D projection loss')
26
+ self.parser.add_argument('--lambda_depth', type=float, default=1.0, help='weight of depth loss')
27
+ self.parser.add_argument('--lambda_reg', type=float, default=1.0, help='weight of regularization loss')
28
+
29
+ self.parser.add_argument('--num_adja', type=int, default=6, help='number of nodes who affect a point')
30
+ self.parser.add_argument('--num_corres', type=int, default=20000, help='number of corres')
31
+ self.parser.add_argument('--iter_num', type=int, default=3, help='GN iter num')
32
+ self.parser.add_argument('--width', type=int, default=512, help='image width')#480
33
+ self.parser.add_argument('--height', type=int, default=512, help='image height')#640
34
+ self.parser.add_argument('--crop_width', type=int, default=240, help='image width')
35
+ self.parser.add_argument('--crop_height', type=int, default=320, help='image height')
36
+ self.parser.add_argument('--max_num_edges', type=int, default=30000, help='number of edges')
37
+ self.parser.add_argument('--max_num_nodes', type=int, default=1500, help='number of edges')
38
+ self.parser.add_argument('--fdim', type=int, default=128)
39
+
40
+ #loss weight for training
41
+ self.parser.add_argument('--lambda_weights', type=float, default=0.0, help='weight of weights loss')#75
42
+ self.parser.add_argument('--lambda_corres', type=float, default=1.0, help='weight of corres loss')#0, 1
43
+ self.parser.add_argument('--lambda_graph', type=float, default=10.0, help='weight of graph loss')#1000, 5
44
+ self.parser.add_argument('--lambda_warp', type=float, default=10.0, help='weight of warp loss')#1000, 5
45
+
46
+
47
+ def parse(self):
48
+ if not self.initialized:
49
+ self.initialize()
50
+
51
+ self.opt = self.parser.parse_args()
52
+ self.opt.isTrain = self.isTrain
53
+ self.opt.isTest = self.isTest
54
+ args = vars(self.opt)
55
+
56
+ return self.opt
57
+
58
+ class TrainOptions(BaseOptions):
59
+ # Override
60
+ def initialize(self):
61
+ BaseOptions.initialize(self)
62
+ #syn_datasets/syn_new_train_data.txt
63
+ self.parser.add_argument('--datapath', type=str, default='./data/train_data.txt', help='path')
64
+ self.parser.add_argument('--pretrain_model_path', type=str, default='./pretrain_model/raft-small.pth', help='path')#
65
+ self.parser.add_argument('--lr_C', type=float, default=0.00001, help='initial learning rate')#0.01
66
+ self.parser.add_argument('--optimizer_C', type=str, default='sgd', help='[sgd | adam]')
67
+ self.parser.add_argument('--lr_W', type=float, default=0.00001, help='initial learning rate')
68
+ self.parser.add_argument('--lr_BSW', type=float, default=0.00001, help='initial learning rate')
69
+ self.parser.add_argument('--optimizer_W', type=str, default='sgd', help='[sgd | adam]')
70
+ self.parser.add_argument('--optimizer_BSW', type=str, default='sgd', help='[sgd | adam]')
71
+ self.parser.add_argument('--lr_decay_epoch', type=int, default=8000, help='multiply by a gamma every set iter')
72
+ self.parser.add_argument('--lr_decay', type=float, default=0.1, help='coefficient of lr decay')
73
+ self.parser.add_argument('--weight_decay', type=float, default=1e-4, help='0.0005coefficient of weight decay')
74
+ self.parser.add_argument('--batch_size', type=int, default=4, help='batch size')
75
+ self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
76
+
77
+ self.parser.add_argument('--validation', type=str, nargs='+')
78
+ #self.parser.add_argument('--image_size', type=int, nargs='+', default=[384, 512])
79
+ self.parser.add_argument('--gpus', type=int, nargs='+', default=[0,1])
80
+ self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
81
+ self.parser.add_argument('--iters', type=int, default=12)
82
+
83
+ self.parser.add_argument('--clip', type=float, default=1.0)
84
+ self.parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
85
+ self.parser.add_argument('--add_noise', action='store_true')
86
+
87
+ self.parser.add_argument('--train_bsw', type=bool, default=True, help='whether to train bsw network')
88
+ self.parser.add_argument('--train_weight', type=bool, default=True, help='whether to train weight network')
89
+ self.parser.add_argument('--train_corres', type=bool, default=True, help='whether to train corresPred network')
90
+
91
+ self.isTrain = True
92
+ self.isTest = False
93
+
94
+ class ValOptions(BaseOptions):
95
+ def initialize(self):
96
+ BaseOptions.initialize(self)
97
+ self.parser.add_argument('--batch_size', type=int, default=4, help='batch size')
98
+ self.parser.add_argument('--datapath', type=str, default='./data/val_data.txt', help='path')
99
+ self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
100
+ self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
101
+ self.parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
102
+ self.parser.add_argument('--iters', type=int, default=12)
103
+ self.isTrain = True
104
+ self.isTest = False
105
+
106
+ class TestOptions(BaseOptions):
107
+ def initialize(self):
108
+ BaseOptions.initialize(self)
109
+ self.parser.add_argument('--batch_size', type=int, default=1, help='batch size')
110
+ self.parser.add_argument('--pretrain_model_path', type=str, default='./pretrain_model/raft-small.pth', help='path')#
111
+
112
+ # self.parser.add_argument('--datapath', type=str, default='./data/real_train_data_1128_1.txt', help='path')
113
+ # self.parser.add_argument('--datapath', type=str, default='./data_test_flow/test_data.txt', help='path')
114
+ self.parser.add_argument('--savepath', type=str, default='flow_result',
115
+ help='save path')
116
+ self.parser.add_argument('--datapath', type=str, default='/data_b/yudong/paper_code/TalkingHead-NeRF/data_guancha/guancha_flow.txt',
117
+ help='path')
118
+ self.parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
119
+ self.parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
120
+ self.parser.add_argument('--iters', type=int, default=12)
121
+ self.parser.add_argument('--shuffle', type=bool, default=True, help='whether to shuffle data')
122
+ self.isTrain = False
123
+ self.isTest = True
data_utils/UNFaceFlow/pretrain_model/raft-small.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7d41b9cc88442bb8aa911dbb33086dac55a226394b142937ff22d5578717332
3
+ size 3984814
data_utils/UNFaceFlow/sgd_NNRT_model_epoch19008_50000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8156ef276732a4cbd9a9d85d9b5653abf500976372daf7892b122971a7b8f37
3
+ size 8808087
data_utils/UNFaceFlow/test_flow.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:https://github.com/ShunyuYao/DFA-NeRF
2
+ import sys
3
+ import os
4
+ from tqdm import tqdm
5
+ dir_path = os.path.dirname(os.path.realpath(__file__))
6
+ sys.path.append(os.path.join(dir_path, 'core'))
7
+ from pathlib import Path
8
+ from data_test_flow import *
9
+ from models.network_test_flow import NeuralNRT
10
+ from options_test_flow import TestOptions
11
+ import torch
12
+ import numpy as np
13
+
14
+
15
+
16
+ def save_flow_numpy(filename, flow_input):
17
+ np.save(filename, flow_input)
18
+
19
+
20
+ def predict(data):
21
+ with torch.no_grad():
22
+ model.eval()
23
+ path_flow = data["path_flow"]
24
+ src_crop_im = data["src_crop_color"].cuda()
25
+ tar_crop_im = data["tar_crop_color"].cuda()
26
+ src_im = data["src_color"].cuda()
27
+ tar_im = data["tar_color"].cuda()
28
+ src_mask = data["src_mask"].cuda()
29
+ crop_param = data["Crop_param"].cuda()
30
+ B = src_mask.shape[0]
31
+ flow = model(src_crop_im, tar_crop_im, src_im, tar_im, crop_param)
32
+ for i in range(B):
33
+ flow_tmp = flow[i].cpu().numpy() * src_mask[i].cpu().numpy()
34
+ save_flow_numpy(os.path.join(save_path, os.path.basename(
35
+ path_flow[i])[:-6]+".npy"), flow_tmp)
36
+
37
+
38
+ if __name__ == "__main__":
39
+ width = 272
40
+ height = 480
41
+
42
+ test_opts = TestOptions().parse()
43
+ test_opts.pretrain_model_path = os.path.join(
44
+ dir_path, 'pretrain_model/raft-small.pth')
45
+ data_loader = CreateDataLoader(test_opts)
46
+ testloader = data_loader.load_data()
47
+ model_path = os.path.join(dir_path, 'sgd_NNRT_model_epoch19008_50000.pth')
48
+ model = NeuralNRT(test_opts, os.path.join(
49
+ dir_path, 'pretrain_model/raft-small.pth'))
50
+ state_dict = torch.load(model_path)
51
+
52
+ model.CorresPred.load_state_dict(state_dict["net_C"])
53
+ model.ImportanceW.load_state_dict(state_dict["net_W"])
54
+
55
+ model = model.cuda()
56
+
57
+ save_path = test_opts.savepath
58
+ Path(save_path).mkdir(parents=True, exist_ok=True)
59
+ total_length = len(testloader)
60
+
61
+ for batch_idx, data in tqdm(enumerate(testloader), total=total_length):
62
+ predict(data)
data_utils/UNFaceFlow/utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import numpy as np
4
+ import struct
5
+ import pickle
6
+ from scipy.sparse import coo_matrix
7
+
8
+ def load_flow(filename):
9
+ # Flow is stored row-wise in order [channels, height, width].
10
+ assert os.path.isfile(filename), "File not found: {}".format(filename)
11
+
12
+ flow = None
13
+ with open(filename, 'rb') as fin:
14
+ width = struct.unpack('I', fin.read(4))[0]
15
+ height = struct.unpack('I', fin.read(4))[0]
16
+ channels = struct.unpack('I', fin.read(4))[0]
17
+ n_elems = height * width * channels
18
+
19
+ flow = struct.unpack('f' * n_elems, fin.read(n_elems * 4))
20
+ flow = np.asarray(flow, dtype=np.float32).reshape([channels, height, width])
21
+
22
+ return flow
23
+
24
+ def load_graph_info(filename, max_edges, max_nodes):
25
+
26
+ assert os.path.isfile(filename), "File not found: {}".format(filename)
27
+
28
+ with open(filename, 'rb') as fin:
29
+ edge_total_size = struct.unpack('I', fin.read(4))[0]
30
+ edges = struct.unpack('I' * (int(edge_total_size / 4)), fin.read(edge_total_size))
31
+ edges = np.asarray(edges, dtype=np.int16).reshape(-1, 2).transpose()
32
+ nodes_total_size = struct.unpack('I', fin.read(4))[0]
33
+ nodes_ids = struct.unpack('I' * (int(nodes_total_size / 4)), fin.read(nodes_total_size))
34
+ nodes_ids = np.asarray(nodes_ids, dtype=np.int32).reshape(-1)
35
+ nodes_ids = np.sort(nodes_ids)
36
+
37
+ edges_extent = np.zeros((2, max_edges), dtype=np.int16)
38
+ edges_mask = np.zeros((max_edges), dtype=np.bool)
39
+ edges_mask[:edges.shape[1]] = 1
40
+ edges_extent[:, :edges.shape[1]] = edges
41
+
42
+ nodes_extent = np.zeros((max_nodes), dtype=np.int32)
43
+ nodes_mask = np.zeros((max_nodes), dtype=np.bool)
44
+ nodes_mask[:nodes_ids.shape[0]] = 1
45
+ nodes_extent[:nodes_ids.shape[0]] = nodes_ids
46
+
47
+ fx = struct.unpack('f', fin.read(4))[0]
48
+ fy = struct.unpack('f', fin.read(4))[0]
49
+ ox = struct.unpack('f', fin.read(4))[0]
50
+ oy = struct.unpack('f', fin.read(4))[0]
51
+
52
+ return edges_extent, edges_mask, nodes_extent, nodes_mask, fx, fy, ox, oy
53
+
54
+ def load_adja_id_info(filename, src_mask, H, W, num_adja, num_neigb):
55
+
56
+ assert os.path.isfile(filename), "File not found: {}".format(filename)
57
+ assert num_adja<=8, "Num of adja is larger than 8"
58
+ assert num_neigb<=8, "Num of neighb is larger than 8"
59
+ src_v_id = np.zeros((H*W, num_adja), dtype=np.int16)
60
+ src_neigb_id = np.zeros((H*W, num_neigb), dtype=np.int32)
61
+ with open(filename, 'rb') as fin:
62
+ neigb_id, value_id = pickle.load(fin)
63
+ assert((src_mask.sum())==value_id.shape[0])
64
+
65
+ for i in range(num_adja):
66
+ src_v_id[src_mask.reshape(-1), i] = value_id[:, i]
67
+ for i in range(num_neigb):
68
+ src_neigb_id[src_mask.reshape(-1), i] = neigb_id[:, i]
69
+ src_v_id = src_v_id.transpose().reshape(num_adja, H, W)
70
+ src_neigb_id = src_neigb_id.transpose().reshape(num_neigb, H, W)
71
+
72
+ return src_v_id, src_neigb_id
73
+
74
+ def save_flow(filename, flow_input):
75
+ flow = np.copy(flow_input)
76
+
77
+ # Flow is stored row-wise in order [channels, height, width].
78
+ assert len(flow.shape) == 3
79
+
80
+ with open(filename, 'wb') as fout:
81
+ fout.write(struct.pack('I', flow.shape[2]))
82
+ fout.write(struct.pack('I', flow.shape[1]))
83
+ fout.write(struct.pack('I', flow.shape[0]))
84
+ fout.write(struct.pack('={}f'.format(flow.size), *flow.flatten("C")))
data_utils/blendshape_capture/face_landmarker.task ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
3
+ size 3758596
data_utils/blendshape_capture/main.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*-coding:utf-8-*-
2
+ import argparse
3
+ import os
4
+ import random
5
+ import numpy as np
6
+ import cv2
7
+ import glob
8
+ from mediapipe import solutions
9
+ from mediapipe.framework.formats import landmark_pb2
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import torch
13
+ import torch.nn as nn
14
+ from scipy.signal import savgol_filter
15
+ import onnxruntime as ort
16
+ from collections import OrderedDict
17
+ import mediapipe as mp
18
+ from mediapipe.tasks import python
19
+ from mediapipe.tasks.python import vision
20
+
21
+
22
+ from tqdm import tqdm
23
+
24
+
25
+ def infer_bs(root_path):
26
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+ base_options = python.BaseOptions(model_asset_path="./data_utils/blendshape_capture/face_landmarker.task")
28
+ options = vision.FaceLandmarkerOptions(base_options=base_options,
29
+ output_face_blendshapes=True,
30
+ output_facial_transformation_matrixes=True,
31
+ num_faces=1)
32
+ detector = vision.FaceLandmarker.create_from_options(options)
33
+
34
+ for i in os.listdir(root_path):
35
+ if i.endswith(".mp4"):
36
+ mp4_path = os.path.join(root_path, i)
37
+ npy_path = os.path.join(root_path, "bs.npy")
38
+ if os.path.exists(npy_path):
39
+ print("npy file exists:", i.split(".")[0])
40
+ continue
41
+ else:
42
+ print("npy file not exists:", i.split(".")[0])
43
+ image_path = os.path.join(root_path, "img/temp.png")
44
+ os.makedirs(os.path.join(root_path, 'img/'), exist_ok=True)
45
+ cap = cv2.VideoCapture(mp4_path)
46
+ fps = cap.get(cv2.CAP_PROP_FPS)
47
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
48
+ print("fps:", fps)
49
+ print("frame_count:", frame_count)
50
+ k = 0
51
+ total = frame_count
52
+ bs = np.zeros((int(total), 52), dtype=np.float32)
53
+ print("total:", total)
54
+ print("videoPath:{} fps:{},k".format(mp4_path.split('/')[-1], fps))
55
+ pbar = tqdm(total=int(total))
56
+ while (cap.isOpened()):
57
+ ret, frame = cap.read()
58
+ if ret:
59
+ cv2.imwrite(image_path, frame)
60
+ image = mp.Image.create_from_file(image_path)
61
+ result = detector.detect(image)
62
+ face_blendshapes_scores = [face_blendshapes_category.score for face_blendshapes_category in
63
+ result.face_blendshapes[0]]
64
+ blendshape_coef = np.array(face_blendshapes_scores)[1:]
65
+ blendshape_coef = np.append(blendshape_coef, 0)
66
+ bs[k] = blendshape_coef
67
+ pbar.update(1)
68
+ k += 1
69
+ else:
70
+ break
71
+ cap.release()
72
+ pbar.close()
73
+ # np.save(npy_path, bs)
74
+ # print(np.shape(bs))
75
+ output = np.zeros((bs.shape[0], bs.shape[1]))
76
+ for j in range(bs.shape[1]):
77
+ output[:, j] = savgol_filter(bs[:, j], 5, 3)
78
+ np.save(npy_path, output)
79
+ print(np.shape(output))
80
+
81
+
82
+ if __name__ == '__main__':
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument("--path", type=str, help="idname of target person")
85
+ args = parser.parse_args()
86
+ infer_bs(args.path)
data_utils/deepspeech_features/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Routines for DeepSpeech features processing
2
+ Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model.
3
+
4
+ ## Installation
5
+
6
+ ```
7
+ pip3 install -r requirements.txt
8
+ ```
9
+
10
+ ## Usage
11
+
12
+ Generate wav files:
13
+ ```
14
+ python3 extract_wav.py --in-video=<you_data_dir>
15
+ ```
16
+
17
+ Generate files with DeepSpeech features:
18
+ ```
19
+ python3 extract_ds_features.py --input=<you_data_dir>
20
+ ```
data_utils/deepspeech_features/deepspeech_features.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSpeech features processing routines.
3
+ NB: Based on VOCA code. See the corresponding license restrictions.
4
+ """
5
+
6
+ __all__ = ['conv_audios_to_deepspeech']
7
+
8
+ import numpy as np
9
+ import warnings
10
+ import resampy
11
+ from scipy.io import wavfile
12
+ from python_speech_features import mfcc
13
+ import tensorflow.compat.v1 as tf
14
+ tf.disable_v2_behavior()
15
+
16
+ def conv_audios_to_deepspeech(audios,
17
+ out_files,
18
+ num_frames_info,
19
+ deepspeech_pb_path,
20
+ audio_window_size=1,
21
+ audio_window_stride=1):
22
+ """
23
+ Convert list of audio files into files with DeepSpeech features.
24
+
25
+ Parameters
26
+ ----------
27
+ audios : list of str or list of None
28
+ Paths to input audio files.
29
+ out_files : list of str
30
+ Paths to output files with DeepSpeech features.
31
+ num_frames_info : list of int
32
+ List of numbers of frames.
33
+ deepspeech_pb_path : str
34
+ Path to DeepSpeech 0.1.0 frozen model.
35
+ audio_window_size : int, default 16
36
+ Audio window size.
37
+ audio_window_stride : int, default 1
38
+ Audio window stride.
39
+ """
40
+ # deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
41
+ graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(
42
+ deepspeech_pb_path)
43
+
44
+ with tf.compat.v1.Session(graph=graph) as sess:
45
+ for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):
46
+ print(audio_file_path)
47
+ print(out_file_path)
48
+ audio_sample_rate, audio = wavfile.read(audio_file_path)
49
+ if audio.ndim != 1:
50
+ warnings.warn(
51
+ "Audio has multiple channels, the first channel is used")
52
+ audio = audio[:, 0]
53
+ ds_features = pure_conv_audio_to_deepspeech(
54
+ audio=audio,
55
+ audio_sample_rate=audio_sample_rate,
56
+ audio_window_size=audio_window_size,
57
+ audio_window_stride=audio_window_stride,
58
+ num_frames=num_frames,
59
+ net_fn=lambda x: sess.run(
60
+ logits_ph,
61
+ feed_dict={
62
+ input_node_ph: x[np.newaxis, ...],
63
+ input_lengths_ph: [x.shape[0]]}))
64
+
65
+ net_output = ds_features.reshape(-1, 29)
66
+ win_size = 16
67
+ zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
68
+ net_output = np.concatenate(
69
+ (zero_pad, net_output, zero_pad), axis=0)
70
+ windows = []
71
+ for window_index in range(0, net_output.shape[0] - win_size, 2):
72
+ windows.append(
73
+ net_output[window_index:window_index + win_size])
74
+ print(np.array(windows).shape)
75
+ np.save(out_file_path, np.array(windows))
76
+
77
+
78
+ def prepare_deepspeech_net(deepspeech_pb_path):
79
+ """
80
+ Load and prepare DeepSpeech network.
81
+
82
+ Parameters
83
+ ----------
84
+ deepspeech_pb_path : str
85
+ Path to DeepSpeech 0.1.0 frozen model.
86
+
87
+ Returns
88
+ -------
89
+ graph : obj
90
+ ThensorFlow graph.
91
+ logits_ph : obj
92
+ ThensorFlow placeholder for `logits`.
93
+ input_node_ph : obj
94
+ ThensorFlow placeholder for `input_node`.
95
+ input_lengths_ph : obj
96
+ ThensorFlow placeholder for `input_lengths`.
97
+ """
98
+ # Load graph and place_holders:
99
+ with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f:
100
+ graph_def = tf.compat.v1.GraphDef()
101
+ graph_def.ParseFromString(f.read())
102
+
103
+ graph = tf.compat.v1.get_default_graph()
104
+ tf.import_graph_def(graph_def, name="deepspeech")
105
+ logits_ph = graph.get_tensor_by_name("deepspeech/logits:0")
106
+ input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0")
107
+ input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0")
108
+
109
+ return graph, logits_ph, input_node_ph, input_lengths_ph
110
+
111
+
112
+ def pure_conv_audio_to_deepspeech(audio,
113
+ audio_sample_rate,
114
+ audio_window_size,
115
+ audio_window_stride,
116
+ num_frames,
117
+ net_fn):
118
+ """
119
+ Core routine for converting audion into DeepSpeech features.
120
+
121
+ Parameters
122
+ ----------
123
+ audio : np.array
124
+ Audio data.
125
+ audio_sample_rate : int
126
+ Audio sample rate.
127
+ audio_window_size : int
128
+ Audio window size.
129
+ audio_window_stride : int
130
+ Audio window stride.
131
+ num_frames : int or None
132
+ Numbers of frames.
133
+ net_fn : func
134
+ Function for DeepSpeech model call.
135
+
136
+ Returns
137
+ -------
138
+ np.array
139
+ DeepSpeech features.
140
+ """
141
+ target_sample_rate = 16000
142
+ if audio_sample_rate != target_sample_rate:
143
+ resampled_audio = resampy.resample(
144
+ x=audio.astype(np.float),
145
+ sr_orig=audio_sample_rate,
146
+ sr_new=target_sample_rate)
147
+ else:
148
+ resampled_audio = audio.astype(np.float32)
149
+ input_vector = conv_audio_to_deepspeech_input_vector(
150
+ audio=resampled_audio.astype(np.int16),
151
+ sample_rate=target_sample_rate,
152
+ num_cepstrum=26,
153
+ num_context=9)
154
+
155
+ network_output = net_fn(input_vector)
156
+ # print(network_output.shape)
157
+
158
+ deepspeech_fps = 50
159
+ video_fps = 50 # Change this option if video fps is different
160
+ audio_len_s = float(audio.shape[0]) / audio_sample_rate
161
+ if num_frames is None:
162
+ num_frames = int(round(audio_len_s * video_fps))
163
+ else:
164
+ video_fps = num_frames / audio_len_s
165
+ network_output = interpolate_features(
166
+ features=network_output[:, 0],
167
+ input_rate=deepspeech_fps,
168
+ output_rate=video_fps,
169
+ output_len=num_frames)
170
+
171
+ # Make windows:
172
+ zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))
173
+ network_output = np.concatenate(
174
+ (zero_pad, network_output, zero_pad), axis=0)
175
+ windows = []
176
+ for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):
177
+ windows.append(
178
+ network_output[window_index:window_index + audio_window_size])
179
+
180
+ return np.array(windows)
181
+
182
+
183
+ def conv_audio_to_deepspeech_input_vector(audio,
184
+ sample_rate,
185
+ num_cepstrum,
186
+ num_context):
187
+ """
188
+ Convert audio raw data into DeepSpeech input vector.
189
+
190
+ Parameters
191
+ ----------
192
+ audio : np.array
193
+ Audio data.
194
+ audio_sample_rate : int
195
+ Audio sample rate.
196
+ num_cepstrum : int
197
+ Number of cepstrum.
198
+ num_context : int
199
+ Number of context.
200
+
201
+ Returns
202
+ -------
203
+ np.array
204
+ DeepSpeech input vector.
205
+ """
206
+ # Get mfcc coefficients:
207
+ features = mfcc(
208
+ signal=audio,
209
+ samplerate=sample_rate,
210
+ numcep=num_cepstrum)
211
+
212
+ # We only keep every second feature (BiRNN stride = 2):
213
+ features = features[::2]
214
+
215
+ # One stride per time step in the input:
216
+ num_strides = len(features)
217
+
218
+ # Add empty initial and final contexts:
219
+ empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)
220
+ features = np.concatenate((empty_context, features, empty_context))
221
+
222
+ # Create a view into the array with overlapping strides of size
223
+ # numcontext (past) + 1 (present) + numcontext (future):
224
+ window_size = 2 * num_context + 1
225
+ train_inputs = np.lib.stride_tricks.as_strided(
226
+ features,
227
+ shape=(num_strides, window_size, num_cepstrum),
228
+ strides=(features.strides[0],
229
+ features.strides[0], features.strides[1]),
230
+ writeable=False)
231
+
232
+ # Flatten the second and third dimensions:
233
+ train_inputs = np.reshape(train_inputs, [num_strides, -1])
234
+
235
+ train_inputs = np.copy(train_inputs)
236
+ train_inputs = (train_inputs - np.mean(train_inputs)) / \
237
+ np.std(train_inputs)
238
+
239
+ return train_inputs
240
+
241
+
242
+ def interpolate_features(features,
243
+ input_rate,
244
+ output_rate,
245
+ output_len):
246
+ """
247
+ Interpolate DeepSpeech features.
248
+
249
+ Parameters
250
+ ----------
251
+ features : np.array
252
+ DeepSpeech features.
253
+ input_rate : int
254
+ input rate (FPS).
255
+ output_rate : int
256
+ Output rate (FPS).
257
+ output_len : int
258
+ Output data length.
259
+
260
+ Returns
261
+ -------
262
+ np.array
263
+ Interpolated data.
264
+ """
265
+ input_len = features.shape[0]
266
+ num_features = features.shape[1]
267
+ input_timestamps = np.arange(input_len) / float(input_rate)
268
+ output_timestamps = np.arange(output_len) / float(output_rate)
269
+ output_features = np.zeros((output_len, num_features))
270
+ for feature_idx in range(num_features):
271
+ output_features[:, feature_idx] = np.interp(
272
+ x=output_timestamps,
273
+ xp=input_timestamps,
274
+ fp=features[:, feature_idx])
275
+ return output_features
data_utils/deepspeech_features/deepspeech_store.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Routines for loading DeepSpeech model.
3
+ """
4
+
5
+ __all__ = ['get_deepspeech_model_file']
6
+
7
+ import os
8
+ import zipfile
9
+ import logging
10
+ import hashlib
11
+
12
+
13
+ deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'
14
+
15
+
16
+ def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")):
17
+ """
18
+ Return location for the pretrained on local file system. This function will download from online model zoo when
19
+ model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
20
+
21
+ Parameters
22
+ ----------
23
+ local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
24
+ Location for keeping the model parameters.
25
+
26
+ Returns
27
+ -------
28
+ file_path
29
+ Path to the requested pretrained model file.
30
+ """
31
+ sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e"
32
+ file_name = "deepspeech-0_1_0-b90017e8.pb"
33
+ local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
34
+ file_path = os.path.join(local_model_store_dir_path, file_name)
35
+ if os.path.exists(file_path):
36
+ if _check_sha1(file_path, sha1_hash):
37
+ return file_path
38
+ else:
39
+ logging.warning("Mismatch in the content of model file detected. Downloading again.")
40
+ else:
41
+ logging.info("Model file not found. Downloading to {}.".format(file_path))
42
+
43
+ if not os.path.exists(local_model_store_dir_path):
44
+ os.makedirs(local_model_store_dir_path)
45
+
46
+ zip_file_path = file_path + ".zip"
47
+ _download(
48
+ url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
49
+ repo_url=deepspeech_features_repo_url,
50
+ repo_release_tag="v0.0.1",
51
+ file_name=file_name),
52
+ path=zip_file_path,
53
+ overwrite=True)
54
+ with zipfile.ZipFile(zip_file_path) as zf:
55
+ zf.extractall(local_model_store_dir_path)
56
+ os.remove(zip_file_path)
57
+
58
+ if _check_sha1(file_path, sha1_hash):
59
+ return file_path
60
+ else:
61
+ raise ValueError("Downloaded file has different hash. Please try again.")
62
+
63
+
64
+ def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
65
+ """
66
+ Download an given URL
67
+
68
+ Parameters
69
+ ----------
70
+ url : str
71
+ URL to download
72
+ path : str, optional
73
+ Destination path to store downloaded file. By default stores to the
74
+ current directory with same name as in url.
75
+ overwrite : bool, optional
76
+ Whether to overwrite destination file if already exists.
77
+ sha1_hash : str, optional
78
+ Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
79
+ but doesn't match.
80
+ retries : integer, default 5
81
+ The number of times to attempt the download in case of failure or non 200 return codes
82
+ verify_ssl : bool, default True
83
+ Verify SSL certificates.
84
+
85
+ Returns
86
+ -------
87
+ str
88
+ The file path of the downloaded file.
89
+ """
90
+ import warnings
91
+ try:
92
+ import requests
93
+ except ImportError:
94
+ class requests_failed_to_import(object):
95
+ pass
96
+ requests = requests_failed_to_import
97
+
98
+ if path is None:
99
+ fname = url.split("/")[-1]
100
+ # Empty filenames are invalid
101
+ assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
102
+ else:
103
+ path = os.path.expanduser(path)
104
+ if os.path.isdir(path):
105
+ fname = os.path.join(path, url.split("/")[-1])
106
+ else:
107
+ fname = path
108
+ assert retries >= 0, "Number of retries should be at least 0"
109
+
110
+ if not verify_ssl:
111
+ warnings.warn(
112
+ "Unverified HTTPS request is being made (verify_ssl=False). "
113
+ "Adding certificate verification is strongly advised.")
114
+
115
+ if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
116
+ dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
117
+ if not os.path.exists(dirname):
118
+ os.makedirs(dirname)
119
+ while retries + 1 > 0:
120
+ # Disable pyling too broad Exception
121
+ # pylint: disable=W0703
122
+ try:
123
+ print("Downloading {} from {}...".format(fname, url))
124
+ r = requests.get(url, stream=True, verify=verify_ssl)
125
+ if r.status_code != 200:
126
+ raise RuntimeError("Failed downloading url {}".format(url))
127
+ with open(fname, "wb") as f:
128
+ for chunk in r.iter_content(chunk_size=1024):
129
+ if chunk: # filter out keep-alive new chunks
130
+ f.write(chunk)
131
+ if sha1_hash and not _check_sha1(fname, sha1_hash):
132
+ raise UserWarning("File {} is downloaded but the content hash does not match."
133
+ " The repo may be outdated or download may be incomplete. "
134
+ "If the `repo_url` is overridden, consider switching to "
135
+ "the default repo.".format(fname))
136
+ break
137
+ except Exception as e:
138
+ retries -= 1
139
+ if retries <= 0:
140
+ raise e
141
+ else:
142
+ print("download failed, retrying, {} attempt{} left"
143
+ .format(retries, "s" if retries > 1 else ""))
144
+
145
+ return fname
146
+
147
+
148
+ def _check_sha1(filename, sha1_hash):
149
+ """
150
+ Check whether the sha1 hash of the file content matches the expected hash.
151
+
152
+ Parameters
153
+ ----------
154
+ filename : str
155
+ Path to the file.
156
+ sha1_hash : str
157
+ Expected sha1 hash in hexadecimal digits.
158
+
159
+ Returns
160
+ -------
161
+ bool
162
+ Whether the file content matches the expected hash.
163
+ """
164
+ sha1 = hashlib.sha1()
165
+ with open(filename, "rb") as f:
166
+ while True:
167
+ data = f.read(1048576)
168
+ if not data:
169
+ break
170
+ sha1.update(data)
171
+
172
+ return sha1.hexdigest() == sha1_hash
data_utils/deepspeech_features/extract_ds_features.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for extracting DeepSpeech features from audio file.
3
+ """
4
+
5
+ import os
6
+ import argparse
7
+ import numpy as np
8
+ import pandas as pd
9
+ from deepspeech_store import get_deepspeech_model_file
10
+ from deepspeech_features import conv_audios_to_deepspeech
11
+
12
+
13
+ def parse_args():
14
+ """
15
+ Create python script parameters.
16
+ Returns
17
+ -------
18
+ ArgumentParser
19
+ Resulted args.
20
+ """
21
+ parser = argparse.ArgumentParser(
22
+ description="Extract DeepSpeech features from audio file",
23
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24
+ parser.add_argument(
25
+ "--input",
26
+ type=str,
27
+ required=True,
28
+ help="path to input audio file or directory")
29
+ parser.add_argument(
30
+ "--output",
31
+ type=str,
32
+ help="path to output file with DeepSpeech features")
33
+ parser.add_argument(
34
+ "--deepspeech",
35
+ type=str,
36
+ help="path to DeepSpeech 0.1.0 frozen model")
37
+ parser.add_argument(
38
+ "--metainfo",
39
+ type=str,
40
+ help="path to file with meta-information")
41
+
42
+ args = parser.parse_args()
43
+ return args
44
+
45
+
46
+ def extract_features(in_audios,
47
+ out_files,
48
+ deepspeech_pb_path,
49
+ metainfo_file_path=None):
50
+ """
51
+ Real extract audio from video file.
52
+ Parameters
53
+ ----------
54
+ in_audios : list of str
55
+ Paths to input audio files.
56
+ out_files : list of str
57
+ Paths to output files with DeepSpeech features.
58
+ deepspeech_pb_path : str
59
+ Path to DeepSpeech 0.1.0 frozen model.
60
+ metainfo_file_path : str, default None
61
+ Path to file with meta-information.
62
+ """
63
+ #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
64
+ if metainfo_file_path is None:
65
+ num_frames_info = [None] * len(in_audios)
66
+ else:
67
+ train_df = pd.read_csv(
68
+ metainfo_file_path,
69
+ sep="\t",
70
+ index_col=False,
71
+ dtype={"Id": np.int, "File": np.unicode, "Count": np.int})
72
+ num_frames_info = train_df["Count"].values
73
+ assert (len(num_frames_info) == len(in_audios))
74
+
75
+ for i, in_audio in enumerate(in_audios):
76
+ if not out_files[i]:
77
+ file_stem, _ = os.path.splitext(in_audio)
78
+ out_files[i] = file_stem + "_ds.npy"
79
+ #print(out_files[i])
80
+ conv_audios_to_deepspeech(
81
+ audios=in_audios,
82
+ out_files=out_files,
83
+ num_frames_info=num_frames_info,
84
+ deepspeech_pb_path=deepspeech_pb_path)
85
+
86
+
87
+ def main():
88
+ """
89
+ Main body of script.
90
+ """
91
+ args = parse_args()
92
+ in_audio = os.path.expanduser(args.input)
93
+ if not os.path.exists(in_audio):
94
+ raise Exception("Input file/directory doesn't exist: {}".format(in_audio))
95
+ deepspeech_pb_path = args.deepspeech
96
+ #add
97
+ deepspeech_pb_path = True
98
+ args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb'
99
+ #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
100
+ if deepspeech_pb_path is None:
101
+ deepspeech_pb_path = ""
102
+ if deepspeech_pb_path:
103
+ deepspeech_pb_path = os.path.expanduser(args.deepspeech)
104
+ if not os.path.exists(deepspeech_pb_path):
105
+ deepspeech_pb_path = get_deepspeech_model_file()
106
+ if os.path.isfile(in_audio):
107
+ extract_features(
108
+ in_audios=[in_audio],
109
+ out_files=[args.output],
110
+ deepspeech_pb_path=deepspeech_pb_path,
111
+ metainfo_file_path=args.metainfo)
112
+ else:
113
+ audio_file_paths = []
114
+ for file_name in os.listdir(in_audio):
115
+ if not os.path.isfile(os.path.join(in_audio, file_name)):
116
+ continue
117
+ _, file_ext = os.path.splitext(file_name)
118
+ if file_ext.lower() == ".wav":
119
+ audio_file_path = os.path.join(in_audio, file_name)
120
+ audio_file_paths.append(audio_file_path)
121
+ audio_file_paths = sorted(audio_file_paths)
122
+ out_file_paths = [""] * len(audio_file_paths)
123
+ extract_features(
124
+ in_audios=audio_file_paths,
125
+ out_files=out_file_paths,
126
+ deepspeech_pb_path=deepspeech_pb_path,
127
+ metainfo_file_path=args.metainfo)
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
132
+
data_utils/deepspeech_features/extract_wav.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for extracting audio (16-bit, mono, 22000 Hz) from video file.
3
+ """
4
+
5
+ import os
6
+ import argparse
7
+ import subprocess
8
+
9
+
10
+ def parse_args():
11
+ """
12
+ Create python script parameters.
13
+
14
+ Returns
15
+ -------
16
+ ArgumentParser
17
+ Resulted args.
18
+ """
19
+ parser = argparse.ArgumentParser(
20
+ description="Extract audio from video file",
21
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22
+ parser.add_argument(
23
+ "--in-video",
24
+ type=str,
25
+ required=True,
26
+ help="path to input video file or directory")
27
+ parser.add_argument(
28
+ "--out-audio",
29
+ type=str,
30
+ help="path to output audio file")
31
+
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+
36
+ def extract_audio(in_video,
37
+ out_audio):
38
+ """
39
+ Real extract audio from video file.
40
+
41
+ Parameters
42
+ ----------
43
+ in_video : str
44
+ Path to input video file.
45
+ out_audio : str
46
+ Path to output audio file.
47
+ """
48
+ if not out_audio:
49
+ file_stem, _ = os.path.splitext(in_video)
50
+ out_audio = file_stem + ".wav"
51
+ # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}"
52
+ # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
53
+ # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
54
+ command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}"
55
+ subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True)
56
+
57
+
58
+ def main():
59
+ """
60
+ Main body of script.
61
+ """
62
+ args = parse_args()
63
+ in_video = os.path.expanduser(args.in_video)
64
+ if not os.path.exists(in_video):
65
+ raise Exception("Input file/directory doesn't exist: {}".format(in_video))
66
+ if os.path.isfile(in_video):
67
+ extract_audio(
68
+ in_video=in_video,
69
+ out_audio=args.out_audio)
70
+ else:
71
+ video_file_paths = []
72
+ for file_name in os.listdir(in_video):
73
+ if not os.path.isfile(os.path.join(in_video, file_name)):
74
+ continue
75
+ _, file_ext = os.path.splitext(file_name)
76
+ if file_ext.lower() in (".mp4", ".mkv", ".avi"):
77
+ video_file_path = os.path.join(in_video, file_name)
78
+ video_file_paths.append(video_file_path)
79
+ video_file_paths = sorted(video_file_paths)
80
+ for video_file_path in video_file_paths:
81
+ extract_audio(
82
+ in_video=video_file_path,
83
+ out_audio="")
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
data_utils/deepspeech_features/fea_win.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ net_output = np.load('french.ds.npy').reshape(-1, 29)
4
+ win_size = 16
5
+ zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
6
+ net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0)
7
+ windows = []
8
+ for window_index in range(0, net_output.shape[0] - win_size, 2):
9
+ windows.append(net_output[window_index:window_index + win_size])
10
+ print(np.array(windows).shape)
11
+ np.save('aud_french.npy', np.array(windows))
data_utils/face_parsing/79999_iter.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
3
+ size 53289463
data_utils/face_parsing/logger.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import os.path as osp
6
+ import time
7
+ import sys
8
+ import logging
9
+
10
+ import torch.distributed as dist
11
+
12
+
13
+ def setup_logger(logpth):
14
+ logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
15
+ logfile = osp.join(logpth, logfile)
16
+ FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
17
+ log_level = logging.INFO
18
+ if dist.is_initialized() and not dist.get_rank()==0:
19
+ log_level = logging.ERROR
20
+ logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
21
+ logging.root.addHandler(logging.StreamHandler())
22
+
23
+
data_utils/face_parsing/model.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18()
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, n_classes, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath()
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+
255
+ # return feat_out, feat_out16, feat_out32
256
+ return feat_out
257
+
258
+ def init_weight(self):
259
+ for ly in self.children():
260
+ if isinstance(ly, nn.Conv2d):
261
+ nn.init.kaiming_normal_(ly.weight, a=1)
262
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
263
+
264
+ def get_params(self):
265
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
266
+ for name, child in self.named_children():
267
+ child_wd_params, child_nowd_params = child.get_params()
268
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
269
+ lr_mul_wd_params += child_wd_params
270
+ lr_mul_nowd_params += child_nowd_params
271
+ else:
272
+ wd_params += child_wd_params
273
+ nowd_params += child_nowd_params
274
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
275
+
276
+
277
+ if __name__ == "__main__":
278
+ net = BiSeNet(19)
279
+ net.cuda()
280
+ net.eval()
281
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
282
+ out, out16, out32 = net(in_ten)
283
+ print(out.shape)
284
+
285
+ net.get_params()
data_utils/face_parsing/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ state_dict = modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
data_utils/face_parsing/test.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+ import numpy as np
4
+ from model import BiSeNet
5
+
6
+ import torch
7
+
8
+ import os
9
+ import os.path as osp
10
+
11
+ from PIL import Image
12
+ import torchvision.transforms as transforms
13
+ import cv2
14
+ from pathlib import Path
15
+ import configargparse
16
+ import tqdm
17
+
18
+ # import ttach as tta
19
+
20
+ def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg',
21
+ img_size=(512, 512)):
22
+ im = np.array(im)
23
+ vis_im = im.copy().astype(np.uint8)
24
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
25
+ vis_parsing_anno = cv2.resize(
26
+ vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
27
+ vis_parsing_anno_color = np.zeros(
28
+ (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
29
+ vis_parsing_anno_color_face = np.zeros(
30
+ (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255
31
+
32
+ num_of_class = np.max(vis_parsing_anno)
33
+ # print(num_of_class)
34
+ for pi in range(1, 14):
35
+ index = np.where(vis_parsing_anno == pi)
36
+ vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
37
+ for pi in range(14, 16):
38
+ index = np.where(vis_parsing_anno == pi)
39
+ vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0])
40
+ for pi in range(16, 17):
41
+ index = np.where(vis_parsing_anno == pi)
42
+ vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255])
43
+ for pi in range(17, num_of_class+1):
44
+ index = np.where(vis_parsing_anno == pi)
45
+ vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0])
46
+
47
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
48
+ index = np.where(vis_parsing_anno == num_of_class-1)
49
+ vis_im = cv2.resize(vis_parsing_anno_color, img_size,
50
+ interpolation=cv2.INTER_NEAREST)
51
+ if save_im:
52
+ cv2.imwrite(save_path, vis_im)
53
+
54
+ for pi in range(1, 7):
55
+ index = np.where(vis_parsing_anno == pi)
56
+ vis_parsing_anno_color_face[index[0], index[1], :] = np.array([255, 0, 0])
57
+ for pi in range(10, 14):
58
+ index = np.where(vis_parsing_anno == pi)
59
+ vis_parsing_anno_color_face[index[0], index[1], :] = np.array([255, 0, 0])
60
+ pad = 5
61
+ vis_parsing_anno_color_face = vis_parsing_anno_color_face.astype(np.uint8)
62
+ face_part = (vis_parsing_anno_color_face[..., 0] == 255) & (vis_parsing_anno_color_face[..., 1] == 0) & (vis_parsing_anno_color_face[..., 2] == 0)
63
+ face_coords = np.stack(np.nonzero(face_part), axis=-1)
64
+ sorted_inds = np.lexsort((-face_coords[:, 0], face_coords[:, 1]))
65
+ sorted_face_coords = face_coords[sorted_inds]
66
+ u, uid, ucnt = np.unique(sorted_face_coords[:, 1], return_index=True, return_counts=True)
67
+ bottom_face_coords = sorted_face_coords[uid] + np.array([pad, 0])
68
+ rows, cols, _ = vis_parsing_anno_color_face.shape
69
+
70
+ # 为了保证新的坐标在图片范围内
71
+ bottom_face_coords[:, 0] = np.clip(bottom_face_coords[:, 0], 0, rows - 1)
72
+
73
+ y_min = np.min(bottom_face_coords[:, 1])
74
+ y_max = np.max(bottom_face_coords[:, 1])
75
+
76
+ # 计算1和2部分的开始和结束位置
77
+ y_range = y_max - y_min
78
+ height_per_part = y_range // 4
79
+
80
+ start_y_part1 = y_min + height_per_part
81
+ end_y_part1 = start_y_part1 + height_per_part
82
+
83
+ start_y_part2 = end_y_part1
84
+ end_y_part2 = start_y_part2 + height_per_part
85
+
86
+ for coord in bottom_face_coords:
87
+ x, y = coord
88
+ start_x = max(x - pad, 0)
89
+ end_x = min(x + pad, rows)
90
+ if start_y_part1 <= y <= end_y_part1 or start_y_part2 <= y <= end_y_part2:
91
+ vis_parsing_anno_color_face[start_x:end_x, y] = [255, 0, 0]
92
+ # else:
93
+ # start_x = max(x - 2*pad, 0)
94
+ # end_x = max(x - pad, 0)
95
+ # vis_parsing_anno_color_face[start_x:end_x+1, y] = [255, 255, 255]
96
+
97
+ vis_im = cv2.GaussianBlur(vis_parsing_anno_color_face, (9, 9), cv2.BORDER_DEFAULT)
98
+
99
+ vis_im = cv2.resize(vis_im, img_size,
100
+ interpolation=cv2.INTER_NEAREST)
101
+
102
+ cv2.imwrite(save_path.replace('.png', '_face.png'), vis_im)
103
+
104
+
105
+ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
106
+
107
+ Path(respth).mkdir(parents=True, exist_ok=True)
108
+
109
+ print(f'[INFO] loading model...')
110
+ n_classes = 19
111
+ net = BiSeNet(n_classes=n_classes)
112
+ net.cuda()
113
+ net.load_state_dict(torch.load(cp))
114
+ net.eval()
115
+
116
+ to_tensor = transforms.Compose([
117
+ transforms.ToTensor(),
118
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
119
+ ])
120
+
121
+ image_paths = os.listdir(dspth)
122
+
123
+ with torch.no_grad():
124
+ for image_path in tqdm.tqdm(image_paths):
125
+ if image_path.endswith('.jpg') or image_path.endswith('.png'):
126
+ img = Image.open(osp.join(dspth, image_path))
127
+ ori_size = img.size
128
+ image = img.resize((512, 512), Image.BILINEAR)
129
+ image = image.convert("RGB")
130
+ img = to_tensor(image)
131
+
132
+ # test-time augmentation.
133
+ inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512]
134
+ outputs = net(inputs.cuda())
135
+ parsing = outputs.mean(0).cpu().numpy().argmax(0)
136
+ image_path = int(image_path[:-4])
137
+ image_path = str(image_path) + '.png'
138
+
139
+ vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ parser = configargparse.ArgumentParser()
144
+ parser.add_argument('--respath', type=str, default='./result/', help='result path for label')
145
+ parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images')
146
+ parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth')
147
+ args = parser.parse_args()
148
+ evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath)
data_utils/face_tracking/3DMM/exp_info.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3196029ed038eb9a461df6c782125fc4d1ec1545f2e5f361891471136b6cbb6
3
+ size 33264853
data_utils/face_tracking/3DMM/keys_info.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:028d3c383bae129c4bdcac880e22c551a5d2436ec7db7a26f5c57148d12469e6
3
+ size 7375
data_utils/face_tracking/3DMM/lands_info.txt ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 136
2
+ 19
3
+ 155
4
+ 22
5
+ 177
6
+ 19
7
+ 196
8
+ 15
9
+ 211
10
+ 12
11
+ 223
12
+ 8
13
+ 231
14
+ 6
15
+ 237
16
+ 4
17
+ 241
18
+ 6
19
+ 247
20
+ 4
21
+ 251
22
+ 6
23
+ 257
24
+ 8
25
+ 265
26
+ 12
27
+ 277
28
+ 15
29
+ 292
30
+ 19
31
+ 311
32
+ 22
33
+ 333
34
+ 19
35
+ 352
36
+ 1
37
+ 353
38
+ 1
39
+ 354
40
+ 1
41
+ 355
42
+ 1
43
+ 356
44
+ 1
45
+ 357
46
+ 1
47
+ 358
48
+ 1
49
+ 359
50
+ 1
51
+ 360
52
+ 1
53
+ 361
54
+ 1
55
+ 362
56
+ 1
57
+ 363
58
+ 1
59
+ 364
60
+ 1
61
+ 365
62
+ 1
63
+ 366
64
+ 1
65
+ 367
66
+ 1
67
+ 368
68
+ 1
69
+ 369
70
+ 1
71
+ 370
72
+ 1
73
+ 371
74
+ 1
75
+ 372
76
+ 1
77
+ 373
78
+ 1
79
+ 374
80
+ 1
81
+ 375
82
+ 1
83
+ 376
84
+ 1
85
+ 377
86
+ 1
87
+ 378
88
+ 1
89
+ 379
90
+ 1
91
+ 380
92
+ 1
93
+ 381
94
+ 1
95
+ 382
96
+ 1
97
+ 383
98
+ 1
99
+ 384
100
+ 1
101
+ 385
102
+ 1
103
+ 386
104
+ 1
105
+ 387
106
+ 1
107
+ 388
108
+ 1
109
+ 389
110
+ 1
111
+ 390
112
+ 1
113
+ 391
114
+ 1
115
+ 392
116
+ 1
117
+ 393
118
+ 1
119
+ 394
120
+ 1
121
+ 395
122
+ 1
123
+ 396
124
+ 1
125
+ 397
126
+ 1
127
+ 398
128
+ 1
129
+ 399
130
+ 1
131
+ 400
132
+ 1
133
+ 401
134
+ 1
135
+ 402
136
+ 1
137
+ 16655
138
+ 16901
139
+ 17155
140
+ 17412
141
+ 17669
142
+ 17926
143
+ 18183
144
+ 18440
145
+ 18826
146
+ 19083
147
+ 19340
148
+ 19726
149
+ 19983
150
+ 20240
151
+ 20625
152
+ 21010
153
+ 21396
154
+ 157
155
+ 671
156
+ 16922
157
+ 17177
158
+ 17435
159
+ 17821
160
+ 18208
161
+ 18594
162
+ 18980
163
+ 19366
164
+ 19752
165
+ 20139
166
+ 20525
167
+ 20911
168
+ 21168
169
+ 21555
170
+ 188
171
+ 575
172
+ 961
173
+ 1477
174
+ 1863
175
+ 2249
176
+ 2636
177
+ 3280
178
+ 16411
179
+ 16948
180
+ 17589
181
+ 18232
182
+ 18876
183
+ 19262
184
+ 19648
185
+ 20163
186
+ 20678
187
+ 21192
188
+ 21707
189
+ 340
190
+ 855
191
+ 1370
192
+ 1756
193
+ 2142
194
+ 2657
195
+ 3043
196
+ 3429
197
+ 16363
198
+ 16973
199
+ 17871
200
+ 18644
201
+ 19416
202
+ 20189
203
+ 20833
204
+ 21733
205
+ 752
206
+ 1523
207
+ 2037
208
+ 2681
209
+ 3323
210
+ 3708
211
+ 4222
212
+ 31497
213
+ 31491
214
+ 31484
215
+ 31555
216
+ 31626
217
+ 31730
218
+ 31865
219
+ 3224
220
+ 3737
221
+ 4250
222
+ 4764
223
+ 5150
224
+ 32139
225
+ 32192
226
+ 32271
227
+ 32368
228
+ 32436
229
+ 32521
230
+ 32600
231
+ 32655
232
+ 32445
233
+ 32465
234
+ 32506
235
+ 32546
236
+ 32585
237
+ 32640
238
+ 32716
239
+ 32733
240
+ 32750
241
+ 32785
242
+ 32914
243
+ 32913
244
+ 32912
245
+ 32911
246
+ 32910
247
+ 32909
248
+ 33076
249
+ 33057
250
+ 33038
251
+ 33001
252
+ 33357
253
+ 33333
254
+ 33287
255
+ 33243
256
+ 33202
257
+ 33144
258
+ 33675
259
+ 33612
260
+ 33524
261
+ 33420
262
+ 33348
263
+ 33260
264
+ 33179
265
+ 33123
266
+ 34322
267
+ 34316
268
+ 34309
269
+ 34227
270
+ 34147
271
+ 34034
272
+ 33897
273
+ 13269
274
+ 12750
275
+ 12231
276
+ 11713
277
+ 11325
278
+ 27304
279
+ 26767
280
+ 25869
281
+ 25094
282
+ 24318
283
+ 23543
284
+ 22897
285
+ 21991
286
+ 15699
287
+ 14922
288
+ 14404
289
+ 13758
290
+ 13110
291
+ 12721
292
+ 12203
293
+ 27231
294
+ 26742
295
+ 26103
296
+ 25456
297
+ 24810
298
+ 24422
299
+ 24034
300
+ 23517
301
+ 23000
302
+ 22482
303
+ 21965
304
+ 16061
305
+ 15544
306
+ 15027
307
+ 14639
308
+ 14251
309
+ 13734
310
+ 13346
311
+ 12958
312
+ 26716
313
+ 26465
314
+ 26207
315
+ 25819
316
+ 25432
317
+ 25044
318
+ 24656
319
+ 24268
320
+ 23880
321
+ 23493
322
+ 23105
323
+ 22717
324
+ 22458
325
+ 22071
326
+ 16167
327
+ 15780
328
+ 15392
329
+ 14876
330
+ 14488
331
+ 14100
332
+ 13713
333
+ 13067
334
+ 26939
335
+ 26695
336
+ 26443
337
+ 26184
338
+ 25925
339
+ 25666
340
+ 25407
341
+ 25148
342
+ 24760
343
+ 24501
344
+ 24242
345
+ 23854
346
+ 23595
347
+ 23336
348
+ 22947
349
+ 22558
350
+ 22170
351
+ 16136
352
+ 15618
353
+ 27932
354
+ 28270
355
+ 28552
356
+ 28771
357
+ 28990
358
+ 29567
359
+ 29780
360
+ 30000
361
+ 30316
362
+ 30627
363
+ 8155
364
+ 8173
365
+ 8184
366
+ 8190
367
+ 6516
368
+ 7363
369
+ 8203
370
+ 9043
371
+ 9884
372
+ 1828
373
+ 4016
374
+ 5177
375
+ 6341
376
+ 4804
377
+ 3771
378
+ 9955
379
+ 11094
380
+ 12255
381
+ 14323
382
+ 12526
383
+ 11495
384
+ 5262
385
+ 6024
386
+ 7375
387
+ 8215
388
+ 9055
389
+ 10394
390
+ 11179
391
+ 9674
392
+ 8835
393
+ 8235
394
+ 7635
395
+ 6793
396
+ 5779
397
+ 7384
398
+ 8225
399
+ 9064
400
+ 10536
401
+ 8828
402
+ 8228
403
+ 7628
data_utils/face_tracking/3DMM/sub_mesh.obj ADDED
The diff for this file is too large to render. See raw diff
 
data_utils/face_tracking/3DMM/topology_info.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2edff6a6ad574d2dddf0d0815e0beabfea9369d7c2d6e53e0ba81f809b81e963
3
+ size 4145201
data_utils/face_tracking/3DMM/tris.txt ADDED
The diff for this file is too large to render. See raw diff
 
data_utils/face_tracking/3DMM/vert_tris.txt ADDED
The diff for this file is too large to render. See raw diff
 
data_utils/face_tracking/__init__.py ADDED
File without changes
data_utils/face_tracking/bundle_adjustment.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ from util import *
4
+ import argparse
5
+
6
+
7
+ def set_requires_grad(tensor_list):
8
+ for tensor in tensor_list:
9
+ tensor.requires_grad = True
10
+
11
+
12
+ parser = argparse.ArgumentParser()
13
+
14
+ parser.add_argument(
15
+ "--path", type=str, default="", help="idname of target person")
16
+ parser.add_argument('--img_h', type=int, default=512, help='height if image')
17
+ parser.add_argument('--img_w', type=int, default=512, help='width of image')
18
+ args = parser.parse_args()
19
+ id_dir = args.path
20
+
21
+ params_dict = torch.load(os.path.join(id_dir, 'track_params.pt'))
22
+ euler_angle = params_dict['euler'].cuda()
23
+ trans = params_dict['trans'].cuda() / 1000.0
24
+ focal_len = params_dict['focal'].cuda()
25
+
26
+ track_xys = torch.as_tensor(
27
+ np.load(os.path.join(id_dir, 'track_xys.npy'))).float().cuda()
28
+ num_frames = track_xys.shape[0]
29
+ point_num = track_xys.shape[1]
30
+
31
+ pts = torch.zeros((point_num, 3), dtype=torch.float32).cuda()
32
+ set_requires_grad([euler_angle, trans, pts])
33
+
34
+ cxy = torch.Tensor((args.img_w/2.0, args.img_h/2.0)).float().cuda()
35
+
36
+ optimizer_pts = torch.optim.Adam([pts], lr=1e-2)
37
+ iter_num = 500
38
+ for iter in range(iter_num):
39
+ proj_pts = forward_transform(pts.unsqueeze(0).expand(
40
+ num_frames, -1, -1), euler_angle, trans, focal_len, cxy)
41
+ loss = cal_lan_loss(proj_pts[..., :2], track_xys)
42
+ optimizer_pts.zero_grad()
43
+ loss.backward()
44
+ optimizer_pts.step()
45
+
46
+
47
+ optimizer_ba = torch.optim.Adam([pts, euler_angle, trans], lr=1e-4)
48
+
49
+
50
+ iter_num = 8000
51
+ for iter in range(iter_num):
52
+ proj_pts = forward_transform(pts.unsqueeze(0).expand(
53
+ num_frames, -1, -1), euler_angle, trans, focal_len, cxy)
54
+ loss_lan = cal_lan_loss(proj_pts[..., :2], track_xys)
55
+ loss = loss_lan
56
+ optimizer_ba.zero_grad()
57
+ loss.backward()
58
+ optimizer_ba.step()
59
+
60
+ torch.save({'euler': euler_angle.detach().cpu(),
61
+ 'trans': trans.detach().cpu(),
62
+ 'focal': focal_len.detach().cpu()}, os.path.join(id_dir, 'bundle_adjustment.pt'))
63
+ print('bundle adjustment params saved')