haakohu commited on
Commit
97a6728
·
1 Parent(s): 6c1a5e0
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +3 -3
  2. deep_privacy/.gitignore +54 -0
  3. deep_privacy/CHANGELOG.md +13 -0
  4. deep_privacy/Dockerfile +47 -0
  5. deep_privacy/LICENSE +201 -0
  6. deep_privacy/anonymize.py +255 -0
  7. deep_privacy/attribute_guided_demo.py +144 -0
  8. deep_privacy/configs/anonymizers/FB_cse.py +28 -0
  9. deep_privacy/configs/anonymizers/FB_cse_mask.py +29 -0
  10. deep_privacy/configs/anonymizers/FB_cse_mask_face.py +29 -0
  11. deep_privacy/configs/anonymizers/deep_privacy1.py +15 -0
  12. deep_privacy/configs/anonymizers/face.py +17 -0
  13. deep_privacy/configs/anonymizers/face_fdf128.py +18 -0
  14. deep_privacy/configs/anonymizers/market1501/blackout.py +8 -0
  15. deep_privacy/configs/anonymizers/market1501/person.py +6 -0
  16. deep_privacy/configs/anonymizers/market1501/pixelation16.py +8 -0
  17. deep_privacy/configs/anonymizers/market1501/pixelation8.py +8 -0
  18. deep_privacy/configs/datasets/coco_cse.py +69 -0
  19. deep_privacy/configs/datasets/fdf128.py +24 -0
  20. deep_privacy/configs/datasets/fdf256.py +55 -0
  21. deep_privacy/configs/datasets/fdh.py +90 -0
  22. deep_privacy/configs/datasets/utils.py +21 -0
  23. deep_privacy/configs/defaults.py +53 -0
  24. deep_privacy/configs/discriminators/sg2_discriminator.py +43 -0
  25. deep_privacy/configs/fdf/deep_privacy1.py +9 -0
  26. deep_privacy/configs/fdf/stylegan.py +14 -0
  27. deep_privacy/configs/fdf/stylegan_fdf128.py +17 -0
  28. deep_privacy/configs/fdh/styleganL.py +16 -0
  29. deep_privacy/configs/fdh/styleganL_nocse.py +14 -0
  30. deep_privacy/configs/generators/stylegan_unet.py +22 -0
  31. deep_privacy/dp2/__init__.py +0 -0
  32. deep_privacy/dp2/anonymizer/__init__.py +1 -0
  33. deep_privacy/dp2/anonymizer/anonymizer.py +163 -0
  34. deep_privacy/dp2/anonymizer/histogram_match_anonymizers.py +93 -0
  35. deep_privacy/dp2/data/__init__.py +0 -0
  36. deep_privacy/dp2/data/build.py +40 -0
  37. deep_privacy/dp2/data/datasets/__init__.py +0 -0
  38. deep_privacy/dp2/data/datasets/coco_cse.py +68 -0
  39. deep_privacy/dp2/data/datasets/fdf.py +128 -0
  40. deep_privacy/dp2/data/datasets/fdf128_wds.py +96 -0
  41. deep_privacy/dp2/data/datasets/fdh.py +142 -0
  42. deep_privacy/dp2/data/transforms/__init__.py +2 -0
  43. deep_privacy/dp2/data/transforms/functional.py +57 -0
  44. deep_privacy/dp2/data/transforms/stylegan2_transform.py +394 -0
  45. deep_privacy/dp2/data/transforms/transforms.py +277 -0
  46. deep_privacy/dp2/data/utils.py +122 -0
  47. deep_privacy/dp2/detection/__init__.py +3 -0
  48. deep_privacy/dp2/detection/base.py +42 -0
  49. deep_privacy/dp2/detection/box_utils.py +104 -0
  50. deep_privacy/dp2/detection/box_utils_fdf.py +202 -0
app.py CHANGED
@@ -9,15 +9,15 @@ os.system("pip install ftfy regex tqdm")
9
  os.system("pip install --no-deps git+https://github.com/openai/CLIP.git")
10
  os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
11
  os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
12
- sys.path.insert(0, Path(os.getcwd(), "deep_privacy2"))
13
  os.environ["TORCH_HOME"] = "torch_home"
14
  from dp2 import utils
15
  from gradio_demos.modules import ExampleDemo, WebcamDemo
16
 
17
- cfg_face = utils.load_config("deep_privacy2/configs/anonymizers/face.py")
18
  for key in ["person_G_cfg", "cse_person_G_cfg", "face_G_cfg", "car_G_cfg"]:
19
  if key in cfg_face.anonymizer:
20
- cfg_face.anonymizer[key] = Path("deep_privacy2", cfg_face.anonymizer[key])
21
 
22
 
23
  anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
 
9
  os.system("pip install --no-deps git+https://github.com/openai/CLIP.git")
10
  os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
11
  os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
12
+ sys.path.insert(0, Path(os.getcwd(), "deep_privacy"))
13
  os.environ["TORCH_HOME"] = "torch_home"
14
  from dp2 import utils
15
  from gradio_demos.modules import ExampleDemo, WebcamDemo
16
 
17
+ cfg_face = utils.load_config("deep_privacy/configs/anonymizers/face.py")
18
  for key in ["person_G_cfg", "cse_person_G_cfg", "face_G_cfg", "car_G_cfg"]:
19
  if key in cfg_face.anonymizer:
20
+ cfg_face.anonymizer[key] = Path("deep_privacy", cfg_face.anonymizer[key])
21
 
22
 
23
  anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
