feng2022 commited on
Commit
89d1ee7
·
1 Parent(s): f9f1dd0

anothertry

Browse files
Files changed (46) hide show
  1. Time-Travel-Rephotography +0 -1
  2. Time_TravelRephotography/LICENSE +21 -0
  3. Time_TravelRephotography/LICENSE-NVIDIA +101 -0
  4. Time_TravelRephotography/LICENSE-STYLEGAN2 +21 -0
  5. Time_TravelRephotography/README.md +118 -0
  6. Time_TravelRephotography/losses/color_transfer_loss.py +60 -0
  7. Time_TravelRephotography/losses/joint_loss.py +167 -0
  8. Time_TravelRephotography/losses/perceptual_loss.py +111 -0
  9. Time_TravelRephotography/losses/reconstruction.py +119 -0
  10. Time_TravelRephotography/losses/regularize_noise.py +37 -0
  11. Time_TravelRephotography/model.py +697 -0
  12. Time_TravelRephotography/models/__init__.py +0 -0
  13. Time_TravelRephotography/models/degrade.py +122 -0
  14. Time_TravelRephotography/models/encoder.py +66 -0
  15. Time_TravelRephotography/models/gaussian_smoothing.py +74 -0
  16. Time_TravelRephotography/models/resnet.py +99 -0
  17. Time_TravelRephotography/models/vggface.py +150 -0
  18. Time_TravelRephotography/op/__init__.py +2 -0
  19. Time_TravelRephotography/op/fused_act.py +86 -0
  20. Time_TravelRephotography/op/fused_bias_act.cpp +21 -0
  21. Time_TravelRephotography/op/fused_bias_act_kernel.cu +99 -0
  22. Time_TravelRephotography/op/upfirdn2d.cpp +23 -0
  23. Time_TravelRephotography/op/upfirdn2d.py +187 -0
  24. Time_TravelRephotography/op/upfirdn2d_kernel.cu +272 -0
  25. Time_TravelRephotography/optim/__init__.py +15 -0
  26. Time_TravelRephotography/optim/radam.py +250 -0
  27. Time_TravelRephotography/projector.py +172 -0
  28. Time_TravelRephotography/requirements.txt +26 -0
  29. Time_TravelRephotography/scripts/download_checkpoints.sh +14 -0
  30. Time_TravelRephotography/scripts/install.sh +6 -0
  31. Time_TravelRephotography/scripts/run.sh +34 -0
  32. Time_TravelRephotography/tools/__init__.py +0 -0
  33. Time_TravelRephotography/tools/data/__init__.py +0 -0
  34. Time_TravelRephotography/tools/data/align_images.py +117 -0
  35. Time_TravelRephotography/tools/initialize.py +160 -0
  36. Time_TravelRephotography/tools/match_histogram.py +167 -0
  37. Time_TravelRephotography/tools/match_skin_histogram.py +67 -0
  38. Time_TravelRephotography/tools/parse_face.py +55 -0
  39. Time_TravelRephotography/utils/__init__.py +0 -0
  40. Time_TravelRephotography/utils/ffhq_dataset/__init__.py +0 -0
  41. Time_TravelRephotography/utils/ffhq_dataset/face_alignment.py +99 -0
  42. Time_TravelRephotography/utils/ffhq_dataset/landmarks_detector.py +71 -0
  43. Time_TravelRephotography/utils/misc.py +18 -0
  44. Time_TravelRephotography/utils/optimize.py +230 -0
  45. Time_TravelRephotography/utils/projector_arguments.py +76 -0
  46. Time_TravelRephotography/utils/torch_helpers.py +36 -0