deep_privacy/.gitignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FILES
2
+ *.yaml
3
+ *.pkl
4
+ *.flist
5
+ *.zip
6
+ *.out
7
+ *.npy
8
+ *.gz
9
+ *.ckpt
10
+ *.pth
11
+ *.log
12
+ *.pyc
13
+ *.csv
14
+ *.yml
15
+ *.ods
16
+ *.ods#
17
+ *.json
18
+ build_docker.sh
19
+
20
+ # Images / Videos
21
+ #*.png
22
+ #*.jpg
23
+ *.jpeg
24
+ *.m4a
25
+ *.mkv
26
+ *.mp4
27
+
28
+ # Directories created by inpaintron
29
+ .cache/
30
+ test_examples/
31
+ .vscode
32
+ __pycache__
33
+ .debug/
34
+ **/.ipynb_checkpoints/**
35
+ outputs/
36
+
37
+
38
+ # From pip setup
39
+ build/
40
+ *.egg-info
41
+ *.egg
42
+ .npm/
43
+
44
+ # From dockerfile
45
+ .bash_history
46
+ .viminfo
47
+ .local/
48
+ *.pickle
49
+ *.onnx
50
+
51
+
52
+ sbatch_files/
53
+ figures/
54
+ image_dump/
deep_privacy/CHANGELOG.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ ## 23.03.2023
4
+ - Quality of life improvements
5
+ - Add support for refined keypoints for the FDH dataset.
6
+ - Add FDF128 dataset loader with webdataset.
7
+ - Support for using detector and anonymizer from DeepPrivacy1.
8
+ - Update visualization of keypoints
9
+ - Fix bug for upsampling/downsampling in the anonymization pipeline.
10
+ - Support for keypoint-guided face anonymization.
11
+ - Add ViTPose + Mask-RCNN detection model for keypoint-guided full-body anonymization.
12
+ - Set caching of detections to False as default, as it can produce unexpected behaviour. For example, using a different score threshold requires re-run of detector.
13
+ - Add Gradio Demos for face and body anonymization
deep_privacy/Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/pytorch:22.08-py3
2
+ ARG UID=1000
3
+ ARG UNAME=testuser
4
+ ARG WANDB_API_KEY
5
+ RUN useradd -ms /bin/bash -u $UID $UNAME && \
6
+ mkdir -p /home/${UNAME} &&\
7
+ chown -R $UID /home/${UNAME}
8
+ WORKDIR /home/${UNAME}
9
+ ENV DEBIAN_FRONTEND="noninteractive"
10
+ ENV WANDB_API_KEY=$WANDB_API_KEY
11
+ ENV TORCH_HOME=/home/${UNAME}/.cache
12
+
13
+ # OPTIONAL - DeepPrivacy2 uses these environment variables to set directories outside the current working directory
14
+ #ENV BASE_DATASET_DIR=/work/haakohu/datasets
15
+ #ENV BASE_OUTPUT_DIR=/work/haakohu/outputs
16
+ #ENV FBA_METRICS_CACHE=/work/haakohu/metrics_cache
17
+
18
+ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 qt5-default -y
19
+ RUN pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose
20
+ COPY setup.py setup.py
21
+ RUN pip install \
22
+ numpy>=1.20 \
23
+ matplotlib \
24
+ cython \
25
+ tensorboard \
26
+ tqdm \
27
+ ninja==1.10.2 \
28
+ opencv-python==4.5.5.64 \
29
+ moviepy \
30
+ pyspng \
31
+ git+https://github.com/hukkelas/DSFD-Pytorch-Inference \
32
+ wandb \
33
+ termcolor \
34
+ git+https://github.com/hukkelas/torch_ops.git \
35
+ git+https://github.com/wmuron/motpy@c77f85d27e371c0a298e9a88ca99292d9b9cbe6b \
36
+ fast_pytorch_kmeans \
37
+ einops_exts \
38
+ einops \
39
+ regex \
40
+ setuptools==59.5.0 \
41
+ resize_right==0.0.2 \
42
+ pillow \
43
+ scipy==1.7.1 \
44
+ webdataset==0.2.26 \
45
+ scikit-image \
46
+ timm==0.6.7
47
+ RUN pip install --no-deps torch_fidelity==0.3.0 clip@git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620
deep_privacy/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
deep_privacy/anonymize.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from typing import Optional
3
+ import click
4
+ import tops
5
+ import numpy as np
6
+ import tqdm
7
+ import moviepy.editor as mp
8
+ import cv2
9
+ from tops.config import instantiate
10
+ from pathlib import Path
11
+ from PIL import Image
12
+ from dp2 import utils
13
+ from detectron2.data.detection_utils import _apply_exif_orientation
14
+ from tops import logger
15
+ from dp2.utils.bufferless_video_capture import BufferlessVideoCapture
16
+
17
+
18
+ def show_video(video_path):
19
+ video_cap = cv2.VideoCapture(str(video_path))
20
+ while video_cap.isOpened():
21
+ ret, frame = video_cap.read()
22
+ cv2.imshow("Frame", frame)
23
+ key = cv2.waitKey(25)
24
+ if key == ord("q"):
25
+ break
26
+ video_cap.release()
27
+ cv2.destroyAllWindows()
28
+
29
+
30
+ class ImageIndexTracker:
31
+
32
+ def __init__(self, fn) -> None:
33
+ self.fn = fn
34
+ self.idx = 0
35
+
36
+ def fl_image(self, frame):
37
+ self.idx += 1
38
+ return self.fn(frame, self.idx-1)
39
+
40
+
41
+ def anonymize_video(
42
+ video_path, output_path: Path,
43
+ anonymizer, visualize: bool, max_res: int,
44
+ start_time: int, fps: int,
45
+ end_time: int,
46
+ visualize_detection: bool,
47
+ track: bool,
48
+ synthesis_kwargs,
49
+ **kwargs):
50
+ video = mp.VideoFileClip(str(video_path))
51
+ if track:
52
+ anonymizer.initialize_tracker(video.fps)
53
+
54
+ def process_frame(frame, idx):
55
+ frame = np.array(resize(Image.fromarray(frame), max_res))
56
+ cache_id = hashlib.md5(frame).hexdigest()
57
+ frame = utils.im2torch(frame, to_float=False, normalize=False)[0]
58
+ cache_id_ = cache_id + str(idx)
59
+ synthesis_kwargs["cache_id"] = cache_id_
60
+ if visualize_detection:
61
+ anonymized = anonymizer.visualize_detection(frame, cache_id=cache_id_)
62
+ else:
63
+ anonymized = anonymizer(frame, **synthesis_kwargs)
64
+ anonymized = utils.im2numpy(anonymized)
65
+ if visualize:
66
+ cv2.imshow("frame", anonymized[:, :, ::-1])
67
+ key = cv2.waitKey(1)
68
+ if key == ord("q"):
69
+ exit()
70
+ return anonymized
71
+ video: mp.VideoClip = video.subclip(start_time, end_time)
72
+
73
+ if fps is not None:
74
+ video = video.set_fps(fps)
75
+
76
+ video = video.fl_image(ImageIndexTracker(process_frame).fl_image)
77
+ if str(output_path).endswith(".avi"):
78
+ output_path = str(output_path).replace(".avi", ".mp4")
79
+ if not output_path.parent.exists():
80
+ output_path.parent.mkdir(parents=True)
81
+ video.write_videofile(str(output_path))
82
+
83
+
84
+ def resize(frame: Image.Image, max_res):
85
+ if max_res is None:
86
+ return frame
87
+ f = max(*[x/max_res for x in frame.size], 1)
88
+ if f == 1:
89
+ return frame
90
+ new_shape = [int(x/f) for x in frame.size]
91
+ return frame.resize(new_shape, resample=Image.BILINEAR)
92
+
93
+
94
+ def anonymize_image(
95
+ image_path, output_path: Path, visualize: bool,
96
+ anonymizer, max_res: int,
97
+ visualize_detection: bool,
98
+ synthesis_kwargs,
99
+ **kwargs):
100
+ with Image.open(image_path) as im:
101
+ im = _apply_exif_orientation(im)
102
+ orig_im_mode = im.mode
103
+
104
+ im = im.convert("RGB")
105
+ im = resize(im, max_res)
106
+ im = np.array(im)
107
+ md5_ = hashlib.md5(im).hexdigest()
108
+ im = utils.im2torch(np.array(im), to_float=False, normalize=False)[0]
109
+ synthesis_kwargs["cache_id"] = md5_
110
+ if visualize_detection:
111
+ im_ = anonymizer.visualize_detection(tops.to_cuda(im), cache_id=md5_)
112
+ else:
113
+ im_ = anonymizer(im, **synthesis_kwargs)
114
+ im_ = utils.im2numpy(im_)
115
+ if visualize:
116
+ while True:
117
+ cv2.imshow("frame", im_[:, :, ::-1])
118
+ key = cv2.waitKey(0)
119
+ if key == ord("q"):
120
+ break
121
+ elif key == ord("u"):
122
+ im_ = utils.im2numpy(anonymizer(im, **synthesis_kwargs))
123
+ im = Image.fromarray(im_).convert(orig_im_mode)
124
+ if output_path is not None:
125
+ output_path.parent.mkdir(exist_ok=True, parents=True)
126
+ im.save(output_path, optimize=False, quality=100)
127
+ print(f"Saved to: {output_path}")
128
+
129
+
130
+ def anonymize_file(input_path: Path, output_path: Optional[Path], **kwargs):
131
+ if output_path is not None and output_path.is_file():
132
+ logger.warn(f"Overwriting previous file: {output_path}")
133
+ if tops.is_image(input_path):
134
+ anonymize_image(input_path, output_path, **kwargs)
135
+ elif tops.is_video(input_path):
136
+ anonymize_video(input_path, output_path, **kwargs)
137
+ else:
138
+ logger.log(f"Filepath not a video or image file: {input_path}")
139
+
140
+
141
+ def anonymize_directory(input_dir: Path, output_dir: Path, **kwargs):
142
+ for childname in tqdm.tqdm(input_dir.iterdir()):
143
+ childpath = input_dir.joinpath(childname.name)
144
+ output_path = output_dir.joinpath(childname.name)
145
+ if not childpath.is_file():
146
+ anonymize_directory(childpath, output_path, **kwargs)
147
+ else:
148
+ assert childpath.is_file()
149
+ anonymize_file(childpath, output_path, **kwargs)
150
+
151
+ def anonymize_webcam(
152
+ anonymizer, max_res: int,
153
+ synthesis_kwargs,
154
+ visualize_detection,
155
+ track: bool,
156
+ **kwargs):
157
+ import time
158
+ cap = BufferlessVideoCapture(0, width=1920, height=1080)
159
+ t = time.time()
160
+ frames = 0
161
+ if track:
162
+ anonymizer.initialize_tracker(fps=5) # FPS used for tracking objects
163
+ while True:
164
+ # Capture frame-by-frame
165
+ ret, frame = cap.read()
166
+ frame = Image.fromarray(frame[:, :, ::-1])
167
+ frame = resize(frame, max_res)
168
+ frame = np.array(frame)
169
+ im = utils.im2torch(np.array(frame), to_float=False, normalize=False)[0]
170
+ if visualize_detection:
171
+ im_ = anonymizer.visualize_detection(tops.to_cuda(im))
172
+ else:
173
+ im_ = anonymizer(im, **synthesis_kwargs)
174
+ im_ = utils.im2numpy(im_)
175
+
176
+ frames += 1
177
+ delta = time.time() - t
178
+ fps = "?"
179
+ if delta > 1e-6:
180
+ fps = frames / delta
181
+ print(f"FPS: {fps:.3f}", end="\r")
182
+ cv2.imshow('frame', im_[:, :, ::-1])
183
+ if cv2.waitKey(1) & 0xFF == ord('q'):
184
+ break
185
+
186
+
187
+ @click.command()
188
+ @click.argument("config_path", type=click.Path(exists=True))
189
+ @click.option("-i", "--input_path", help="Input path. Accepted inputs: images, videos, directories.")
190
+ @click.option("-o", "--output_path", default=None, type=click.Path(), help="Output path to save. Can be directory or file.")
191
+ @click.option("-v","--visualize", default=False, is_flag=True, help="Visualize the result")
192
+ @click.option("--max-res", default=None, type=int, help="Maximum resolution of height/wideo")
193
+ @click.option("--start-time", "--st", default=0, type=int, help="Start time (second) for vide anonymization")
194
+ @click.option("--end-time", "--et", default=None, type=int, help="End time (second) for vide anonymization")
195
+ @click.option("--fps", default=None, type=int, help="FPS for anonymization")
196
+ @click.option("--detection-score-threshold", "--dst", default=.3, type=click.FloatRange(0, 1), help="Detection threshold, threshold applied for all detection models.")
197
+ @click.option("--visualize-detection", "--vd",default=False, is_flag=True, help="Visualize only detections without running anonymization.")
198
+ @click.option("--multi-modal-truncation", "--mt", default=False, is_flag=True, help="Enable multi-modal truncation proposed by: https://arxiv.org/pdf/2202.12211.pdf")
199
+ @click.option("--cache", default=False, is_flag=True, help="Enable detection caching. Will save and load detections from cache.")
200
+ @click.option("--amp", default=True, is_flag=True, help="Use automatic mixed precision for generator forward pass")
201
+ @click.option("-t", "--truncation_value", default=0, type=click.FloatRange(0, 1), help="Latent interpolation truncation value.")
202
+ @click.option("--track", default=False, is_flag=True, help="Track detections over frames. Will use the same latent variable (z) for tracked identities.")
203
+ @click.option("--seed", default=0, type=int, help="Set random seed for generating images.")
204
+ @click.option("--person-generator", default=None, help="Config path to unconditional person generator", type=click.Path())
205
+ @click.option("--cse-person-generator", default=None, help="Config path to CSE-guided person generator", type=click.Path())
206
+ @click.option("--webcam", default=False, is_flag=True, help="Read image from webcam feed.")
207
+ def anonymize_path(
208
+ config_path,
209
+ input_path,
210
+ output_path,
211
+ detection_score_threshold: float,
212
+ visualize_detection: bool,
213
+ cache: bool,
214
+ seed: int,
215
+ person_generator: str,
216
+ cse_person_generator: str,
217
+ webcam: bool,
218
+ **kwargs):
219
+ """
220
+ config_path: Specify the path to the anonymization model to use.
221
+ """
222
+ tops.set_seed(seed)
223
+ cfg = utils.load_config(config_path)
224
+ if person_generator is not None:
225
+ cfg.anonymizer.person_G_cfg = person_generator
226
+ if cse_person_generator is not None:
227
+ cfg.anonymizer.cse_person_G_cfg = cse_person_generator
228
+ cfg.detector.score_threshold = detection_score_threshold
229
+ utils.print_config(cfg)
230
+
231
+ anonymizer = instantiate(cfg.anonymizer, load_cache=cache)
232
+ synthesis_kwargs = ["amp", "multi_modal_truncation", "truncation_value"]
233
+ synthesis_kwargs = {k: kwargs.pop(k) for k in synthesis_kwargs}
234
+
235
+ kwargs["anonymizer"] = anonymizer
236
+ kwargs["visualize_detection"] = visualize_detection
237
+ kwargs["synthesis_kwargs"] = synthesis_kwargs
238
+ if webcam:
239
+ anonymize_webcam(**kwargs)
240
+ return
241
+ input_path = Path(input_path)
242
+ output_path = Path(output_path) if output_path is not None else None
243
+ if output_path is None and not kwargs["visualize"]:
244
+ logger.log("Output path not set. Setting visualize to True")
245
+ kwargs["visualize"] = True
246
+ if input_path.is_dir():
247
+ assert output_path is None or not output_path.is_file()
248
+ anonymize_directory(input_path, output_path, **kwargs)
249
+ else:
250
+ anonymize_file(input_path, output_path, **kwargs)
251
+
252
+
253
+ if __name__ == "__main__":
254
+
255
+ anonymize_path()
deep_privacy/attribute_guided_demo.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import gradio
3
+ import numpy as np
4
+ import torch
5
+ import cv2
6
+ from PIL import Image
7
+ from dp2 import utils
8
+ from tops.config import instantiate
9
+ import tops
10
+ import gradio.inputs
11
+ from stylemc import get_and_cache_direction, get_styles
12
+
13
+
14
+ class GuidedDemo:
15
+ def __init__(self, face_anonymizer, cfg_face) -> None:
16
+ self.anonymizer = face_anonymizer
17
+ assert sum([x is not None for x in list(face_anonymizer.generators.values())]) == 1
18
+ self.generator = [x for x in list(face_anonymizer.generators.values()) if x is not None][0]
19
+ face_G_cfg = utils.load_config(cfg_face.anonymizer.face_G_cfg)
20
+ face_G_cfg.train.batch_size = 1
21
+ self.dl = instantiate(face_G_cfg.data.val.loader)
22
+ self.cache_dir = face_G_cfg.output_dir
23
+ self.precompute_edits()
24
+
25
+ def precompute_edits(self):
26
+ self.precomputed_edits = set()
27
+ for edit in self.precomputed_edits:
28
+ get_and_cache_direction(self.cache_dir, self.dl, self.generator, edit)
29
+ if self.cache_dir.joinpath("stylemc_cache").is_dir():
30
+ for path in self.cache_dir.joinpath("stylemc_cache").iterdir():
31
+ text_prompt = path.stem.replace("_", " ")
32
+ self.precomputed_edits.add(text_prompt)
33
+ print(text_prompt)
34
+ self.edits = defaultdict(defaultdict)
35
+
36
+ def anonymize(self, img, show_boxes: bool, current_box_idx: int, current_styles, current_boxes, update_identity, edits, cache_id=None):
37
+ if not isinstance(img, torch.Tensor):
38
+ img, cache_id = pil2torch(img)
39
+ img = tops.to_cuda(img)
40
+
41
+ current_box_idx = current_box_idx % len(current_boxes)
42
+ edited_styles = [s.clone() for s in current_styles]
43
+ for face_idx, face_edits in edits.items():
44
+ for prompt, strength in face_edits.items():
45
+ direction = get_and_cache_direction(self.cache_dir, self.dl, self.generator, prompt)
46
+ edited_styles[int(face_idx)] += direction * strength
47
+ update_identity[int(face_idx)] = True
48
+ assert img.dtype == torch.uint8
49
+ img = self.anonymizer(
50
+ img, truncation_value=0,
51
+ multi_modal_truncation=True, amp=True,
52
+ cache_id=cache_id,
53
+ all_styles=edited_styles,
54
+ update_identity=update_identity)
55
+ update_identity = [True for i in range(len(update_identity))]
56
+ img = utils.im2numpy(img)
57
+ if show_boxes:
58
+ x0, y0, x1, y1 = [int(_) for _ in current_boxes[int(current_box_idx)]]
59
+ img = cv2.rectangle(img, (x0, y0), (x1, y1), (255, 0, 0), 1)
60
+ return img, update_identity
61
+
62
+ def update_image(self, img, show_boxes):
63
+ img, cache_id = pil2torch(img)
64
+ img = tops.to_cuda(img)
65
+ det = self.anonymizer.detector.forward_and_cache(img, cache_id, load_cache=True)[0]
66
+ current_styles = []
67
+ for i in range(len(det)):
68
+ # Need to do forward pass to register all affine modules.
69
+ batch = det.get_crop(i, img)
70
+ batch["condition"] = batch["img"].float()
71
+
72
+ s = get_styles(
73
+ np.random.randint(0, 999999),self.generator,
74
+ batch, truncation_value=0)
75
+ current_styles.append(s)
76
+ update_identity = [True for i in range(len(det))]
77
+ current_boxes = np.array(det.boxes)
78
+ edits = defaultdict(defaultdict)
79
+ cur_face_idx = -1 % len(current_boxes)
80
+ img, update_identity = self.anonymize(img, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits, cache_id=cache_id)
81
+ return img, current_styles, current_boxes, update_identity, edits, cur_face_idx
82
+
83
+ def change_face(self, change, cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits):
84
+ cur_face_idx = (cur_face_idx+change) % len(current_boxes)
85
+ img, update_identity = self.anonymize(input_image, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits)
86
+ return img, update_identity, cur_face_idx
87
+
88
+ def add_style(self, face_idx: int, prompt: str, strength: float, input_image, show_boxes, current_styles, current_boxes, update_identity, edits):
89
+ face_idx = face_idx % len(current_boxes)
90
+ edits[face_idx][prompt] = strength
91
+ img, update_identity = self.anonymize(input_image, show_boxes, face_idx, current_styles, current_boxes, update_identity, edits)
92
+ return img, update_identity, edits
93
+
94
+ def setup_interface(self):
95
+ current_styles = gradio.State()
96
+ current_boxes = gradio.State(None)
97
+ update_identity = gradio.State([])
98
+ edits = gradio.State([])
99
+ with gradio.Row():
100
+ input_image = gradio.Image(
101
+ type="pil", label="Upload your image or try the example below!",source="webcam")
102
+ output_image = gradio.Image(type="numpy", label="Output")
103
+ with gradio.Row():
104
+ update_btn = gradio.Button("Update Anonymization").style(full_width=True)
105
+ with gradio.Row():
106
+ show_boxes = gradio.Checkbox(value=True, label="Show Selected")
107
+ cur_face_idx = gradio.Number(value=-1,label="Current", interactive=False)
108
+ previous = gradio.Button("Previous Person")
109
+ next_ = gradio.Button("Next Person")
110
+ with gradio.Row():
111
+ text_prompt = gradio.Textbox(
112
+ placeholder=" | ".join(list(self.precomputed_edits)),
113
+ label="Text Prompt for Edit")
114
+ edit_strength = gradio.Slider(0, 5, step=.01)
115
+ add_btn = gradio.Button("Add Edit")
116
+ add_btn.click(self.add_style, inputs=[cur_face_idx, text_prompt, edit_strength, input_image, show_boxes, current_styles, current_boxes, update_identity, edits], outputs=[output_image, update_identity, edits])
117
+ update_btn.click(self.update_image, inputs=[input_image, show_boxes], outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx])
118
+ input_image.change(self.update_image, inputs=[input_image, show_boxes], outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx])
119
+ previous.click(self.change_face, inputs=[gradio.State(-1), cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits], outputs=[output_image, update_identity, cur_face_idx])
120
+ next_.click(self.change_face, inputs=[gradio.State(1), cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits], outputs=[output_image, update_identity, cur_face_idx])
121
+
122
+ show_boxes.change(self.anonymize, inputs=[input_image, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits], outputs=[output_image, update_identity])
123
+
124
+
125
+ def pil2torch(img: Image.Image):
126
+ img = img.convert("RGB")
127
+ img = np.array(img)
128
+ img = np.rollaxis(img, 2)
129
+ return torch.from_numpy(img), None
130
+
131
+
132
+ cfg_face = utils.load_config("configs/anonymizers/face.py")
133
+ anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
134
+ anonymizer_face.initialize_tracker(fps=1)
135
+
136
+
137
+ with gradio.Blocks() as demo:
138
+ gradio.Markdown("# <center> DeepPrivacy2 - Realistic Image Anonymization </center>")
139
+ gradio.Markdown("### <center> Håkon Hukkelås, Rudolf Mester, Frank Lindseth </center>")
140
+ with gradio.Tab("Text-Guided Anonymization"):
141
+ GuidedDemo(anonymizer_face, cfg_face).setup_interface()
142
+
143
+
144
+ demo.launch()
deep_privacy/configs/anonymizers/FB_cse.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.person_detector import CSEPersonDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+ from dp2.generator.dummy_generators import MaskOutGenerator
6
+
7
+
8
+ maskout_G = L(MaskOutGenerator)(noise="constant")
9
+
10
+ detector = L(CSEPersonDetector)(
11
+ mask_rcnn_cfg=dict(),
12
+ cse_cfg=dict(),
13
+ cse_post_process_cfg=dict(
14
+ target_imsize=(288, 160),
15
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
16
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
17
+ iou_combine_threshold=0.4,
18
+ dilation_percentage=0.02,
19
+ normalize_embedding=False
20
+ ),
21
+ score_threshold=0.3,
22
+ cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
23
+ )
24
+
25
+ anonymizer = L(Anonymizer)(
26
+ detector="${detector}",
27
+ cse_person_G_cfg="configs/fdh/styleganL.py",
28
+ )
deep_privacy/configs/anonymizers/FB_cse_mask.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.person_detector import CSEPersonDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+ from dp2.generator.dummy_generators import MaskOutGenerator
6
+
7
+
8
+ maskout_G = L(MaskOutGenerator)(noise="constant")
9
+
10
+ detector = L(CSEPersonDetector)(
11
+ mask_rcnn_cfg=dict(),
12
+ cse_cfg=dict(),
13
+ cse_post_process_cfg=dict(
14
+ target_imsize=(288, 160),
15
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
16
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
17
+ iou_combine_threshold=0.4,
18
+ dilation_percentage=0.02,
19
+ normalize_embedding=False
20
+ ),
21
+ score_threshold=0.3,
22
+ cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
23
+ )
24
+
25
+ anonymizer = L(Anonymizer)(
26
+ detector="${detector}",
27
+ person_G_cfg="configs/fdh/styleganL_nocse.py",
28
+ cse_person_G_cfg="configs/fdh/styleganL.py",
29
+ )
deep_privacy/configs/anonymizers/FB_cse_mask_face.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+
6
+ detector = L(CSeMaskFaceDetector)(
7
+ mask_rcnn_cfg=dict(),
8
+ face_detector_cfg=dict(),
9
+ face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
10
+ cse_cfg=dict(),
11
+ cse_post_process_cfg=dict(
12
+ target_imsize=(288, 160),
13
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
14
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
15
+ iou_combine_threshold=0.4,
16
+ dilation_percentage=0.02,
17
+ normalize_embedding=False
18
+ ),
19
+ score_threshold=0.3,
20
+ cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache")
21
+ )
22
+
23
+ anonymizer = L(Anonymizer)(
24
+ detector="${detector}",
25
+ face_G_cfg="configs/fdf/stylegan.py",
26
+ person_G_cfg="configs/fdh/styleganL_nocse.py",
27
+ cse_person_G_cfg="configs/fdh/styleganL.py",
28
+ car_G_cfg="configs/generators/dummy/pixelation8.py"
29
+ )
deep_privacy/configs/anonymizers/deep_privacy1.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .face_fdf128 import anonymizer, common, detector
2
+ from dp2.detection.deep_privacy1_detector import DeepPrivacy1Detector
3
+ from tops.config import LazyCall as L
4
+
5
+ anonymizer.update(
6
+ face_G_cfg="configs/fdf/deep_privacy1.py",
7
+ )
8
+
9
+ anonymizer.detector = L(DeepPrivacy1Detector)(
10
+ face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
11
+ face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True),
12
+ score_threshold=0.3,
13
+ keypoint_threshold=0.3,
14
+ cache_directory=common.output_dir.joinpath("deep_privacy1_cache")
15
+ )
deep_privacy/configs/anonymizers/face.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.face_detector import FaceDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+
6
+
7
+ detector = L(FaceDetector)(
8
+ face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
9
+ face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
10
+ score_threshold=0.3,
11
+ cache_directory=common.output_dir.joinpath("face_detection_cache"),
12
+ )
13
+
14
+ anonymizer = L(Anonymizer)(
15
+ detector="${detector}",
16
+ face_G_cfg="configs/fdf/stylegan.py",
17
+ )
deep_privacy/configs/anonymizers/face_fdf128.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.face_detector import FaceDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+
6
+
7
+ detector = L(FaceDetector)(
8
+ face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
9
+ face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True),
10
+ score_threshold=0.3,
11
+ cache_directory=common.output_dir.joinpath("face_detection_cache")
12
+ )
13
+
14
+
15
+ anonymizer = L(Anonymizer)(
16
+ detector="${detector}",
17
+ face_G_cfg="configs/fdf/stylegan_fdf128.py",
18
+ )
deep_privacy/configs/anonymizers/market1501/blackout.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
7
+ anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py"
8
+ anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py"
deep_privacy/configs/anonymizers/market1501/person.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
deep_privacy/configs/anonymizers/market1501/pixelation16.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
7
+ anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py"
8
+ anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py"
deep_privacy/configs/anonymizers/market1501/pixelation8.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
7
+ anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py"
8
+ anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py"
deep_privacy/configs/datasets/coco_cse.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tops.config import LazyCall as L
4
+ import torch
5
+ import functools
6
+ from dp2.data.datasets.coco_cse import CocoCSE
7
+ from dp2.data.build import get_dataloader
8
+ from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
9
+ from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
10
+ from dp2.metrics.torch_metrics import compute_metrics_iteratively
11
+ from .utils import final_eval_fn
12
+
13
+
14
+ dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
15
+ metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
16
+ data_dir = Path(dataset_base_dir, "coco_cse")
17
+ data = dict(
18
+ imsize=(288, 160),
19
+ im_channels=3,
20
+ semantic_nc=26,
21
+ cse_nc=16,
22
+ train=dict(
23
+ dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
24
+ loader=L(get_dataloader)(
25
+ shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
26
+ batch_size="${train.batch_size}",
27
+ dataset="${..dataset}",
28
+ infinite=True,
29
+ gpu_transform=L(torch.nn.Sequential)(*[
30
+ L(ToFloat)(),
31
+ L(StyleGANAugmentPipe)(
32
+ rotate=0.5, rotate_max=.05,
33
+ xint=.5, xint_max=0.05,
34
+ scale=.5, scale_std=.05,
35
+ aniso=0.5, aniso_std=.05,
36
+ xfrac=.5, xfrac_std=.05,
37
+ brightness=.5, brightness_std=.05,
38
+ contrast=.5, contrast_std=.1,
39
+ hue=.5, hue_max=.05,
40
+ saturation=.5, saturation_std=.5,
41
+ imgfilter=.5, imgfilter_std=.1),
42
+ L(RandomHorizontalFlip)(p=0.5),
43
+ L(CreateEmbedding)(),
44
+ L(Resize)(size="${data.imsize}"),
45
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
46
+ L(CreateCondition)(),
47
+ ])
48
+ )
49
+ ),
50
+ val=dict(
51
+ dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
52
+ loader=L(get_dataloader)(
53
+ shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
54
+ batch_size="${train.batch_size}",
55
+ dataset="${..dataset}",
56
+ infinite=False,
57
+ gpu_transform=L(torch.nn.Sequential)(*[
58
+ L(ToFloat)(),
59
+ L(CreateEmbedding)(),
60
+ L(Resize)(size="${data.imsize}"),
61
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
62
+ L(CreateCondition)(),
63
+ ])
64
+ )
65
+ ),
66
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
67
+ train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
68
+ evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
69
+ )
deep_privacy/configs/datasets/fdf128.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from functools import partial
3
+ from dp2.data.datasets.fdf import FDFDataset
4
+ from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn, train_eval_fn
5
+
6
+ data_dir = Path(dataset_base_dir, "fdf")
7
+ data.train.dataset.dirpath = data_dir.joinpath("train")
8
+ data.val.dataset.dirpath = data_dir.joinpath("val")
9
+ data.imsize = (128, 128)
10
+
11
+
12
+ data.train_evaluation_fn = partial(
13
+ train_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train"))
14
+ data.evaluation_fn = partial(
15
+ final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final"))
16
+
17
+ data.train.dataset.update(
18
+ _target_ = FDFDataset,
19
+ imsize="${data.imsize}"
20
+ )
21
+ data.val.dataset.update(
22
+ _target_ = FDFDataset,
23
+ imsize="${data.imsize}"
24
+ )
deep_privacy/configs/datasets/fdf256.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tops.config import LazyCall as L
4
+ import torch
5
+ import functools
6
+ from dp2.data.datasets.fdf import FDF256Dataset
7
+ from dp2.data.build import get_dataloader
8
+ from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
9
+ from .utils import final_eval_fn, train_eval_fn
10
+
11
+
12
+ dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
13
+ metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
14
+ data_dir = Path(dataset_base_dir, "fdf256")
15
+ data = dict(
16
+ imsize=(256, 256),
17
+ im_channels=3,
18
+ semantic_nc=None,
19
+ cse_nc=None,
20
+ n_keypoints=None,
21
+ train=dict(
22
+ dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
23
+ loader=L(get_dataloader)(
24
+ shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
25
+ batch_size="${train.batch_size}",
26
+ dataset="${..dataset}",
27
+ infinite=True,
28
+ gpu_transform=L(torch.nn.Sequential)(*[
29
+ L(ToFloat)(),
30
+ L(RandomHorizontalFlip)(p=0.5),
31
+ L(Resize)(size="${data.imsize}"),
32
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
33
+ L(CreateCondition)(),
34
+ ])
35
+ )
36
+ ),
37
+ val=dict(
38
+ dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
39
+ loader=L(get_dataloader)(
40
+ shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
41
+ batch_size="${train.batch_size}",
42
+ dataset="${..dataset}",
43
+ infinite=False,
44
+ gpu_transform=L(torch.nn.Sequential)(*[
45
+ L(ToFloat)(),
46
+ L(Resize)(size="${data.imsize}"),
47
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
48
+ L(CreateCondition)(),
49
+ ])
50
+ )
51
+ ),
52
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
53
+ train_evaluation_fn=functools.partial(train_eval_fn, cache_directory=Path(metrics_cache, "fdf_val_train")),
54
+ evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
55
+ )
deep_privacy/configs/datasets/fdh.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tops.config import LazyCall as L
4
+ import torch
5
+ import functools
6
+ from dp2.data.datasets.fdh import get_dataloader_fdh_wds
7
+ from dp2.data.utils import get_coco_flipmap
8
+ from dp2.data.transforms.transforms import (
9
+ Normalize,
10
+ ToFloat,
11
+ CreateCondition,
12
+ RandomHorizontalFlip,
13
+ CreateEmbedding,
14
+ )
15
+ from dp2.metrics.torch_metrics import compute_metrics_iteratively
16
+ from dp2.metrics.fid_clip import compute_fid_clip
17
+ from dp2.metrics.ppl import calculate_ppl
18
+ from .utils import train_eval_fn
19
+
20
+
21
+ def final_eval_fn(*args, **kwargs):
22
+ result = compute_metrics_iteratively(*args, **kwargs)
23
+ result2 = calculate_ppl(*args, **kwargs, upsample_size=(288, 160))
24
+ result3 = compute_fid_clip(*args, **kwargs)
25
+ assert all(key not in result for key in result2)
26
+ result.update(result2)
27
+ result.update(result3)
28
+ return result
29
+
30
+
31
+ def get_cache_directory(imsize, subset):
32
+ return Path(metrics_cache, f"{subset}{imsize[0]}")
33
+
34
+ dataset_base_dir = (
35
+ os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
36
+ )
37
+ metrics_cache = (
38
+ os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
39
+ )
40
+ data_dir = Path(dataset_base_dir, "fdh")
41
+ data = dict(
42
+ imsize=(288, 160),
43
+ im_channels=3,
44
+ cse_nc=16,
45
+ n_keypoints=17,
46
+ train=dict(
47
+ loader=L(get_dataloader_fdh_wds)(
48
+ path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
49
+ batch_size="${train.batch_size}",
50
+ num_workers=6,
51
+ transform=L(torch.nn.Sequential)(
52
+ L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
53
+ ),
54
+ gpu_transform=L(torch.nn.Sequential)(
55
+ L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
56
+ L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
57
+ L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
58
+ L(CreateCondition)(),
59
+ ),
60
+ infinite=True,
61
+ shuffle=True,
62
+ partial_batches=False,
63
+ load_embedding=True,
64
+ keypoints_split="train",
65
+ load_new_keypoints=False
66
+ )
67
+ ),
68
+ val=dict(
69
+ loader=L(get_dataloader_fdh_wds)(
70
+ path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
71
+ batch_size="${train.batch_size}",
72
+ num_workers=6,
73
+ transform=None,
74
+ gpu_transform="${data.train.loader.gpu_transform}",
75
+ infinite=False,
76
+ shuffle=False,
77
+ partial_batches=True,
78
+ load_embedding=True,
79
+ keypoints_split="val",
80
+ load_new_keypoints="${data.train.loader.load_new_keypoints}"
81
+ )
82
+ ),
83
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
84
+ train_evaluation_fn=L(functools.partial)(
85
+ train_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh"),
86
+ data_len=30_000),
87
+ evaluation_fn=L(functools.partial)(
88
+ final_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh_eval"),
89
+ data_len=30_000)
90
+ )
deep_privacy/configs/datasets/utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.metrics.ppl import calculate_ppl
2
+ from dp2.metrics.torch_metrics import compute_metrics_iteratively
3
+ from dp2.metrics.fid_clip import compute_fid_clip
4
+
5
+
6
+ def final_eval_fn(*args, **kwargs):
7
+ result = compute_metrics_iteratively(*args, **kwargs)
8
+ result2 = calculate_ppl(*args, **kwargs,)
9
+ result3 = compute_fid_clip(*args, **kwargs)
10
+ assert all(key not in result for key in result2)
11
+ result.update(result2)
12
+ result.update(result3)
13
+ return result
14
+
15
+
16
+ def train_eval_fn(*args, **kwargs):
17
+ result = compute_metrics_iteratively(*args, **kwargs)
18
+ result2 = compute_fid_clip(*args, **kwargs)
19
+ assert all(key not in result for key in result2)
20
+ result.update(result2)
21
+ return result
deep_privacy/configs/defaults.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import os
3
+ import torch
4
+ from tops.config import LazyCall as L
5
+
6
+ if "PRETRAINED_CHECKPOINTS_PATH" in os.environ:
7
+ PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"])
8
+ else:
9
+ PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints")
10
+ if "BASE_OUTPUT_DIR" in os.environ:
11
+ BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"])
12
+ else:
13
+ BASE_OUTPUT_DIR = pathlib.Path("outputs")
14
+
15
+
16
+
17
+ common = dict(
18
+ logger_backend=["wandb", "stdout", "json", "image_dumper"],
19
+ wandb_project="deep_privacy2",
20
+ output_dir=BASE_OUTPUT_DIR,
21
+ experiment_name=None, # Optional experiment name to show on wandb
22
+ )
23
+
24
+ train = dict(
25
+ batch_size=32,
26
+ seed=0,
27
+ ims_per_log=1024,
28
+ ims_per_val=int(200e3),
29
+ max_images_to_train=int(12e6),
30
+ amp=dict(
31
+ enabled=True,
32
+ scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
33
+ scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
34
+ ),
35
+ fp16_ddp_accumulate=False, # All gather gradients in fp16?
36
+ broadcast_buffers=False,
37
+ bias_act_plugin_enabled=True,
38
+ grid_sample_gradfix_enabled=True,
39
+ conv2d_gradfix_enabled=False,
40
+ channels_last=False,
41
+ compile_G=dict(
42
+ enabled=False,
43
+ mode="default" # default, reduce-overhead or max-autotune
44
+ ),
45
+ compile_D=dict(
46
+ enabled=False,
47
+ mode="default" # default, reduce-overhead or max-autotune
48
+ )
49
+ )
50
+
51
+ # exponential moving average
52
+ EMA = dict(rampup=0.05)
53
+
deep_privacy/configs/discriminators/sg2_discriminator.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from dp2.discriminator import SG2Discriminator
3
+ import torch
4
+ from dp2.loss import StyleGAN2Loss
5
+
6
+
7
+ discriminator = L(SG2Discriminator)(
8
+ imsize="${data.imsize}",
9
+ im_channels="${data.im_channels}",
10
+ min_fmap_resolution=4,
11
+ max_cnum_mul=8,
12
+ cnum=80,
13
+ input_condition=True,
14
+ conv_clamp=256,
15
+ input_cse=False,
16
+ cse_nc="${data.cse_nc}",
17
+ fix_residual=False,
18
+ )
19
+
20
+
21
+ loss_fnc = L(StyleGAN2Loss)(
22
+ lazy_regularization=True,
23
+ lazy_reg_interval=16,
24
+ r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False),
25
+ EP_lambd=0.001,
26
+ pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01)
27
+ )
28
+
29
+ def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs):
30
+ if lazy_regularization:
31
+ # From Analyzing and improving the image quality of stylegan, CVPR 2020
32
+ c = lazy_reg_interval / (lazy_reg_interval + 1)
33
+ betas = [beta ** c for beta in betas]
34
+ lr *= c
35
+ print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}")
36
+ return type(lr=lr, betas=betas, **kwargs)
37
+
38
+
39
+ D_optim = L(build_D_optim)(
40
+ type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99),
41
+ lazy_regularization="${loss_fnc.lazy_regularization}",
42
+ lazy_reg_interval="${loss_fnc.lazy_reg_interval}")
43
+ G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99))
deep_privacy/configs/fdf/deep_privacy1.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from dp2.generator.deep_privacy1 import MSGGenerator
3
+ from ..datasets.fdf128 import data
4
+ from ..defaults import common, train
5
+
6
+ generator = L(MSGGenerator)()
7
+
8
+ common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/fdf128_model512.ckpt"
9
+ common.model_md5sum = "6cc8b285bdc1fcdfc64f5db7c521d0a6"
deep_privacy/configs/fdf/stylegan.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..generators.stylegan_unet import generator
2
+ from ..datasets.fdf256 import data
3
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
4
+ from ..defaults import train, common, EMA
5
+
6
+ train.max_images_to_train = int(35e6)
7
+ G_optim.lr = 0.002
8
+ D_optim.lr = 0.002
9
+ generator.input_cse = False
10
+ loss_fnc.r1_opts.lambd = 1
11
+ train.ims_per_val = int(2e6)
12
+
13
+ common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e"
14
+ common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07"
deep_privacy/configs/fdf/stylegan_fdf128.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
2
+ from ..datasets.fdf128 import data
3
+ from ..generators.stylegan_unet import generator
4
+ from ..defaults import train, common, EMA
5
+ from tops.config import LazyCall as L
6
+
7
+ G_optim.lr = 0.002
8
+ D_optim.lr = 0.002
9
+ generator.update(cnum=128, max_cnum_mul=4, input_cse=False)
10
+ loss_fnc.r1_opts.lambd = 0.1
11
+
12
+ train.update(ims_per_val=int(2e6), batch_size=64, max_images_to_train=int(35e6))
13
+
14
+ common.update(
15
+ model_url="https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/66d803c0-55ce-44c0-9d53-815c2c0e6ba4eb458409-9e91-45d1-bce0-95c8a47a57218b102fdf-bea3-44dc-aac4-0fb1d370ef1c",
16
+ model_md5sum="bccd4403e7c9bca682566ff3319e8176"
17
+ )
deep_privacy/configs/fdh/styleganL.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from ..generators.stylegan_unet import generator
3
+ from ..datasets.fdh import data
4
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
5
+ from ..defaults import train, common, EMA
6
+
7
+ train.max_images_to_train = int(50e6)
8
+ train.batch_size = 64
9
+ G_optim.lr = 0.002
10
+ D_optim.lr = 0.002
11
+ data.train.loader.num_workers = 4
12
+ train.ims_per_val = int(1e6)
13
+ loss_fnc.r1_opts.lambd = .1
14
+
15
+ common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c"
16
+ common.model_md5sum = "3411478b5ec600a4219cccf4499732bd"
deep_privacy/configs/fdh/styleganL_nocse.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from ..generators.stylegan_unet import generator
3
+ from ..datasets.fdh import data
4
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
5
+ from ..defaults import train, common, EMA
6
+
7
+ train.max_images_to_train = int(50e6)
8
+ G_optim.lr = 0.002
9
+ D_optim.lr = 0.002
10
+ generator.input_cse = False
11
+ data.load_embeddings = False
12
+ common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt"
13
+ common.model_md5sum = "fda0d809741bc67487abada793975c37"
14
+ generator.fix_errors = False
deep_privacy/configs/generators/stylegan_unet.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.generator.stylegan_unet import StyleGANUnet
2
+ from tops.config import LazyCall as L
3
+
4
+ generator = L(StyleGANUnet)(
5
+ imsize="${data.imsize}",
6
+ im_channels="${data.im_channels}",
7
+ min_fmap_resolution=8,
8
+ cnum=64,
9
+ max_cnum_mul=8,
10
+ n_middle_blocks=0,
11
+ z_channels=512,
12
+ mask_output=True,
13
+ conv_clamp=256,
14
+ input_cse=True,
15
+ scale_grad=True,
16
+ cse_nc="${data.cse_nc}",
17
+ w_dim=512,
18
+ n_keypoints="${data.n_keypoints}",
19
+ input_keypoints=False,
20
+ input_keypoint_indices=[],
21
+ fix_errors=True
22
+ )
deep_privacy/dp2/__init__.py ADDED
File without changes
deep_privacy/dp2/anonymizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .anonymizer import Anonymizer
deep_privacy/dp2/anonymizer/anonymizer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Union, Optional
3
+ import numpy as np
4
+ import torch
5
+ import tops
6
+ import torchvision.transforms.functional as F
7
+ from motpy import Detection, MultiObjectTracker
8
+ from dp2.utils import load_config
9
+ from dp2.infer import build_trained_generator
10
+ from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
11
+
12
+
13
+ def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
14
+ cfg = load_config(cfg_path)
15
+ G = build_trained_generator(cfg)
16
+ tops.logger.log(f"Loaded generator from: {cfg_path}")
17
+ return G
18
+
19
+
20
+ class Anonymizer:
21
+
22
+ def __init__(
23
+ self,
24
+ detector,
25
+ load_cache: bool = False,
26
+ person_G_cfg: Optional[Union[str, Path]] = None,
27
+ cse_person_G_cfg: Optional[Union[str, Path]] = None,
28
+ face_G_cfg: Optional[Union[str, Path]] = None,
29
+ car_G_cfg: Optional[Union[str, Path]] = None,
30
+ ) -> None:
31
+ self.detector = detector
32
+ self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
33
+ self.load_cache = load_cache
34
+ if cse_person_G_cfg is not None:
35
+ self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
36
+ if person_G_cfg is not None:
37
+ self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
38
+ if face_G_cfg is not None:
39
+ self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
40
+ if car_G_cfg is not None:
41
+ self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
42
+
43
+ def initialize_tracker(self, fps: float):
44
+ self.tracker = MultiObjectTracker(dt=1/fps)
45
+ self.track_to_z_idx = dict()
46
+
47
+ def reset_tracker(self):
48
+ self.track_to_z_idx = dict()
49
+
50
+ def forward_G(self,
51
+ G,
52
+ batch,
53
+ multi_modal_truncation: bool,
54
+ amp: bool,
55
+ z_idx: int,
56
+ truncation_value: float,
57
+ idx: int,
58
+ all_styles=None):
59
+ batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
60
+ batch["img"] = batch["img"].float()
61
+ batch["condition"] = batch["mask"].float() * batch["img"]
62
+
63
+ with torch.cuda.amp.autocast(amp):
64
+ z = None
65
+ if z_idx is not None:
66
+ state = np.random.RandomState(seed=z_idx[idx])
67
+ z = state.normal(size=(1, G.z_channels)).astype(np.float32)
68
+ z = tops.to_cuda(torch.from_numpy(z))
69
+
70
+ if all_styles is not None:
71
+ anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
72
+ elif multi_modal_truncation:
73
+ w_indices = None
74
+ if z_idx is not None:
75
+ w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
76
+ anonymized_im = G.multi_modal_truncate(
77
+ **batch, truncation_value=truncation_value,
78
+ w_indices=w_indices,
79
+ z=z
80
+ )["img"]
81
+ else:
82
+ anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
83
+ anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
84
+ return anonymized_im
85
+
86
+ @torch.no_grad()
87
+ def anonymize_detections(self,
88
+ im, detection,
89
+ update_identity=None,
90
+ **synthesis_kwargs
91
+ ):
92
+ G = self.generators[type(detection)]
93
+ if G is None:
94
+ return im
95
+ C, H, W = im.shape
96
+ if update_identity is None:
97
+ update_identity = [True for i in range(len(detection))]
98
+ for idx in range(len(detection)):
99
+ if not update_identity[idx]:
100
+ continue
101
+ batch = detection.get_crop(idx, im)
102
+ x0, y0, x1, y1 = batch.pop("boxes")[0]
103
+ batch = {k: tops.to_cuda(v) for k, v in batch.items()}
104
+ anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx)
105
+
106
+ gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True)
107
+ mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
108
+ # Remove padding
109
+ pad = [max(-x0, 0), max(-y0, 0)]
110
+ pad = [*pad, max(x1-W, 0), max(y1-H, 0)]
111
+ def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
112
+
113
+ gim = remove_pad(gim)
114
+ mask = remove_pad(mask) > 0.5
115
+ x0, y0 = max(x0, 0), max(y0, 0)
116
+ x1, y1 = min(x1, W), min(y1, H)
117
+ mask = mask.logical_not()[None].repeat(3, 1, 1)
118
+
119
+ im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte()
120
+ return im
121
+
122
+ def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
123
+ all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
124
+ im = im.cpu()
125
+ for det in all_detections:
126
+ im = det.visualize(im)
127
+ return im
128
+
129
+ @torch.no_grad()
130
+ def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor:
131
+ assert im.dtype == torch.uint8
132
+ im = tops.to_cuda(im)
133
+ all_detections = detections
134
+ if detections is None:
135
+ if self.load_cache:
136
+ all_detections = self.detector.forward_and_cache(im, cache_id)
137
+ else:
138
+ all_detections = self.detector(im)
139
+ if hasattr(self, "tracker") and track:
140
+ [_.pre_process() for _ in all_detections]
141
+ boxes = np.concatenate([_.boxes for _ in all_detections])
142
+ boxes = [Detection(box) for box in boxes]
143
+ self.tracker.step(boxes)
144
+ track_ids = self.tracker.detections_matched_ids
145
+ z_idx = []
146
+ for track_id in track_ids:
147
+ if track_id not in self.track_to_z_idx:
148
+ self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1)
149
+ z_idx.append(self.track_to_z_idx[track_id])
150
+ z_idx = np.array(z_idx)
151
+ idx_offset = 0
152
+
153
+ for detection in all_detections:
154
+ zs = None
155
+ if hasattr(self, "tracker") and track:
156
+ zs = z_idx[idx_offset:idx_offset+len(detection)]
157
+ idx_offset += len(detection)
158
+ im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
159
+
160
+ return im.cpu()
161
+
162
+ def __call__(self, *args, **kwargs):
163
+ return self.forward(*args, **kwargs)
deep_privacy/dp2/anonymizer/histogram_match_anonymizers.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import tops
4
+ import numpy as np
5
+ from kornia.color import rgb_to_hsv
6
+ from dp2 import utils
7
+ from kornia.enhance import histogram
8
+ from .anonymizer import Anonymizer
9
+ import torchvision.transforms.functional as F
10
+ from skimage.exposure import match_histograms
11
+ from kornia.filters import gaussian_blur2d
12
+
13
+
14
+ class LatentHistogramMatchAnonymizer(Anonymizer):
15
+
16
+ def forward_G(
17
+ self,
18
+ G,
19
+ batch,
20
+ multi_modal_truncation: bool,
21
+ amp: bool,
22
+ z_idx: int,
23
+ truncation_value: float,
24
+ idx: int,
25
+ n_sampling_steps: int = 1,
26
+ all_styles=None,
27
+ ):
28
+ batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
29
+ batch["img"] = batch["img"].float()
30
+ batch["condition"] = batch["mask"].float() * batch["img"]
31
+
32
+ assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
33
+ real_hls = rgb_to_hsv(utils.denormalize_img(batch["img"]))
34
+ real_hls[:, 0] /= 2 * torch.pi
35
+ indices = [1, 2]
36
+ hist_kwargs = dict(
37
+ bins=torch.linspace(0, 1, 256, dtype=torch.float32, device=tops.get_device()),
38
+ bandwidth=torch.tensor(1., device=tops.get_device()))
39
+ real_hist = [histogram(real_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
40
+ for j in range(n_sampling_steps):
41
+ if j == 0:
42
+ if multi_modal_truncation:
43
+ w = G.style_net.multi_modal_truncate(
44
+ truncation_value=truncation_value, **batch, w_indices=None).detach()
45
+ else:
46
+ w = G.style_net.get_truncated(truncation_value, **batch).detach()
47
+ assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
48
+ w.requires_grad = True
49
+ optim = torch.optim.Adam([w])
50
+ with torch.set_grad_enabled(True):
51
+ with torch.cuda.amp.autocast(amp):
52
+ anonymized_im = G(**batch, truncation_value=None, w=w)["img"]
53
+ fake_hls = rgb_to_hsv(anonymized_im*0.5 + 0.5)
54
+ fake_hls[:, 0] /= 2 * torch.pi
55
+ fake_hist = [histogram(fake_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
56
+ dist = sum([utils.torch_wasserstein_loss(r, f) for r, f in zip(real_hist, fake_hist)])
57
+ dist.backward()
58
+ if w.grad.sum() == 0:
59
+ break
60
+ assert w.grad.sum() != 0
61
+ optim.step()
62
+ optim.zero_grad()
63
+ if dist < 0.02:
64
+ break
65
+ anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
66
+ return anonymized_im
67
+
68
+
69
+ class HistogramMatchAnonymizer(Anonymizer):
70
+
71
+ def forward_G(self, batch, *args, **kwargs):
72
+ rimg = batch["img"]
73
+ batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
74
+ batch["img"] = batch["img"].float()
75
+ batch["condition"] = batch["mask"].float() * batch["img"]
76
+
77
+ anonymized_im = super().forward_G(batch, *args, **kwargs)
78
+
79
+ equalized_gim = match_histograms(tops.im2numpy(anonymized_im.round().clamp(0, 255).byte()), tops.im2numpy(rimg))
80
+ if equalized_gim.dtype != np.uint8:
81
+ equalized_gim = equalized_gim.astype(np.float32)
82
+ assert equalized_gim.dtype == np.float32, equalized_gim.dtype
83
+ equalized_gim = tops.im2torch(equalized_gim, to_float=False)[0]
84
+ else:
85
+ equalized_gim = tops.im2torch(equalized_gim, to_float=False).float()[0]
86
+ equalized_gim = equalized_gim.to(device=rimg.device)
87
+ assert equalized_gim.dtype == torch.float32
88
+ gaussian_mask = 1 - (batch["maskrcnn_mask"][0].repeat(3, 1, 1) > 0.5).float()
89
+
90
+ gaussian_mask = gaussian_blur2d(gaussian_mask[None], kernel_size=[19, 19], sigma=[10, 10])[0]
91
+ gaussian_mask = gaussian_mask / gaussian_mask.max()
92
+ anonymized_im = gaussian_mask * equalized_gim + (1-gaussian_mask) * anonymized_im
93
+ return anonymized_im
deep_privacy/dp2/data/__init__.py ADDED
File without changes
deep_privacy/dp2/data/build.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ from .utils import collate_fn
4
+
5
+
6
+ def get_dataloader(
7
+ dataset, gpu_transform: torch.nn.Module,
8
+ num_workers,
9
+ batch_size,
10
+ infinite: bool,
11
+ drop_last: bool,
12
+ prefetch_factor: int,
13
+ shuffle,
14
+ channels_last=False
15
+ ):
16
+ sampler = None
17
+ dl_kwargs = dict(
18
+ pin_memory=True,
19
+ )
20
+ if infinite:
21
+ sampler = tops.InfiniteSampler(
22
+ dataset, rank=tops.rank(),
23
+ num_replicas=tops.world_size(),
24
+ shuffle=shuffle
25
+ )
26
+ elif tops.world_size() > 1:
27
+ sampler = torch.utils.data.DistributedSampler(
28
+ dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
29
+ dl_kwargs["drop_last"] = drop_last
30
+ else:
31
+ dl_kwargs["shuffle"] = shuffle
32
+ dl_kwargs["drop_last"] = drop_last
33
+ dataloader = torch.utils.data.DataLoader(
34
+ dataset, sampler=sampler, collate_fn=collate_fn,
35
+ batch_size=batch_size,
36
+ num_workers=num_workers, prefetch_factor=prefetch_factor,
37
+ **dl_kwargs
38
+ )
39
+ dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
40
+ return dataloader
deep_privacy/dp2/data/datasets/__init__.py ADDED
File without changes
deep_privacy/dp2/data/datasets/coco_cse.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torchvision
3
+ import torch
4
+ import pathlib
5
+ import numpy as np
6
+ from typing import Callable, Optional, Union
7
+ from torch.hub import get_dir as get_hub_dir
8
+
9
+
10
+ def cache_embed_stats(embed_map: torch.Tensor):
11
+ mean = embed_map.mean(dim=0, keepdim=True)
12
+ rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
13
+
14
+ cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
15
+ path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
16
+ path.parent.mkdir(exist_ok=True, parents=True)
17
+ torch.save(cache, path)
18
+
19
+
20
+ class CocoCSE(torch.utils.data.Dataset):
21
+
22
+ def __init__(self,
23
+ dirpath: Union[str, pathlib.Path],
24
+ transform: Optional[Callable],
25
+ normalize_E: bool,):
26
+ dirpath = pathlib.Path(dirpath)
27
+ self.dirpath = dirpath
28
+
29
+ self.transform = transform
30
+ assert self.dirpath.is_dir(),\
31
+ f"Did not find dataset at: {dirpath}"
32
+ self.image_paths, self.embedding_paths = self._load_impaths()
33
+ self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
34
+ mean = self.embed_map.mean(dim=0, keepdim=True)
35
+ rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
36
+ self.embed_map = (self.embed_map - mean) * rstd
37
+ cache_embed_stats(self.embed_map)
38
+
39
+ def _load_impaths(self):
40
+ image_dir = self.dirpath.joinpath("images")
41
+ image_paths = list(image_dir.glob("*.png"))
42
+ image_paths.sort()
43
+ embedding_paths = [
44
+ self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
45
+ ]
46
+ return image_paths, embedding_paths
47
+
48
+ def __len__(self):
49
+ return len(self.image_paths)
50
+
51
+ def __getitem__(self, idx):
52
+ im = torchvision.io.read_image(str(self.image_paths[idx]))
53
+ vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
54
+ vertices = torch.from_numpy(vertices.squeeze()).long()
55
+ mask = torch.from_numpy(mask.squeeze()).float()
56
+ border = torch.from_numpy(border.squeeze()).float()
57
+ E_mask = 1 - mask - border
58
+ batch = {
59
+ "img": im,
60
+ "vertices": vertices[None],
61
+ "mask": mask[None],
62
+ "embed_map": self.embed_map,
63
+ "border": border[None],
64
+ "E_mask": E_mask[None]
65
+ }
66
+ if self.transform is None:
67
+ return batch
68
+ return self.transform(batch)
deep_privacy/dp2/data/datasets/fdf.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import Tuple
3
+ import numpy as np
4
+ import torch
5
+ import pathlib
6
+ try:
7
+ import pyspng
8
+ PYSPNG_IMPORTED = True
9
+ except ImportError:
10
+ PYSPNG_IMPORTED = False
11
+ print("Could not load pyspng. Defaulting to pillow image backend.")
12
+ from PIL import Image
13
+ from tops import logger
14
+
15
+
16
+ class FDFDataset:
17
+
18
+ def __init__(self,
19
+ dirpath,
20
+ imsize: Tuple[int],
21
+ load_keypoints: bool,
22
+ transform):
23
+ dirpath = pathlib.Path(dirpath)
24
+ self.dirpath = dirpath
25
+ self.transform = transform
26
+ self.imsize = imsize[0]
27
+ self.load_keypoints = load_keypoints
28
+ assert self.dirpath.is_dir(),\
29
+ f"Did not find dataset at: {dirpath}"
30
+ image_dir = self.dirpath.joinpath("images", str(self.imsize))
31
+ self.image_paths = list(image_dir.glob("*.png"))
32
+ assert len(self.image_paths) > 0,\
33
+ f"Did not find images in: {image_dir}"
34
+ self.image_paths.sort(key=lambda x: int(x.stem))
35
+ self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
36
+
37
+ self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch"))
38
+ assert len(self.image_paths) == len(self.bounding_boxes)
39
+ assert len(self.image_paths) == len(self.landmarks)
40
+ logger.log(
41
+ f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}")
42
+
43
+ def get_mask(self, idx):
44
+ mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool)
45
+ bounding_box = self.bounding_boxes[idx]
46
+ x0, y0, x1, y1 = bounding_box
47
+ mask[:, y0:y1, x0:x1] = 0
48
+ return mask
49
+
50
+ def __len__(self):
51
+ return len(self.image_paths)
52
+
53
+ def __getitem__(self, index):
54
+ impath = self.image_paths[index]
55
+ if PYSPNG_IMPORTED:
56
+ with open(impath, "rb") as fp:
57
+ im = pyspng.load(fp.read())
58
+ else:
59
+ with Image.open(impath) as fp:
60
+ im = np.array(fp)
61
+ im = torch.from_numpy(np.rollaxis(im, -1, 0))
62
+ masks = self.get_mask(index)
63
+ landmark = self.landmarks[index]
64
+ batch = {
65
+ "img": im,
66
+ "mask": masks,
67
+ }
68
+ if self.load_keypoints:
69
+ batch["keypoints"] = landmark
70
+ if self.transform is None:
71
+ return batch
72
+ return self.transform(batch)
73
+
74
+
75
+ class FDF256Dataset:
76
+
77
+ def __init__(self,
78
+ dirpath,
79
+ load_keypoints: bool,
80
+ transform):
81
+ dirpath = pathlib.Path(dirpath)
82
+ self.dirpath = dirpath
83
+ self.transform = transform
84
+ self.load_keypoints = load_keypoints
85
+ assert self.dirpath.is_dir(),\
86
+ f"Did not find dataset at: {dirpath}"
87
+ image_dir = self.dirpath.joinpath("images")
88
+ self.image_paths = list(image_dir.glob("*.png"))
89
+ assert len(self.image_paths) > 0,\
90
+ f"Did not find images in: {image_dir}"
91
+ self.image_paths.sort(key=lambda x: int(x.stem))
92
+ self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
93
+ self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy")))
94
+ assert len(self.image_paths) == len(self.bounding_boxes)
95
+ assert len(self.image_paths) == len(self.landmarks)
96
+ logger.log(
97
+ f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}")
98
+
99
+ def get_mask(self, idx):
100
+ mask = torch.ones((1, 256, 256), dtype=torch.bool)
101
+ bounding_box = self.bounding_boxes[idx]
102
+ x0, y0, x1, y1 = bounding_box
103
+ mask[:, y0:y1, x0:x1] = 0
104
+ return mask
105
+
106
+ def __len__(self):
107
+ return len(self.image_paths)
108
+
109
+ def __getitem__(self, index):
110
+ impath = self.image_paths[index]
111
+ if PYSPNG_IMPORTED:
112
+ with open(impath, "rb") as fp:
113
+ im = pyspng.load(fp.read())
114
+ else:
115
+ with Image.open(impath) as fp:
116
+ im = np.array(fp)
117
+ im = torch.from_numpy(np.rollaxis(im, -1, 0))
118
+ masks = self.get_mask(index)
119
+ landmark = self.landmarks[index]
120
+ batch = {
121
+ "img": im,
122
+ "mask": masks,
123
+ }
124
+ if self.load_keypoints:
125
+ batch["keypoints"] = landmark
126
+ if self.transform is None:
127
+ return batch
128
+ return self.transform(batch)
deep_privacy/dp2/data/datasets/fdf128_wds.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ import numpy as np
4
+ import io
5
+ import webdataset as wds
6
+ import os
7
+ from ..utils import png_decoder, get_num_workers, collate_fn
8
+
9
+
10
+ def kp_decoder(x):
11
+ # Keypoints are between [0, 1] for webdataset
12
+ keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float().view(7, 2).clamp(0, 1)
13
+ keypoints = torch.cat((keypoints, torch.ones((7, 1))), dim=-1)
14
+ return keypoints
15
+
16
+
17
+ def bbox_decoder(x):
18
+ return torch.from_numpy(np.load(io.BytesIO(x))).float().view(4)
19
+
20
+
21
+ class BBoxToMask:
22
+
23
+ def __call__(self, sample):
24
+ imsize = sample["image.png"].shape[-1]
25
+ bbox = sample["bounding_box.npy"] * imsize
26
+ x0, y0, x1, y1 = np.round(bbox).astype(np.int64)
27
+ mask = torch.ones((1, imsize, imsize), dtype=torch.bool)
28
+ mask[:, y0:y1, x0:x1] = 0
29
+ sample["mask"] = mask
30
+ return sample
31
+
32
+
33
+ def get_dataloader_fdf_wds(
34
+ path,
35
+ batch_size: int,
36
+ num_workers: int,
37
+ transform: torch.nn.Module,
38
+ gpu_transform: torch.nn.Module,
39
+ infinite: bool,
40
+ shuffle: bool,
41
+ partial_batches: bool,
42
+ sample_shuffle=10_000,
43
+ tar_shuffle=100,
44
+ channels_last=False,
45
+ ):
46
+ # Need to set this for split_by_node to work.
47
+ os.environ["RANK"] = str(tops.rank())
48
+ os.environ["WORLD_SIZE"] = str(tops.world_size())
49
+ if infinite:
50
+ pipeline = [wds.ResampledShards(str(path))]
51
+ else:
52
+ pipeline = [wds.SimpleShardList(str(path))]
53
+ if shuffle:
54
+ pipeline.append(wds.shuffle(tar_shuffle))
55
+ pipeline.extend([
56
+ wds.split_by_node,
57
+ wds.split_by_worker,
58
+ ])
59
+ if shuffle:
60
+ pipeline.append(wds.shuffle(sample_shuffle))
61
+
62
+ decoder = [
63
+ wds.handle_extension("image.png", png_decoder),
64
+ wds.handle_extension("keypoints.npy", kp_decoder),
65
+ ]
66
+
67
+ rename_keys = [
68
+ ["img", "image.png"],
69
+ ["keypoints", "keypoints.npy"],
70
+ ["__key__", "__key__"],
71
+ ["mask", "mask"]
72
+ ]
73
+
74
+ pipeline.extend([
75
+ wds.tarfile_to_samples(),
76
+ wds.decode(*decoder),
77
+ ])
78
+ pipeline.append(wds.map(BBoxToMask()))
79
+ pipeline.extend([
80
+ wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
81
+ wds.rename_keys(*rename_keys),
82
+ ])
83
+
84
+ if transform is not None:
85
+ pipeline.append(wds.map(transform))
86
+ pipeline = wds.DataPipeline(*pipeline)
87
+ if infinite:
88
+ pipeline = pipeline.repeat(nepochs=1000000)
89
+
90
+ loader = wds.WebLoader(
91
+ pipeline, batch_size=None, shuffle=False,
92
+ num_workers=get_num_workers(num_workers),
93
+ persistent_workers=True,
94
+ )
95
+ loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
96
+ return loader
deep_privacy/dp2/data/datasets/fdh.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ import numpy as np
4
+ import io
5
+ import webdataset as wds
6
+ import os
7
+ import json
8
+ from pathlib import Path
9
+ from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
10
+
11
+
12
+ def kp_decoder(x):
13
+ # Keypoints are between [0, 1] for webdataset
14
+ keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
15
+ def check_outside(x): return (x < 0).logical_or(x > 1)
16
+ is_outside = check_outside(keypoints[:, 0]).logical_or(
17
+ check_outside(keypoints[:, 1])
18
+ )
19
+ keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
20
+ return keypoints
21
+
22
+
23
+ def vertices_decoder(x):
24
+ vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
25
+ return vertices.squeeze()[None]
26
+
27
+
28
+ class InsertNewKeypoints:
29
+
30
+ def __init__(self, keypoints_path: Path) -> None:
31
+ with open(keypoints_path, "r") as fp:
32
+ self.keypoints = json.load(fp)
33
+
34
+ def __call__(self, sample):
35
+ key = sample["__key__"]
36
+ keypoints = torch.tensor(self.keypoints[key], dtype=torch.float32)
37
+ def check_outside(x): return (x < 0).logical_or(x > 1)
38
+ is_outside = check_outside(keypoints[:, 0]).logical_or(
39
+ check_outside(keypoints[:, 1])
40
+ )
41
+ keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
42
+
43
+ sample["keypoints.npy"] = keypoints
44
+ return sample
45
+
46
+
47
+ def get_dataloader_fdh_wds(
48
+ path,
49
+ batch_size: int,
50
+ num_workers: int,
51
+ transform: torch.nn.Module,
52
+ gpu_transform: torch.nn.Module,
53
+ infinite: bool,
54
+ shuffle: bool,
55
+ partial_batches: bool,
56
+ load_embedding: bool,
57
+ sample_shuffle=10_000,
58
+ tar_shuffle=100,
59
+ read_condition=False,
60
+ channels_last=False,
61
+ load_new_keypoints=False,
62
+ keypoints_split=None,
63
+ ):
64
+ # Need to set this for split_by_node to work.
65
+ os.environ["RANK"] = str(tops.rank())
66
+ os.environ["WORLD_SIZE"] = str(tops.world_size())
67
+ if infinite:
68
+ pipeline = [wds.ResampledShards(str(path))]
69
+ else:
70
+ pipeline = [wds.SimpleShardList(str(path))]
71
+ if shuffle:
72
+ pipeline.append(wds.shuffle(tar_shuffle))
73
+ pipeline.extend([
74
+ wds.split_by_node,
75
+ wds.split_by_worker,
76
+ ])
77
+ if shuffle:
78
+ pipeline.append(wds.shuffle(sample_shuffle))
79
+
80
+ decoder = [
81
+ wds.handle_extension("image.png", png_decoder),
82
+ wds.handle_extension("mask.png", mask_decoder),
83
+ wds.handle_extension("maskrcnn_mask.png", mask_decoder),
84
+ wds.handle_extension("keypoints.npy", kp_decoder),
85
+ ]
86
+
87
+ rename_keys = [
88
+ ["img", "image.png"], ["mask", "mask.png"],
89
+ ["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"],
90
+ ["__key__", "__key__"]
91
+ ]
92
+ if load_embedding:
93
+ decoder.extend([
94
+ wds.handle_extension("vertices.npy", vertices_decoder),
95
+ wds.handle_extension("E_mask.png", mask_decoder)
96
+ ])
97
+ rename_keys.extend([
98
+ ["vertices", "vertices.npy"],
99
+ ["E_mask", "e_mask.png"]
100
+ ])
101
+
102
+ if read_condition:
103
+ decoder.append(
104
+ wds.handle_extension("condition.png", png_decoder)
105
+ )
106
+ rename_keys.append(["condition", "condition.png"])
107
+
108
+ pipeline.extend([
109
+ wds.tarfile_to_samples(),
110
+ wds.decode(*decoder),
111
+
112
+ ])
113
+ if load_new_keypoints:
114
+ assert keypoints_split in ["train", "val"]
115
+ keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/1eb88522-8b91-49c7-b56a-ed98a9c7888cef9c0429-a385-4248-abe3-8682de26d041f268aed1-7c88-4677-baad-7623c2ee330f"
116
+ file_name = "fdh_keypoints_val-050133b34d.json"
117
+ if keypoints_split == "train":
118
+ keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/3e828b1c-d6c0-4622-90bc-1b2cce48ccfff14ab45d-0a5c-431d-be13-7e60580765bd7938601c-e72e-41d9-8836-fffc49e76f58"
119
+ file_name = "fdh_keypoints_train-2cff11f69a.json"
120
+ # Set check_hash=True if you suspect download is incorrect.
121
+ filepath = tops.download_file(keypoint_url, file_name=file_name, check_hash=False)
122
+ pipeline.append(
123
+ wds.map(InsertNewKeypoints(filepath))
124
+ )
125
+ pipeline.extend([
126
+ wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
127
+ wds.rename_keys(*rename_keys),
128
+ ])
129
+
130
+ if transform is not None:
131
+ pipeline.append(wds.map(transform))
132
+ pipeline = wds.DataPipeline(*pipeline)
133
+ if infinite:
134
+ pipeline = pipeline.repeat(nepochs=1000000)
135
+
136
+ loader = wds.WebLoader(
137
+ pipeline, batch_size=None, shuffle=False,
138
+ num_workers=get_num_workers(num_workers),
139
+ persistent_workers=True,
140
+ )
141
+ loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
142
+ return loader
deep_privacy/dp2/data/transforms/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize
2
+ from .stylegan2_transform import StyleGANAugmentPipe
deep_privacy/dp2/data/transforms/functional.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms.functional as F
2
+ import torch
3
+ import pickle
4
+ from tops import download_file, assert_shape
5
+ from typing import Dict
6
+ from functools import lru_cache
7
+
8
+ global symmetry_transform
9
+
10
+
11
+ @lru_cache(maxsize=1)
12
+ def get_symmetry_transform(symmetry_url):
13
+ file_name = download_file(symmetry_url)
14
+ with open(file_name, "rb") as fp:
15
+ symmetry = pickle.load(fp)
16
+ return torch.from_numpy(symmetry["vertex_transforms"]).long()
17
+
18
+
19
+ hflip_handled_cases = set([
20
+ "keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition",
21
+ "embedding", "vertx2cat", "maskrcnn_mask", "__key__"])
22
+
23
+
24
+ def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]:
25
+ container["img"] = F.hflip(container["img"])
26
+ if "condition" in container:
27
+ container["condition"] = F.hflip(container["condition"])
28
+ if "embedding" in container:
29
+ container["embedding"] = F.hflip(container["embedding"])
30
+ assert all([key in hflip_handled_cases for key in container]), container.keys()
31
+ if "keypoints" in container:
32
+ assert flip_map is not None
33
+ if container["keypoints"].ndim == 3:
34
+ keypoints = container["keypoints"][:, flip_map, :]
35
+ keypoints[:, :, 0] = 1 - keypoints[:, :, 0]
36
+ else:
37
+ assert_shape(container["keypoints"], (None, 3))
38
+ keypoints = container["keypoints"][flip_map, :]
39
+ keypoints[:, 0] = 1 - keypoints[:, 0]
40
+ container["keypoints"] = keypoints
41
+ if "mask" in container:
42
+ container["mask"] = F.hflip(container["mask"])
43
+ if "border" in container:
44
+ container["border"] = F.hflip(container["border"])
45
+ if "semantic_mask" in container:
46
+ container["semantic_mask"] = F.hflip(container["semantic_mask"])
47
+ if "vertices" in container:
48
+ symmetry_transform = get_symmetry_transform(
49
+ "https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl")
50
+ container["vertices"] = F.hflip(container["vertices"])
51
+ symmetry_transform_ = symmetry_transform.to(container["vertices"].device)
52
+ container["vertices"] = symmetry_transform_[container["vertices"].long()]
53
+ if "E_mask" in container:
54
+ container["E_mask"] = F.hflip(container["E_mask"])
55
+ if "maskrcnn_mask" in container:
56
+ container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"])
57
+ return container
deep_privacy/dp2/data/transforms/stylegan2_transform.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.signal
3
+ import torch
4
+ try:
5
+ from sg3_torch_utils import misc
6
+ from sg3_torch_utils.ops import upfirdn2d
7
+ from sg3_torch_utils.ops import grid_sample_gradfix
8
+ from sg3_torch_utils.ops import conv2d_gradfix
9
+ except:
10
+ pass
11
+ #----------------------------------------------------------------------------
12
+ # Coefficients of various wavelet decomposition low-pass filters.
13
+
14
+ wavelets = {
15
+ 'haar': [0.7071067811865476, 0.7071067811865476],
16
+ 'db1': [0.7071067811865476, 0.7071067811865476],
17
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
18
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
19
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
20
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
21
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
22
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
23
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
24
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
25
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
26
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
27
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
28
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
29
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
30
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
31
+ }
32
+
33
+ #----------------------------------------------------------------------------
34
+ # Helpers for constructing transformation matrices.
35
+
36
+
37
+ def matrix(*rows, device=None):
38
+ assert all(len(row) == len(rows[0]) for row in rows)
39
+ elems = [x for row in rows for x in row]
40
+ ref = [x for x in elems if isinstance(x, torch.Tensor)]
41
+ if len(ref) == 0:
42
+ return misc.constant(np.asarray(rows), device=device)
43
+ assert device is None or device == ref[0].device
44
+ elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
45
+ return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
46
+
47
+
48
+ def translate2d(tx, ty, **kwargs):
49
+ return matrix(
50
+ [1, 0, tx],
51
+ [0, 1, ty],
52
+ [0, 0, 1],
53
+ **kwargs)
54
+
55
+
56
+ def translate3d(tx, ty, tz, **kwargs):
57
+ return matrix(
58
+ [1, 0, 0, tx],
59
+ [0, 1, 0, ty],
60
+ [0, 0, 1, tz],
61
+ [0, 0, 0, 1],
62
+ **kwargs)
63
+
64
+
65
+ def scale2d(sx, sy, **kwargs):
66
+ return matrix(
67
+ [sx, 0, 0],
68
+ [0, sy, 0],
69
+ [0, 0, 1],
70
+ **kwargs)
71
+
72
+
73
+ def scale3d(sx, sy, sz, **kwargs):
74
+ return matrix(
75
+ [sx, 0, 0, 0],
76
+ [0, sy, 0, 0],
77
+ [0, 0, sz, 0],
78
+ [0, 0, 0, 1],
79
+ **kwargs)
80
+
81
+
82
+ def rotate2d(theta, **kwargs):
83
+ return matrix(
84
+ [torch.cos(theta), torch.sin(-theta), 0],
85
+ [torch.sin(theta), torch.cos(theta), 0],
86
+ [0, 0, 1],
87
+ **kwargs)
88
+
89
+
90
+ def rotate3d(v, theta, **kwargs):
91
+ vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
92
+ s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
93
+ return matrix(
94
+ [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
95
+ [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
96
+ [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
97
+ [0, 0, 0, 1],
98
+ **kwargs)
99
+
100
+
101
+ def translate2d_inv(tx, ty, **kwargs):
102
+ return translate2d(-tx, -ty, **kwargs)
103
+
104
+
105
+ def scale2d_inv(sx, sy, **kwargs):
106
+ return scale2d(1 / sx, 1 / sy, **kwargs)
107
+
108
+
109
+ def rotate2d_inv(theta, **kwargs):
110
+ return rotate2d(-theta, **kwargs)
111
+
112
+
113
+ class StyleGANAugmentPipe(torch.nn.Module):
114
+ def __init__(self,
115
+ rotate90=0, xint=0, xint_max=0.125,
116
+ scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
117
+ brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5,
118
+ hue_max=1, saturation_std=1,
119
+ imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
120
+ ):
121
+ super().__init__()
122
+ self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
123
+
124
+ # Pixel blitting.
125
+ self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
126
+ self.xint = float(xint) # Probability multiplier for integer translation.
127
+ self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
128
+
129
+ # General geometric transformations.
130
+ self.scale = float(scale) # Probability multiplier for isotropic scaling.
131
+ self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
132
+ self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
133
+ self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
134
+ self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
135
+ self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
136
+ self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
137
+ self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
138
+
139
+ # Color transformations.
140
+ self.brightness = float(brightness) # Probability multiplier for brightness.
141
+ self.contrast = float(contrast) # Probability multiplier for contrast.
142
+ self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
143
+ self.hue = float(hue) # Probability multiplier for hue rotation.
144
+ self.saturation = float(saturation) # Probability multiplier for saturation.
145
+ self.brightness_std = float(brightness_std) # Standard deviation of brightness.
146
+ self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
147
+ self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
148
+ self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
149
+
150
+ # Image-space filtering.
151
+ self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
152
+ self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
153
+ self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
154
+
155
+ # Setup orthogonal lowpass filter for geometric augmentations.
156
+ self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
157
+
158
+ # Construct filter bank for image-space filtering.
159
+ Hz_lo = np.asarray(wavelets['sym2']) # H(z)
160
+ Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
161
+ Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
162
+ Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
163
+ Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
164
+ for i in range(1, Hz_fbank.shape[0]):
165
+ Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
166
+ Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
167
+ Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
168
+ self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
169
+
170
+ def forward(self, batch, debug_percentile=None):
171
+ images = batch["img"]
172
+ batch["vertices"] = batch["vertices"].float()
173
+ assert isinstance(images, torch.Tensor) and images.ndim == 4
174
+ batch_size, num_channels, height, width = images.shape
175
+ device = images.device
176
+ self.Hz_fbank = self.Hz_fbank.to(device)
177
+ self.Hz_geom = self.Hz_geom.to(device)
178
+ if debug_percentile is not None:
179
+ debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
180
+
181
+ # -------------------------------------
182
+ # Select parameters for pixel blitting.
183
+ # -------------------------------------
184
+
185
+ # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
186
+ I_3 = torch.eye(3, device=device)
187
+ G_inv = I_3
188
+
189
+ # Apply integer translation with probability (xint * strength).
190
+ if self.xint > 0:
191
+ t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
192
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
193
+ if debug_percentile is not None:
194
+ t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
195
+ G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
196
+
197
+ # --------------------------------------------------------
198
+ # Select parameters for general geometric transformations.
199
+ # --------------------------------------------------------
200
+
201
+ # Apply isotropic scaling with probability (scale * strength).
202
+ if self.scale > 0:
203
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
204
+ s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
205
+ if debug_percentile is not None:
206
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
207
+ G_inv = G_inv @ scale2d_inv(s, s)
208
+
209
+ # Apply pre-rotation with probability p_rot.
210
+ p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
211
+ if self.rotate > 0:
212
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
213
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
214
+ if debug_percentile is not None:
215
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
216
+ G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
217
+
218
+ # Apply anisotropic scaling with probability (aniso * strength).
219
+ if self.aniso > 0:
220
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
221
+ s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
222
+ if debug_percentile is not None:
223
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
224
+ G_inv = G_inv @ scale2d_inv(s, 1 / s)
225
+
226
+ # Apply post-rotation with probability p_rot.
227
+ if self.rotate > 0:
228
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
229
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
230
+ if debug_percentile is not None:
231
+ theta = torch.zeros_like(theta)
232
+ G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
233
+
234
+ # Apply fractional translation with probability (xfrac * strength).
235
+ if self.xfrac > 0:
236
+ t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
237
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
238
+ if debug_percentile is not None:
239
+ t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
240
+ G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
241
+
242
+ # ----------------------------------
243
+ # Execute geometric transformations.
244
+ # ----------------------------------
245
+
246
+ # Execute if the transform is not identity.
247
+ if G_inv is not I_3:
248
+ # Calculate padding.
249
+ cx = (width - 1) / 2
250
+ cy = (height - 1) / 2
251
+ cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
252
+ cp = G_inv @ cp.t() # [batch, xyz, idx]
253
+ Hz_pad = self.Hz_geom.shape[0] // 4
254
+ margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
255
+ margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
256
+ margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
257
+ margin = margin.max(misc.constant([0, 0] * 2, device=device))
258
+ margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
259
+ mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
260
+
261
+ # Pad image and adjust origin.
262
+ images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
263
+ batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0)
264
+ batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
265
+ batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
266
+ G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
267
+
268
+ # Upsample.
269
+ images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
270
+ batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest")
271
+ batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest")
272
+ batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest")
273
+ G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
274
+ G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
275
+
276
+ # Execute transformation.
277
+ shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
278
+ G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
279
+ grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
280
+ images = grid_sample_gradfix.grid_sample(images, grid)
281
+
282
+ batch["mask"] = torch.nn.functional.grid_sample(
283
+ input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
284
+ batch["E_mask"] = torch.nn.functional.grid_sample(
285
+ input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
286
+ batch["vertices"] = torch.nn.functional.grid_sample(
287
+ input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
288
+
289
+
290
+ # Downsample and crop.
291
+ images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
292
+ batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
293
+ batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
294
+ batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
295
+ # --------------------------------------------
296
+ # Select parameters for color transformations.
297
+ # --------------------------------------------
298
+
299
+ # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
300
+ I_4 = torch.eye(4, device=device)
301
+ C = I_4
302
+
303
+ # Apply brightness with probability (brightness * strength).
304
+ if self.brightness > 0:
305
+ b = torch.randn([batch_size], device=device) * self.brightness_std
306
+ b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
307
+ if debug_percentile is not None:
308
+ b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
309
+ C = translate3d(b, b, b) @ C
310
+
311
+ # Apply contrast with probability (contrast * strength).
312
+ if self.contrast > 0:
313
+ c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
314
+ c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
315
+ if debug_percentile is not None:
316
+ c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
317
+ C = scale3d(c, c, c) @ C
318
+
319
+ # Apply luma flip with probability (lumaflip * strength).
320
+ v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
321
+
322
+ # Apply hue rotation with probability (hue * strength).
323
+ if self.hue > 0 and num_channels > 1:
324
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
325
+ theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
326
+ if debug_percentile is not None:
327
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
328
+ C = rotate3d(v, theta) @ C # Rotate around v.
329
+
330
+ # Apply saturation with probability (saturation * strength).
331
+ if self.saturation > 0 and num_channels > 1:
332
+ s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
333
+ s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
334
+ if debug_percentile is not None:
335
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
336
+ C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
337
+
338
+ # ------------------------------
339
+ # Execute color transformations.
340
+ # ------------------------------
341
+
342
+ # Execute if the transform is not identity.
343
+ if C is not I_4:
344
+ images = images.reshape([batch_size, num_channels, height * width])
345
+ if num_channels == 3:
346
+ images = C[:, :3, :3] @ images + C[:, :3, 3:]
347
+ elif num_channels == 1:
348
+ C = C[:, :3, :].mean(dim=1, keepdims=True)
349
+ images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
350
+ else:
351
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
352
+ images = images.reshape([batch_size, num_channels, height, width])
353
+
354
+ # ----------------------
355
+ # Image-space filtering.
356
+ # ----------------------
357
+
358
+ if self.imgfilter > 0:
359
+ num_bands = self.Hz_fbank.shape[0]
360
+ assert len(self.imgfilter_bands) == num_bands
361
+ expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
362
+
363
+ # Apply amplification for each band with probability (imgfilter * strength * band_strength).
364
+ g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
365
+ for i, band_strength in enumerate(self.imgfilter_bands):
366
+ t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
367
+ t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
368
+ if debug_percentile is not None:
369
+ t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
370
+ t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
371
+ t[:, i] = t_i # Replace i'th element.
372
+ t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
373
+ g = g * t # Accumulate into global gain.
374
+
375
+ # Construct combined amplification filter.
376
+ Hz_prime = g @ self.Hz_fbank # [batch, tap]
377
+ Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
378
+ Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
379
+
380
+ # Apply filter.
381
+ p = self.Hz_fbank.shape[1] // 2
382
+ images = images.reshape([1, batch_size * num_channels, height, width])
383
+ images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
384
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
385
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
386
+ images = images.reshape([batch_size, num_channels, height, width])
387
+
388
+ # ------------------------
389
+ # Image-space corruptions.
390
+ # ------------------------
391
+ batch["img"] = images
392
+ batch["vertices"] = batch["vertices"].long()
393
+ batch["border"] = 1 - batch["E_mask"] - batch["mask"]
394
+ return batch
deep_privacy/dp2/data/transforms/transforms.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, List
3
+ import torchvision
4
+ import torch
5
+ import tops
6
+ import torchvision.transforms.functional as F
7
+ from .functional import hflip
8
+ import numpy as np
9
+ from dp2.utils.vis_utils import get_coco_keypoints
10
+ from PIL import Image, ImageDraw
11
+ from typing import Tuple
12
+
13
+
14
+ class RandomHorizontalFlip(torch.nn.Module):
15
+
16
+ def __init__(self, p: float, flip_map=None, **kwargs):
17
+ super().__init__()
18
+ self.flip_ratio = p
19
+ self.flip_map = flip_map
20
+ if self.flip_ratio is None:
21
+ self.flip_ratio = 0.5
22
+ assert 0 <= self.flip_ratio <= 1
23
+
24
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
25
+ if torch.rand(1) > self.flip_ratio:
26
+ return container
27
+ return hflip(container, self.flip_map)
28
+
29
+
30
+ class CenterCrop(torch.nn.Module):
31
+ """
32
+ Performs the transform on the image.
33
+ NOTE: Does not transform the mask to improve runtime.
34
+ """
35
+
36
+ def __init__(self, size: List[int]):
37
+ super().__init__()
38
+ self.size = tuple(size)
39
+
40
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
41
+ min_size = min(container["img"].shape[1], container["img"].shape[2])
42
+ if min_size < self.size[0]:
43
+ container["img"] = F.center_crop(container["img"], min_size)
44
+ container["img"] = F.resize(container["img"], self.size)
45
+ return container
46
+ container["img"] = F.center_crop(container["img"], self.size)
47
+ return container
48
+
49
+
50
+ class Resize(torch.nn.Module):
51
+ """
52
+ Performs the transform on the image.
53
+ NOTE: Does not transform the mask to improve runtime.
54
+ """
55
+
56
+ def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
57
+ super().__init__()
58
+ self.size = tuple(size)
59
+ self.interpolation = interpolation
60
+
61
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
62
+ container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
63
+ if "semantic_mask" in container:
64
+ container["semantic_mask"] = F.resize(
65
+ container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
66
+ if "embedding" in container:
67
+ container["embedding"] = F.resize(
68
+ container["embedding"], self.size, self.interpolation)
69
+ if "mask" in container:
70
+ container["mask"] = F.resize(
71
+ container["mask"], self.size, F.InterpolationMode.NEAREST)
72
+ if "E_mask" in container:
73
+ container["E_mask"] = F.resize(
74
+ container["E_mask"], self.size, F.InterpolationMode.NEAREST)
75
+ if "maskrcnn_mask" in container:
76
+ container["maskrcnn_mask"] = F.resize(
77
+ container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
78
+ if "vertices" in container:
79
+ container["vertices"] = F.resize(
80
+ container["vertices"], self.size, F.InterpolationMode.NEAREST)
81
+ return container
82
+
83
+ def __repr__(self):
84
+ repr = super().__repr__()
85
+ vars_ = dict(size=self.size, interpolation=self.interpolation)
86
+ return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
87
+
88
+
89
+ class Normalize(torch.nn.Module):
90
+ """
91
+ Performs the transform on the image.
92
+ NOTE: Does not transform the mask to improve runtime.
93
+ """
94
+
95
+ def __init__(self, mean, std, inplace, keys=["img"]):
96
+ super().__init__()
97
+ self.mean = mean
98
+ self.std = std
99
+ self.inplace = inplace
100
+ self.keys = keys
101
+
102
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
103
+ for key in self.keys:
104
+ container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
105
+ return container
106
+
107
+ def __repr__(self):
108
+ repr = super().__repr__()
109
+ vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
110
+ return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
111
+
112
+
113
+ class ToFloat(torch.nn.Module):
114
+
115
+ def __init__(self, keys=["img"], norm=True) -> None:
116
+ super().__init__()
117
+ self.keys = keys
118
+ self.gain = 255 if norm else 1
119
+
120
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
121
+ for key in self.keys:
122
+ container[key] = container[key].float() / self.gain
123
+ return container
124
+
125
+
126
+ class RandomCrop(torchvision.transforms.RandomCrop):
127
+ """
128
+ Performs the transform on the image.
129
+ NOTE: Does not transform the mask to improve runtime.
130
+ """
131
+
132
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
133
+ container["img"] = super().forward(container["img"])
134
+ return container
135
+
136
+
137
+ class CreateCondition(torch.nn.Module):
138
+
139
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
140
+ if container["img"].dtype == torch.uint8:
141
+ container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
142
+ return container
143
+ container["condition"] = container["img"] * container["mask"]
144
+ return container
145
+
146
+
147
+ class CreateEmbedding(torch.nn.Module):
148
+
149
+ def __init__(self, embed_path: Path, cuda=True) -> None:
150
+ super().__init__()
151
+ self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
152
+ if cuda:
153
+ self.embed_map = tops.to_cuda(self.embed_map)
154
+
155
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
156
+ vertices = container["vertices"]
157
+ if vertices.ndim == 3:
158
+ embedding = self.embed_map[vertices.long()].squeeze(dim=0)
159
+ embedding = embedding.permute(2, 0, 1) * container["E_mask"]
160
+ pass
161
+ else:
162
+ assert vertices.ndim == 4
163
+ embedding = self.embed_map[vertices.long()].squeeze(dim=1)
164
+ embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
165
+ container["embedding"] = embedding
166
+ container["embed_map"] = self.embed_map.clone()
167
+ return container
168
+
169
+
170
+ class InsertJointMap(torch.nn.Module):
171
+
172
+ def __init__(self, imsize: Tuple) -> None:
173
+ super().__init__()
174
+ self.imsize = imsize
175
+ knames = get_coco_keypoints()[0]
176
+ knames = knames + ["neck", "mid_hip"]
177
+ connectivity = {
178
+ "nose": ["left_eye", "right_eye", "neck"],
179
+ "left_eye": ["right_eye", "left_ear"],
180
+ "right_eye": ["right_ear"],
181
+ "left_shoulder": ["right_shoulder", "left_elbow", "left_hip"],
182
+ "right_shoulder": ["right_elbow", "right_hip"],
183
+ "left_elbow": ["left_wrist"],
184
+ "right_elbow": ["right_wrist"],
185
+ "left_hip": ["right_hip", "left_knee"],
186
+ "right_hip": ["right_knee"],
187
+ "left_knee": ["left_ankle"],
188
+ "right_knee": ["right_ankle"],
189
+ "neck": ["mid_hip", "nose"],
190
+ }
191
+ category = {
192
+ ("nose", "left_eye"): 0, # head
193
+ ("nose", "right_eye"): 0, # head
194
+ ("nose", "neck"): 0, # head
195
+ ("left_eye", "right_eye"): 0, # head
196
+ ("left_eye", "left_ear"): 0, # head
197
+ ("right_eye", "right_ear"): 0, # head
198
+ ("left_shoulder", "left_elbow"): 1, # left arm
199
+ ("left_elbow", "left_wrist"): 1, # left arm
200
+ ("right_shoulder", "right_elbow"): 2, # right arm
201
+ ("right_elbow", "right_wrist"): 2, # right arm
202
+ ("left_shoulder", "right_shoulder"): 3, # body
203
+ ("left_shoulder", "left_hip"): 3, # body
204
+ ("right_shoulder", "right_hip"): 3, # body
205
+ ("left_hip", "right_hip"): 3, # body
206
+ ("left_hip", "left_knee"): 4, # left leg
207
+ ("left_knee", "left_ankle"): 4, # left leg
208
+ ("right_hip", "right_knee"): 5, # right leg
209
+ ("right_knee", "right_ankle"): 5, # right leg
210
+ ("neck", "mid_hip"): 3, # body
211
+ ("neck", "nose"): 0, # head
212
+ }
213
+ self.indices2category = {
214
+ tuple([knames.index(n) for n in k]): v for k, v in category.items()
215
+ }
216
+ self.connectivity_indices = {
217
+ knames.index(k): [knames.index(v_) for v_ in v]
218
+ for k, v in connectivity.items()
219
+ }
220
+ self.l_shoulder = knames.index("left_shoulder")
221
+ self.r_shoulder = knames.index("right_shoulder")
222
+ self.l_hip = knames.index("left_hip")
223
+ self.r_hip = knames.index("right_hip")
224
+ self.l_eye = knames.index("left_eye")
225
+ self.r_eye = knames.index("right_eye")
226
+ self.nose = knames.index("nose")
227
+ self.neck = knames.index("neck")
228
+
229
+ def create_joint_map(self, N, H, W, keypoints):
230
+ joint_maps = np.zeros((N, H, W), dtype=np.uint8)
231
+ for bidx, keypoints in enumerate(keypoints):
232
+ assert keypoints.shape == (17, 3), keypoints.shape
233
+ keypoints = torch.cat((keypoints, torch.zeros(2, 3)))
234
+ visible = keypoints[:, -1] > 0
235
+
236
+ if visible[self.l_shoulder] and visible[self.r_shoulder]:
237
+ neck = (keypoints[self.l_shoulder]
238
+ + (keypoints[self.r_shoulder] - keypoints[self.l_shoulder]) / 2)
239
+ keypoints[-2] = neck
240
+ visible[-2] = 1
241
+ if visible[self.l_hip] and visible[self.r_hip]:
242
+ mhip = (keypoints[self.l_hip]
243
+ + (keypoints[self.r_hip] - keypoints[self.l_hip]) / 2
244
+ )
245
+ keypoints[-1] = mhip
246
+ visible[-1] = 1
247
+
248
+ keypoints[:, 0] *= W
249
+ keypoints[:, 1] *= H
250
+ joint_map = Image.fromarray(np.zeros((H, W), dtype=np.uint8))
251
+ draw = ImageDraw.Draw(joint_map)
252
+ for fidx in self.connectivity_indices.keys():
253
+ for tidx in self.connectivity_indices[fidx]:
254
+ if visible[fidx] == 0 or visible[tidx] == 0:
255
+ continue
256
+ c = self.indices2category[(fidx, tidx)]
257
+ s = tuple(keypoints[fidx, :2].round().long().numpy().tolist())
258
+ e = tuple(keypoints[tidx, :2].round().long().numpy().tolist())
259
+ draw.line((s, e), width=1, fill=c + 1)
260
+ if visible[self.nose] == 0 and visible[self.neck] == 1:
261
+ m_eye = (
262
+ keypoints[self.l_eye]
263
+ + (keypoints[self.r_eye] - keypoints[self.l_eye]) / 2
264
+ )
265
+ s = tuple(m_eye[:2].round().long().numpy().tolist())
266
+ e = tuple(keypoints[self.neck, :2].round().long().numpy().tolist())
267
+ c = self.indices2category[(self.nose, self.neck)]
268
+ draw.line((s, e), width=1, fill=c + 1)
269
+ joint_map = np.array(joint_map)
270
+
271
+ joint_maps[bidx] = np.array(joint_map)
272
+ return joint_maps[:, None]
273
+
274
+ def forward(self, batch):
275
+ batch["joint_map"] = torch.from_numpy(self.create_joint_map(
276
+ batch["img"].shape[0], *self.imsize, batch["keypoints"]))
277
+ return batch
deep_privacy/dp2/data/utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ import multiprocessing
5
+ import io
6
+ from tops import logger
7
+ from torch.utils.data._utils.collate import default_collate
8
+
9
+ try:
10
+ import pyspng
11
+
12
+ PYSPNG_IMPORTED = True
13
+ except ImportError:
14
+ PYSPNG_IMPORTED = False
15
+ print("Could not load pyspng. Defaulting to pillow image backend.")
16
+ from PIL import Image
17
+
18
+
19
+ def get_fdf_keypoints():
20
+ return get_coco_keypoints()[:7]
21
+
22
+
23
+ def get_fdf_flipmap():
24
+ keypoints = get_fdf_keypoints()
25
+ keypoint_flip_map = {
26
+ "left_eye": "right_eye",
27
+ "left_ear": "right_ear",
28
+ "left_shoulder": "right_shoulder",
29
+ }
30
+ for key, value in list(keypoint_flip_map.items()):
31
+ keypoint_flip_map[value] = key
32
+ keypoint_flip_map["nose"] = "nose"
33
+ keypoint_flip_map_idx = []
34
+ for source in keypoints:
35
+ keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
36
+ return keypoint_flip_map_idx
37
+
38
+
39
+ def get_coco_keypoints():
40
+ return [
41
+ "nose",
42
+ "left_eye",
43
+ "right_eye", # 2
44
+ "left_ear",
45
+ "right_ear", # 4
46
+ "left_shoulder",
47
+ "right_shoulder", # 6
48
+ "left_elbow",
49
+ "right_elbow", # 8
50
+ "left_wrist",
51
+ "right_wrist", # 10
52
+ "left_hip",
53
+ "right_hip", # 12
54
+ "left_knee",
55
+ "right_knee", # 14
56
+ "left_ankle",
57
+ "right_ankle", # 16
58
+ ]
59
+
60
+
61
+ def get_coco_flipmap():
62
+ keypoints = get_coco_keypoints()
63
+ keypoint_flip_map = {
64
+ "left_eye": "right_eye",
65
+ "left_ear": "right_ear",
66
+ "left_shoulder": "right_shoulder",
67
+ "left_elbow": "right_elbow",
68
+ "left_wrist": "right_wrist",
69
+ "left_hip": "right_hip",
70
+ "left_knee": "right_knee",
71
+ "left_ankle": "right_ankle",
72
+ }
73
+ for key, value in list(keypoint_flip_map.items()):
74
+ keypoint_flip_map[value] = key
75
+ keypoint_flip_map["nose"] = "nose"
76
+ keypoint_flip_map_idx = []
77
+ for source in keypoints:
78
+ keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
79
+ return keypoint_flip_map_idx
80
+
81
+
82
+ def mask_decoder(x):
83
+ mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
84
+ mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255.
85
+ return mask
86
+
87
+
88
+ def png_decoder(x):
89
+ if PYSPNG_IMPORTED:
90
+ return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
91
+ with Image.open(io.BytesIO(x)) as im:
92
+ im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
93
+ return im
94
+
95
+
96
+ def jpg_decoder(x):
97
+ with Image.open(io.BytesIO(x)) as im:
98
+ im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
99
+ return im
100
+
101
+
102
+ def get_num_workers(num_workers: int):
103
+ n_cpus = multiprocessing.cpu_count()
104
+ if num_workers > n_cpus:
105
+ logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
106
+ return n_cpus
107
+ return num_workers
108
+
109
+
110
+ def collate_fn(batch):
111
+ elem = batch[0]
112
+ ignore_keys = set(["embed_map", "vertx2cat"])
113
+ batch_ = {
114
+ key: default_collate([d[key] for d in batch])
115
+ for key in elem
116
+ if key not in ignore_keys
117
+ }
118
+ if "embed_map" in elem:
119
+ batch_["embed_map"] = elem["embed_map"]
120
+ if "vertx2cat" in elem:
121
+ batch_["vertx2cat"] = elem["vertx2cat"]
122
+ return batch_
deep_privacy/dp2/detection/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .cse_mask_face_detector import CSeMaskFaceDetector
2
+ from .person_detector import CSEPersonDetector
3
+ from .structures import PersonDetection, VehicleDetection, FaceDetection
deep_privacy/dp2/detection/base.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ import lzma
4
+ from pathlib import Path
5
+ from tops import logger
6
+
7
+
8
+ class BaseDetector:
9
+
10
+ def __init__(self, cache_directory: str) -> None:
11
+ if cache_directory is not None:
12
+ self.cache_directory = Path(cache_directory, str(self.__class__.__name__))
13
+ self.cache_directory.mkdir(exist_ok=True, parents=True)
14
+
15
+ def save_to_cache(self, detection, cache_path: Path, after_preprocess=True):
16
+ logger.log(f"Caching detection to: {cache_path}")
17
+ with lzma.open(cache_path, "wb") as fp:
18
+ torch.save(
19
+ [det.state_dict(after_preprocess=after_preprocess) for det in detection], fp,
20
+ pickle_protocol=pickle.HIGHEST_PROTOCOL)
21
+
22
+ def load_from_cache(self, cache_path: Path):
23
+ logger.log(f"Loading detection from cache path: {cache_path}")
24
+ with lzma.open(cache_path, "rb") as fp:
25
+ state_dict = torch.load(fp)
26
+ return [
27
+ state["cls"].from_state_dict(state_dict=state) for state in state_dict
28
+ ]
29
+
30
+ def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool):
31
+ if cache_id is None:
32
+ return self.forward(im)
33
+ cache_path = self.cache_directory.joinpath(cache_id + ".torch")
34
+ if cache_path.is_file() and load_cache:
35
+ try:
36
+ return self.load_from_cache(cache_path)
37
+ except Exception as e:
38
+ logger.warn(f"The cache file was corrupted: {cache_path}")
39
+ exit()
40
+ detections = self.forward(im)
41
+ self.save_to_cache(detections, cache_path)
42
+ return detections
deep_privacy/dp2/detection/box_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio):
5
+ x0, y0, x1, y1 = [int(_) for _ in bbox]
6
+ h, w = y1 - y0, x1 - x0
7
+ cur_ratio = h / w
8
+
9
+ if cur_ratio == target_aspect_ratio:
10
+ return [x0, y0, x1, y1]
11
+ if cur_ratio < target_aspect_ratio:
12
+ target_height = int(w*target_aspect_ratio)
13
+ y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
14
+ else:
15
+ target_width = int(h/target_aspect_ratio)
16
+ x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
17
+ return x0, y0, x1, y1
18
+
19
+
20
+ def expand_axis(start, end, target_width, limit):
21
+ # Can return a bbox outside of limit
22
+ cur_width = end - start
23
+ start = start - (target_width-cur_width)//2
24
+ end = end + (target_width-cur_width)//2
25
+ if end - start != target_width:
26
+ end += 1
27
+ assert end - start == target_width
28
+ if start < 0 and end > limit:
29
+ return start, end
30
+ if start < 0 and end < limit:
31
+ to_shift = min(0 - start, limit - end)
32
+ start += to_shift
33
+ end += to_shift
34
+ if end > limit and start > 0:
35
+ to_shift = min(end - limit, start)
36
+ end -= to_shift
37
+ start -= to_shift
38
+ assert end - start == target_width
39
+ return start, end
40
+
41
+
42
+ def expand_box(bbox, imshape, mask, percentage_background: float):
43
+ assert isinstance(bbox[0], int)
44
+ assert 0 < percentage_background < 1
45
+ # Percentage in S
46
+ mask_pixels = mask.long().sum().cpu()
47
+ total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
48
+ percentage_mask = mask_pixels / total_pixels
49
+ if (1 - percentage_mask) > percentage_background:
50
+ return bbox
51
+ target_pixels = mask_pixels / (1 - percentage_background)
52
+ x0, y0, x1, y1 = bbox
53
+ H = y1 - y0
54
+ W = x1 - x0
55
+ p = np.sqrt(target_pixels/(H*W))
56
+ target_width = int(np.ceil(p * W))
57
+ target_height = int(np.ceil(p * H))
58
+ x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
59
+ y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
60
+ return [x0, y0, x1, y1]
61
+
62
+
63
+ def expand_axises_by_percentage(bbox_XYXY, imshape, percentage):
64
+ x0, y0, x1, y1 = bbox_XYXY
65
+ H = y1 - y0
66
+ W = x1 - x0
67
+ expansion = int(((H*W)**0.5) * percentage)
68
+ new_width = W + expansion
69
+ new_height = H + expansion
70
+ x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1])
71
+ y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0])
72
+ return [x0, y0, x1, y1]
73
+
74
+
75
+ def get_expanded_bbox(
76
+ bbox_XYXY,
77
+ imshape,
78
+ mask,
79
+ percentage_background: float,
80
+ axis_minimum_expansion: float,
81
+ target_aspect_ratio: float):
82
+ bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist()
83
+ # Expand each axis of the bounding box by a minimum percentage
84
+ bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion)
85
+ # Find the minimum bbox with the aspect ratio. Can be outside of imshape
86
+ bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio)
87
+ # Expands square box such that X% of the bbox is background
88
+ bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background)
89
+ assert isinstance(bbox_XYXY[0], (int, np.int64))
90
+ return bbox_XYXY
91
+
92
+
93
+ def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape):
94
+ def area_inside_ratio(bbox, imshape):
95
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
96
+ area_inside = (min(bbox[2], imshape[1]) - max(0, bbox[0])) * (min(imshape[0], bbox[3]) - max(0, bbox[1]))
97
+ return area_inside / area
98
+ ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0])
99
+ area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
100
+ if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside:
101
+ return False
102
+ if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area:
103
+ return False
104
+ return True
deep_privacy/dp2/detection/box_utils_fdf.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The FDF dataset expands bound boxes differently from what is used for CSE.
3
+ """
4
+
5
+ import numpy as np
6
+
7
+
8
+ def quadratic_bounding_box(x0, y0, width, height, imshape):
9
+ # We assume that we can create a image that is quadratic without
10
+ # minimizing any of the sides
11
+ assert width <= min(imshape[:2])
12
+ assert height <= min(imshape[:2])
13
+ min_side = min(height, width)
14
+ if height != width:
15
+ side_diff = abs(height - width)
16
+ # Want to extend the shortest side
17
+ if min_side == height:
18
+ # Vertical side
19
+ height += side_diff
20
+ if height > imshape[0]:
21
+ # Take full frame, and shrink width
22
+ y0 = 0
23
+ height = imshape[0]
24
+
25
+ side_diff = abs(height - width)
26
+ width -= side_diff
27
+ x0 += side_diff // 2
28
+ else:
29
+ y0 -= side_diff // 2
30
+ y0 = max(0, y0)
31
+ else:
32
+ # Horizontal side
33
+ width += side_diff
34
+ if width > imshape[1]:
35
+ # Take full frame width, and shrink height
36
+ x0 = 0
37
+ width = imshape[1]
38
+
39
+ side_diff = abs(height - width)
40
+ height -= side_diff
41
+ y0 += side_diff // 2
42
+ else:
43
+ x0 -= side_diff // 2
44
+ x0 = max(0, x0)
45
+ # Check that bbox goes outside image
46
+ x1 = x0 + width
47
+ y1 = y0 + height
48
+ if imshape[1] < x1:
49
+ diff = x1 - imshape[1]
50
+ x0 -= diff
51
+ if imshape[0] < y1:
52
+ diff = y1 - imshape[0]
53
+ y0 -= diff
54
+ assert x0 >= 0, "Bounding box outside image."
55
+ assert y0 >= 0, "Bounding box outside image."
56
+ assert x0 + width <= imshape[1], "Bounding box outside image."
57
+ assert y0 + height <= imshape[0], "Bounding box outside image."
58
+ return x0, y0, width, height
59
+
60
+
61
+ def expand_bounding_box(bbox, percentage, imshape):
62
+ orig_bbox = bbox.copy()
63
+ x0, y0, x1, y1 = bbox
64
+ width = x1 - x0
65
+ height = y1 - y0
66
+ x0, y0, width, height = quadratic_bounding_box(
67
+ x0, y0, width, height, imshape)
68
+ expanding_factor = int(max(height, width) * percentage)
69
+
70
+ possible_max_expansion = [(imshape[0] - width) // 2,
71
+ (imshape[1] - height) // 2,
72
+ expanding_factor]
73
+
74
+ expanding_factor = min(possible_max_expansion)
75
+ # Expand height
76
+
77
+ if expanding_factor > 0:
78
+
79
+ y0 = y0 - expanding_factor
80
+ y0 = max(0, y0)
81
+
82
+ height += expanding_factor * 2
83
+ if height > imshape[0]:
84
+ y0 -= (imshape[0] - height)
85
+ height = imshape[0]
86
+
87
+ if height + y0 > imshape[0]:
88
+ y0 -= (height + y0 - imshape[0])
89
+
90
+ # Expand width
91
+ x0 = x0 - expanding_factor
92
+ x0 = max(0, x0)
93
+
94
+ width += expanding_factor * 2
95
+ if width > imshape[1]:
96
+ x0 -= (imshape[1] - width)
97
+ width = imshape[1]
98
+
99
+ if width + x0 > imshape[1]:
100
+ x0 -= (width + x0 - imshape[1])
101
+ y1 = y0 + height
102
+ x1 = x0 + width
103
+ assert y0 >= 0, "Y0 is minus"
104
+ assert height <= imshape[0], "Height is larger than image."
105
+ assert x0 + width <= imshape[1]
106
+ assert y0 + height <= imshape[0]
107
+ assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!"
108
+ assert x0 >= 0, "Y0 is minus"
109
+ assert width <= imshape[1], "Height is larger than image."
110
+ # Check that original bbox is within new
111
+ x0_o, y0_o, x1_o, y1_o = orig_bbox
112
+ assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}"
113
+ assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}"
114
+ assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}"
115
+ assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}"
116
+
117
+ x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]]
118
+ x1 = x0 + width
119
+ y1 = y0 + height
120
+ return np.array([x0, y0, x1, y1])
121
+
122
+
123
+ def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
124
+ keypoint = keypoint[:, :3] # only nose + eyes are relevant
125
+ kp_X = keypoint[0, :]
126
+ kp_Y = keypoint[1, :]
127
+ within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1)
128
+ within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1)
129
+ return within_X and within_Y
130
+
131
+
132
+ def expand_bbox_simple(bbox, percentage):
133
+ x0, y0, x1, y1 = bbox.astype(float)
134
+ width = x1 - x0
135
+ height = y1 - y0
136
+ x_c = int(x0) + width // 2
137
+ y_c = int(y0) + height // 2
138
+ avg_size = max(width, height)
139
+ new_width = avg_size * (1 + percentage)
140
+ x0 = x_c - new_width // 2
141
+ y0 = y_c - new_width // 2
142
+ x1 = x_c + new_width // 2
143
+ y1 = y_c + new_width // 2
144
+ return np.array([x0, y0, x1, y1]).astype(int)
145
+
146
+
147
+ def pad_image(im, bbox, pad_value):
148
+ x0, y0, x1, y1 = bbox
149
+ if x0 < 0:
150
+ pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]),
151
+ dtype=np.uint8) + pad_value
152
+ im = np.concatenate((pad_im, im), axis=1)
153
+ x1 += abs(x0)
154
+ x0 = 0
155
+ if y0 < 0:
156
+ pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]),
157
+ dtype=np.uint8) + pad_value
158
+ im = np.concatenate((pad_im, im), axis=0)
159
+ y1 += abs(y0)
160
+ y0 = 0
161
+ if x1 >= im.shape[1]:
162
+ pad_im = np.zeros(
163
+ (im.shape[0], x1 - im.shape[1] + 1, im.shape[2]),
164
+ dtype=np.uint8) + pad_value
165
+ im = np.concatenate((im, pad_im), axis=1)
166
+ if y1 >= im.shape[0]:
167
+ pad_im = np.zeros(
168
+ (y1 - im.shape[0] + 1, im.shape[1], im.shape[2]),
169
+ dtype=np.uint8) + pad_value
170
+ im = np.concatenate((im, pad_im), axis=0)
171
+ return im[y0:y1, x0:x1]
172
+
173
+
174
+ def clip_box(bbox, im):
175
+ bbox[0] = max(0, bbox[0])
176
+ bbox[1] = max(0, bbox[1])
177
+ bbox[2] = min(im.shape[1] - 1, bbox[2])
178
+ bbox[3] = min(im.shape[0] - 1, bbox[3])
179
+ return bbox
180
+
181
+
182
+ def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True):
183
+ outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0]
184
+ if simple_expand or (outside_im and pad_im):
185
+ return pad_image(im, bbox, pad_value)
186
+ bbox = clip_box(bbox, im)
187
+ x0, y0, x1, y1 = bbox
188
+ return im[y0:y1, x0:x1]
189
+
190
+
191
+ def expand_bbox(
192
+ bbox_ltrb, imshape, simple_expand, default_to_simple=False,
193
+ expansion_factor=0.35):
194
+ assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox_ltrb.shape}"
195
+ bbox = bbox_ltrb.astype(float)
196
+ # FDF256 uses simple expand with ratio 0.4
197
+ if simple_expand:
198
+ return expand_bbox_simple(bbox, 0.4)
199
+ try:
200
+ return expand_bounding_box(bbox, expansion_factor, imshape)
201
+ except AssertionError:
202
+ return expand_bbox_simple(bbox, expansion_factor * 2)