Time-Travel-Rephotography DELETED
@@ -1 +0,0 @@
1
- Subproject commit 2045d895f671e72e4dca1f81327b1ce462a7d32f
 
 
Time_TravelRephotography/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Time-Travel-Rephotography
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Time_TravelRephotography/LICENSE-NVIDIA ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+
3
+
4
+ Nvidia Source Code License-NC
5
+
6
+ =======================================================================
7
+
8
+ 1. Definitions
9
+
10
+ "Licensor" means any person or entity that distributes its Work.
11
+
12
+ "Software" means the original work of authorship made available under
13
+ this License.
14
+
15
+ "Work" means the Software and any additions to or derivative works of
16
+ the Software that are made available under this License.
17
+
18
+ "Nvidia Processors" means any central processing unit (CPU), graphics
19
+ processing unit (GPU), field-programmable gate array (FPGA),
20
+ application-specific integrated circuit (ASIC) or any combination
21
+ thereof designed, made, sold, or provided by Nvidia or its affiliates.
22
+
23
+ The terms "reproduce," "reproduction," "derivative works," and
24
+ "distribution" have the meaning as provided under U.S. copyright law;
25
+ provided, however, that for the purposes of this License, derivative
26
+ works shall not include works that remain separable from, or merely
27
+ link (or bind by name) to the interfaces of, the Work.
28
+
29
+ Works, including the Software, are "made available" under this License
30
+ by including in or with the Work either (a) a copyright notice
31
+ referencing the applicability of this License to the Work, or (b) a
32
+ copy of this License.
33
+
34
+ 2. License Grants
35
+
36
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
37
+ License, each Licensor grants to you a perpetual, worldwide,
38
+ non-exclusive, royalty-free, copyright license to reproduce,
39
+ prepare derivative works of, publicly display, publicly perform,
40
+ sublicense and distribute its Work and any resulting derivative
41
+ works in any form.
42
+
43
+ 3. Limitations
44
+
45
+ 3.1 Redistribution. You may reproduce or distribute the Work only
46
+ if (a) you do so under this License, (b) you include a complete
47
+ copy of this License with your distribution, and (c) you retain
48
+ without modification any copyright, patent, trademark, or
49
+ attribution notices that are present in the Work.
50
+
51
+ 3.2 Derivative Works. You may specify that additional or different
52
+ terms apply to the use, reproduction, and distribution of your
53
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
54
+ provide that the use limitation in Section 3.3 applies to your
55
+ derivative works, and (b) you identify the specific derivative
56
+ works that are subject to Your Terms. Notwithstanding Your Terms,
57
+ this License (including the redistribution requirements in Section
58
+ 3.1) will continue to apply to the Work itself.
59
+
60
+ 3.3 Use Limitation. The Work and any derivative works thereof only
61
+ may be used or intended for use non-commercially. The Work or
62
+ derivative works thereof may be used or intended for use by Nvidia
63
+ or its affiliates commercially or non-commercially. As used herein,
64
+ "non-commercially" means for research or evaluation purposes only.
65
+
66
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
67
+ against any Licensor (including any claim, cross-claim or
68
+ counterclaim in a lawsuit) to enforce any patents that you allege
69
+ are infringed by any Work, then your rights under this License from
70
+ such Licensor (including the grants in Sections 2.1 and 2.2) will
71
+ terminate immediately.
72
+
73
+ 3.5 Trademarks. This License does not grant any rights to use any
74
+ Licensor's or its affiliates' names, logos, or trademarks, except
75
+ as necessary to reproduce the notices described in this License.
76
+
77
+ 3.6 Termination. If you violate any term of this License, then your
78
+ rights under this License (including the grants in Sections 2.1 and
79
+ 2.2) will terminate immediately.
80
+
81
+ 4. Disclaimer of Warranty.
82
+
83
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
84
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
85
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
86
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
87
+ THIS LICENSE.
88
+
89
+ 5. Limitation of Liability.
90
+
91
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
92
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
93
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
94
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
95
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
96
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
97
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
98
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
99
+ THE POSSIBILITY OF SUCH DAMAGES.
100
+
101
+ =======================================================================
Time_TravelRephotography/LICENSE-STYLEGAN2 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Kim Seonghyeon
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Time_TravelRephotography/README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [SIGGRAPH Asia 2021] Time-Travel Rephotography
2
+ <a href="https://arxiv.org/abs/2012.12261"><img src="https://img.shields.io/badge/arXiv-2008.00951-b31b1b.svg"></a>
3
+ <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
4
+ [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1KZXGkHVhvz2X3ljaCQANC1bDr7OrzDpg?usp=sharing)
5
+ ### [[Project Website](https://time-travel-rephotography.github.io/)]
6
+ <p align='center'>
7
+ <img src="time-travel-rephotography.gif" width='100%'/>
8
+ </p>
9
+
10
+ Many historical people were only ever captured by old, faded, black and white photos, that are distorted due to the limitations of early cameras and the passage of time. This paper simulates traveling back in time with a modern camera to rephotograph famous subjects. Unlike conventional image restoration filters which apply independent operations like denoising, colorization, and superresolution, we leverage the StyleGAN2 framework to project old photos into the space of modern high-resolution photos, achieving all of these effects in a unified framework. A unique challenge with this approach is retaining the identity and pose of the subject in the original photo, while discarding the many artifacts frequently seen in low-quality antique photos. Our comparisons to current state-of-the-art restoration filters show significant improvements and compelling results for a variety of important historical people.
11
+ <br/>
12
+
13
+ **Time-Travel Rephotography**
14
+ <br/>
15
+ [Xuan Luo](https://roxanneluo.github.io),
16
+ [Xuaner Zhang](https://people.eecs.berkeley.edu/~cecilia77/),
17
+ [Paul Yoo](https://www.linkedin.com/in/paul-yoo-768a3715b),
18
+ [Ricardo Martin-Brualla](http://www.ricardomartinbrualla.com/),
19
+ [Jason Lawrence](http://jasonlawrence.info/), and
20
+ [Steven M. Seitz](https://homes.cs.washington.edu/~seitz/)
21
+ <br/>
22
+ In SIGGRAPH Asia 2021.
23
+
24
+ ## Demo
25
+ We provide an easy-to-get-started demo using Google Colab!
26
+ The Colab will allow you to try our method on the sample Abraham Lincoln photo or **your own photos** using Cloud GPUs on Google Colab.
27
+
28
+ [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/15D2WIF_vE2l48ddxEx45cM3RykZwQXM8?usp=sharing)
29
+
30
+ Or you can run our method on your own machine following the instructions below.
31
+
32
+ ## Prerequisite
33
+ - Pull third-party packages.
34
+ ```
35
+ git submodule update --init --recursive
36
+ ```
37
+ - Install python packages.
38
+ ```
39
+ conda create --name rephotography python=3.8.5
40
+ conda activate rephotography
41
+ conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch
42
+ pip install -r requirements.txt
43
+ ```
44
+
45
+ ## Quick Start
46
+ Run our method on the example photo of Abraham Lincoln.
47
+ - Download models:
48
+ ```
49
+ ./scripts/download_checkpoints.sh
50
+ ```
51
+ - Run:
52
+ ```
53
+ ./scripts/run.sh b "dataset/Abraham Lincoln_01.png" 0.75
54
+ ```
55
+ - You can inspect the optimization process by
56
+ ```
57
+ tensorboard --logdir "log/Abraham Lincoln_01"
58
+ ```
59
+ - You can find your results as below.
60
+ ```
61
+ results/
62
+ Abraham Lincoln_01/ # intermediate outputs for histogram matching and face parsing
63
+ Abraham Lincoln_01_b.png # the input after matching the histogram of the sibling image
64
+ Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750)-init.png # the sibling image
65
+ Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750)-init.pt # the sibing latent codes and initialized noise maps
66
+ Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750).png # the output result
67
+ Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750).pt # the final optimized latent codes and noise maps
68
+ Abraham Lincoln_01-b-G0.75-init(10,18)-s256-vgg1-vggface0.3-eye0.1-color1.0e+10-cx0.1(relu3_4,relu2_2,relu1_2)-NR5.0e+04-lr0.1_0.01-c32-wp(250,750)-rand.png # the result with the final latent codes but random noise maps
69
+
70
+ ```
71
+
72
+ ## Run on Your Own Image
73
+ - Crop and align the head regions of your images:
74
+ ```
75
+ python -m tools.data.align_images <input_raw_image_dir> <aligned_image_dir>
76
+ ```
77
+ - Run:
78
+ ```
79
+ ./scripts/run.sh <spectral_sensitivity> <input_image_path> <blur_radius>
80
+ ```
81
+ The `spectral_sensitivity` can be `b` (blue-sensitive), `gb` (orthochromatic), or `g` (panchromatic). You can roughly estimate the `spectral_sensitivity` of your photo as follows. Use the *blue-sensitive* model for photos before 1873, manually select between blue-sensitive and *orthochromatic* for images from 1873 to 1906 and among all models for photos taken afterwards.
82
+
83
+ The `blur_radius` is the estimated gaussian blur radius in pixels if the input photot is resized to 1024x1024.
84
+
85
+ ## Historical Wiki Face Dataset
86
+ | Path | Size | Description |
87
+ |----------- | ----------- | ----------- |
88
+ | [Historical Wiki Face Dataset.zip](https://drive.google.com/open?id=1mgC2U7quhKSz_lTL97M-0cPrIILTiUCE&authuser=xuanluo%40cs.washington.edu&usp=drive_fs)| 148 MB | Images|
89
+ | [spectral_sensitivity.json](https://drive.google.com/open?id=1n3Bqd8G0g-wNpshlgoZiOMXxLlOycAXr&authuser=xuanluo%40cs.washington.edu&usp=drive_fs)| 6 KB | Spectral sensitivity (`b`, `gb`, or `g`). |
90
+ | [blur_radius.json](https://drive.google.com/open?id=1n4vUsbQo2BcxtKVMGfD1wFHaINzEmAVP&authuser=xuanluo%40cs.washington.edu&usp=drive_fs)| 6 KB | Blur radius in pixels|
91
+
92
+ The `json`s are dictionares that map input names to the corresponding spectral sensitivity or blur radius.
93
+ Due to copyright constraints, `Historical Wiki Face Dataset.zip` contains all images in the *Historical Wiki Face Dataset* that were used in our user study except the photo of [Mao Zedong](https://en.wikipedia.org/wiki/File:Mao_Zedong_in_1959_%28cropped%29.jpg). You can download it separately and crop it as [above](#run-on-your-own-image).
94
+
95
+ ## Citation
96
+ If you find our code useful, please consider citing our paper:
97
+ ```
98
+ @article{Luo-Rephotography-2021,
99
+ author = {Luo, Xuan and Zhang, Xuaner and Yoo, Paul and Martin-Brualla, Ricardo and Lawrence, Jason and Seitz, Steven M.},
100
+ title = {Time-Travel Rephotography},
101
+ journal = {ACM Transactions on Graphics (Proceedings of ACM SIGGRAPH Asia 2021)},
102
+ publisher = {ACM New York, NY, USA},
103
+ volume = {40},
104
+ number = {6},
105
+ articleno = {213},
106
+ doi = {https://doi.org/10.1145/3478513.3480485},
107
+ year = {2021},
108
+ month = {12}
109
+ }
110
+ ```
111
+
112
+ ## License
113
+ This work is licensed under MIT License. See [LICENSE](LICENSE) for details.
114
+
115
+ Codes for the StyleGAN2 model come from [https://github.com/rosinality/stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch).
116
+
117
+ ## Acknowledgments
118
+ We thank [Nick Brandreth](https://www.nickbrandreth.com/) for capturing the dry plate photos. We thank Bo Zhang, Qingnan Fan, Roy Or-El, Aleksander Holynski and Keunhong Park for insightful advice.
Time_TravelRephotography/losses/color_transfer_loss.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.functional import (
6
+ smooth_l1_loss,
7
+ )
8
+
9
+
10
+ def flatten_CHW(im: torch.Tensor) -> torch.Tensor:
11
+ """
12
+ (B, C, H, W) -> (B, -1)
13
+ """
14
+ B = im.shape[0]
15
+ return im.reshape(B, -1)
16
+
17
+
18
+ def stddev(x: torch.Tensor) -> torch.Tensor:
19
+ """
20
+ x: (B, -1), assume with mean normalized
21
+ Retuens:
22
+ stddev: (B)
23
+ """
24
+ return torch.sqrt(torch.mean(x * x, dim=-1))
25
+
26
+
27
+ def gram_matrix(input_):
28
+ B, C = input_.shape[:2]
29
+ features = input_.view(B, C, -1)
30
+ N = features.shape[-1]
31
+ G = torch.bmm(features, features.transpose(1, 2)) # C x C
32
+ return G.div(C * N)
33
+
34
+
35
+ class ColorTransferLoss(nn.Module):
36
+ """Penalize the gram matrix difference between StyleGAN2's ToRGB outputs"""
37
+ def __init__(
38
+ self,
39
+ init_rgbs,
40
+ scale_rgb: bool = False
41
+ ):
42
+ super().__init__()
43
+
44
+ with torch.no_grad():
45
+ init_feats = [x.detach() for x in init_rgbs]
46
+ self.stds = [stddev(flatten_CHW(rgb)) if scale_rgb else 1 for rgb in init_feats] # (B, 1, 1, 1) or scalar
47
+ self.grams = [gram_matrix(rgb / std) for rgb, std in zip(init_feats, self.stds)]
48
+
49
+ def forward(self, rgbs: List[torch.Tensor], level: int = None):
50
+ if level is None:
51
+ level = len(self.grams)
52
+
53
+ feats = rgbs
54
+ loss = 0
55
+ for i, (rgb, std) in enumerate(zip(feats[:level], self.stds[:level])):
56
+ G = gram_matrix(rgb / std)
57
+ loss = loss + smooth_l1_loss(G, self.grams[i])
58
+
59
+ return loss
60
+
Time_TravelRephotography/losses/joint_loss.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import (
2
+ ArgumentParser,
3
+ Namespace,
4
+ )
5
+ from typing import (
6
+ Dict,
7
+ Iterable,
8
+ Optional,
9
+ Tuple,
10
+ )
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+
16
+ from utils.misc import (
17
+ optional_string,
18
+ iterable_to_str,
19
+ )
20
+
21
+ from .contextual_loss import ContextualLoss
22
+ from .color_transfer_loss import ColorTransferLoss
23
+ from .regularize_noise import NoiseRegularizer
24
+ from .reconstruction import (
25
+ EyeLoss,
26
+ FaceLoss,
27
+ create_perceptual_loss,
28
+ ReconstructionArguments,
29
+ )
30
+
31
+ class LossArguments:
32
+ @staticmethod
33
+ def add_arguments(parser: ArgumentParser):
34
+ ReconstructionArguments.add_arguments(parser)
35
+
36
+ parser.add_argument("--color_transfer", type=float, default=1e10, help="color transfer loss weight")
37
+ parser.add_argument("--eye", type=float, default=0.1, help="eye loss weight")
38
+ parser.add_argument('--noise_regularize', type=float, default=5e4)
39
+ # contextual loss
40
+ parser.add_argument("--contextual", type=float, default=0.1, help="contextual loss weight")
41
+ parser.add_argument("--cx_layers", nargs='*', help="contextual loss layers",
42
+ choices=['relu1_2', 'relu2_2', 'relu3_4', 'relu4_4', 'relu5_4'],
43
+ default=['relu3_4', 'relu2_2', 'relu1_2'])
44
+
45
+ @staticmethod
46
+ def to_string(args: Namespace) -> str:
47
+ return (
48
+ ReconstructionArguments.to_string(args)
49
+ + optional_string(args.eye > 0, f"-eye{args.eye}")
50
+ + optional_string(args.color_transfer, f"-color{args.color_transfer:.1e}")
51
+ + optional_string(
52
+ args.contextual,
53
+ f"-cx{args.contextual}({iterable_to_str(args.cx_layers)})"
54
+ )
55
+ #+ optional_string(args.mse, f"-mse{args.mse}")
56
+ + optional_string(args.noise_regularize, f"-NR{args.noise_regularize:.1e}")
57
+ )
58
+
59
+
60
+ class BakedMultiContextualLoss(nn.Module):
61
+ """Random sample different image patches for different vgg layers."""
62
+ def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
63
+ super().__init__()
64
+
65
+ self.cxs = nn.ModuleList([ContextualLoss(use_vgg=True, vgg_layers=[layer])
66
+ for layer in args.cx_layers])
67
+ self.size = size
68
+ self.sibling = sibling.detach()
69
+
70
+ def forward(self, img: torch.Tensor):
71
+ cx_loss = 0
72
+ for cx in self.cxs:
73
+ h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
74
+ cx_loss = cx(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size]) + cx_loss
75
+ return cx_loss
76
+
77
+
78
+ class BakedContextualLoss(ContextualLoss):
79
+ def __init__(self, sibling: torch.Tensor, args: Namespace, size: int = 256):
80
+ super().__init__(use_vgg=True, vgg_layers=args.cx_layers)
81
+ self.size = size
82
+ self.sibling = sibling.detach()
83
+
84
+ def forward(self, img: torch.Tensor):
85
+ h, w = np.random.randint(0, high=img.shape[-1] - self.size, size=2)
86
+ return super().forward(self.sibling[..., h:h+self.size, w:w+self.size], img[..., h:h+self.size, w:w+self.size])
87
+
88
+
89
+ class JointLoss(nn.Module):
90
+ def __init__(
91
+ self,
92
+ args: Namespace,
93
+ target: torch.Tensor,
94
+ sibling: Optional[torch.Tensor],
95
+ sibling_rgbs: Optional[Iterable[torch.Tensor]] = None,
96
+ ):
97
+ super().__init__()
98
+
99
+ self.weights = {
100
+ "face": 1., "eye": args.eye,
101
+ "contextual": args.contextual, "color_transfer": args.color_transfer,
102
+ "noise": args.noise_regularize,
103
+ }
104
+
105
+ reconstruction = {}
106
+ if args.vgg > 0 or args.vggface > 0:
107
+ percept = create_perceptual_loss(args)
108
+ reconstruction.update(
109
+ {"face": FaceLoss(target, input_size=args.generator_size, size=args.recon_size, percept=percept)}
110
+ )
111
+ if args.eye > 0:
112
+ reconstruction.update(
113
+ {"eye": EyeLoss(target, input_size=args.generator_size, percept=percept)}
114
+ )
115
+ self.reconstruction = nn.ModuleDict(reconstruction)
116
+
117
+ exemplar = {}
118
+ if args.contextual > 0 and len(args.cx_layers) > 0:
119
+ assert sibling is not None
120
+ exemplar.update(
121
+ {"contextual": BakedContextualLoss(sibling, args)}
122
+ )
123
+ if args.color_transfer > 0:
124
+ assert sibling_rgbs is not None
125
+ self.sibling_rgbs = sibling_rgbs
126
+ exemplar.update(
127
+ {"color_transfer": ColorTransferLoss(init_rgbs=sibling_rgbs)}
128
+ )
129
+ self.exemplar = nn.ModuleDict(exemplar)
130
+
131
+ if args.noise_regularize > 0:
132
+ self.noise_criterion = NoiseRegularizer()
133
+
134
+ def forward(
135
+ self, img, degrade=None, noises=None, rgbs=None, rgb_level: Optional[int] = None
136
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
137
+ """
138
+ Args:
139
+ rgbs: results from the ToRGB layers
140
+ """
141
+ # TODO: add current optimization resolution for noises
142
+
143
+ losses = {}
144
+
145
+ # reconstruction losses
146
+ for name, criterion in self.reconstruction.items():
147
+ losses[name] = criterion(img, degrade=degrade)
148
+
149
+ # exemplar losses
150
+ if 'contextual' in self.exemplar:
151
+ losses["contextual"] = self.exemplar["contextual"](img)
152
+ if "color_transfer" in self.exemplar:
153
+ assert rgbs is not None
154
+ losses["color_transfer"] = self.exemplar["color_transfer"](rgbs, level=rgb_level)
155
+
156
+ # noise regularizer
157
+ if self.weights["noise"] > 0:
158
+ losses["noise"] = self.noise_criterion(noises)
159
+
160
+ total_loss = 0
161
+ for name, loss in losses.items():
162
+ total_loss = total_loss + self.weights[name] * loss
163
+ return total_loss, losses
164
+
165
+ def update_sibling(self, sibling: torch.Tensor):
166
+ assert "contextual" in self.exemplar
167
+ self.exemplar["contextual"].sibling = sibling.detach()
Time_TravelRephotography/losses/perceptual_loss.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code borrowed from https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#file-vgg_perceptual_loss-py-L5
3
+ """
4
+ import torch
5
+ import torchvision
6
+ from models.vggface import VGGFaceFeats
7
+
8
+
9
+ def cos_loss(fi, ft):
10
+ return 1 - torch.nn.functional.cosine_similarity(fi, ft).mean()
11
+
12
+
13
+ class VGGPerceptualLoss(torch.nn.Module):
14
+ def __init__(self, resize=False):
15
+ super(VGGPerceptualLoss, self).__init__()
16
+ blocks = []
17
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
18
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
19
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
20
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
21
+ for bl in blocks:
22
+ for p in bl:
23
+ p.requires_grad = False
24
+ self.blocks = torch.nn.ModuleList(blocks)
25
+ self.transform = torch.nn.functional.interpolate
26
+ self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
27
+ self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
28
+ self.resize = resize
29
+
30
+ def forward(self, input, target, max_layer=4, cos_dist: bool = False):
31
+ target = (target + 1) * 0.5
32
+ input = (input + 1) * 0.5
33
+
34
+ if input.shape[1] != 3:
35
+ input = input.repeat(1, 3, 1, 1)
36
+ target = target.repeat(1, 3, 1, 1)
37
+ input = (input-self.mean) / self.std
38
+ target = (target-self.mean) / self.std
39
+ if self.resize:
40
+ input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
41
+ target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
42
+ x = input
43
+ y = target
44
+ loss = 0.0
45
+ loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
46
+ for bi, block in enumerate(self.blocks[:max_layer]):
47
+ x = block(x)
48
+ y = block(y)
49
+ loss += loss_func(x, y.detach())
50
+ return loss
51
+
52
+
53
+ class VGGFacePerceptualLoss(torch.nn.Module):
54
+ def __init__(self, weight_path: str = "checkpoint/vgg_face_dag.pt", resize: bool = False):
55
+ super().__init__()
56
+ self.vgg = VGGFaceFeats()
57
+ self.vgg.load_state_dict(torch.load(weight_path))
58
+
59
+ mean = torch.tensor(self.vgg.meta["mean"]).view(1, 3, 1, 1) / 255.0
60
+ self.register_buffer("mean", mean)
61
+
62
+ self.transform = torch.nn.functional.interpolate
63
+ self.resize = resize
64
+
65
+ def forward(self, input, target, max_layer: int = 4, cos_dist: bool = False):
66
+ target = (target + 1) * 0.5
67
+ input = (input + 1) * 0.5
68
+
69
+ # preprocessing
70
+ if input.shape[1] != 3:
71
+ input = input.repeat(1, 3, 1, 1)
72
+ target = target.repeat(1, 3, 1, 1)
73
+ input = input - self.mean
74
+ target = target - self.mean
75
+ if self.resize:
76
+ input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
77
+ target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
78
+
79
+ input_feats = self.vgg(input)
80
+ target_feats = self.vgg(target)
81
+
82
+ loss_func = cos_loss if cos_dist else torch.nn.functional.l1_loss
83
+ # calc perceptual loss
84
+ loss = 0.0
85
+ for fi, ft in zip(input_feats[:max_layer], target_feats[:max_layer]):
86
+ loss = loss + loss_func(fi, ft.detach())
87
+ return loss
88
+
89
+
90
+ class PerceptualLoss(torch.nn.Module):
91
+ def __init__(
92
+ self, lambda_vggface: float = 0.025 / 0.15, lambda_vgg: float = 1, eps: float = 1e-8, cos_dist: bool = False
93
+ ):
94
+ super().__init__()
95
+ self.register_buffer("lambda_vggface", torch.tensor(lambda_vggface))
96
+ self.register_buffer("lambda_vgg", torch.tensor(lambda_vgg))
97
+ self.cos_dist = cos_dist
98
+
99
+ if lambda_vgg > eps:
100
+ self.vgg = VGGPerceptualLoss()
101
+ if lambda_vggface > eps:
102
+ self.vggface = VGGFacePerceptualLoss()
103
+
104
+ def forward(self, input, target, eps=1e-8, use_vggface: bool = True, use_vgg=True, max_vgg_layer=4):
105
+ loss = 0.0
106
+ if self.lambda_vgg > eps and use_vgg:
107
+ loss = loss + self.lambda_vgg * self.vgg(input, target, max_layer=max_vgg_layer)
108
+ if self.lambda_vggface > eps and use_vggface:
109
+ loss = loss + self.lambda_vggface * self.vggface(input, target, cos_dist=self.cos_dist)
110
+ return loss
111
+
Time_TravelRephotography/losses/reconstruction.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import (
2
+ ArgumentParser,
3
+ Namespace,
4
+ )
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+ from losses.perceptual_loss import PerceptualLoss
12
+ from models.degrade import Downsample
13
+ from utils.misc import optional_string
14
+
15
+
16
+ class ReconstructionArguments:
17
+ @staticmethod
18
+ def add_arguments(parser: ArgumentParser):
19
+ parser.add_argument("--vggface", type=float, default=0.3, help="vggface")
20
+ parser.add_argument("--vgg", type=float, default=1, help="vgg")
21
+ parser.add_argument('--recon_size', type=int, default=256, help="size for face reconstruction loss")
22
+
23
+ @staticmethod
24
+ def to_string(args: Namespace) -> str:
25
+ return (
26
+ f"s{args.recon_size}"
27
+ + optional_string(args.vgg > 0, f"-vgg{args.vgg}")
28
+ + optional_string(args.vggface > 0, f"-vggface{args.vggface}")
29
+ )
30
+
31
+
32
+ def create_perceptual_loss(args: Namespace):
33
+ return PerceptualLoss(lambda_vgg=args.vgg, lambda_vggface=args.vggface, cos_dist=False)
34
+
35
+
36
+ class EyeLoss(nn.Module):
37
+ def __init__(
38
+ self,
39
+ target: torch.Tensor,
40
+ input_size: int = 1024,
41
+ input_channels: int = 3,
42
+ percept: Optional[nn.Module] = None,
43
+ args: Optional[Namespace] = None
44
+ ):
45
+ """
46
+ target: target image
47
+ """
48
+ assert not (percept is None and args is None)
49
+
50
+ super().__init__()
51
+
52
+ self.target = target
53
+
54
+ target_size = target.shape[-1]
55
+ self.downsample = Downsample(input_size, target_size, input_channels) \
56
+ if target_size != input_size else (lambda x: x)
57
+
58
+ self.percept = percept if percept is not None else create_perceptual_loss(args)
59
+
60
+ eye_size = np.array((224, 224))
61
+ btlrs = []
62
+ for sgn in [1, -1]:
63
+ center = np.array((480, 384 * sgn)) # (y, x)
64
+ b, t = center[0] - eye_size[0] // 2, center[0] + eye_size[0] // 2
65
+ l, r = center[1] - eye_size[1] // 2, center[1] + eye_size[1] // 2
66
+ btlrs.append((np.array((b, t, l, r)) / 1024 * target_size).astype(int))
67
+ self.btlrs = np.stack(btlrs, axis=0)
68
+
69
+ def forward(self, img: torch.Tensor, degrade: nn.Module = None):
70
+ """
71
+ img: it should be the degraded version of the generated image
72
+ """
73
+ if degrade is not None:
74
+ img = degrade(img, downsample=self.downsample)
75
+
76
+ loss = 0
77
+ for (b, t, l, r) in self.btlrs:
78
+ loss = loss + self.percept(
79
+ img[:, :, b:t, l:r], self.target[:, :, b:t, l:r],
80
+ use_vggface=False, max_vgg_layer=4,
81
+ # use_vgg=False,
82
+ )
83
+ return loss
84
+
85
+
86
+ class FaceLoss(nn.Module):
87
+ def __init__(
88
+ self,
89
+ target: torch.Tensor,
90
+ input_size: int = 1024,
91
+ input_channels: int = 3,
92
+ size: int = 256,
93
+ percept: Optional[nn.Module] = None,
94
+ args: Optional[Namespace] = None
95
+ ):
96
+ """
97
+ target: target image
98
+ """
99
+ assert not (percept is None and args is None)
100
+
101
+ super().__init__()
102
+
103
+ target_size = target.shape[-1]
104
+ self.target = target if target_size == size \
105
+ else Downsample(target_size, size, target.shape[1]).to(target.device)(target)
106
+
107
+ self.downsample = Downsample(input_size, size, input_channels) \
108
+ if size != input_size else (lambda x: x)
109
+
110
+ self.percept = percept if percept is not None else create_perceptual_loss(args)
111
+
112
+ def forward(self, img: torch.Tensor, degrade: nn.Module = None):
113
+ """
114
+ img: it should be the degraded version of the generated image
115
+ """
116
+ if degrade is not None:
117
+ img = degrade(img, downsample=self.downsample)
118
+ loss = self.percept(img, self.target)
119
+ return loss
Time_TravelRephotography/losses/regularize_noise.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class NoiseRegularizer(nn.Module):
8
+ def forward(self, noises: Iterable[torch.Tensor]):
9
+ loss = 0
10
+
11
+ for noise in noises:
12
+ size = noise.shape[2]
13
+
14
+ while True:
15
+ loss = (
16
+ loss
17
+ + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
18
+ + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
19
+ )
20
+
21
+ if size <= 8:
22
+ break
23
+
24
+ noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2])
25
+ noise = noise.mean([3, 5])
26
+ size //= 2
27
+
28
+ return loss
29
+
30
+ @staticmethod
31
+ def normalize(noises: Iterable[torch.Tensor]):
32
+ for noise in noises:
33
+ mean = noise.mean()
34
+ std = noise.std()
35
+
36
+ noise.data.add_(-mean).div_(std)
37
+
Time_TravelRephotography/model.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+ import numpy as np
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.autograd import Function
11
+
12
+ from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
13
+
14
+
15
+ class PixelNorm(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, input):
20
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
21
+
22
+
23
+ def make_kernel(k):
24
+ k = torch.tensor(k, dtype=torch.float32)
25
+
26
+ if k.ndim == 1:
27
+ k = k[None, :] * k[:, None]
28
+
29
+ k /= k.sum()
30
+
31
+ return k
32
+
33
+
34
+ class Upsample(nn.Module):
35
+ def __init__(self, kernel, factor=2):
36
+ super().__init__()
37
+
38
+ self.factor = factor
39
+ kernel = make_kernel(kernel) * (factor ** 2)
40
+ self.register_buffer('kernel', kernel)
41
+
42
+ p = kernel.shape[0] - factor
43
+
44
+ pad0 = (p + 1) // 2 + factor - 1
45
+ pad1 = p // 2
46
+
47
+ self.pad = (pad0, pad1)
48
+
49
+ def forward(self, input):
50
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
51
+
52
+ return out
53
+
54
+
55
+ class Downsample(nn.Module):
56
+ def __init__(self, kernel, factor=2):
57
+ super().__init__()
58
+
59
+ self.factor = factor
60
+ kernel = make_kernel(kernel)
61
+ self.register_buffer('kernel', kernel)
62
+
63
+ p = kernel.shape[0] - factor
64
+
65
+ pad0 = (p + 1) // 2
66
+ pad1 = p // 2
67
+
68
+ self.pad = (pad0, pad1)
69
+
70
+ def forward(self, input):
71
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
72
+
73
+ return out
74
+
75
+
76
+ class Blur(nn.Module):
77
+ def __init__(self, kernel, pad, upsample_factor=1):
78
+ super().__init__()
79
+
80
+ kernel = make_kernel(kernel)
81
+
82
+ if upsample_factor > 1:
83
+ kernel = kernel * (upsample_factor ** 2)
84
+
85
+ self.register_buffer('kernel', kernel)
86
+
87
+ self.pad = pad
88
+
89
+ def forward(self, input):
90
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
91
+
92
+ return out
93
+
94
+
95
+ class EqualConv2d(nn.Module):
96
+ def __init__(
97
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
98
+ ):
99
+ super().__init__()
100
+
101
+ self.weight = nn.Parameter(
102
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
103
+ )
104
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
105
+
106
+ self.stride = stride
107
+ self.padding = padding
108
+
109
+ if bias:
110
+ self.bias = nn.Parameter(torch.zeros(out_channel))
111
+
112
+ else:
113
+ self.bias = None
114
+
115
+ def forward(self, input):
116
+ out = F.conv2d(
117
+ input,
118
+ self.weight * self.scale,
119
+ bias=self.bias,
120
+ stride=self.stride,
121
+ padding=self.padding,
122
+ )
123
+
124
+ return out
125
+
126
+ def __repr__(self):
127
+ return (
128
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
129
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
130
+ )
131
+
132
+
133
+ class EqualLinear(nn.Module):
134
+ def __init__(
135
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
136
+ ):
137
+ super().__init__()
138
+
139
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
140
+
141
+ if bias:
142
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
143
+
144
+ else:
145
+ self.bias = None
146
+
147
+ self.activation = activation
148
+
149
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
150
+ self.lr_mul = lr_mul
151
+
152
+ def forward(self, input):
153
+ if self.activation:
154
+ out = F.linear(input, self.weight * self.scale)
155
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
156
+
157
+ else:
158
+ out = F.linear(
159
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
160
+ )
161
+
162
+ return out
163
+
164
+ def __repr__(self):
165
+ return (
166
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
167
+ )
168
+
169
+
170
+ class ScaledLeakyReLU(nn.Module):
171
+ def __init__(self, negative_slope=0.2):
172
+ super().__init__()
173
+
174
+ self.negative_slope = negative_slope
175
+
176
+ def forward(self, input):
177
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
178
+
179
+ return out * math.sqrt(2)
180
+
181
+
182
+ class ModulatedConv2d(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_channel,
186
+ out_channel,
187
+ kernel_size,
188
+ style_dim,
189
+ demodulate=True,
190
+ upsample=False,
191
+ downsample=False,
192
+ blur_kernel=[1, 3, 3, 1],
193
+ ):
194
+ super().__init__()
195
+
196
+ self.eps = 1e-8
197
+ self.kernel_size = kernel_size
198
+ self.in_channel = in_channel
199
+ self.out_channel = out_channel
200
+ self.upsample = upsample
201
+ self.downsample = downsample
202
+
203
+ if upsample:
204
+ factor = 2
205
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
206
+ pad0 = (p + 1) // 2 + factor - 1
207
+ pad1 = p // 2 + 1
208
+
209
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
210
+
211
+ if downsample:
212
+ factor = 2
213
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
214
+ pad0 = (p + 1) // 2
215
+ pad1 = p // 2
216
+
217
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
218
+
219
+ fan_in = in_channel * kernel_size ** 2
220
+ self.scale = 1 / math.sqrt(fan_in)
221
+ self.padding = kernel_size // 2
222
+
223
+ self.weight = nn.Parameter(
224
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
225
+ )
226
+
227
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
228
+
229
+ self.demodulate = demodulate
230
+
231
+ def __repr__(self):
232
+ return (
233
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
234
+ f'upsample={self.upsample}, downsample={self.downsample})'
235
+ )
236
+
237
+ def forward(self, input, style):
238
+ batch, in_channel, height, width = input.shape
239
+
240
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
241
+ weight = self.scale * self.weight * style
242
+
243
+ if self.demodulate:
244
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
245
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
246
+
247
+ weight = weight.view(
248
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
249
+ )
250
+
251
+ if self.upsample:
252
+ input = input.view(1, batch * in_channel, height, width)
253
+ weight = weight.view(
254
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
255
+ )
256
+ weight = weight.transpose(1, 2).reshape(
257
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
258
+ )
259
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
260
+ _, _, height, width = out.shape
261
+ out = out.view(batch, self.out_channel, height, width)
262
+ out = self.blur(out)
263
+
264
+ elif self.downsample:
265
+ input = self.blur(input)
266
+ _, _, height, width = input.shape
267
+ input = input.view(1, batch * in_channel, height, width)
268
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
269
+ _, _, height, width = out.shape
270
+ out = out.view(batch, self.out_channel, height, width)
271
+
272
+ else:
273
+ input = input.view(1, batch * in_channel, height, width)
274
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
275
+ _, _, height, width = out.shape
276
+ out = out.view(batch, self.out_channel, height, width)
277
+
278
+ return out
279
+
280
+
281
+ class NoiseInjection(nn.Module):
282
+ def __init__(self):
283
+ super().__init__()
284
+
285
+ self.weight = nn.Parameter(torch.zeros(1))
286
+
287
+ def forward(self, image, noise=None):
288
+ if noise is None:
289
+ batch, _, height, width = image.shape
290
+ noise = image.new_empty(batch, 1, height, width).normal_()
291
+
292
+ return image + self.weight * noise
293
+
294
+
295
+ class ConstantInput(nn.Module):
296
+ def __init__(self, channel, size=4):
297
+ super().__init__()
298
+
299
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
300
+
301
+ def forward(self, input):
302
+ batch = input.shape[0]
303
+ out = self.input.repeat(batch, 1, 1, 1)
304
+
305
+ return out
306
+
307
+
308
+ class StyledConv(nn.Module):
309
+ def __init__(
310
+ self,
311
+ in_channel,
312
+ out_channel,
313
+ kernel_size,
314
+ style_dim,
315
+ upsample=False,
316
+ blur_kernel=[1, 3, 3, 1],
317
+ demodulate=True,
318
+ ):
319
+ super().__init__()
320
+
321
+ self.conv = ModulatedConv2d(
322
+ in_channel,
323
+ out_channel,
324
+ kernel_size,
325
+ style_dim,
326
+ upsample=upsample,
327
+ blur_kernel=blur_kernel,
328
+ demodulate=demodulate,
329
+ )
330
+
331
+ self.noise = NoiseInjection()
332
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
333
+ # self.activate = ScaledLeakyReLU(0.2)
334
+ self.activate = FusedLeakyReLU(out_channel)
335
+
336
+ def forward(self, input, style, noise=None):
337
+ out = self.conv(input, style)
338
+ out = self.noise(out, noise=noise)
339
+ # out = out + self.bias
340
+ out = self.activate(out)
341
+
342
+ return out
343
+
344
+
345
+ class ToRGB(nn.Module):
346
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
347
+ super().__init__()
348
+
349
+ if upsample:
350
+ self.upsample = Upsample(blur_kernel)
351
+
352
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
353
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
354
+
355
+ def forward(self, input, style, skip=None):
356
+ out = self.conv(input, style)
357
+ style_modulated = out
358
+ out = out + self.bias
359
+
360
+ if skip is not None:
361
+ skip = self.upsample(skip)
362
+
363
+ out = out + skip
364
+
365
+ return out, style_modulated
366
+
367
+
368
+ class Generator(nn.Module):
369
+ def __init__(
370
+ self,
371
+ size,
372
+ style_dim,
373
+ n_mlp,
374
+ channel_multiplier=2,
375
+ blur_kernel=[1, 3, 3, 1],
376
+ lr_mlp=0.01,
377
+ ):
378
+ super().__init__()
379
+
380
+ self.size = size
381
+
382
+ self.style_dim = style_dim
383
+
384
+ layers = [PixelNorm()]
385
+
386
+ for i in range(n_mlp):
387
+ layers.append(
388
+ EqualLinear(
389
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
390
+ )
391
+ )
392
+
393
+ self.style = nn.Sequential(*layers)
394
+
395
+ self.channels = {
396
+ 4: 512,
397
+ 8: 512,
398
+ 16: 512,
399
+ 32: 512,
400
+ 64: 256 * channel_multiplier,
401
+ 128: 128 * channel_multiplier,
402
+ 256: 64 * channel_multiplier,
403
+ 512: 32 * channel_multiplier,
404
+ 1024: 16 * channel_multiplier,
405
+ }
406
+
407
+ self.input = ConstantInput(self.channels[4])
408
+ self.conv1 = StyledConv(
409
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
410
+ )
411
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
412
+
413
+ self.log_size = int(math.log(size, 2))
414
+ self.num_layers = (self.log_size - 2) * 2 + 1
415
+
416
+ self.convs = nn.ModuleList()
417
+ self.upsamples = nn.ModuleList()
418
+ self.to_rgbs = nn.ModuleList()
419
+ self.noises = nn.Module()
420
+
421
+ in_channel = self.channels[4]
422
+
423
+ for layer_idx in range(self.num_layers):
424
+ res = (layer_idx + 5) // 2
425
+ shape = [1, 1, 2 ** res, 2 ** res]
426
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
427
+
428
+ for i in range(3, self.log_size + 1):
429
+ out_channel = self.channels[2 ** i]
430
+
431
+ self.convs.append(
432
+ StyledConv(
433
+ in_channel,
434
+ out_channel,
435
+ 3,
436
+ style_dim,
437
+ upsample=True,
438
+ blur_kernel=blur_kernel,
439
+ )
440
+ )
441
+
442
+ self.convs.append(
443
+ StyledConv(
444
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
445
+ )
446
+ )
447
+
448
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
449
+
450
+ in_channel = out_channel
451
+
452
+ self.n_latent = self.log_size * 2 - 2
453
+
454
+ @property
455
+ def device(self):
456
+ # TODO if multi-gpu is expected, could use the following more expensive version
457
+ #device, = list(set(p.device for p in self.parameters()))
458
+ return next(self.parameters()).device
459
+
460
+ @staticmethod
461
+ def get_latent_size(size):
462
+ log_size = int(math.log(size, 2))
463
+ return log_size * 2 - 2
464
+
465
+ @staticmethod
466
+ def make_noise_by_size(size: int, device: torch.device):
467
+ log_size = int(math.log(size, 2))
468
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
469
+
470
+ for i in range(3, log_size + 1):
471
+ for _ in range(2):
472
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
473
+
474
+ return noises
475
+
476
+
477
+ def make_noise(self):
478
+ return self.make_noise_by_size(self.size, self.input.input.device)
479
+
480
+ def mean_latent(self, n_latent):
481
+ latent_in = torch.randn(
482
+ n_latent, self.style_dim, device=self.input.input.device
483
+ )
484
+ latent = self.style(latent_in).mean(0, keepdim=True)
485
+
486
+ return latent
487
+
488
+ def get_latent(self, input):
489
+ return self.style(input)
490
+
491
+ def forward(
492
+ self,
493
+ styles,
494
+ return_latents=False,
495
+ inject_index=None,
496
+ truncation=1,
497
+ truncation_latent=None,
498
+ input_is_latent=False,
499
+ noise=None,
500
+ randomize_noise=True,
501
+ ):
502
+ if not input_is_latent:
503
+ styles = [self.style(s) for s in styles]
504
+
505
+ if noise is None:
506
+ if randomize_noise:
507
+ noise = [None] * self.num_layers
508
+ else:
509
+ noise = [
510
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
511
+ ]
512
+
513
+ if truncation < 1:
514
+ style_t = []
515
+
516
+ for style in styles:
517
+ style_t.append(
518
+ truncation_latent + truncation * (style - truncation_latent)
519
+ )
520
+
521
+ styles = style_t
522
+
523
+ if len(styles) < 2:
524
+ inject_index = self.n_latent
525
+
526
+ if styles[0].ndim < 3:
527
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
528
+
529
+ else:
530
+ latent = styles[0]
531
+
532
+ else:
533
+ if inject_index is None:
534
+ inject_index = random.randint(1, self.n_latent - 1)
535
+
536
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
537
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
538
+
539
+ latent = torch.cat([latent, latent2], 1)
540
+
541
+ out = self.input(latent)
542
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
543
+
544
+ skip, rgb_mod = self.to_rgb1(out, latent[:, 1])
545
+
546
+
547
+ rgbs = [rgb_mod] # all but the last skip
548
+ i = 1
549
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
550
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
551
+ ):
552
+ out = conv1(out, latent[:, i], noise=noise1)
553
+ out = conv2(out, latent[:, i + 1], noise=noise2)
554
+ skip, rgb_mod = to_rgb(out, latent[:, i + 2], skip)
555
+ rgbs.append(rgb_mod)
556
+
557
+ i += 2
558
+
559
+ image = skip
560
+
561
+ if return_latents:
562
+ return image, latent, rgbs
563
+
564
+ else:
565
+ return image, None, rgbs
566
+
567
+
568
+ class ConvLayer(nn.Sequential):
569
+ def __init__(
570
+ self,
571
+ in_channel,
572
+ out_channel,
573
+ kernel_size,
574
+ downsample=False,
575
+ blur_kernel=[1, 3, 3, 1],
576
+ bias=True,
577
+ activate=True,
578
+ ):
579
+ layers = []
580
+
581
+ if downsample:
582
+ factor = 2
583
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
584
+ pad0 = (p + 1) // 2
585
+ pad1 = p // 2
586
+
587
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
588
+
589
+ stride = 2
590
+ self.padding = 0
591
+
592
+ else:
593
+ stride = 1
594
+ self.padding = kernel_size // 2
595
+
596
+ layers.append(
597
+ EqualConv2d(
598
+ in_channel,
599
+ out_channel,
600
+ kernel_size,
601
+ padding=self.padding,
602
+ stride=stride,
603
+ bias=bias and not activate,
604
+ )
605
+ )
606
+
607
+ if activate:
608
+ if bias:
609
+ layers.append(FusedLeakyReLU(out_channel))
610
+
611
+ else:
612
+ layers.append(ScaledLeakyReLU(0.2))
613
+
614
+ super().__init__(*layers)
615
+
616
+
617
+ class ResBlock(nn.Module):
618
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
619
+ super().__init__()
620
+
621
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
622
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
623
+
624
+ self.skip = ConvLayer(
625
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
626
+ )
627
+
628
+ def forward(self, input):
629
+ out = self.conv1(input)
630
+ out = self.conv2(out)
631
+
632
+ skip = self.skip(input)
633
+ out = (out + skip) / math.sqrt(2)
634
+
635
+ return out
636
+
637
+
638
+ class Discriminator(nn.Module):
639
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
640
+ super().__init__()
641
+
642
+ channels = {
643
+ 4: 512,
644
+ 8: 512,
645
+ 16: 512,
646
+ 32: 512,
647
+ 64: 256 * channel_multiplier,
648
+ 128: 128 * channel_multiplier,
649
+ 256: 64 * channel_multiplier,
650
+ 512: 32 * channel_multiplier,
651
+ 1024: 16 * channel_multiplier,
652
+ }
653
+
654
+ convs = [ConvLayer(3, channels[size], 1)]
655
+
656
+ log_size = int(math.log(size, 2))
657
+
658
+ in_channel = channels[size]
659
+
660
+ for i in range(log_size, 2, -1):
661
+ out_channel = channels[2 ** (i - 1)]
662
+
663
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
664
+
665
+ in_channel = out_channel
666
+
667
+ self.convs = nn.Sequential(*convs)
668
+
669
+ self.stddev_group = 4
670
+ self.stddev_feat = 1
671
+
672
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
673
+ self.final_linear = nn.Sequential(
674
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
675
+ EqualLinear(channels[4], 1),
676
+ )
677
+
678
+ def forward(self, input):
679
+ out = self.convs(input)
680
+
681
+ batch, channel, height, width = out.shape
682
+ group = min(batch, self.stddev_group)
683
+ stddev = out.view(
684
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
685
+ )
686
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
687
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
688
+ stddev = stddev.repeat(group, 1, height, width)
689
+ out = torch.cat([out, stddev], 1)
690
+
691
+ out = self.final_conv(out)
692
+
693
+ out = out.view(batch, -1)
694
+ out = self.final_linear(out)
695
+
696
+ return out
697
+
Time_TravelRephotography/models/__init__.py ADDED
File without changes
Time_TravelRephotography/models/degrade.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import (
2
+ ArgumentParser,
3
+ Namespace,
4
+ )
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from utils.misc import optional_string
11
+
12
+ from .gaussian_smoothing import GaussianSmoothing
13
+
14
+
15
+ class DegradeArguments:
16
+ @staticmethod
17
+ def add_arguments(parser: ArgumentParser):
18
+ parser.add_argument('--spectral_sensitivity', choices=["g", "b", "gb"], default="g",
19
+ help="Type of spectral sensitivity. g: grayscale (panchromatic), b: blue-sensitive, gb: green+blue (orthochromatic)")
20
+ parser.add_argument('--gaussian', type=float, default=0,
21
+ help="estimated blur radius in pixels of the input photo if it is scaled to 1024x1024")
22
+
23
+ @staticmethod
24
+ def to_string(args: Namespace) -> str:
25
+ return (
26
+ f"{args.spectral_sensitivity}"
27
+ + optional_string(args.gaussian > 0, f"-G{args.gaussian}")
28
+ )
29
+
30
+
31
+ class CameraResponse(nn.Module):
32
+ def __init__(self):
33
+ super().__init__()
34
+
35
+ self.register_parameter("gamma", nn.Parameter(torch.ones(1)))
36
+ self.register_parameter("offset", nn.Parameter(torch.zeros(1)))
37
+ self.register_parameter("gain", nn.Parameter(torch.ones(1)))
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ x = torch.clamp(x, max=1, min=-1+1e-2)
41
+ x = (1 + x) * 0.5
42
+ x = self.offset + self.gain * torch.pow(x, self.gamma)
43
+ x = (x - 0.5) * 2
44
+ # b = torch.clamp(b, max=1, min=-1)
45
+ return x
46
+
47
+
48
+ class SpectralResponse(nn.Module):
49
+ # TODO: use enum instead for color mode
50
+ def __init__(self, spectral_sensitivity: str = 'b'):
51
+ assert spectral_sensitivity in ("g", "b", "gb"), f"spectral_sensitivity {spectral_sensitivity} is not implemented."
52
+
53
+ super().__init__()
54
+
55
+ self.spectral_sensitivity = spectral_sensitivity
56
+
57
+ if self.spectral_sensitivity == "g":
58
+ self.register_buffer("to_gray", torch.tensor([0.299, 0.587, 0.114]).reshape(1, -1, 1, 1))
59
+
60
+ def forward(self, rgb: torch.Tensor) -> torch.Tensor:
61
+ if self.spectral_sensitivity == "b":
62
+ x = rgb[:, -1:]
63
+ elif self.spectral_sensitivity == "gb":
64
+ x = (rgb[:, 1:2] + rgb[:, -1:]) * 0.5
65
+ else:
66
+ assert self.spectral_sensitivity == "g"
67
+ x = (rgb * self.to_gray).sum(dim=1, keepdim=True)
68
+ return x
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ """Antialiasing downsampling"""
73
+ def __init__(self, input_size: int, output_size: int, channels: int):
74
+ super().__init__()
75
+ if input_size % output_size == 0:
76
+ self.stride = input_size // output_size
77
+ self.grid = None
78
+ else:
79
+ self.stride = 1
80
+ step = input_size / output_size
81
+ x = torch.arange(output_size) * step
82
+ Y, X = torch.meshgrid(x, x)
83
+ grid = torch.stack((X, Y), dim=-1)
84
+ grid /= torch.Tensor((input_size - 1, input_size - 1)).view(1, 1, -1)
85
+ grid = grid * 2 - 1
86
+ self.register_buffer("grid", grid)
87
+ sigma = 0.5 * input_size / output_size
88
+ #print(f"{input_size} -> {output_size}: sigma={sigma}")
89
+ self.blur = GaussianSmoothing(channels, int(2 * (sigma * 2) + 1 + 0.5), sigma)
90
+
91
+ def forward(self, im: torch.Tensor):
92
+ out = self.blur(im, stride=self.stride)
93
+ if self.grid is not None:
94
+ out = F.grid_sample(out, self.grid[None].expand(im.shape[0], -1, -1, -1))
95
+ return out
96
+
97
+
98
+
99
+ class Degrade(nn.Module):
100
+ """
101
+ Simulate the degradation of antique film
102
+ """
103
+ def __init__(self, args:Namespace):
104
+ super().__init__()
105
+ self.srf = SpectralResponse(args.spectral_sensitivity)
106
+ self.crf = CameraResponse()
107
+ self.gaussian = None
108
+ if args.gaussian is not None and args.gaussian > 0:
109
+ self.gaussian = GaussianSmoothing(3, 2 * int(args.gaussian * 2 + 0.5) + 1, args.gaussian)
110
+
111
+ def forward(self, img: torch.Tensor, downsample: nn.Module = None):
112
+ if self.gaussian is not None:
113
+ img = self.gaussian(img)
114
+ if downsample is not None:
115
+ img = downsample(img)
116
+ img = self.srf(img)
117
+ img = self.crf(img)
118
+ # Note that I changed it back to 3 channels
119
+ return img.repeat((1, 3, 1, 1)) if img.shape[1] == 1 else img
120
+
121
+
122
+
Time_TravelRephotography/models/encoder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace, ArgumentParser
2
+ from functools import partial
3
+
4
+ from torch import nn
5
+
6
+ from .resnet import ResNetBasicBlock, activation_func, norm_module, Conv2dAuto
7
+
8
+
9
+ def add_arguments(parser: ArgumentParser) -> ArgumentParser:
10
+ parser.add_argument("--latent_size", type=int, default=512, help="latent size")
11
+ return parser
12
+
13
+
14
+ def create_model(args) -> nn.Module:
15
+ in_channels = 3 if "rgb" in args and args.rgb else 1
16
+ return Encoder(in_channels, args.encoder_size, latent_size=args.latent_size)
17
+
18
+
19
+ class Flatten(nn.Module):
20
+ def forward(self, input_):
21
+ return input_.view(input_.size(0), -1)
22
+
23
+
24
+ class Encoder(nn.Module):
25
+ def __init__(
26
+ self, in_channels: int, size: int, latent_size: int = 512,
27
+ activation: str = 'leaky_relu', norm: str = "instance"
28
+ ):
29
+ super().__init__()
30
+
31
+ out_channels0 = 64
32
+ norm_m = norm_module(norm)
33
+ self.conv0 = nn.Sequential(
34
+ Conv2dAuto(in_channels, out_channels0, kernel_size=5),
35
+ norm_m(out_channels0),
36
+ activation_func(activation),
37
+ )
38
+
39
+ pool_kernel = 2
40
+ self.pool = nn.AvgPool2d(pool_kernel)
41
+
42
+ num_channels = [128, 256, 512, 512]
43
+ # FIXME: this is a hack
44
+ if size >= 256:
45
+ num_channels.append(512)
46
+
47
+ residual = partial(ResNetBasicBlock, activation=activation, norm=norm, bias=True)
48
+ residual_blocks = nn.ModuleList()
49
+ for in_channel, out_channel in zip([out_channels0] + num_channels[:-1], num_channels):
50
+ residual_blocks.append(residual(in_channel, out_channel))
51
+ residual_blocks.append(nn.AvgPool2d(pool_kernel))
52
+ self.residual_blocks = nn.Sequential(*residual_blocks)
53
+
54
+ self.last = nn.Sequential(
55
+ nn.ReLU(),
56
+ nn.AvgPool2d(4), # TODO: not sure whehter this would cause problem
57
+ Flatten(),
58
+ nn.Linear(num_channels[-1], latent_size, bias=True)
59
+ )
60
+
61
+ def forward(self, input_):
62
+ out = self.conv0(input_)
63
+ out = self.pool(out)
64
+ out = self.residual_blocks(out)
65
+ out = self.last(out)
66
+ return out
Time_TravelRephotography/models/gaussian_smoothing.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class GaussianSmoothing(nn.Module):
9
+ """
10
+ Apply gaussian smoothing on a
11
+ 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
12
+ in the input using a depthwise convolution.
13
+ Arguments:
14
+ channels (int, sequence): Number of channels of the input tensors. Output will
15
+ have this number of channels as well.
16
+ kernel_size (int, sequence): Size of the gaussian kernel.
17
+ sigma (float, sequence): Standard deviation of the gaussian kernel.
18
+ dim (int, optional): The number of dimensions of the data.
19
+ Default value is 2 (spatial).
20
+ """
21
+ def __init__(self, channels, kernel_size, sigma, dim=2):
22
+ super(GaussianSmoothing, self).__init__()
23
+ if isinstance(kernel_size, numbers.Number):
24
+ kernel_size = [kernel_size] * dim
25
+ if isinstance(sigma, numbers.Number):
26
+ sigma = [sigma] * dim
27
+
28
+ # The gaussian kernel is the product of the
29
+ # gaussian function of each dimension.
30
+ kernel = 1
31
+ meshgrids = torch.meshgrid(
32
+ [
33
+ torch.arange(size, dtype=torch.float32)
34
+ for size in kernel_size
35
+ ]
36
+ )
37
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
38
+ mean = (size - 1) / 2
39
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
40
+ torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
41
+
42
+ # Make sure sum of values in gaussian kernel equals 1.
43
+ kernel = kernel / torch.sum(kernel)
44
+
45
+ # Reshape to depthwise convolutional weight
46
+ kernel = kernel.view(1, 1, *kernel.size())
47
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
48
+
49
+ self.register_buffer('weight', kernel)
50
+ self.groups = channels
51
+
52
+ if dim == 1:
53
+ self.conv = F.conv1d
54
+ elif dim == 2:
55
+ self.conv = F.conv2d
56
+ elif dim == 3:
57
+ self.conv = F.conv3d
58
+ else:
59
+ raise RuntimeError(
60
+ 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
61
+ )
62
+
63
+ def forward(self, input, stride: int = 1):
64
+ """
65
+ Apply gaussian filter to input.
66
+ Arguments:
67
+ input (torch.Tensor): Input to apply gaussian filter on.
68
+ stride for applying conv
69
+ Returns:
70
+ filtered (torch.Tensor): Filtered output.
71
+ """
72
+ padding = (self.weight.shape[-1] - 1) // 2
73
+ return self.conv(input, weight=self.weight, groups=self.groups, padding=padding, stride=stride)
74
+
Time_TravelRephotography/models/resnet.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ from torch import nn
4
+
5
+
6
+ def activation_func(activation: str):
7
+ return nn.ModuleDict([
8
+ ['relu', nn.ReLU(inplace=True)],
9
+ ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
10
+ ['selu', nn.SELU(inplace=True)],
11
+ ['none', nn.Identity()]
12
+ ])[activation]
13
+
14
+
15
+ def norm_module(norm: str):
16
+ return {
17
+ 'batch': nn.BatchNorm2d,
18
+ 'instance': nn.InstanceNorm2d,
19
+ }[norm]
20
+
21
+
22
+ class Conv2dAuto(nn.Conv2d):
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ # dynamic add padding based on the kernel_size
26
+ self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
27
+
28
+
29
+ conv3x3 = partial(Conv2dAuto, kernel_size=3)
30
+
31
+
32
+ class ResidualBlock(nn.Module):
33
+ def __init__(self, in_channels: int, out_channels: int, activation: str = 'relu'):
34
+ super().__init__()
35
+ self.in_channels, self.out_channels = in_channels, out_channels
36
+ self.blocks = nn.Identity()
37
+ self.activate = activation_func(activation)
38
+ self.shortcut = nn.Identity()
39
+
40
+ def forward(self, x):
41
+ residual = x
42
+ if self.should_apply_shortcut:
43
+ residual = self.shortcut(x)
44
+ x = self.blocks(x)
45
+ x += residual
46
+ x = self.activate(x)
47
+ return x
48
+
49
+ @property
50
+ def should_apply_shortcut(self):
51
+ return self.in_channels != self.out_channels
52
+
53
+
54
+ class ResNetResidualBlock(ResidualBlock):
55
+ def __init__(
56
+ self, in_channels: int, out_channels: int,
57
+ expansion: int = 1, downsampling: int = 1,
58
+ conv=conv3x3, norm: str = 'batch', *args, **kwargs
59
+ ):
60
+ super().__init__(in_channels, out_channels, *args, **kwargs)
61
+ self.expansion, self.downsampling = expansion, downsampling
62
+ self.conv, self.norm = conv, norm_module(norm)
63
+ self.shortcut = nn.Sequential(
64
+ nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1,
65
+ stride=self.downsampling, bias=False),
66
+ self.norm(self.expanded_channels)) if self.should_apply_shortcut else None
67
+
68
+ @property
69
+ def expanded_channels(self):
70
+ return self.out_channels * self.expansion
71
+
72
+ @property
73
+ def should_apply_shortcut(self):
74
+ return self.in_channels != self.expanded_channels
75
+
76
+
77
+ def conv_norm(in_channels: int, out_channels: int, conv, norm, *args, **kwargs):
78
+ return nn.Sequential(conv(in_channels, out_channels, *args, **kwargs), norm(out_channels))
79
+
80
+
81
+ class ResNetBasicBlock(ResNetResidualBlock):
82
+ """
83
+ Basic ResNet block composed by two layers of 3x3conv/batchnorm/activation
84
+ """
85
+ expansion = 1
86
+
87
+ def __init__(
88
+ self, in_channels: int, out_channels: int, bias: bool = False, *args, **kwargs
89
+ ):
90
+ super().__init__(in_channels, out_channels, *args, **kwargs)
91
+ self.blocks = nn.Sequential(
92
+ conv_norm(
93
+ self.in_channels, self.out_channels, conv=self.conv, norm=self.norm,
94
+ bias=bias, stride=self.downsampling
95
+ ),
96
+ self.activate,
97
+ conv_norm(self.out_channels, self.expanded_channels, conv=self.conv, norm=self.norm, bias=bias),
98
+ )
99
+
Time_TravelRephotography/models/vggface.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Vgg_face_dag(nn.Module):
7
+
8
+ def __init__(self):
9
+ super(Vgg_face_dag, self).__init__()
10
+ self.meta = {'mean': [129.186279296875, 104.76238250732422, 93.59396362304688],
11
+ 'std': [1, 1, 1],
12
+ 'imageSize': [224, 224, 3]}
13
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
14
+ self.relu1_1 = nn.ReLU(inplace=True)
15
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
16
+ self.relu1_2 = nn.ReLU(inplace=True)
17
+ self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
18
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
19
+ self.relu2_1 = nn.ReLU(inplace=True)
20
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
21
+ self.relu2_2 = nn.ReLU(inplace=True)
22
+ self.pool2 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
23
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
24
+ self.relu3_1 = nn.ReLU(inplace=True)
25
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
26
+ self.relu3_2 = nn.ReLU(inplace=True)
27
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
28
+ self.relu3_3 = nn.ReLU(inplace=True)
29
+ self.pool3 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
30
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
31
+ self.relu4_1 = nn.ReLU(inplace=True)
32
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
33
+ self.relu4_2 = nn.ReLU(inplace=True)
34
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
35
+ self.relu4_3 = nn.ReLU(inplace=True)
36
+ self.pool4 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
37
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
38
+ self.relu5_1 = nn.ReLU(inplace=True)
39
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
40
+ self.relu5_2 = nn.ReLU(inplace=True)
41
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
42
+ self.relu5_3 = nn.ReLU(inplace=True)
43
+ self.pool5 = nn.MaxPool2d(kernel_size=[2, 2], stride=[2, 2], padding=0, dilation=1, ceil_mode=False)
44
+ self.fc6 = nn.Linear(in_features=25088, out_features=4096, bias=True)
45
+ self.relu6 = nn.ReLU(inplace=True)
46
+ self.dropout6 = nn.Dropout(p=0.5)
47
+ self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True)
48
+ self.relu7 = nn.ReLU(inplace=True)
49
+ self.dropout7 = nn.Dropout(p=0.5)
50
+ self.fc8 = nn.Linear(in_features=4096, out_features=2622, bias=True)
51
+
52
+ def forward(self, x0):
53
+ x1 = self.conv1_1(x0)
54
+ x2 = self.relu1_1(x1)
55
+ x3 = self.conv1_2(x2)
56
+ x4 = self.relu1_2(x3)
57
+ x5 = self.pool1(x4)
58
+ x6 = self.conv2_1(x5)
59
+ x7 = self.relu2_1(x6)
60
+ x8 = self.conv2_2(x7)
61
+ x9 = self.relu2_2(x8)
62
+ x10 = self.pool2(x9)
63
+ x11 = self.conv3_1(x10)
64
+ x12 = self.relu3_1(x11)
65
+ x13 = self.conv3_2(x12)
66
+ x14 = self.relu3_2(x13)
67
+ x15 = self.conv3_3(x14)
68
+ x16 = self.relu3_3(x15)
69
+ x17 = self.pool3(x16)
70
+ x18 = self.conv4_1(x17)
71
+ x19 = self.relu4_1(x18)
72
+ x20 = self.conv4_2(x19)
73
+ x21 = self.relu4_2(x20)
74
+ x22 = self.conv4_3(x21)
75
+ x23 = self.relu4_3(x22)
76
+ x24 = self.pool4(x23)
77
+ x25 = self.conv5_1(x24)
78
+ x26 = self.relu5_1(x25)
79
+ x27 = self.conv5_2(x26)
80
+ x28 = self.relu5_2(x27)
81
+ x29 = self.conv5_3(x28)
82
+ x30 = self.relu5_3(x29)
83
+ x31_preflatten = self.pool5(x30)
84
+ x31 = x31_preflatten.view(x31_preflatten.size(0), -1)
85
+ x32 = self.fc6(x31)
86
+ x33 = self.relu6(x32)
87
+ x34 = self.dropout6(x33)
88
+ x35 = self.fc7(x34)
89
+ x36 = self.relu7(x35)
90
+ x37 = self.dropout7(x36)
91
+ x38 = self.fc8(x37)
92
+ return x38
93
+
94
+
95
+ def vgg_face_dag(weights_path=None, **kwargs):
96
+ """
97
+ load imported model instance
98
+
99
+ Args:
100
+ weights_path (str): If set, loads model weights from the given path
101
+ """
102
+ model = Vgg_face_dag()
103
+ if weights_path:
104
+ state_dict = torch.load(weights_path)
105
+ model.load_state_dict(state_dict)
106
+ return model
107
+
108
+
109
+ class VGGFaceFeats(Vgg_face_dag):
110
+ def forward(self, x0):
111
+ x1 = self.conv1_1(x0)
112
+ x2 = self.relu1_1(x1)
113
+ x3 = self.conv1_2(x2)
114
+ x4 = self.relu1_2(x3)
115
+ x5 = self.pool1(x4)
116
+ x6 = self.conv2_1(x5)
117
+ x7 = self.relu2_1(x6)
118
+ x8 = self.conv2_2(x7)
119
+ x9 = self.relu2_2(x8)
120
+ x10 = self.pool2(x9)
121
+ x11 = self.conv3_1(x10)
122
+ x12 = self.relu3_1(x11)
123
+ x13 = self.conv3_2(x12)
124
+ x14 = self.relu3_2(x13)
125
+ x15 = self.conv3_3(x14)
126
+ x16 = self.relu3_3(x15)
127
+ x17 = self.pool3(x16)
128
+ x18 = self.conv4_1(x17)
129
+ x19 = self.relu4_1(x18)
130
+ x20 = self.conv4_2(x19)
131
+ x21 = self.relu4_2(x20)
132
+ x22 = self.conv4_3(x21)
133
+ x23 = self.relu4_3(x22)
134
+ x24 = self.pool4(x23)
135
+ x25 = self.conv5_1(x24)
136
+ # x26 = self.relu5_1(x25)
137
+ # x27 = self.conv5_2(x26)
138
+ # x28 = self.relu5_2(x27)
139
+ # x29 = self.conv5_3(x28)
140
+ # x30 = self.relu5_3(x29)
141
+ # x31_preflatten = self.pool5(x30)
142
+ # x31 = x31_preflatten.view(x31_preflatten.size(0), -1)
143
+ # x32 = self.fc6(x31)
144
+ # x33 = self.relu6(x32)
145
+ # x34 = self.dropout6(x33)
146
+ # x35 = self.fc7(x34)
147
+ # x36 = self.relu7(x35)
148
+ # x37 = self.dropout7(x36)
149
+ # x38 = self.fc8(x37)
150
+ return x1, x6, x11, x18, x25
Time_TravelRephotography/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
Time_TravelRephotography/op/fused_act.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+
8
+
9
+ module_path = os.path.dirname(__file__)
10
+ fused = load(
11
+ 'fused',
12
+ sources=[
13
+ os.path.join(module_path, 'fused_bias_act.cpp'),
14
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
15
+ ],
16
+ )
17
+
18
+
19
+ class FusedLeakyReLUFunctionBackward(Function):
20
+ @staticmethod
21
+ def forward(ctx, grad_output, out, negative_slope, scale):
22
+ ctx.save_for_backward(out)
23
+ ctx.negative_slope = negative_slope
24
+ ctx.scale = scale
25
+
26
+ empty = grad_output.new_empty(0)
27
+
28
+ grad_input = fused.fused_bias_act(
29
+ grad_output, empty, out, 3, 1, negative_slope, scale
30
+ )
31
+
32
+ dim = [0]
33
+
34
+ if grad_input.ndim > 2:
35
+ dim += list(range(2, grad_input.ndim))
36
+
37
+ grad_bias = grad_input.sum(dim).detach()
38
+
39
+ return grad_input, grad_bias
40
+
41
+ @staticmethod
42
+ def backward(ctx, gradgrad_input, gradgrad_bias):
43
+ out, = ctx.saved_tensors
44
+ gradgrad_out = fused.fused_bias_act(
45
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
46
+ )
47
+
48
+ return gradgrad_out, None, None, None
49
+
50
+
51
+ class FusedLeakyReLUFunction(Function):
52
+ @staticmethod
53
+ def forward(ctx, input, bias, negative_slope, scale):
54
+ empty = input.new_empty(0)
55
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
56
+ ctx.save_for_backward(out)
57
+ ctx.negative_slope = negative_slope
58
+ ctx.scale = scale
59
+
60
+ return out
61
+
62
+ @staticmethod
63
+ def backward(ctx, grad_output):
64
+ out, = ctx.saved_tensors
65
+
66
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
67
+ grad_output, out, ctx.negative_slope, ctx.scale
68
+ )
69
+
70
+ return grad_input, grad_bias, None, None
71
+
72
+
73
+ class FusedLeakyReLU(nn.Module):
74
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
75
+ super().__init__()
76
+
77
+ self.bias = nn.Parameter(torch.zeros(channel))
78
+ self.negative_slope = negative_slope
79
+ self.scale = scale
80
+
81
+ def forward(self, input):
82
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
83
+
84
+
85
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
86
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
Time_TravelRephotography/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
Time_TravelRephotography/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
Time_TravelRephotography/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
+ }
Time_TravelRephotography/op/upfirdn2d.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.utils.cpp_extension import load
6
+
7
+
8
+ module_path = os.path.dirname(__file__)
9
+ upfirdn2d_op = load(
10
+ 'upfirdn2d',
11
+ sources=[
12
+ os.path.join(module_path, 'upfirdn2d.cpp'),
13
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
14
+ ],
15
+ )
16
+
17
+
18
+ class UpFirDn2dBackward(Function):
19
+ @staticmethod
20
+ def forward(
21
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
22
+ ):
23
+
24
+ up_x, up_y = up
25
+ down_x, down_y = down
26
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
27
+
28
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
29
+
30
+ grad_input = upfirdn2d_op.upfirdn2d(
31
+ grad_output,
32
+ grad_kernel,
33
+ down_x,
34
+ down_y,
35
+ up_x,
36
+ up_y,
37
+ g_pad_x0,
38
+ g_pad_x1,
39
+ g_pad_y0,
40
+ g_pad_y1,
41
+ )
42
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
43
+
44
+ ctx.save_for_backward(kernel)
45
+
46
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
47
+
48
+ ctx.up_x = up_x
49
+ ctx.up_y = up_y
50
+ ctx.down_x = down_x
51
+ ctx.down_y = down_y
52
+ ctx.pad_x0 = pad_x0
53
+ ctx.pad_x1 = pad_x1
54
+ ctx.pad_y0 = pad_y0
55
+ ctx.pad_y1 = pad_y1
56
+ ctx.in_size = in_size
57
+ ctx.out_size = out_size
58
+
59
+ return grad_input
60
+
61
+ @staticmethod
62
+ def backward(ctx, gradgrad_input):
63
+ kernel, = ctx.saved_tensors
64
+
65
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
66
+
67
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
68
+ gradgrad_input,
69
+ kernel,
70
+ ctx.up_x,
71
+ ctx.up_y,
72
+ ctx.down_x,
73
+ ctx.down_y,
74
+ ctx.pad_x0,
75
+ ctx.pad_x1,
76
+ ctx.pad_y0,
77
+ ctx.pad_y1,
78
+ )
79
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
80
+ gradgrad_out = gradgrad_out.view(
81
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
82
+ )
83
+
84
+ return gradgrad_out, None, None, None, None, None, None, None, None
85
+
86
+
87
+ class UpFirDn2d(Function):
88
+ @staticmethod
89
+ def forward(ctx, input, kernel, up, down, pad):
90
+ up_x, up_y = up
91
+ down_x, down_y = down
92
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
93
+
94
+ kernel_h, kernel_w = kernel.shape
95
+ batch, channel, in_h, in_w = input.shape
96
+ ctx.in_size = input.shape
97
+
98
+ input = input.reshape(-1, in_h, in_w, 1)
99
+
100
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
101
+
102
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
103
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
104
+ ctx.out_size = (out_h, out_w)
105
+
106
+ ctx.up = (up_x, up_y)
107
+ ctx.down = (down_x, down_y)
108
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
109
+
110
+ g_pad_x0 = kernel_w - pad_x0 - 1
111
+ g_pad_y0 = kernel_h - pad_y0 - 1
112
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
113
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
114
+
115
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
116
+
117
+ out = upfirdn2d_op.upfirdn2d(
118
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
119
+ )
120
+ # out = out.view(major, out_h, out_w, minor)
121
+ out = out.view(-1, channel, out_h, out_w)
122
+
123
+ return out
124
+
125
+ @staticmethod
126
+ def backward(ctx, grad_output):
127
+ kernel, grad_kernel = ctx.saved_tensors
128
+
129
+ grad_input = UpFirDn2dBackward.apply(
130
+ grad_output,
131
+ kernel,
132
+ grad_kernel,
133
+ ctx.up,
134
+ ctx.down,
135
+ ctx.pad,
136
+ ctx.g_pad,
137
+ ctx.in_size,
138
+ ctx.out_size,
139
+ )
140
+
141
+ return grad_input, None, None, None, None
142
+
143
+
144
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
145
+ out = UpFirDn2d.apply(
146
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
147
+ )
148
+
149
+ return out
150
+
151
+
152
+ def upfirdn2d_native(
153
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
154
+ ):
155
+ _, in_h, in_w, minor = input.shape
156
+ kernel_h, kernel_w = kernel.shape
157
+
158
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
159
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
160
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
161
+
162
+ out = F.pad(
163
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
164
+ )
165
+ out = out[
166
+ :,
167
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
168
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
169
+ :,
170
+ ]
171
+
172
+ out = out.permute(0, 3, 1, 2)
173
+ out = out.reshape(
174
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
175
+ )
176
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
177
+ out = F.conv2d(out, w)
178
+ out = out.reshape(
179
+ -1,
180
+ minor,
181
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
182
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
183
+ )
184
+ out = out.permute(0, 2, 3, 1)
185
+
186
+ return out[:, ::down_y, ::down_x, :]
187
+
Time_TravelRephotography/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19
+ int c = a / b;
20
+
21
+ if (c * b > a) {
22
+ c--;
23
+ }
24
+
25
+ return c;
26
+ }
27
+
28
+
29
+ struct UpFirDn2DKernelParams {
30
+ int up_x;
31
+ int up_y;
32
+ int down_x;
33
+ int down_y;
34
+ int pad_x0;
35
+ int pad_x1;
36
+ int pad_y0;
37
+ int pad_y1;
38
+
39
+ int major_dim;
40
+ int in_h;
41
+ int in_w;
42
+ int minor_dim;
43
+ int kernel_h;
44
+ int kernel_w;
45
+ int out_h;
46
+ int out_w;
47
+ int loop_major;
48
+ int loop_x;
49
+ };
50
+
51
+
52
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
53
+ __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56
+
57
+ __shared__ volatile float sk[kernel_h][kernel_w];
58
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
59
+
60
+ int minor_idx = blockIdx.x;
61
+ int tile_out_y = minor_idx / p.minor_dim;
62
+ minor_idx -= tile_out_y * p.minor_dim;
63
+ tile_out_y *= tile_out_h;
64
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65
+ int major_idx_base = blockIdx.z * p.loop_major;
66
+
67
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68
+ return;
69
+ }
70
+
71
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72
+ int ky = tap_idx / kernel_w;
73
+ int kx = tap_idx - ky * kernel_w;
74
+ scalar_t v = 0.0;
75
+
76
+ if (kx < p.kernel_w & ky < p.kernel_h) {
77
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78
+ }
79
+
80
+ sk[ky][kx] = v;
81
+ }
82
+
83
+ for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84
+ for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87
+ int tile_in_x = floor_div(tile_mid_x, up_x);
88
+ int tile_in_y = floor_div(tile_mid_y, up_y);
89
+
90
+ __syncthreads();
91
+
92
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93
+ int rel_in_y = in_idx / tile_in_w;
94
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
95
+ int in_x = rel_in_x + tile_in_x;
96
+ int in_y = rel_in_y + tile_in_y;
97
+
98
+ scalar_t v = 0.0;
99
+
100
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102
+ }
103
+
104
+ sx[rel_in_y][rel_in_x] = v;
105
+ }
106
+
107
+ __syncthreads();
108
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109
+ int rel_out_y = out_idx / tile_out_w;
110
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
111
+ int out_x = rel_out_x + tile_out_x;
112
+ int out_y = rel_out_y + tile_out_y;
113
+
114
+ int mid_x = tile_mid_x + rel_out_x * down_x;
115
+ int mid_y = tile_mid_y + rel_out_y * down_y;
116
+ int in_x = floor_div(mid_x, up_x);
117
+ int in_y = floor_div(mid_y, up_y);
118
+ int rel_in_x = in_x - tile_in_x;
119
+ int rel_in_y = in_y - tile_in_y;
120
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122
+
123
+ scalar_t v = 0.0;
124
+
125
+ #pragma unroll
126
+ for (int y = 0; y < kernel_h / up_y; y++)
127
+ #pragma unroll
128
+ for (int x = 0; x < kernel_w / up_x; x++)
129
+ v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130
+
131
+ if (out_x < p.out_w & out_y < p.out_h) {
132
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133
+ }
134
+ }
135
+ }
136
+ }
137
+ }
138
+
139
+
140
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141
+ int up_x, int up_y, int down_x, int down_y,
142
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143
+ int curDevice = -1;
144
+ cudaGetDevice(&curDevice);
145
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146
+
147
+ UpFirDn2DKernelParams p;
148
+
149
+ auto x = input.contiguous();
150
+ auto k = kernel.contiguous();
151
+
152
+ p.major_dim = x.size(0);
153
+ p.in_h = x.size(1);
154
+ p.in_w = x.size(2);
155
+ p.minor_dim = x.size(3);
156
+ p.kernel_h = k.size(0);
157
+ p.kernel_w = k.size(1);
158
+ p.up_x = up_x;
159
+ p.up_y = up_y;
160
+ p.down_x = down_x;
161
+ p.down_y = down_y;
162
+ p.pad_x0 = pad_x0;
163
+ p.pad_x1 = pad_x1;
164
+ p.pad_y0 = pad_y0;
165
+ p.pad_y1 = pad_y1;
166
+
167
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169
+
170
+ auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171
+
172
+ int mode = -1;
173
+
174
+ int tile_out_h;
175
+ int tile_out_w;
176
+
177
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178
+ mode = 1;
179
+ tile_out_h = 16;
180
+ tile_out_w = 64;
181
+ }
182
+
183
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184
+ mode = 2;
185
+ tile_out_h = 16;
186
+ tile_out_w = 64;
187
+ }
188
+
189
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190
+ mode = 3;
191
+ tile_out_h = 16;
192
+ tile_out_w = 64;
193
+ }
194
+
195
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196
+ mode = 4;
197
+ tile_out_h = 16;
198
+ tile_out_w = 64;
199
+ }
200
+
201
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202
+ mode = 5;
203
+ tile_out_h = 8;
204
+ tile_out_w = 32;
205
+ }
206
+
207
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208
+ mode = 6;
209
+ tile_out_h = 8;
210
+ tile_out_w = 32;
211
+ }
212
+
213
+ dim3 block_size;
214
+ dim3 grid_size;
215
+
216
+ if (tile_out_h > 0 && tile_out_w) {
217
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
218
+ p.loop_x = 1;
219
+ block_size = dim3(32 * 8, 1, 1);
220
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222
+ (p.major_dim - 1) / p.loop_major + 1);
223
+ }
224
+
225
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226
+ switch (mode) {
227
+ case 1:
228
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
229
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
230
+ );
231
+
232
+ break;
233
+
234
+ case 2:
235
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
236
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
237
+ );
238
+
239
+ break;
240
+
241
+ case 3:
242
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
243
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
244
+ );
245
+
246
+ break;
247
+
248
+ case 4:
249
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
250
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
251
+ );
252
+
253
+ break;
254
+
255
+ case 5:
256
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
257
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
258
+ );
259
+
260
+ break;
261
+
262
+ case 6:
263
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
264
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
265
+ );
266
+
267
+ break;
268
+ }
269
+ });
270
+
271
+ return out;
272
+ }
Time_TravelRephotography/optim/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim import Adam
2
+ from torch.optim.lbfgs import LBFGS
3
+ from .radam import RAdam
4
+
5
+
6
+ OPTIMIZER_MAP = {
7
+ "adam": Adam,
8
+ "radam": RAdam,
9
+ "lbfgs": LBFGS,
10
+ }
11
+
12
+
13
+ def get_optimizer_class(optimizer_name):
14
+ name = optimizer_name.lower()
15
+ return OPTIMIZER_MAP[name]
Time_TravelRephotography/optim/radam.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.optim.optimizer import Optimizer, required
4
+
5
+
6
+ class RAdam(Optimizer):
7
+
8
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
9
+ if not 0.0 <= lr:
10
+ raise ValueError("Invalid learning rate: {}".format(lr))
11
+ if not 0.0 <= eps:
12
+ raise ValueError("Invalid epsilon value: {}".format(eps))
13
+ if not 0.0 <= betas[0] < 1.0:
14
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
15
+ if not 0.0 <= betas[1] < 1.0:
16
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
17
+
18
+ self.degenerated_to_sgd = degenerated_to_sgd
19
+ if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
20
+ for param in params:
21
+ if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
22
+ param['buffer'] = [[None, None, None] for _ in range(10)]
23
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
24
+ buffer=[[None, None, None] for _ in range(10)])
25
+ super(RAdam, self).__init__(params, defaults)
26
+
27
+ def __setstate__(self, state):
28
+ super(RAdam, self).__setstate__(state)
29
+
30
+ def step(self, closure=None):
31
+
32
+ loss = None
33
+ if closure is not None:
34
+ loss = closure()
35
+
36
+ for group in self.param_groups:
37
+
38
+ for p in group['params']:
39
+ if p.grad is None:
40
+ continue
41
+ grad = p.grad.data.float()
42
+ if grad.is_sparse:
43
+ raise RuntimeError('RAdam does not support sparse gradients')
44
+
45
+ p_data_fp32 = p.data.float()
46
+
47
+ state = self.state[p]
48
+
49
+ if len(state) == 0:
50
+ state['step'] = 0
51
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
52
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
53
+ else:
54
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
55
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
56
+
57
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
58
+ beta1, beta2 = group['betas']
59
+
60
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
61
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
62
+
63
+ state['step'] += 1
64
+ buffered = group['buffer'][int(state['step'] % 10)]
65
+ if state['step'] == buffered[0]:
66
+ N_sma, step_size = buffered[1], buffered[2]
67
+ else:
68
+ buffered[0] = state['step']
69
+ beta2_t = beta2 ** state['step']
70
+ N_sma_max = 2 / (1 - beta2) - 1
71
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
72
+ buffered[1] = N_sma
73
+
74
+ # more conservative since it's an approximated value
75
+ if N_sma >= 5:
76
+ step_size = math.sqrt(
77
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
78
+ N_sma_max - 2)) / (1 - beta1 ** state['step'])
79
+ elif self.degenerated_to_sgd:
80
+ step_size = 1.0 / (1 - beta1 ** state['step'])
81
+ else:
82
+ step_size = -1
83
+ buffered[2] = step_size
84
+
85
+ # more conservative since it's an approximated value
86
+ if N_sma >= 5:
87
+ if group['weight_decay'] != 0:
88
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
89
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
90
+ p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
91
+ p.data.copy_(p_data_fp32)
92
+ elif step_size > 0:
93
+ if group['weight_decay'] != 0:
94
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
95
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
96
+ p.data.copy_(p_data_fp32)
97
+
98
+ return loss
99
+
100
+
101
+ class PlainRAdam(Optimizer):
102
+
103
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
104
+ if not 0.0 <= lr:
105
+ raise ValueError("Invalid learning rate: {}".format(lr))
106
+ if not 0.0 <= eps:
107
+ raise ValueError("Invalid epsilon value: {}".format(eps))
108
+ if not 0.0 <= betas[0] < 1.0:
109
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
110
+ if not 0.0 <= betas[1] < 1.0:
111
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
112
+
113
+ self.degenerated_to_sgd = degenerated_to_sgd
114
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
115
+
116
+ super(PlainRAdam, self).__init__(params, defaults)
117
+
118
+ def __setstate__(self, state):
119
+ super(PlainRAdam, self).__setstate__(state)
120
+
121
+ def step(self, closure=None):
122
+
123
+ loss = None
124
+ if closure is not None:
125
+ loss = closure()
126
+
127
+ for group in self.param_groups:
128
+
129
+ for p in group['params']:
130
+ if p.grad is None:
131
+ continue
132
+ grad = p.grad.data.float()
133
+ if grad.is_sparse:
134
+ raise RuntimeError('RAdam does not support sparse gradients')
135
+
136
+ p_data_fp32 = p.data.float()
137
+
138
+ state = self.state[p]
139
+
140
+ if len(state) == 0:
141
+ state['step'] = 0
142
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
143
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
144
+ else:
145
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
146
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
147
+
148
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
149
+ beta1, beta2 = group['betas']
150
+
151
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
152
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
153
+
154
+ state['step'] += 1
155
+ beta2_t = beta2 ** state['step']
156
+ N_sma_max = 2 / (1 - beta2) - 1
157
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
158
+
159
+ # more conservative since it's an approximated value
160
+ if N_sma >= 5:
161
+ if group['weight_decay'] != 0:
162
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
163
+ step_size = group['lr'] * math.sqrt(
164
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
165
+ N_sma_max - 2)) / (1 - beta1 ** state['step'])
166
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
167
+ p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
168
+ p.data.copy_(p_data_fp32)
169
+ elif self.degenerated_to_sgd:
170
+ if group['weight_decay'] != 0:
171
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
172
+ step_size = group['lr'] / (1 - beta1 ** state['step'])
173
+ p_data_fp32.add_(-step_size, exp_avg)
174
+ p.data.copy_(p_data_fp32)
175
+
176
+ return loss
177
+
178
+
179
+ class AdamW(Optimizer):
180
+
181
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0):
182
+ if not 0.0 <= lr:
183
+ raise ValueError("Invalid learning rate: {}".format(lr))
184
+ if not 0.0 <= eps:
185
+ raise ValueError("Invalid epsilon value: {}".format(eps))
186
+ if not 0.0 <= betas[0] < 1.0:
187
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
188
+ if not 0.0 <= betas[1] < 1.0:
189
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
190
+
191
+ defaults = dict(lr=lr, betas=betas, eps=eps,
192
+ weight_decay=weight_decay, warmup=warmup)
193
+ super(AdamW, self).__init__(params, defaults)
194
+
195
+ def __setstate__(self, state):
196
+ super(AdamW, self).__setstate__(state)
197
+
198
+ def step(self, closure=None):
199
+ loss = None
200
+ if closure is not None:
201
+ loss = closure()
202
+
203
+ for group in self.param_groups:
204
+
205
+ for p in group['params']:
206
+ if p.grad is None:
207
+ continue
208
+ grad = p.grad.data.float()
209
+ if grad.is_sparse:
210
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
211
+
212
+ p_data_fp32 = p.data.float()
213
+
214
+ state = self.state[p]
215
+
216
+ if len(state) == 0:
217
+ state['step'] = 0
218
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
219
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
220
+ else:
221
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
222
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
223
+
224
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
225
+ beta1, beta2 = group['betas']
226
+
227
+ state['step'] += 1
228
+
229
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
230
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
231
+
232
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
233
+ bias_correction1 = 1 - beta1 ** state['step']
234
+ bias_correction2 = 1 - beta2 ** state['step']
235
+
236
+ if group['warmup'] > state['step']:
237
+ scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
238
+ else:
239
+ scheduled_lr = group['lr']
240
+
241
+ step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
242
+
243
+ if group['weight_decay'] != 0:
244
+ p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
245
+
246
+ p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
247
+
248
+ p.data.copy_(p_data_fp32)
249
+
250
+ return loss
Time_TravelRephotography/projector.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import os
3
+ from os.path import join as pjoin
4
+ import random
5
+ import sys
6
+ from typing import (
7
+ Iterable,
8
+ Optional,
9
+ )
10
+
11
+ import cv2
12
+ import numpy as np
13
+ from PIL import Image
14
+ import torch
15
+ from torch.utils.tensorboard import SummaryWriter
16
+ from torchvision.transforms import (
17
+ Compose,
18
+ Grayscale,
19
+ Resize,
20
+ ToTensor,
21
+ Normalize,
22
+ )
23
+
24
+ from losses.joint_loss import JointLoss
25
+ from model import Generator
26
+ from tools.initialize import Initializer
27
+ from tools.match_skin_histogram import match_skin_histogram
28
+ from utils.projector_arguments import ProjectorArguments
29
+ from utils import torch_helpers as th
30
+ from utils.torch_helpers import make_image
31
+ from utils.misc import stem
32
+ from utils.optimize import Optimizer
33
+ from models.degrade import (
34
+ Degrade,
35
+ Downsample,
36
+ )
37
+
38
+
39
+ def set_random_seed(seed: int):
40
+ # FIXME (xuanluo): this setup still allows randomness somehow
41
+ torch.manual_seed(seed)
42
+ random.seed(seed)
43
+ np.random.seed(seed)
44
+
45
+
46
+ def read_images(paths: str, max_size: Optional[int] = None):
47
+ transform = Compose(
48
+ [
49
+ Grayscale(),
50
+ ToTensor(),
51
+ ]
52
+ )
53
+
54
+ imgs = []
55
+ for path in paths:
56
+ img = Image.open(path)
57
+ if max_size is not None and img.width > max_size:
58
+ img = img.resize((max_size, max_size))
59
+ img = transform(img)
60
+ imgs.append(img)
61
+ imgs = torch.stack(imgs, 0)
62
+ return imgs
63
+
64
+
65
+ def normalize(img: torch.Tensor, mean=0.5, std=0.5):
66
+ """[0, 1] -> [-1, 1]"""
67
+ return (img - mean) / std
68
+
69
+
70
+ def create_generator(args: Namespace, device: torch.device):
71
+ generator = Generator(args.generator_size, 512, 8)
72
+ generator.load_state_dict(torch.load(args.ckpt)['g_ema'], strict=False)
73
+ generator.eval()
74
+ generator = generator.to(device)
75
+ return generator
76
+
77
+
78
+ def save(
79
+ path_prefixes: Iterable[str],
80
+ imgs: torch.Tensor, # BCHW
81
+ latents: torch.Tensor,
82
+ noises: torch.Tensor,
83
+ imgs_rand: Optional[torch.Tensor] = None,
84
+ ):
85
+ assert len(path_prefixes) == len(imgs) and len(latents) == len(path_prefixes)
86
+ if imgs_rand is not None:
87
+ assert len(imgs) == len(imgs_rand)
88
+ imgs_arr = make_image(imgs)
89
+ for path_prefix, img, latent, noise in zip(path_prefixes, imgs_arr, latents, noises):
90
+ os.makedirs(os.path.dirname(path_prefix), exist_ok=True)
91
+ cv2.imwrite(path_prefix + ".png", img[...,::-1])
92
+ torch.save({"latent": latent.detach().cpu(), "noise": noise.detach().cpu()},
93
+ path_prefix + ".pt")
94
+
95
+ if imgs_rand is not None:
96
+ imgs_arr = make_image(imgs_rand)
97
+ for path_prefix, img in zip(path_prefixes, imgs_arr):
98
+ cv2.imwrite(path_prefix + "-rand.png", img[...,::-1])
99
+
100
+
101
+ def main(args):
102
+ opt_str = ProjectorArguments.to_string(args)
103
+ print(opt_str)
104
+
105
+ if args.rand_seed is not None:
106
+ set_random_seed(args.rand_seed)
107
+ device = th.device()
108
+
109
+ # read inputs. TODO imgs_orig has channel 1
110
+ imgs_orig = read_images([args.input], max_size=args.generator_size).to(device)
111
+ imgs = normalize(imgs_orig) # actually this will be overwritten by the histogram matching result
112
+
113
+ # initialize
114
+ with torch.no_grad():
115
+ init = Initializer(args).to(device)
116
+ latent_init = init(imgs_orig)
117
+
118
+ # create generator
119
+ generator = create_generator(args, device)
120
+
121
+ # init noises
122
+ with torch.no_grad():
123
+ noises_init = generator.make_noise()
124
+
125
+ # create a new input by matching the input's histogram to the sibling image
126
+ with torch.no_grad():
127
+ sibling, _, sibling_rgbs = generator([latent_init], input_is_latent=True, noise=noises_init)
128
+ mh_dir = pjoin(args.results_dir, stem(args.input))
129
+ imgs = match_skin_histogram(
130
+ imgs, sibling,
131
+ args.spectral_sensitivity,
132
+ pjoin(mh_dir, "input_sibling"),
133
+ pjoin(mh_dir, "skin_mask"),
134
+ matched_hist_fn=mh_dir.rstrip(os.sep) + f"_{args.spectral_sensitivity}.png",
135
+ normalize=normalize,
136
+ ).to(device)
137
+ torch.cuda.empty_cache()
138
+ # TODO imgs has channel 3
139
+
140
+ degrade = Degrade(args).to(device)
141
+
142
+ rgb_levels = generator.get_latent_size(args.coarse_min) // 2 + len(args.wplus_step) - 1
143
+ criterion = JointLoss(
144
+ args, imgs,
145
+ sibling=sibling.detach(), sibling_rgbs=sibling_rgbs[:rgb_levels]).to(device)
146
+
147
+ # save initialization
148
+ save(
149
+ [pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}-init")],
150
+ sibling, latent_init, noises_init,
151
+ )
152
+
153
+ writer = SummaryWriter(pjoin(args.log_dir, f"{stem(args.input)}/{opt_str}"))
154
+ # start optimize
155
+ latent, noises = Optimizer.optimize(generator, criterion, degrade, imgs, latent_init, noises_init, args, writer=writer)
156
+
157
+ # generate output
158
+ img_out, _, _ = generator([latent], input_is_latent=True, noise=noises)
159
+ img_out_rand_noise, _, _ = generator([latent], input_is_latent=True)
160
+ # save output
161
+ save(
162
+ [pjoin(args.results_dir, f"{stem(args.input)}-{opt_str}")],
163
+ img_out, latent, noises,
164
+ imgs_rand=img_out_rand_noise
165
+ )
166
+
167
+
168
+ def parse_args():
169
+ return ProjectorArguments().parse()
170
+
171
+ if __name__ == "__main__":
172
+ sys.exit(main(parse_args()))
Time_TravelRephotography/requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Torch
2
+ #--find-links https://download.pytorch.org/whl/torch_stable.html
3
+ #torch==1.4.0+cu100
4
+ #torchvision==0.11.2+cu100
5
+ #torchaudio==0.10.1+cu100
6
+ #setuptools==59.5.0
7
+
8
+ Pillow
9
+ ninja
10
+ tqdm
11
+ opencv-python
12
+ scikit-image
13
+ numpy
14
+
15
+ tensorboard
16
+
17
+ # for face alignment
18
+ tensorflow
19
+ #keras
20
+ #bz2
21
+ dlib
22
+ scipy
23
+
24
+ matplotlib
25
+ pprintpp
26
+ huggingface_hub
Time_TravelRephotography/scripts/download_checkpoints.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -exo
2
+
3
+ mkdir -p checkpoint
4
+ gdown https://drive.google.com/uc?id=1hWc2JLM58_PkwfLG23Q5IH3Ysj2Mo1nr -O checkpoint/e4e_ffhq_encode.pt
5
+ gdown https://drive.google.com/uc?id=1hvAAql9Jo0wlmLBSHRIGrtXHcKQE-Whn -O checkpoint/stylegan2-ffhq-config-f.pt
6
+ gdown https://drive.google.com/uc?id=1mbGWbjivZxMGxZqyyOHbE310aOkYe2BR -O checkpoint/vgg_face_dag.pt
7
+ mkdir -p checkpoint/encoder
8
+ gdown https://drive.google.com/uc?id=1ha4WXsaIpZfMHsqNLvqOPlUXsgh9VawU -O checkpoint/encoder/checkpoint_b.pt
9
+ gdown https://drive.google.com/uc?id=1hfxDLujRIGU0G7pOdW9MMSBRzxZBmSKJ -O checkpoint/encoder/checkpoint_g.pt
10
+ gdown https://drive.google.com/uc?id=1htekHopgxaW-MIjs6pYy7pyIK0v7Q0iS -O checkpoint/encoder/checkpoint_gb.pt
11
+
12
+ pushd third_party/face_parsing
13
+ ./scripts/download_checkpoints.sh
14
+ popd
Time_TravelRephotography/scripts/install.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # conda create -n stylegan python=3.7
2
+ # conda activate stylegan
3
+ conda install -c conda-forge/label/gcc7 opencv --yes
4
+ conda install tensorflow-gpu=1.15 cudatoolkit=10.0 --yes
5
+ conda install pytorch torchvision cudatoolkit=10.0 -c pytorch --yes
6
+ pip install -r requirements.txt
Time_TravelRephotography/scripts/run.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ # Example command
4
+ # ```
5
+ # ./scripts/run.sh b "dataset/Abraham Lincoln_01.png" 0.75
6
+ # ```
7
+
8
+ spectral_sensitivity="$1"
9
+ path="$2"
10
+ blur_radius="$3"
11
+
12
+
13
+ list="$(dirname "${path}")"
14
+ list="$(basename "${list}")"
15
+
16
+ if [ "${spectral_sensitivity}" == "b" ]; then
17
+ FLAGS=(--spectral_sensitivity b --encoder_ckpt checkpoint/encoder/checkpoint_b.pt);
18
+ elif [ "${spectral_sensitivity}" == "gb" ]; then
19
+ FLAGS=(--spectral_sensitivity "gb" --encoder_ckpt checkpoint/encoder/checkpoint_gb.pt);
20
+ else
21
+ FLAGS=(--spectral_sensitivity "g" --encoder_ckpt checkpoint/encoder/checkpoint_g.pt);
22
+ fi
23
+
24
+ name="${path%.*}"
25
+ name="${name##*/}"
26
+ echo "${name}"
27
+
28
+ # TODO: I did l2 or cos for contextual
29
+ time python projector.py \
30
+ "${path}" \
31
+ --gaussian "${blur_radius}" \
32
+ --log_dir "log/" \
33
+ --results_dir "results/" \
34
+ "${FLAGS[@]}"
Time_TravelRephotography/tools/__init__.py ADDED
File without changes
Time_TravelRephotography/tools/data/__init__.py ADDED
File without changes
Time_TravelRephotography/tools/data/align_images.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from os.path import join as pjoin
5
+ import sys
6
+ import bz2
7
+ import numpy as np
8
+ import cv2
9
+ from tqdm import tqdm
10
+ from tensorflow.keras.utils import get_file
11
+ from utils.ffhq_dataset.face_alignment import image_align
12
+ from utils.ffhq_dataset.landmarks_detector import LandmarksDetector
13
+
14
+ LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'
15
+
16
+
17
+ def unpack_bz2(src_path):
18
+ data = bz2.BZ2File(src_path).read()
19
+ dst_path = src_path[:-4]
20
+ with open(dst_path, 'wb') as fp:
21
+ fp.write(data)
22
+ return dst_path
23
+
24
+
25
+ class SizePathMap(dict):
26
+ """{size: {aligned_face_path0, aligned_face_path1, ...}, ...}"""
27
+ def add_item(self, size, path):
28
+ if size not in self:
29
+ self[size] = set()
30
+ self[size].add(path)
31
+
32
+ def get_sizes(self):
33
+ sizes = []
34
+ for key, paths in self.items():
35
+ sizes.extend([key,]*len(paths))
36
+ return sizes
37
+
38
+ def serialize(self):
39
+ result = {}
40
+ for key, paths in self.items():
41
+ result[key] = list(paths)
42
+ return result
43
+
44
+
45
+ def main(args):
46
+ landmarks_model_path = unpack_bz2(get_file('shape_predictor_68_face_landmarks.dat.bz2',
47
+ LANDMARKS_MODEL_URL, cache_subdir='temp'))
48
+
49
+ landmarks_detector = LandmarksDetector(landmarks_model_path)
50
+ face_sizes = SizePathMap()
51
+ raw_img_dir = args.raw_image_dir
52
+ img_names = [n for n in os.listdir(raw_img_dir) if os.path.isfile(pjoin(raw_img_dir, n))]
53
+ aligned_image_dir = args.aligned_image_dir
54
+ os.makedirs(aligned_image_dir, exist_ok=True)
55
+ pbar = tqdm(img_names)
56
+ for img_name in pbar:
57
+ pbar.set_description(img_name)
58
+ if os.path.splitext(img_name)[-1] == '.txt':
59
+ continue
60
+ raw_img_path = os.path.join(raw_img_dir, img_name)
61
+ try:
62
+ for i, face_landmarks in enumerate(landmarks_detector.get_landmarks(raw_img_path), start=1):
63
+ face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
64
+ aligned_face_path = os.path.join(aligned_image_dir, face_img_name)
65
+
66
+ face_size = image_align(
67
+ raw_img_path, aligned_face_path, face_landmarks, resize=args.resize
68
+ )
69
+ face_sizes.add_item(face_size, aligned_face_path)
70
+ pbar.set_description(f"{img_name}: {face_size}")
71
+
72
+ if args.draw:
73
+ visual = LandmarksDetector.draw(cv2.imread(raw_img_path), face_landmarks)
74
+ cv2.imwrite(
75
+ pjoin(args.aligned_image_dir, os.path.splitext(face_img_name)[0] + "_landmarks.png"),
76
+ visual
77
+ )
78
+ except Exception as e:
79
+ print('[Error]', e, 'error happened when processing', raw_img_path)
80
+
81
+ print(args.raw_image_dir, ':')
82
+ sizes = face_sizes.get_sizes()
83
+ results = {
84
+ 'mean_size': np.mean(sizes),
85
+ 'num_faces_detected': len(sizes),
86
+ 'num_images': len(img_names),
87
+ 'sizes': sizes,
88
+ 'size_path_dict': face_sizes.serialize(),
89
+ }
90
+ print('\t', results)
91
+ if args.out_stats is not None:
92
+ os.makedirs(os.path.dirname(args.out_stats), exist_ok=True)
93
+ with open(out_stats, 'w') as f:
94
+ json.dump(results, f)
95
+
96
+
97
+ def parse_args(args=None, namespace=None):
98
+ parser = argparse.ArgumentParser(description="""
99
+ Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
100
+ python align_images.py /raw_images /aligned_images
101
+ """
102
+ )
103
+ parser.add_argument('raw_image_dir')
104
+ parser.add_argument('aligned_image_dir')
105
+ parser.add_argument('--resize',
106
+ help="True if want to resize to 1024",
107
+ action='store_true')
108
+ parser.add_argument('--draw',
109
+ help="True if want to visualize landmarks",
110
+ action='store_true')
111
+ parser.add_argument('--out_stats',
112
+ help="output_fn for statistics of faces", default=None)
113
+ return parser.parse_args(args=args, namespace=namespace)
114
+
115
+
116
+ if __name__ == "__main__":
117
+ main(parse_args())
Time_TravelRephotography/tools/initialize.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser, Namespace
2
+ from typing import (
3
+ List,
4
+ Tuple,
5
+ )
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ from torch import nn
11
+ import torch.nn.functional as F
12
+ from torchvision.transforms import (
13
+ Compose,
14
+ Grayscale,
15
+ Resize,
16
+ ToTensor,
17
+ )
18
+
19
+ from models.encoder import Encoder
20
+ from models.encoder4editing import (
21
+ get_latents as get_e4e_latents,
22
+ setup_model as setup_e4e_model,
23
+ )
24
+ from utils.misc import (
25
+ optional_string,
26
+ iterable_to_str,
27
+ stem,
28
+ )
29
+
30
+
31
+
32
+ class ColorEncoderArguments:
33
+ def __init__(self):
34
+ parser = ArgumentParser("Encode an image via a feed-forward encoder")
35
+
36
+ self.add_arguments(parser)
37
+
38
+ self.parser = parser
39
+
40
+ @staticmethod
41
+ def add_arguments(parser: ArgumentParser):
42
+ parser.add_argument("--encoder_ckpt", default=None,
43
+ help="encoder checkpoint path. initialize w with encoder output if specified")
44
+ parser.add_argument("--encoder_size", type=int, default=256,
45
+ help="Resize to this size to pass as input to the encoder")
46
+
47
+
48
+ class InitializerArguments:
49
+ @classmethod
50
+ def add_arguments(cls, parser: ArgumentParser):
51
+ ColorEncoderArguments.add_arguments(parser)
52
+ cls.add_e4e_arguments(parser)
53
+ parser.add_argument("--mix_layer_range", default=[10, 18], type=int, nargs=2,
54
+ help="replace layers <start> to <end> in the e4e code by the color code")
55
+
56
+ parser.add_argument("--init_latent", default=None, help="path to init wp")
57
+
58
+ @staticmethod
59
+ def to_string(args: Namespace):
60
+ return (f"init{stem(args.init_latent).lstrip('0')[:10]}" if args.init_latent
61
+ else f"init({iterable_to_str(args.mix_layer_range)})")
62
+ #+ optional_string(args.init_noise > 0, f"-initN{args.init_noise}")
63
+
64
+ @staticmethod
65
+ def add_e4e_arguments(parser: ArgumentParser):
66
+ parser.add_argument("--e4e_ckpt", default='checkpoint/e4e_ffhq_encode.pt',
67
+ help="e4e checkpoint path.")
68
+ parser.add_argument("--e4e_size", type=int, default=256,
69
+ help="Resize to this size to pass as input to the e4e")
70
+
71
+
72
+
73
+ def create_color_encoder(args: Namespace):
74
+ encoder = Encoder(1, args.encoder_size, 512)
75
+ ckpt = torch.load(args.encoder_ckpt)
76
+ encoder.load_state_dict(ckpt["model"])
77
+ return encoder
78
+
79
+
80
+ def transform_input(img: Image):
81
+ tsfm = Compose([
82
+ Grayscale(),
83
+ Resize(args.encoder_size),
84
+ ToTensor(),
85
+ ])
86
+ return tsfm(img)
87
+
88
+
89
+ def encode_color(imgs: torch.Tensor, args: Namespace) -> torch.Tensor:
90
+ assert args.encoder_size is not None
91
+
92
+ imgs = Resize(args.encoder_size)(imgs)
93
+
94
+ color_encoder = create_color_encoder(args).to(imgs.device)
95
+ color_encoder.eval()
96
+ with torch.no_grad():
97
+ latent = color_encoder(imgs)
98
+ return latent.detach()
99
+
100
+
101
+ def resize(imgs: torch.Tensor, size: int) -> torch.Tensor:
102
+ return F.interpolate(imgs, size=size, mode='bilinear')
103
+
104
+
105
+ class Initializer(nn.Module):
106
+ def __init__(self, args: Namespace):
107
+ super().__init__()
108
+
109
+ self.path = None
110
+ if args.init_latent is not None:
111
+ self.path = args.init_latent
112
+ return
113
+
114
+
115
+ assert args.encoder_size is not None
116
+ self.color_encoder = create_color_encoder(args)
117
+ self.color_encoder.eval()
118
+ self.color_encoder_size = args.encoder_size
119
+
120
+ self.e4e, e4e_opts = setup_e4e_model(args.e4e_ckpt)
121
+ assert 'cars_' not in e4e_opts.dataset_type
122
+ self.e4e.decoder.eval()
123
+ self.e4e.eval()
124
+ self.e4e_size = args.e4e_size
125
+
126
+ self.mix_layer_range = args.mix_layer_range
127
+
128
+ def encode_color(self, imgs: torch.Tensor) -> torch.Tensor:
129
+ """
130
+ Get the color W code
131
+ """
132
+ imgs = resize(imgs, self.color_encoder_size)
133
+
134
+ latent = self.color_encoder(imgs)
135
+
136
+ return latent
137
+
138
+ def encode_shape(self, imgs: torch.Tensor) -> torch.Tensor:
139
+ imgs = resize(imgs, self.e4e_size)
140
+ imgs = (imgs - 0.5) / 0.5
141
+ if imgs.shape[1] == 1: # 1 channel
142
+ imgs = imgs.repeat(1, 3, 1, 1)
143
+ return get_e4e_latents(self.e4e, imgs)
144
+
145
+ def load(self, device: torch.device):
146
+ latent_np = np.load(self.path)
147
+ return torch.tensor(latent_np, device=device)[None, ...]
148
+
149
+ def forward(self, imgs: torch.Tensor) -> torch.Tensor:
150
+ if self.path is not None:
151
+ return self.load(imgs.device)
152
+
153
+ shape_code = self.encode_shape(imgs)
154
+ color_code = self.encode_color(imgs)
155
+
156
+ # style mix
157
+ latent = shape_code
158
+ start, end = self.mix_layer_range
159
+ latent[:, start:end] = color_code
160
+ return latent
Time_TravelRephotography/tools/match_histogram.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import (
2
+ ArgumentParser,
3
+ Namespace,
4
+ )
5
+ import os
6
+ from os.path import join as pjoin
7
+ from typing import Optional
8
+ import sys
9
+
10
+ import numpy as np
11
+ import cv2
12
+ from skimage import exposure
13
+
14
+
15
+ # sys.path.append('Face_Detection')
16
+ # from align_warp_back_multiple_dlib import match_histograms
17
+
18
+
19
+ def calculate_cdf(histogram):
20
+ """
21
+ This method calculates the cumulative distribution function
22
+ :param array histogram: The values of the histogram
23
+ :return: normalized_cdf: The normalized cumulative distribution function
24
+ :rtype: array
25
+ """
26
+ # Get the cumulative sum of the elements
27
+ cdf = histogram.cumsum()
28
+
29
+ # Normalize the cdf
30
+ normalized_cdf = cdf / float(cdf.max())
31
+
32
+ return normalized_cdf
33
+
34
+
35
+ def calculate_lookup(src_cdf, ref_cdf):
36
+ """
37
+ This method creates the lookup table
38
+ :param array src_cdf: The cdf for the source image
39
+ :param array ref_cdf: The cdf for the reference image
40
+ :return: lookup_table: The lookup table
41
+ :rtype: array
42
+ """
43
+ lookup_table = np.zeros(256)
44
+ lookup_val = 0
45
+ for src_pixel_val in range(len(src_cdf)):
46
+ lookup_val
47
+ for ref_pixel_val in range(len(ref_cdf)):
48
+ if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]:
49
+ lookup_val = ref_pixel_val
50
+ break
51
+ lookup_table[src_pixel_val] = lookup_val
52
+ return lookup_table
53
+
54
+
55
+ def match_histograms(src_image, ref_image, src_mask=None, ref_mask=None):
56
+ """
57
+ This method matches the source image histogram to the
58
+ reference signal
59
+ :param image src_image: The original source image
60
+ :param image ref_image: The reference image
61
+ :return: image_after_matching
62
+ :rtype: image (array)
63
+ """
64
+ # Split the images into the different color channels
65
+ # b means blue, g means green and r means red
66
+ src_b, src_g, src_r = cv2.split(src_image)
67
+ ref_b, ref_g, ref_r = cv2.split(ref_image)
68
+
69
+ def rv(im):
70
+ if ref_mask is None:
71
+ return im.flatten()
72
+ return im[ref_mask]
73
+
74
+ def sv(im):
75
+ if src_mask is None:
76
+ return im.flatten()
77
+ return im[src_mask]
78
+
79
+ # Compute the b, g, and r histograms separately
80
+ # The flatten() Numpy method returns a copy of the array c
81
+ # collapsed into one dimension.
82
+ src_hist_blue, bin_0 = np.histogram(sv(src_b), 256, [0, 256])
83
+ src_hist_green, bin_1 = np.histogram(sv(src_g), 256, [0, 256])
84
+ src_hist_red, bin_2 = np.histogram(sv(src_r), 256, [0, 256])
85
+ ref_hist_blue, bin_3 = np.histogram(rv(ref_b), 256, [0, 256])
86
+ ref_hist_green, bin_4 = np.histogram(rv(ref_g), 256, [0, 256])
87
+ ref_hist_red, bin_5 = np.histogram(rv(ref_r), 256, [0, 256])
88
+
89
+ # Compute the normalized cdf for the source and reference image
90
+ src_cdf_blue = calculate_cdf(src_hist_blue)
91
+ src_cdf_green = calculate_cdf(src_hist_green)
92
+ src_cdf_red = calculate_cdf(src_hist_red)
93
+ ref_cdf_blue = calculate_cdf(ref_hist_blue)
94
+ ref_cdf_green = calculate_cdf(ref_hist_green)
95
+ ref_cdf_red = calculate_cdf(ref_hist_red)
96
+
97
+ # Make a separate lookup table for each color
98
+ blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue)
99
+ green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green)
100
+ red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red)
101
+
102
+ # Use the lookup function to transform the colors of the original
103
+ # source image
104
+ blue_after_transform = cv2.LUT(src_b, blue_lookup_table)
105
+ green_after_transform = cv2.LUT(src_g, green_lookup_table)
106
+ red_after_transform = cv2.LUT(src_r, red_lookup_table)
107
+
108
+ # Put the image back together
109
+ image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform])
110
+ image_after_matching = cv2.convertScaleAbs(image_after_matching)
111
+
112
+ return image_after_matching
113
+
114
+
115
+ def convert_to_BW(im, mode):
116
+ if mode == "b":
117
+ gray = im[..., 0]
118
+ elif mode == "gb":
119
+ gray = (im[..., 0].astype(float) + im[..., 1]) / 2.0
120
+ else:
121
+ gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
122
+ gray = gray.astype(np.uint8)
123
+
124
+ return np.stack([gray] * 3, axis=-1)
125
+
126
+
127
+ def parse_args(args=None, namespace: Optional[Namespace] = None):
128
+ parser = ArgumentParser('match histogram of src to ref')
129
+ parser.add_argument('src')
130
+ parser.add_argument('ref')
131
+ parser.add_argument('--out', default=None, help="converted src that matches ref")
132
+ parser.add_argument('--src_mask', default=None, help="mask on which to match the histogram")
133
+ parser.add_argument('--ref_mask', default=None, help="mask on which to match the histogram")
134
+ parser.add_argument('--spectral_sensitivity', choices=['b', 'gb', 'g'], help="match the histogram of corresponding sensitive channel(s)")
135
+ parser.add_argument('--crop', type=int, default=0, help="crop the boundary to match")
136
+ return parser.parse_args(args=args, namespace=namespace)
137
+
138
+
139
+ def main(args):
140
+ A = cv2.imread(args.ref)
141
+ A = convert_to_BW(A, args.spectral_sensitivity)
142
+ B = cv2.imread(args.src, 0)
143
+ B = np.stack((B,) * 3, axis=-1)
144
+
145
+ mask_A = cv2.resize(cv2.imread(args.ref_mask, 0), A.shape[:2][::-1],
146
+ interpolation=cv2.INTER_NEAREST) > 0 if args.ref_mask else None
147
+ mask_B = cv2.resize(cv2.imread(args.src_mask, 0), B.shape[:2][::-1],
148
+ interpolation=cv2.INTER_NEAREST) > 0 if args.src_mask else None
149
+
150
+ if args.crop > 0:
151
+ c = args.crop
152
+ bc = int(c / A.shape[0] * B.shape[0] + 0.5)
153
+ A = A[c:-c, c:-c]
154
+ B = B[bc:-bc, bc:-bc]
155
+
156
+ B = match_histograms(B, A, src_mask=mask_B, ref_mask=mask_A)
157
+ # B = exposure.match_histograms(B, A, multichannel=True)
158
+
159
+ if args.out:
160
+ os.makedirs(os.path.dirname(args.out), exist_ok=True)
161
+ cv2.imwrite(args.out, B)
162
+
163
+ return B
164
+
165
+
166
+ if __name__ == "__main__":
167
+ main(parse_args())
Time_TravelRephotography/tools/match_skin_histogram.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import os
3
+ from os.path import join as pjoin
4
+ from typing import Optional
5
+
6
+ import cv2
7
+ import torch
8
+
9
+ from tools import (
10
+ parse_face,
11
+ match_histogram,
12
+ )
13
+ from utils.torch_helpers import make_image
14
+ from utils.misc import stem
15
+
16
+
17
+ def match_skin_histogram(
18
+ imgs: torch.Tensor,
19
+ sibling_img: torch.Tensor,
20
+ spectral_sensitivity,
21
+ im_sibling_dir: str,
22
+ mask_dir: str,
23
+ matched_hist_fn: Optional[str] = None,
24
+ normalize=None, # normalize the range of the tensor
25
+ ):
26
+ """
27
+ Extract the skin of the input and sibling images. Create a new input image by matching
28
+ its histogram to the sibling.
29
+ """
30
+ # TODO: Currently only allows imgs of batch size 1
31
+ im_sibling_dir = os.path.abspath(im_sibling_dir)
32
+ mask_dir = os.path.abspath(mask_dir)
33
+
34
+ img_np = make_image(imgs)[0]
35
+ sibling_np = make_image(sibling_img)[0][...,::-1]
36
+
37
+ # save img, sibling
38
+ os.makedirs(im_sibling_dir, exist_ok=True)
39
+ im_name, sibling_name = 'input.png', 'sibling.png'
40
+ cv2.imwrite(pjoin(im_sibling_dir, im_name), img_np)
41
+ cv2.imwrite(pjoin(im_sibling_dir, sibling_name), sibling_np)
42
+
43
+ # face parsing
44
+ parse_face.main(
45
+ Namespace(in_dir=im_sibling_dir, out_dir=mask_dir, include_hair=False)
46
+ )
47
+
48
+ # match_histogram
49
+ mh_args = match_histogram.parse_args(
50
+ args=[
51
+ pjoin(im_sibling_dir, im_name),
52
+ pjoin(im_sibling_dir, sibling_name),
53
+ ],
54
+ namespace=Namespace(
55
+ out=matched_hist_fn if matched_hist_fn else pjoin(im_sibling_dir, "match_histogram.png"),
56
+ src_mask=pjoin(mask_dir, im_name),
57
+ ref_mask=pjoin(mask_dir, sibling_name),
58
+ spectral_sensitivity=spectral_sensitivity,
59
+ )
60
+ )
61
+ matched_np = match_histogram.main(mh_args) / 255.0 # [0, 1]
62
+ matched = torch.FloatTensor(matched_np).permute(2, 0, 1)[None,...] #BCHW
63
+
64
+ if normalize is not None:
65
+ matched = normalize(matched)
66
+
67
+ return matched
Time_TravelRephotography/tools/parse_face.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import os
3
+ from os.path import join as pjoin
4
+ from subprocess import run
5
+
6
+ import numpy as np
7
+ import cv2
8
+ from tqdm import tqdm
9
+
10
+
11
+ def create_skin_mask(anno_dir, mask_dir, skin_thresh=13, include_hair=False):
12
+ names = os.listdir(anno_dir)
13
+ names = [n for n in names if n.endswith('.png')]
14
+ os.makedirs(mask_dir, exist_ok=True)
15
+ for name in tqdm(names):
16
+ anno = cv2.imread(pjoin(anno_dir, name), 0)
17
+ mask = np.logical_and(0 < anno, anno <= skin_thresh)
18
+ if include_hair:
19
+ mask |= anno == 17
20
+ cv2.imwrite(pjoin(mask_dir, name), mask * 255)
21
+
22
+
23
+ def main(args):
24
+ FACE_PARSING_DIR = 'third_party/face_parsing'
25
+
26
+ main_env = os.getcwd()
27
+ os.chdir(FACE_PARSING_DIR)
28
+ tmp_parse_dir = pjoin(args.out_dir, 'face_parsing')
29
+ cmd = [
30
+ 'python',
31
+ 'test.py',
32
+ args.in_dir,
33
+ tmp_parse_dir,
34
+ ]
35
+ print(' '.join(cmd))
36
+ run(cmd)
37
+
38
+ create_skin_mask(tmp_parse_dir, args.out_dir, include_hair=args.include_hair)
39
+
40
+ os.chdir(main_env)
41
+
42
+
43
+ def parse_args(args=None, namespace=None):
44
+ parser = ArgumentParser("Face Parsing and generate skin (& hair) mask")
45
+ parser.add_argument('in_dir')
46
+ parser.add_argument('out_dir')
47
+ parser.add_argument('--include_hair', action="store_true", help="include hair in the mask")
48
+ return parser.parse_args(args=args, namespace=namespace)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main(parse_args())
53
+
54
+
55
+
Time_TravelRephotography/utils/__init__.py ADDED
File without changes
Time_TravelRephotography/utils/ffhq_dataset/__init__.py ADDED
File without changes
Time_TravelRephotography/utils/ffhq_dataset/face_alignment.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.ndimage
3
+ import os
4
+ import PIL.Image
5
+
6
+
7
+ def image_align(src_file, dst_file, face_landmarks, resize=True, output_size=1024, transform_size=4096, enable_padding=True):
8
+ # Align function from FFHQ dataset pre-processing step
9
+ # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
10
+
11
+ lm = np.array(face_landmarks)
12
+ lm_chin = lm[0 : 17] # left-right
13
+ lm_eyebrow_left = lm[17 : 22] # left-right
14
+ lm_eyebrow_right = lm[22 : 27] # left-right
15
+ lm_nose = lm[27 : 31] # top-down
16
+ lm_nostrils = lm[31 : 36] # top-down
17
+ lm_eye_left = lm[36 : 42] # left-clockwise
18
+ lm_eye_right = lm[42 : 48] # left-clockwise
19
+ lm_mouth_outer = lm[48 : 60] # left-clockwise
20
+ lm_mouth_inner = lm[60 : 68] # left-clockwise
21
+
22
+ # Calculate auxiliary vectors.
23
+ eye_left = np.mean(lm_eye_left, axis=0)
24
+ eye_right = np.mean(lm_eye_right, axis=0)
25
+ eye_avg = (eye_left + eye_right) * 0.5
26
+ eye_to_eye = eye_right - eye_left
27
+ mouth_left = lm_mouth_outer[0]
28
+ mouth_right = lm_mouth_outer[6]
29
+ mouth_avg = (mouth_left + mouth_right) * 0.5
30
+ eye_to_mouth = mouth_avg - eye_avg
31
+
32
+ # Choose oriented crop rectangle.
33
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
34
+ x /= np.hypot(*x)
35
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
36
+ y = np.flipud(x) * [-1, 1]
37
+ c = eye_avg + eye_to_mouth * 0.1
38
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
39
+ qsize = np.hypot(*x) * 2
40
+
41
+ # Load in-the-wild image.
42
+ if not os.path.isfile(src_file):
43
+ print('\nCannot find source image. Please run "--wilds" before "--align".')
44
+ return
45
+ #img = cv2.imread(src_file)
46
+ #img = PIL.Image.fromarray(img)
47
+ img = PIL.Image.open(src_file)
48
+
49
+ # Shrink.
50
+ shrink = int(np.floor(qsize / output_size * 0.5))
51
+ if shrink > 1:
52
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
53
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
54
+ quad /= shrink
55
+ qsize /= shrink
56
+
57
+ # Crop.
58
+ border = max(int(np.rint(qsize * 0.1)), 3)
59
+ crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
60
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
61
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
62
+ img = img.crop(crop)
63
+ quad -= crop[0:2]
64
+
65
+ # Pad.
66
+ pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
67
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
68
+ if enable_padding and max(pad) > border - 4:
69
+ img = np.float32(img)
70
+ if img.ndim == 2:
71
+ img = np.stack((img,)*3, axis=-1)
72
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
73
+ img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
74
+ h, w, _ = img.shape
75
+ y, x, _ = np.ogrid[:h, :w, :1]
76
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
77
+ blur = qsize * 0.02
78
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
79
+ img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
80
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
81
+ quad += pad[:2]
82
+
83
+ xmin, xmax = np.amin(quad[:,0]), np.amax(quad[:,0])
84
+ ymin, ymax = np.amin(quad[:,1]), np.amax(quad[:,1])
85
+ quad_size = int(max(xmax-xmin, ymax-ymin)+0.5)
86
+
87
+ if not resize:
88
+ transform_size = output_size = quad_size
89
+
90
+
91
+ # Transform.
92
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
93
+ if output_size < transform_size:
94
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
95
+
96
+ # Save aligned image.
97
+ os.makedirs(os.path.dirname(dst_file), exist_ok=True)
98
+ img.save(dst_file, 'PNG')
99
+ return quad_size
Time_TravelRephotography/utils/ffhq_dataset/landmarks_detector.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dlib
2
+ import cv2
3
+
4
+
5
+ class LandmarksDetector:
6
+ def __init__(self, predictor_model_path):
7
+ """
8
+ :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
9
+ """
10
+ self.detector = dlib.get_frontal_face_detector() # cnn_face_detection_model_v1 also can be used
11
+ self.shape_predictor = dlib.shape_predictor(predictor_model_path)
12
+
13
+ def get_landmarks(self, image):
14
+ img = dlib.load_rgb_image(image)
15
+ dets = self.detector(img, 1)
16
+ #print('face bounding boxes', dets)
17
+
18
+ for detection in dets:
19
+ face_landmarks = [(item.x, item.y) for item in self.shape_predictor(img, detection).parts()]
20
+ #print('face landmarks', face_landmarks)
21
+ yield face_landmarks
22
+
23
+ def draw(img, landmarks):
24
+ for (x, y) in landmarks:
25
+ cv2.circle(img, (x, y), 1, (0, 0, 255), -1)
26
+ return img
27
+
28
+
29
+ class DNNLandmarksDetector:
30
+ def __init__(self, predictor_model_path, DNN='TF'):
31
+ """
32
+ :param
33
+ DNN: "TF" or "CAFFE"
34
+ predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
35
+ """
36
+ if DNN == "CAFFE":
37
+ modelFile = "res10_300x300_ssd_iter_140000_fp16.caffemodel"
38
+ configFile = "deploy.prototxt"
39
+ net = cv2.dnn.readNetFromCaffe(configFile, modelFile)
40
+ else:
41
+ modelFile = "opencv_face_detector_uint8.pb"
42
+ configFile = "opencv_face_detector.pbtxt"
43
+ net = cv2.dnn.readNetFromTensorflow(modelFile, configFile)
44
+
45
+ self.shape_predictor = dlib.shape_predictor(predictor_model_path)
46
+
47
+ def detect_faces(self, image, conf_threshold=0):
48
+ H, W = image.shape[:2]
49
+ blob = cv2.dnn.blobFromImage(image, 1.0, (300, 300), [104, 117, 123], False, False)
50
+ net.setInput(blob)
51
+ detections = net.forward()
52
+ bboxes = []
53
+ for i in range(detections.shape[2]):
54
+ confidence = detections[0, 0, i, 2]
55
+ if confidence > conf_threshold:
56
+ x1 = int(detections[0, 0, i, 3] * W)
57
+ y1 = int(detections[0, 0, i, 4] * H)
58
+ x2 = int(detections[0, 0, i, 5] * W)
59
+ y2 = int(detections[0, 0, i, 6] * H)
60
+ bboxes.append(dlib.rectangle(x1, y1, x2, y2))
61
+ return bboxes
62
+
63
+ def get_landmarks(self, image):
64
+ img = cv2.imread(image)
65
+ dets = self.detect_faces(img, 0)
66
+ print('face bounding boxes', dets)
67
+
68
+ for detection in dets:
69
+ face_landmarks = [(item.x, item.y) for item in self.shape_predictor(img, detection).parts()]
70
+ print('face landmarks', face_landmarks)
71
+ yield face_landmarks
Time_TravelRephotography/utils/misc.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Iterable
3
+
4
+
5
+ def optional_string(condition: bool, string: str):
6
+ return string if condition else ""
7
+
8
+
9
+ def parent_dir(path: str) -> str:
10
+ return os.path.basename(os.path.dirname(path))
11
+
12
+
13
+ def stem(path: str) -> str:
14
+ return os.path.splitext(os.path.basename(path))[0]
15
+
16
+
17
+ def iterable_to_str(iterable: Iterable) -> str:
18
+ return ','.join([str(x) for x in iterable])
Time_TravelRephotography/utils/optimize.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from argparse import (
3
+ ArgumentParser,
4
+ Namespace,
5
+ )
6
+ from typing import (
7
+ Dict,
8
+ Iterable,
9
+ Optional,
10
+ Tuple,
11
+ )
12
+
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ import torch
16
+ from torch import nn
17
+ import torch.nn.functional as F
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ from torchvision.utils import make_grid
20
+ from torchvision.transforms import Resize
21
+
22
+ #from optim import get_optimizer_class, OPTIMIZER_MAP
23
+ from losses.regularize_noise import NoiseRegularizer
24
+ from optim import RAdam
25
+ from utils.misc import (
26
+ iterable_to_str,
27
+ optional_string,
28
+ )
29
+
30
+
31
+ class OptimizerArguments:
32
+ @staticmethod
33
+ def add_arguments(parser: ArgumentParser):
34
+ parser.add_argument('--coarse_min', type=int, default=32)
35
+ parser.add_argument('--wplus_step', type=int, nargs="+", default=[250, 750], help="#step for optimizing w_plus")
36
+ #parser.add_argument('--lr_rampup', type=float, default=0.05)
37
+ #parser.add_argument('--lr_rampdown', type=float, default=0.25)
38
+ parser.add_argument('--lr', type=float, default=0.1)
39
+ parser.add_argument('--noise_strength', type=float, default=.0)
40
+ parser.add_argument('--noise_ramp', type=float, default=0.75)
41
+ #parser.add_argument('--optimize_noise', action="store_true")
42
+ parser.add_argument('--camera_lr', type=float, default=0.01)
43
+
44
+ parser.add_argument("--log_dir", default="log/projector", help="tensorboard log directory")
45
+ parser.add_argument("--log_freq", type=int, default=10, help="log frequency")
46
+ parser.add_argument("--log_visual_freq", type=int, default=50, help="log frequency")
47
+
48
+ @staticmethod
49
+ def to_string(args: Namespace) -> str:
50
+ return (
51
+ f"lr{args.lr}_{args.camera_lr}-c{args.coarse_min}"
52
+ + f"-wp({iterable_to_str(args.wplus_step)})"
53
+ + optional_string(args.noise_strength, f"-n{args.noise_strength}")
54
+ )
55
+
56
+
57
+ class LatentNoiser(nn.Module):
58
+ def __init__(
59
+ self, generator: torch.nn,
60
+ noise_ramp: float = 0.75, noise_strength: float = 0.05,
61
+ n_mean_latent: int = 10000
62
+ ):
63
+ super().__init__()
64
+
65
+ self.noise_ramp = noise_ramp
66
+ self.noise_strength = noise_strength
67
+
68
+ with torch.no_grad():
69
+ # TODO: get 512 from generator
70
+ noise_sample = torch.randn(n_mean_latent, 512, device=generator.device)
71
+ latent_out = generator.style(noise_sample)
72
+
73
+ latent_mean = latent_out.mean(0)
74
+ self.latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
75
+
76
+ def forward(self, latent: torch.Tensor, t: float) -> torch.Tensor:
77
+ strength = self.latent_std * self.noise_strength * max(0, 1 - t / self.noise_ramp) ** 2
78
+ noise = torch.randn_like(latent) * strength
79
+ return latent + noise
80
+
81
+
82
+ class Optimizer:
83
+ @classmethod
84
+ def optimize(
85
+ cls,
86
+ generator: torch.nn,
87
+ criterion: torch.nn,
88
+ degrade: torch.nn,
89
+ target: torch.Tensor, # only used in writer since it's mostly baked in criterion
90
+ latent_init: torch.Tensor,
91
+ noise_init: torch.Tensor,
92
+ args: Namespace,
93
+ writer: Optional[SummaryWriter] = None,
94
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
95
+ # do not optimize generator
96
+ generator = generator.eval()
97
+ target = target.detach()
98
+ # prepare parameters
99
+ noises = []
100
+ for n in noise_init:
101
+ noise = n.detach().clone()
102
+ noise.requires_grad = True
103
+ noises.append(noise)
104
+
105
+
106
+ def create_parameters(latent_coarse):
107
+ parameters = [
108
+ {'params': [latent_coarse], 'lr': args.lr},
109
+ {'params': noises, 'lr': args.lr},
110
+ {'params': degrade.parameters(), 'lr': args.camera_lr},
111
+ ]
112
+ return parameters
113
+
114
+
115
+ device = target.device
116
+
117
+ # start optimize
118
+ total_steps = np.sum(args.wplus_step)
119
+ max_coarse_size = (2 ** (len(args.wplus_step) - 1)) * args.coarse_min
120
+ noiser = LatentNoiser(generator, noise_ramp=args.noise_ramp, noise_strength=args.noise_strength).to(device)
121
+ latent = latent_init.detach().clone()
122
+ for coarse_level, steps in enumerate(args.wplus_step):
123
+ if criterion.weights["contextual"] > 0:
124
+ with torch.no_grad():
125
+ # synthesize new sibling image using the current optimization results
126
+ # FIXME: update rgbs sibling
127
+ sibling, _, _ = generator([latent], input_is_latent=True, randomize_noise=True)
128
+ criterion.update_sibling(sibling)
129
+
130
+ coarse_size = (2 ** coarse_level) * args.coarse_min
131
+ latent_coarse, latent_fine = cls.split_latent(
132
+ latent, generator.get_latent_size(coarse_size))
133
+ parameters = create_parameters(latent_coarse)
134
+ optimizer = RAdam(parameters)
135
+
136
+ print(f"Optimizing {coarse_size}x{coarse_size}")
137
+ pbar = tqdm(range(steps))
138
+ for si in pbar:
139
+ latent = torch.cat((latent_coarse, latent_fine), dim=1)
140
+ niters = si + np.sum(args.wplus_step[:coarse_level])
141
+ latent_noisy = noiser(latent, niters / total_steps)
142
+ img_gen, _, rgbs = generator([latent_noisy], input_is_latent=True, noise=noises)
143
+ # TODO: use coarse_size instead of args.coarse_size for rgb_level
144
+ loss, losses = criterion(img_gen, degrade=degrade, noises=noises, rgbs=rgbs)
145
+
146
+ optimizer.zero_grad()
147
+ loss.backward()
148
+ optimizer.step()
149
+
150
+ NoiseRegularizer.normalize(noises)
151
+
152
+ # log
153
+ pbar.set_description("; ".join([f"{k}: {v.item(): .3e}" for k, v in losses.items()]))
154
+
155
+ if writer is not None and niters % args.log_freq == 0:
156
+ cls.log_losses(writer, niters, loss, losses, criterion.weights)
157
+ cls.log_parameters(writer, niters, degrade.named_parameters())
158
+ if writer is not None and niters % args.log_visual_freq == 0:
159
+ cls.log_visuals(writer, niters, img_gen, target, degraded=degrade(img_gen), rgbs=rgbs)
160
+
161
+ latent = torch.cat((latent_coarse, latent_fine), dim=1).detach()
162
+
163
+ return latent, noises
164
+
165
+ @staticmethod
166
+ def split_latent(latent: torch.Tensor, coarse_latent_size: int):
167
+ latent_coarse = latent[:, :coarse_latent_size]
168
+ latent_coarse.requires_grad = True
169
+ latent_fine = latent[:, coarse_latent_size:]
170
+ latent_fine.requires_grad = False
171
+ return latent_coarse, latent_fine
172
+
173
+ @staticmethod
174
+ def log_losses(
175
+ writer: SummaryWriter,
176
+ niters: int,
177
+ loss_total: torch.Tensor,
178
+ losses: Dict[str, torch.Tensor],
179
+ weights: Optional[Dict[str, torch.Tensor]] = None
180
+ ):
181
+ writer.add_scalar("loss", loss_total.item(), niters)
182
+
183
+ for name, loss in losses.items():
184
+ writer.add_scalar(name, loss.item(), niters)
185
+ if weights is not None:
186
+ writer.add_scalar(f"weighted_{name}", weights[name] * loss.item(), niters)
187
+
188
+ @staticmethod
189
+ def log_parameters(
190
+ writer: SummaryWriter,
191
+ niters: int,
192
+ named_parameters: Iterable[Tuple[str, torch.nn.Parameter]],
193
+ ):
194
+ for name, para in named_parameters:
195
+ writer.add_scalar(name, para.item(), niters)
196
+
197
+ @classmethod
198
+ def log_visuals(
199
+ cls,
200
+ writer: SummaryWriter,
201
+ niters: int,
202
+ img: torch.Tensor,
203
+ target: torch.Tensor,
204
+ degraded=None,
205
+ rgbs=None,
206
+ ):
207
+ if target.shape[-1] != img.shape[-1]:
208
+ visual = make_grid(img, nrow=1, normalize=True, range=(-1, 1))
209
+ writer.add_image("pred", visual, niters)
210
+
211
+ def resize(img):
212
+ return F.interpolate(img, size=target.shape[2:], mode="area")
213
+
214
+ vis = resize(img)
215
+ if degraded is not None:
216
+ vis = torch.cat((resize(degraded), vis), dim=-1)
217
+ visual = make_grid(torch.cat((target.repeat(1, vis.shape[1] // target.shape[1], 1, 1), vis), dim=-1), nrow=1, normalize=True, range=(-1, 1))
218
+ writer.add_image("gnd[-degraded]-pred", visual, niters)
219
+
220
+ # log to rgbs
221
+ if rgbs is not None:
222
+ cls.log_torgbs(writer, niters, rgbs)
223
+
224
+ @staticmethod
225
+ def log_torgbs(writer: SummaryWriter, niters: int, rgbs: Iterable[torch.Tensor], prefix: str = ""):
226
+ for ri, rgb in enumerate(rgbs):
227
+ scale = 2 ** (-(len(rgbs) - ri))
228
+ visual = make_grid(torch.cat((rgb, rgb / scale), dim=-1), nrow=1, normalize=True, range=(-1, 1))
229
+ writer.add_image(f"{prefix}to_rbg_{2 ** (ri + 2)}", visual, niters)
230
+
Time_TravelRephotography/utils/projector_arguments.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import (
3
+ ArgumentParser,
4
+ Namespace,
5
+ )
6
+
7
+ from models.degrade import DegradeArguments
8
+ from tools.initialize import InitializerArguments
9
+ from losses.joint_loss import LossArguments
10
+ from utils.optimize import OptimizerArguments
11
+ from .misc import (
12
+ optional_string,
13
+ iterable_to_str,
14
+ )
15
+
16
+
17
+ class ProjectorArguments:
18
+ def __init__(self):
19
+ parser = ArgumentParser("Project image into stylegan2")
20
+ self.add_arguments(parser)
21
+ self.parser = parser
22
+
23
+ @classmethod
24
+ def add_arguments(cls, parser: ArgumentParser):
25
+ parser.add_argument('--rand_seed', type=int, default=None,
26
+ help="random seed")
27
+ cls.add_io_args(parser)
28
+ cls.add_preprocess_args(parser)
29
+ cls.add_stylegan_args(parser)
30
+
31
+ InitializerArguments.add_arguments(parser)
32
+ LossArguments.add_arguments(parser)
33
+ OptimizerArguments.add_arguments(parser)
34
+ DegradeArguments.add_arguments(parser)
35
+
36
+ @staticmethod
37
+ def add_stylegan_args(parser: ArgumentParser):
38
+ parser.add_argument('--ckpt', type=str, default="checkpoint/stylegan2-ffhq-config-f.pt",
39
+ help="stylegan2 checkpoint")
40
+ parser.add_argument('--generator_size', type=int, default=1024,
41
+ help="output size of the generator")
42
+
43
+ @staticmethod
44
+ def add_io_args(parser: ArgumentParser) -> ArgumentParser:
45
+ parser.add_argument('input', type=str, help="input image path")
46
+ parser.add_argument('--results_dir', default="results/projector", help="directory to save results.")
47
+
48
+ @staticmethod
49
+ def add_preprocess_args(parser: ArgumentParser):
50
+ # parser.add_argument("--match_histogram", action='store_true', help="match the histogram of the input image to the sibling")
51
+ pass
52
+
53
+ def parse(self, args=None, namespace=None) -> Namespace:
54
+ args = self.parser.parse_args(args, namespace=namespace)
55
+ self.print(args)
56
+ return args
57
+
58
+ @staticmethod
59
+ def print(args: Namespace):
60
+ print("------------ Parameters -------------")
61
+ args = vars(args)
62
+ for k, v in sorted(args.items()):
63
+ print(f"{k}: {v}")
64
+ print("-------------------------------------")
65
+
66
+ @staticmethod
67
+ def to_string(args: Namespace) -> str:
68
+ return "-".join([
69
+ #+ optional_string(args.no_camera_response, "-noCR")
70
+ #+ optional_string(args.match_histogram, "-MH")
71
+ DegradeArguments.to_string(args),
72
+ InitializerArguments.to_string(args),
73
+ LossArguments.to_string(args),
74
+ OptimizerArguments.to_string(args),
75
+ ]) + optional_string(args.rand_seed is not None, f"-S{args.rand_seed}")
76
+
Time_TravelRephotography/utils/torch_helpers.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def device(gpu_id=0):
6
+ if torch.cuda.is_available():
7
+ return torch.device(f"cuda:{gpu_id}")
8
+ return torch.device("cpu")
9
+
10
+
11
+ def load_matching_state_dict(model: nn.Module, state_dict):
12
+ model_dict = model.state_dict()
13
+ filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict}
14
+ model.load_state_dict(filtered_dict)
15
+
16
+
17
+ def resize(t: torch.Tensor, size: int) -> torch.Tensor:
18
+ B, C, H, W = t.shape
19
+ t = t.reshape(B, C, size, H // size, size, W // size)
20
+ return t.mean([3, 5])
21
+
22
+
23
+ def make_image(tensor):
24
+ return (
25
+ tensor.detach()
26
+ .clamp_(min=-1, max=1)
27
+ .add(1)
28
+ .div_(2)
29
+ .mul(255)
30
+ .type(torch.uint8)
31
+ .permute(0, 2, 3, 1)
32
+ .to('cpu')
33
+ .numpy()
34
+ )
35
+
36
+