Spaces:
Runtime error
Runtime error
fix
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +3 -3
- deep_privacy/.gitignore +54 -0
- deep_privacy/CHANGELOG.md +13 -0
- deep_privacy/Dockerfile +47 -0
- deep_privacy/LICENSE +201 -0
- deep_privacy/anonymize.py +255 -0
- deep_privacy/attribute_guided_demo.py +144 -0
- deep_privacy/configs/anonymizers/FB_cse.py +28 -0
- deep_privacy/configs/anonymizers/FB_cse_mask.py +29 -0
- deep_privacy/configs/anonymizers/FB_cse_mask_face.py +29 -0
- deep_privacy/configs/anonymizers/deep_privacy1.py +15 -0
- deep_privacy/configs/anonymizers/face.py +17 -0
- deep_privacy/configs/anonymizers/face_fdf128.py +18 -0
- deep_privacy/configs/anonymizers/market1501/blackout.py +8 -0
- deep_privacy/configs/anonymizers/market1501/person.py +6 -0
- deep_privacy/configs/anonymizers/market1501/pixelation16.py +8 -0
- deep_privacy/configs/anonymizers/market1501/pixelation8.py +8 -0
- deep_privacy/configs/datasets/coco_cse.py +69 -0
- deep_privacy/configs/datasets/fdf128.py +24 -0
- deep_privacy/configs/datasets/fdf256.py +55 -0
- deep_privacy/configs/datasets/fdh.py +90 -0
- deep_privacy/configs/datasets/utils.py +21 -0
- deep_privacy/configs/defaults.py +53 -0
- deep_privacy/configs/discriminators/sg2_discriminator.py +43 -0
- deep_privacy/configs/fdf/deep_privacy1.py +9 -0
- deep_privacy/configs/fdf/stylegan.py +14 -0
- deep_privacy/configs/fdf/stylegan_fdf128.py +17 -0
- deep_privacy/configs/fdh/styleganL.py +16 -0
- deep_privacy/configs/fdh/styleganL_nocse.py +14 -0
- deep_privacy/configs/generators/stylegan_unet.py +22 -0
- deep_privacy/dp2/__init__.py +0 -0
- deep_privacy/dp2/anonymizer/__init__.py +1 -0
- deep_privacy/dp2/anonymizer/anonymizer.py +163 -0
- deep_privacy/dp2/anonymizer/histogram_match_anonymizers.py +93 -0
- deep_privacy/dp2/data/__init__.py +0 -0
- deep_privacy/dp2/data/build.py +40 -0
- deep_privacy/dp2/data/datasets/__init__.py +0 -0
- deep_privacy/dp2/data/datasets/coco_cse.py +68 -0
- deep_privacy/dp2/data/datasets/fdf.py +128 -0
- deep_privacy/dp2/data/datasets/fdf128_wds.py +96 -0
- deep_privacy/dp2/data/datasets/fdh.py +142 -0
- deep_privacy/dp2/data/transforms/__init__.py +2 -0
- deep_privacy/dp2/data/transforms/functional.py +57 -0
- deep_privacy/dp2/data/transforms/stylegan2_transform.py +394 -0
- deep_privacy/dp2/data/transforms/transforms.py +277 -0
- deep_privacy/dp2/data/utils.py +122 -0
- deep_privacy/dp2/detection/__init__.py +3 -0
- deep_privacy/dp2/detection/base.py +42 -0
- deep_privacy/dp2/detection/box_utils.py +104 -0
- deep_privacy/dp2/detection/box_utils_fdf.py +202 -0
app.py
CHANGED
@@ -9,15 +9,15 @@ os.system("pip install ftfy regex tqdm")
|
|
9 |
os.system("pip install --no-deps git+https://github.com/openai/CLIP.git")
|
10 |
os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
|
11 |
os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
|
12 |
-
sys.path.insert(0, Path(os.getcwd(), "
|
13 |
os.environ["TORCH_HOME"] = "torch_home"
|
14 |
from dp2 import utils
|
15 |
from gradio_demos.modules import ExampleDemo, WebcamDemo
|
16 |
|
17 |
-
cfg_face = utils.load_config("
|
18 |
for key in ["person_G_cfg", "cse_person_G_cfg", "face_G_cfg", "car_G_cfg"]:
|
19 |
if key in cfg_face.anonymizer:
|
20 |
-
cfg_face.anonymizer[key] = Path("
|
21 |
|
22 |
|
23 |
anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
|
|
|
9 |
os.system("pip install --no-deps git+https://github.com/openai/CLIP.git")
|
10 |
os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
|
11 |
os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
|
12 |
+
sys.path.insert(0, Path(os.getcwd(), "deep_privacy"))
|
13 |
os.environ["TORCH_HOME"] = "torch_home"
|
14 |
from dp2 import utils
|
15 |
from gradio_demos.modules import ExampleDemo, WebcamDemo
|
16 |
|
17 |
+
cfg_face = utils.load_config("deep_privacy/configs/anonymizers/face.py")
|
18 |
for key in ["person_G_cfg", "cse_person_G_cfg", "face_G_cfg", "car_G_cfg"]:
|
19 |
if key in cfg_face.anonymizer:
|
20 |
+
cfg_face.anonymizer[key] = Path("deep_privacy", cfg_face.anonymizer[key])
|
21 |
|
22 |
|
23 |
anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
|
deep_privacy/.gitignore
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FILES
|
2 |
+
*.yaml
|
3 |
+
*.pkl
|
4 |
+
*.flist
|
5 |
+
*.zip
|
6 |
+
*.out
|
7 |
+
*.npy
|
8 |
+
*.gz
|
9 |
+
*.ckpt
|
10 |
+
*.pth
|
11 |
+
*.log
|
12 |
+
*.pyc
|
13 |
+
*.csv
|
14 |
+
*.yml
|
15 |
+
*.ods
|
16 |
+
*.ods#
|
17 |
+
*.json
|
18 |
+
build_docker.sh
|
19 |
+
|
20 |
+
# Images / Videos
|
21 |
+
#*.png
|
22 |
+
#*.jpg
|
23 |
+
*.jpeg
|
24 |
+
*.m4a
|
25 |
+
*.mkv
|
26 |
+
*.mp4
|
27 |
+
|
28 |
+
# Directories created by inpaintron
|
29 |
+
.cache/
|
30 |
+
test_examples/
|
31 |
+
.vscode
|
32 |
+
__pycache__
|
33 |
+
.debug/
|
34 |
+
**/.ipynb_checkpoints/**
|
35 |
+
outputs/
|
36 |
+
|
37 |
+
|
38 |
+
# From pip setup
|
39 |
+
build/
|
40 |
+
*.egg-info
|
41 |
+
*.egg
|
42 |
+
.npm/
|
43 |
+
|
44 |
+
# From dockerfile
|
45 |
+
.bash_history
|
46 |
+
.viminfo
|
47 |
+
.local/
|
48 |
+
*.pickle
|
49 |
+
*.onnx
|
50 |
+
|
51 |
+
|
52 |
+
sbatch_files/
|
53 |
+
figures/
|
54 |
+
image_dump/
|
deep_privacy/CHANGELOG.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changelog
|
2 |
+
|
3 |
+
## 23.03.2023
|
4 |
+
- Quality of life improvements
|
5 |
+
- Add support for refined keypoints for the FDH dataset.
|
6 |
+
- Add FDF128 dataset loader with webdataset.
|
7 |
+
- Support for using detector and anonymizer from DeepPrivacy1.
|
8 |
+
- Update visualization of keypoints
|
9 |
+
- Fix bug for upsampling/downsampling in the anonymization pipeline.
|
10 |
+
- Support for keypoint-guided face anonymization.
|
11 |
+
- Add ViTPose + Mask-RCNN detection model for keypoint-guided full-body anonymization.
|
12 |
+
- Set caching of detections to False as default, as it can produce unexpected behaviour. For example, using a different score threshold requires re-run of detector.
|
13 |
+
- Add Gradio Demos for face and body anonymization
|
deep_privacy/Dockerfile
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvcr.io/nvidia/pytorch:22.08-py3
|
2 |
+
ARG UID=1000
|
3 |
+
ARG UNAME=testuser
|
4 |
+
ARG WANDB_API_KEY
|
5 |
+
RUN useradd -ms /bin/bash -u $UID $UNAME && \
|
6 |
+
mkdir -p /home/${UNAME} &&\
|
7 |
+
chown -R $UID /home/${UNAME}
|
8 |
+
WORKDIR /home/${UNAME}
|
9 |
+
ENV DEBIAN_FRONTEND="noninteractive"
|
10 |
+
ENV WANDB_API_KEY=$WANDB_API_KEY
|
11 |
+
ENV TORCH_HOME=/home/${UNAME}/.cache
|
12 |
+
|
13 |
+
# OPTIONAL - DeepPrivacy2 uses these environment variables to set directories outside the current working directory
|
14 |
+
#ENV BASE_DATASET_DIR=/work/haakohu/datasets
|
15 |
+
#ENV BASE_OUTPUT_DIR=/work/haakohu/outputs
|
16 |
+
#ENV FBA_METRICS_CACHE=/work/haakohu/metrics_cache
|
17 |
+
|
18 |
+
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 qt5-default -y
|
19 |
+
RUN pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose
|
20 |
+
COPY setup.py setup.py
|
21 |
+
RUN pip install \
|
22 |
+
numpy>=1.20 \
|
23 |
+
matplotlib \
|
24 |
+
cython \
|
25 |
+
tensorboard \
|
26 |
+
tqdm \
|
27 |
+
ninja==1.10.2 \
|
28 |
+
opencv-python==4.5.5.64 \
|
29 |
+
moviepy \
|
30 |
+
pyspng \
|
31 |
+
git+https://github.com/hukkelas/DSFD-Pytorch-Inference \
|
32 |
+
wandb \
|
33 |
+
termcolor \
|
34 |
+
git+https://github.com/hukkelas/torch_ops.git \
|
35 |
+
git+https://github.com/wmuron/motpy@c77f85d27e371c0a298e9a88ca99292d9b9cbe6b \
|
36 |
+
fast_pytorch_kmeans \
|
37 |
+
einops_exts \
|
38 |
+
einops \
|
39 |
+
regex \
|
40 |
+
setuptools==59.5.0 \
|
41 |
+
resize_right==0.0.2 \
|
42 |
+
pillow \
|
43 |
+
scipy==1.7.1 \
|
44 |
+
webdataset==0.2.26 \
|
45 |
+
scikit-image \
|
46 |
+
timm==0.6.7
|
47 |
+
RUN pip install --no-deps torch_fidelity==0.3.0 clip@git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620
|
deep_privacy/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
deep_privacy/anonymize.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
from typing import Optional
|
3 |
+
import click
|
4 |
+
import tops
|
5 |
+
import numpy as np
|
6 |
+
import tqdm
|
7 |
+
import moviepy.editor as mp
|
8 |
+
import cv2
|
9 |
+
from tops.config import instantiate
|
10 |
+
from pathlib import Path
|
11 |
+
from PIL import Image
|
12 |
+
from dp2 import utils
|
13 |
+
from detectron2.data.detection_utils import _apply_exif_orientation
|
14 |
+
from tops import logger
|
15 |
+
from dp2.utils.bufferless_video_capture import BufferlessVideoCapture
|
16 |
+
|
17 |
+
|
18 |
+
def show_video(video_path):
|
19 |
+
video_cap = cv2.VideoCapture(str(video_path))
|
20 |
+
while video_cap.isOpened():
|
21 |
+
ret, frame = video_cap.read()
|
22 |
+
cv2.imshow("Frame", frame)
|
23 |
+
key = cv2.waitKey(25)
|
24 |
+
if key == ord("q"):
|
25 |
+
break
|
26 |
+
video_cap.release()
|
27 |
+
cv2.destroyAllWindows()
|
28 |
+
|
29 |
+
|
30 |
+
class ImageIndexTracker:
|
31 |
+
|
32 |
+
def __init__(self, fn) -> None:
|
33 |
+
self.fn = fn
|
34 |
+
self.idx = 0
|
35 |
+
|
36 |
+
def fl_image(self, frame):
|
37 |
+
self.idx += 1
|
38 |
+
return self.fn(frame, self.idx-1)
|
39 |
+
|
40 |
+
|
41 |
+
def anonymize_video(
|
42 |
+
video_path, output_path: Path,
|
43 |
+
anonymizer, visualize: bool, max_res: int,
|
44 |
+
start_time: int, fps: int,
|
45 |
+
end_time: int,
|
46 |
+
visualize_detection: bool,
|
47 |
+
track: bool,
|
48 |
+
synthesis_kwargs,
|
49 |
+
**kwargs):
|
50 |
+
video = mp.VideoFileClip(str(video_path))
|
51 |
+
if track:
|
52 |
+
anonymizer.initialize_tracker(video.fps)
|
53 |
+
|
54 |
+
def process_frame(frame, idx):
|
55 |
+
frame = np.array(resize(Image.fromarray(frame), max_res))
|
56 |
+
cache_id = hashlib.md5(frame).hexdigest()
|
57 |
+
frame = utils.im2torch(frame, to_float=False, normalize=False)[0]
|
58 |
+
cache_id_ = cache_id + str(idx)
|
59 |
+
synthesis_kwargs["cache_id"] = cache_id_
|
60 |
+
if visualize_detection:
|
61 |
+
anonymized = anonymizer.visualize_detection(frame, cache_id=cache_id_)
|
62 |
+
else:
|
63 |
+
anonymized = anonymizer(frame, **synthesis_kwargs)
|
64 |
+
anonymized = utils.im2numpy(anonymized)
|
65 |
+
if visualize:
|
66 |
+
cv2.imshow("frame", anonymized[:, :, ::-1])
|
67 |
+
key = cv2.waitKey(1)
|
68 |
+
if key == ord("q"):
|
69 |
+
exit()
|
70 |
+
return anonymized
|
71 |
+
video: mp.VideoClip = video.subclip(start_time, end_time)
|
72 |
+
|
73 |
+
if fps is not None:
|
74 |
+
video = video.set_fps(fps)
|
75 |
+
|
76 |
+
video = video.fl_image(ImageIndexTracker(process_frame).fl_image)
|
77 |
+
if str(output_path).endswith(".avi"):
|
78 |
+
output_path = str(output_path).replace(".avi", ".mp4")
|
79 |
+
if not output_path.parent.exists():
|
80 |
+
output_path.parent.mkdir(parents=True)
|
81 |
+
video.write_videofile(str(output_path))
|
82 |
+
|
83 |
+
|
84 |
+
def resize(frame: Image.Image, max_res):
|
85 |
+
if max_res is None:
|
86 |
+
return frame
|
87 |
+
f = max(*[x/max_res for x in frame.size], 1)
|
88 |
+
if f == 1:
|
89 |
+
return frame
|
90 |
+
new_shape = [int(x/f) for x in frame.size]
|
91 |
+
return frame.resize(new_shape, resample=Image.BILINEAR)
|
92 |
+
|
93 |
+
|
94 |
+
def anonymize_image(
|
95 |
+
image_path, output_path: Path, visualize: bool,
|
96 |
+
anonymizer, max_res: int,
|
97 |
+
visualize_detection: bool,
|
98 |
+
synthesis_kwargs,
|
99 |
+
**kwargs):
|
100 |
+
with Image.open(image_path) as im:
|
101 |
+
im = _apply_exif_orientation(im)
|
102 |
+
orig_im_mode = im.mode
|
103 |
+
|
104 |
+
im = im.convert("RGB")
|
105 |
+
im = resize(im, max_res)
|
106 |
+
im = np.array(im)
|
107 |
+
md5_ = hashlib.md5(im).hexdigest()
|
108 |
+
im = utils.im2torch(np.array(im), to_float=False, normalize=False)[0]
|
109 |
+
synthesis_kwargs["cache_id"] = md5_
|
110 |
+
if visualize_detection:
|
111 |
+
im_ = anonymizer.visualize_detection(tops.to_cuda(im), cache_id=md5_)
|
112 |
+
else:
|
113 |
+
im_ = anonymizer(im, **synthesis_kwargs)
|
114 |
+
im_ = utils.im2numpy(im_)
|
115 |
+
if visualize:
|
116 |
+
while True:
|
117 |
+
cv2.imshow("frame", im_[:, :, ::-1])
|
118 |
+
key = cv2.waitKey(0)
|
119 |
+
if key == ord("q"):
|
120 |
+
break
|
121 |
+
elif key == ord("u"):
|
122 |
+
im_ = utils.im2numpy(anonymizer(im, **synthesis_kwargs))
|
123 |
+
im = Image.fromarray(im_).convert(orig_im_mode)
|
124 |
+
if output_path is not None:
|
125 |
+
output_path.parent.mkdir(exist_ok=True, parents=True)
|
126 |
+
im.save(output_path, optimize=False, quality=100)
|
127 |
+
print(f"Saved to: {output_path}")
|
128 |
+
|
129 |
+
|
130 |
+
def anonymize_file(input_path: Path, output_path: Optional[Path], **kwargs):
|
131 |
+
if output_path is not None and output_path.is_file():
|
132 |
+
logger.warn(f"Overwriting previous file: {output_path}")
|
133 |
+
if tops.is_image(input_path):
|
134 |
+
anonymize_image(input_path, output_path, **kwargs)
|
135 |
+
elif tops.is_video(input_path):
|
136 |
+
anonymize_video(input_path, output_path, **kwargs)
|
137 |
+
else:
|
138 |
+
logger.log(f"Filepath not a video or image file: {input_path}")
|
139 |
+
|
140 |
+
|
141 |
+
def anonymize_directory(input_dir: Path, output_dir: Path, **kwargs):
|
142 |
+
for childname in tqdm.tqdm(input_dir.iterdir()):
|
143 |
+
childpath = input_dir.joinpath(childname.name)
|
144 |
+
output_path = output_dir.joinpath(childname.name)
|
145 |
+
if not childpath.is_file():
|
146 |
+
anonymize_directory(childpath, output_path, **kwargs)
|
147 |
+
else:
|
148 |
+
assert childpath.is_file()
|
149 |
+
anonymize_file(childpath, output_path, **kwargs)
|
150 |
+
|
151 |
+
def anonymize_webcam(
|
152 |
+
anonymizer, max_res: int,
|
153 |
+
synthesis_kwargs,
|
154 |
+
visualize_detection,
|
155 |
+
track: bool,
|
156 |
+
**kwargs):
|
157 |
+
import time
|
158 |
+
cap = BufferlessVideoCapture(0, width=1920, height=1080)
|
159 |
+
t = time.time()
|
160 |
+
frames = 0
|
161 |
+
if track:
|
162 |
+
anonymizer.initialize_tracker(fps=5) # FPS used for tracking objects
|
163 |
+
while True:
|
164 |
+
# Capture frame-by-frame
|
165 |
+
ret, frame = cap.read()
|
166 |
+
frame = Image.fromarray(frame[:, :, ::-1])
|
167 |
+
frame = resize(frame, max_res)
|
168 |
+
frame = np.array(frame)
|
169 |
+
im = utils.im2torch(np.array(frame), to_float=False, normalize=False)[0]
|
170 |
+
if visualize_detection:
|
171 |
+
im_ = anonymizer.visualize_detection(tops.to_cuda(im))
|
172 |
+
else:
|
173 |
+
im_ = anonymizer(im, **synthesis_kwargs)
|
174 |
+
im_ = utils.im2numpy(im_)
|
175 |
+
|
176 |
+
frames += 1
|
177 |
+
delta = time.time() - t
|
178 |
+
fps = "?"
|
179 |
+
if delta > 1e-6:
|
180 |
+
fps = frames / delta
|
181 |
+
print(f"FPS: {fps:.3f}", end="\r")
|
182 |
+
cv2.imshow('frame', im_[:, :, ::-1])
|
183 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
184 |
+
break
|
185 |
+
|
186 |
+
|
187 |
+
@click.command()
|
188 |
+
@click.argument("config_path", type=click.Path(exists=True))
|
189 |
+
@click.option("-i", "--input_path", help="Input path. Accepted inputs: images, videos, directories.")
|
190 |
+
@click.option("-o", "--output_path", default=None, type=click.Path(), help="Output path to save. Can be directory or file.")
|
191 |
+
@click.option("-v","--visualize", default=False, is_flag=True, help="Visualize the result")
|
192 |
+
@click.option("--max-res", default=None, type=int, help="Maximum resolution of height/wideo")
|
193 |
+
@click.option("--start-time", "--st", default=0, type=int, help="Start time (second) for vide anonymization")
|
194 |
+
@click.option("--end-time", "--et", default=None, type=int, help="End time (second) for vide anonymization")
|
195 |
+
@click.option("--fps", default=None, type=int, help="FPS for anonymization")
|
196 |
+
@click.option("--detection-score-threshold", "--dst", default=.3, type=click.FloatRange(0, 1), help="Detection threshold, threshold applied for all detection models.")
|
197 |
+
@click.option("--visualize-detection", "--vd",default=False, is_flag=True, help="Visualize only detections without running anonymization.")
|
198 |
+
@click.option("--multi-modal-truncation", "--mt", default=False, is_flag=True, help="Enable multi-modal truncation proposed by: https://arxiv.org/pdf/2202.12211.pdf")
|
199 |
+
@click.option("--cache", default=False, is_flag=True, help="Enable detection caching. Will save and load detections from cache.")
|
200 |
+
@click.option("--amp", default=True, is_flag=True, help="Use automatic mixed precision for generator forward pass")
|
201 |
+
@click.option("-t", "--truncation_value", default=0, type=click.FloatRange(0, 1), help="Latent interpolation truncation value.")
|
202 |
+
@click.option("--track", default=False, is_flag=True, help="Track detections over frames. Will use the same latent variable (z) for tracked identities.")
|
203 |
+
@click.option("--seed", default=0, type=int, help="Set random seed for generating images.")
|
204 |
+
@click.option("--person-generator", default=None, help="Config path to unconditional person generator", type=click.Path())
|
205 |
+
@click.option("--cse-person-generator", default=None, help="Config path to CSE-guided person generator", type=click.Path())
|
206 |
+
@click.option("--webcam", default=False, is_flag=True, help="Read image from webcam feed.")
|
207 |
+
def anonymize_path(
|
208 |
+
config_path,
|
209 |
+
input_path,
|
210 |
+
output_path,
|
211 |
+
detection_score_threshold: float,
|
212 |
+
visualize_detection: bool,
|
213 |
+
cache: bool,
|
214 |
+
seed: int,
|
215 |
+
person_generator: str,
|
216 |
+
cse_person_generator: str,
|
217 |
+
webcam: bool,
|
218 |
+
**kwargs):
|
219 |
+
"""
|
220 |
+
config_path: Specify the path to the anonymization model to use.
|
221 |
+
"""
|
222 |
+
tops.set_seed(seed)
|
223 |
+
cfg = utils.load_config(config_path)
|
224 |
+
if person_generator is not None:
|
225 |
+
cfg.anonymizer.person_G_cfg = person_generator
|
226 |
+
if cse_person_generator is not None:
|
227 |
+
cfg.anonymizer.cse_person_G_cfg = cse_person_generator
|
228 |
+
cfg.detector.score_threshold = detection_score_threshold
|
229 |
+
utils.print_config(cfg)
|
230 |
+
|
231 |
+
anonymizer = instantiate(cfg.anonymizer, load_cache=cache)
|
232 |
+
synthesis_kwargs = ["amp", "multi_modal_truncation", "truncation_value"]
|
233 |
+
synthesis_kwargs = {k: kwargs.pop(k) for k in synthesis_kwargs}
|
234 |
+
|
235 |
+
kwargs["anonymizer"] = anonymizer
|
236 |
+
kwargs["visualize_detection"] = visualize_detection
|
237 |
+
kwargs["synthesis_kwargs"] = synthesis_kwargs
|
238 |
+
if webcam:
|
239 |
+
anonymize_webcam(**kwargs)
|
240 |
+
return
|
241 |
+
input_path = Path(input_path)
|
242 |
+
output_path = Path(output_path) if output_path is not None else None
|
243 |
+
if output_path is None and not kwargs["visualize"]:
|
244 |
+
logger.log("Output path not set. Setting visualize to True")
|
245 |
+
kwargs["visualize"] = True
|
246 |
+
if input_path.is_dir():
|
247 |
+
assert output_path is None or not output_path.is_file()
|
248 |
+
anonymize_directory(input_path, output_path, **kwargs)
|
249 |
+
else:
|
250 |
+
anonymize_file(input_path, output_path, **kwargs)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
|
255 |
+
anonymize_path()
|
deep_privacy/attribute_guided_demo.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import gradio
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
from dp2 import utils
|
8 |
+
from tops.config import instantiate
|
9 |
+
import tops
|
10 |
+
import gradio.inputs
|
11 |
+
from stylemc import get_and_cache_direction, get_styles
|
12 |
+
|
13 |
+
|
14 |
+
class GuidedDemo:
|
15 |
+
def __init__(self, face_anonymizer, cfg_face) -> None:
|
16 |
+
self.anonymizer = face_anonymizer
|
17 |
+
assert sum([x is not None for x in list(face_anonymizer.generators.values())]) == 1
|
18 |
+
self.generator = [x for x in list(face_anonymizer.generators.values()) if x is not None][0]
|
19 |
+
face_G_cfg = utils.load_config(cfg_face.anonymizer.face_G_cfg)
|
20 |
+
face_G_cfg.train.batch_size = 1
|
21 |
+
self.dl = instantiate(face_G_cfg.data.val.loader)
|
22 |
+
self.cache_dir = face_G_cfg.output_dir
|
23 |
+
self.precompute_edits()
|
24 |
+
|
25 |
+
def precompute_edits(self):
|
26 |
+
self.precomputed_edits = set()
|
27 |
+
for edit in self.precomputed_edits:
|
28 |
+
get_and_cache_direction(self.cache_dir, self.dl, self.generator, edit)
|
29 |
+
if self.cache_dir.joinpath("stylemc_cache").is_dir():
|
30 |
+
for path in self.cache_dir.joinpath("stylemc_cache").iterdir():
|
31 |
+
text_prompt = path.stem.replace("_", " ")
|
32 |
+
self.precomputed_edits.add(text_prompt)
|
33 |
+
print(text_prompt)
|
34 |
+
self.edits = defaultdict(defaultdict)
|
35 |
+
|
36 |
+
def anonymize(self, img, show_boxes: bool, current_box_idx: int, current_styles, current_boxes, update_identity, edits, cache_id=None):
|
37 |
+
if not isinstance(img, torch.Tensor):
|
38 |
+
img, cache_id = pil2torch(img)
|
39 |
+
img = tops.to_cuda(img)
|
40 |
+
|
41 |
+
current_box_idx = current_box_idx % len(current_boxes)
|
42 |
+
edited_styles = [s.clone() for s in current_styles]
|
43 |
+
for face_idx, face_edits in edits.items():
|
44 |
+
for prompt, strength in face_edits.items():
|
45 |
+
direction = get_and_cache_direction(self.cache_dir, self.dl, self.generator, prompt)
|
46 |
+
edited_styles[int(face_idx)] += direction * strength
|
47 |
+
update_identity[int(face_idx)] = True
|
48 |
+
assert img.dtype == torch.uint8
|
49 |
+
img = self.anonymizer(
|
50 |
+
img, truncation_value=0,
|
51 |
+
multi_modal_truncation=True, amp=True,
|
52 |
+
cache_id=cache_id,
|
53 |
+
all_styles=edited_styles,
|
54 |
+
update_identity=update_identity)
|
55 |
+
update_identity = [True for i in range(len(update_identity))]
|
56 |
+
img = utils.im2numpy(img)
|
57 |
+
if show_boxes:
|
58 |
+
x0, y0, x1, y1 = [int(_) for _ in current_boxes[int(current_box_idx)]]
|
59 |
+
img = cv2.rectangle(img, (x0, y0), (x1, y1), (255, 0, 0), 1)
|
60 |
+
return img, update_identity
|
61 |
+
|
62 |
+
def update_image(self, img, show_boxes):
|
63 |
+
img, cache_id = pil2torch(img)
|
64 |
+
img = tops.to_cuda(img)
|
65 |
+
det = self.anonymizer.detector.forward_and_cache(img, cache_id, load_cache=True)[0]
|
66 |
+
current_styles = []
|
67 |
+
for i in range(len(det)):
|
68 |
+
# Need to do forward pass to register all affine modules.
|
69 |
+
batch = det.get_crop(i, img)
|
70 |
+
batch["condition"] = batch["img"].float()
|
71 |
+
|
72 |
+
s = get_styles(
|
73 |
+
np.random.randint(0, 999999),self.generator,
|
74 |
+
batch, truncation_value=0)
|
75 |
+
current_styles.append(s)
|
76 |
+
update_identity = [True for i in range(len(det))]
|
77 |
+
current_boxes = np.array(det.boxes)
|
78 |
+
edits = defaultdict(defaultdict)
|
79 |
+
cur_face_idx = -1 % len(current_boxes)
|
80 |
+
img, update_identity = self.anonymize(img, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits, cache_id=cache_id)
|
81 |
+
return img, current_styles, current_boxes, update_identity, edits, cur_face_idx
|
82 |
+
|
83 |
+
def change_face(self, change, cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits):
|
84 |
+
cur_face_idx = (cur_face_idx+change) % len(current_boxes)
|
85 |
+
img, update_identity = self.anonymize(input_image, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits)
|
86 |
+
return img, update_identity, cur_face_idx
|
87 |
+
|
88 |
+
def add_style(self, face_idx: int, prompt: str, strength: float, input_image, show_boxes, current_styles, current_boxes, update_identity, edits):
|
89 |
+
face_idx = face_idx % len(current_boxes)
|
90 |
+
edits[face_idx][prompt] = strength
|
91 |
+
img, update_identity = self.anonymize(input_image, show_boxes, face_idx, current_styles, current_boxes, update_identity, edits)
|
92 |
+
return img, update_identity, edits
|
93 |
+
|
94 |
+
def setup_interface(self):
|
95 |
+
current_styles = gradio.State()
|
96 |
+
current_boxes = gradio.State(None)
|
97 |
+
update_identity = gradio.State([])
|
98 |
+
edits = gradio.State([])
|
99 |
+
with gradio.Row():
|
100 |
+
input_image = gradio.Image(
|
101 |
+
type="pil", label="Upload your image or try the example below!",source="webcam")
|
102 |
+
output_image = gradio.Image(type="numpy", label="Output")
|
103 |
+
with gradio.Row():
|
104 |
+
update_btn = gradio.Button("Update Anonymization").style(full_width=True)
|
105 |
+
with gradio.Row():
|
106 |
+
show_boxes = gradio.Checkbox(value=True, label="Show Selected")
|
107 |
+
cur_face_idx = gradio.Number(value=-1,label="Current", interactive=False)
|
108 |
+
previous = gradio.Button("Previous Person")
|
109 |
+
next_ = gradio.Button("Next Person")
|
110 |
+
with gradio.Row():
|
111 |
+
text_prompt = gradio.Textbox(
|
112 |
+
placeholder=" | ".join(list(self.precomputed_edits)),
|
113 |
+
label="Text Prompt for Edit")
|
114 |
+
edit_strength = gradio.Slider(0, 5, step=.01)
|
115 |
+
add_btn = gradio.Button("Add Edit")
|
116 |
+
add_btn.click(self.add_style, inputs=[cur_face_idx, text_prompt, edit_strength, input_image, show_boxes, current_styles, current_boxes, update_identity, edits], outputs=[output_image, update_identity, edits])
|
117 |
+
update_btn.click(self.update_image, inputs=[input_image, show_boxes], outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx])
|
118 |
+
input_image.change(self.update_image, inputs=[input_image, show_boxes], outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx])
|
119 |
+
previous.click(self.change_face, inputs=[gradio.State(-1), cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits], outputs=[output_image, update_identity, cur_face_idx])
|
120 |
+
next_.click(self.change_face, inputs=[gradio.State(1), cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits], outputs=[output_image, update_identity, cur_face_idx])
|
121 |
+
|
122 |
+
show_boxes.change(self.anonymize, inputs=[input_image, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits], outputs=[output_image, update_identity])
|
123 |
+
|
124 |
+
|
125 |
+
def pil2torch(img: Image.Image):
|
126 |
+
img = img.convert("RGB")
|
127 |
+
img = np.array(img)
|
128 |
+
img = np.rollaxis(img, 2)
|
129 |
+
return torch.from_numpy(img), None
|
130 |
+
|
131 |
+
|
132 |
+
cfg_face = utils.load_config("configs/anonymizers/face.py")
|
133 |
+
anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
|
134 |
+
anonymizer_face.initialize_tracker(fps=1)
|
135 |
+
|
136 |
+
|
137 |
+
with gradio.Blocks() as demo:
|
138 |
+
gradio.Markdown("# <center> DeepPrivacy2 - Realistic Image Anonymization </center>")
|
139 |
+
gradio.Markdown("### <center> Håkon Hukkelås, Rudolf Mester, Frank Lindseth </center>")
|
140 |
+
with gradio.Tab("Text-Guided Anonymization"):
|
141 |
+
GuidedDemo(anonymizer_face, cfg_face).setup_interface()
|
142 |
+
|
143 |
+
|
144 |
+
demo.launch()
|
deep_privacy/configs/anonymizers/FB_cse.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.person_detector import CSEPersonDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
from dp2.generator.dummy_generators import MaskOutGenerator
|
6 |
+
|
7 |
+
|
8 |
+
maskout_G = L(MaskOutGenerator)(noise="constant")
|
9 |
+
|
10 |
+
detector = L(CSEPersonDetector)(
|
11 |
+
mask_rcnn_cfg=dict(),
|
12 |
+
cse_cfg=dict(),
|
13 |
+
cse_post_process_cfg=dict(
|
14 |
+
target_imsize=(288, 160),
|
15 |
+
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
16 |
+
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
17 |
+
iou_combine_threshold=0.4,
|
18 |
+
dilation_percentage=0.02,
|
19 |
+
normalize_embedding=False
|
20 |
+
),
|
21 |
+
score_threshold=0.3,
|
22 |
+
cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
|
23 |
+
)
|
24 |
+
|
25 |
+
anonymizer = L(Anonymizer)(
|
26 |
+
detector="${detector}",
|
27 |
+
cse_person_G_cfg="configs/fdh/styleganL.py",
|
28 |
+
)
|
deep_privacy/configs/anonymizers/FB_cse_mask.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.person_detector import CSEPersonDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
from dp2.generator.dummy_generators import MaskOutGenerator
|
6 |
+
|
7 |
+
|
8 |
+
maskout_G = L(MaskOutGenerator)(noise="constant")
|
9 |
+
|
10 |
+
detector = L(CSEPersonDetector)(
|
11 |
+
mask_rcnn_cfg=dict(),
|
12 |
+
cse_cfg=dict(),
|
13 |
+
cse_post_process_cfg=dict(
|
14 |
+
target_imsize=(288, 160),
|
15 |
+
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
16 |
+
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
17 |
+
iou_combine_threshold=0.4,
|
18 |
+
dilation_percentage=0.02,
|
19 |
+
normalize_embedding=False
|
20 |
+
),
|
21 |
+
score_threshold=0.3,
|
22 |
+
cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
|
23 |
+
)
|
24 |
+
|
25 |
+
anonymizer = L(Anonymizer)(
|
26 |
+
detector="${detector}",
|
27 |
+
person_G_cfg="configs/fdh/styleganL_nocse.py",
|
28 |
+
cse_person_G_cfg="configs/fdh/styleganL.py",
|
29 |
+
)
|
deep_privacy/configs/anonymizers/FB_cse_mask_face.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
detector = L(CSeMaskFaceDetector)(
|
7 |
+
mask_rcnn_cfg=dict(),
|
8 |
+
face_detector_cfg=dict(),
|
9 |
+
face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
|
10 |
+
cse_cfg=dict(),
|
11 |
+
cse_post_process_cfg=dict(
|
12 |
+
target_imsize=(288, 160),
|
13 |
+
exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
|
14 |
+
exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
|
15 |
+
iou_combine_threshold=0.4,
|
16 |
+
dilation_percentage=0.02,
|
17 |
+
normalize_embedding=False
|
18 |
+
),
|
19 |
+
score_threshold=0.3,
|
20 |
+
cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache")
|
21 |
+
)
|
22 |
+
|
23 |
+
anonymizer = L(Anonymizer)(
|
24 |
+
detector="${detector}",
|
25 |
+
face_G_cfg="configs/fdf/stylegan.py",
|
26 |
+
person_G_cfg="configs/fdh/styleganL_nocse.py",
|
27 |
+
cse_person_G_cfg="configs/fdh/styleganL.py",
|
28 |
+
car_G_cfg="configs/generators/dummy/pixelation8.py"
|
29 |
+
)
|
deep_privacy/configs/anonymizers/deep_privacy1.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .face_fdf128 import anonymizer, common, detector
|
2 |
+
from dp2.detection.deep_privacy1_detector import DeepPrivacy1Detector
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
|
5 |
+
anonymizer.update(
|
6 |
+
face_G_cfg="configs/fdf/deep_privacy1.py",
|
7 |
+
)
|
8 |
+
|
9 |
+
anonymizer.detector = L(DeepPrivacy1Detector)(
|
10 |
+
face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
|
11 |
+
face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True),
|
12 |
+
score_threshold=0.3,
|
13 |
+
keypoint_threshold=0.3,
|
14 |
+
cache_directory=common.output_dir.joinpath("deep_privacy1_cache")
|
15 |
+
)
|
deep_privacy/configs/anonymizers/face.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.face_detector import FaceDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
|
7 |
+
detector = L(FaceDetector)(
|
8 |
+
face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
|
9 |
+
face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
|
10 |
+
score_threshold=0.3,
|
11 |
+
cache_directory=common.output_dir.joinpath("face_detection_cache"),
|
12 |
+
)
|
13 |
+
|
14 |
+
anonymizer = L(Anonymizer)(
|
15 |
+
detector="${detector}",
|
16 |
+
face_G_cfg="configs/fdf/stylegan.py",
|
17 |
+
)
|
deep_privacy/configs/anonymizers/face_fdf128.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.anonymizer import Anonymizer
|
2 |
+
from dp2.detection.face_detector import FaceDetector
|
3 |
+
from ..defaults import common
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
|
7 |
+
detector = L(FaceDetector)(
|
8 |
+
face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
|
9 |
+
face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True),
|
10 |
+
score_threshold=0.3,
|
11 |
+
cache_directory=common.output_dir.joinpath("face_detection_cache")
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
anonymizer = L(Anonymizer)(
|
16 |
+
detector="${detector}",
|
17 |
+
face_G_cfg="configs/fdf/stylegan_fdf128.py",
|
18 |
+
)
|
deep_privacy/configs/anonymizers/market1501/blackout.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
7 |
+
anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py"
|
8 |
+
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py"
|
deep_privacy/configs/anonymizers/market1501/person.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
deep_privacy/configs/anonymizers/market1501/pixelation16.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
7 |
+
anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py"
|
8 |
+
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py"
|
deep_privacy/configs/anonymizers/market1501/pixelation8.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..FB_cse_mask_face import anonymizer, detector, common
|
2 |
+
|
3 |
+
detector.score_threshold = .1
|
4 |
+
detector.face_detector_cfg.confidence_threshold = .5
|
5 |
+
detector.cse_cfg.score_thres = 0.3
|
6 |
+
anonymizer.generators.face_G_cfg = None
|
7 |
+
anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py"
|
8 |
+
anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py"
|
deep_privacy/configs/datasets/coco_cse.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
import torch
|
5 |
+
import functools
|
6 |
+
from dp2.data.datasets.coco_cse import CocoCSE
|
7 |
+
from dp2.data.build import get_dataloader
|
8 |
+
from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
|
9 |
+
from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
|
10 |
+
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
11 |
+
from .utils import final_eval_fn
|
12 |
+
|
13 |
+
|
14 |
+
dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
15 |
+
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
16 |
+
data_dir = Path(dataset_base_dir, "coco_cse")
|
17 |
+
data = dict(
|
18 |
+
imsize=(288, 160),
|
19 |
+
im_channels=3,
|
20 |
+
semantic_nc=26,
|
21 |
+
cse_nc=16,
|
22 |
+
train=dict(
|
23 |
+
dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
|
24 |
+
loader=L(get_dataloader)(
|
25 |
+
shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
|
26 |
+
batch_size="${train.batch_size}",
|
27 |
+
dataset="${..dataset}",
|
28 |
+
infinite=True,
|
29 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
30 |
+
L(ToFloat)(),
|
31 |
+
L(StyleGANAugmentPipe)(
|
32 |
+
rotate=0.5, rotate_max=.05,
|
33 |
+
xint=.5, xint_max=0.05,
|
34 |
+
scale=.5, scale_std=.05,
|
35 |
+
aniso=0.5, aniso_std=.05,
|
36 |
+
xfrac=.5, xfrac_std=.05,
|
37 |
+
brightness=.5, brightness_std=.05,
|
38 |
+
contrast=.5, contrast_std=.1,
|
39 |
+
hue=.5, hue_max=.05,
|
40 |
+
saturation=.5, saturation_std=.5,
|
41 |
+
imgfilter=.5, imgfilter_std=.1),
|
42 |
+
L(RandomHorizontalFlip)(p=0.5),
|
43 |
+
L(CreateEmbedding)(),
|
44 |
+
L(Resize)(size="${data.imsize}"),
|
45 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
46 |
+
L(CreateCondition)(),
|
47 |
+
])
|
48 |
+
)
|
49 |
+
),
|
50 |
+
val=dict(
|
51 |
+
dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
|
52 |
+
loader=L(get_dataloader)(
|
53 |
+
shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
|
54 |
+
batch_size="${train.batch_size}",
|
55 |
+
dataset="${..dataset}",
|
56 |
+
infinite=False,
|
57 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
58 |
+
L(ToFloat)(),
|
59 |
+
L(CreateEmbedding)(),
|
60 |
+
L(Resize)(size="${data.imsize}"),
|
61 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
62 |
+
L(CreateCondition)(),
|
63 |
+
])
|
64 |
+
)
|
65 |
+
),
|
66 |
+
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
67 |
+
train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
|
68 |
+
evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
|
69 |
+
)
|
deep_privacy/configs/datasets/fdf128.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from functools import partial
|
3 |
+
from dp2.data.datasets.fdf import FDFDataset
|
4 |
+
from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn, train_eval_fn
|
5 |
+
|
6 |
+
data_dir = Path(dataset_base_dir, "fdf")
|
7 |
+
data.train.dataset.dirpath = data_dir.joinpath("train")
|
8 |
+
data.val.dataset.dirpath = data_dir.joinpath("val")
|
9 |
+
data.imsize = (128, 128)
|
10 |
+
|
11 |
+
|
12 |
+
data.train_evaluation_fn = partial(
|
13 |
+
train_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train"))
|
14 |
+
data.evaluation_fn = partial(
|
15 |
+
final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final"))
|
16 |
+
|
17 |
+
data.train.dataset.update(
|
18 |
+
_target_ = FDFDataset,
|
19 |
+
imsize="${data.imsize}"
|
20 |
+
)
|
21 |
+
data.val.dataset.update(
|
22 |
+
_target_ = FDFDataset,
|
23 |
+
imsize="${data.imsize}"
|
24 |
+
)
|
deep_privacy/configs/datasets/fdf256.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
import torch
|
5 |
+
import functools
|
6 |
+
from dp2.data.datasets.fdf import FDF256Dataset
|
7 |
+
from dp2.data.build import get_dataloader
|
8 |
+
from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
|
9 |
+
from .utils import final_eval_fn, train_eval_fn
|
10 |
+
|
11 |
+
|
12 |
+
dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
13 |
+
metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
14 |
+
data_dir = Path(dataset_base_dir, "fdf256")
|
15 |
+
data = dict(
|
16 |
+
imsize=(256, 256),
|
17 |
+
im_channels=3,
|
18 |
+
semantic_nc=None,
|
19 |
+
cse_nc=None,
|
20 |
+
n_keypoints=None,
|
21 |
+
train=dict(
|
22 |
+
dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
|
23 |
+
loader=L(get_dataloader)(
|
24 |
+
shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
|
25 |
+
batch_size="${train.batch_size}",
|
26 |
+
dataset="${..dataset}",
|
27 |
+
infinite=True,
|
28 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
29 |
+
L(ToFloat)(),
|
30 |
+
L(RandomHorizontalFlip)(p=0.5),
|
31 |
+
L(Resize)(size="${data.imsize}"),
|
32 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
33 |
+
L(CreateCondition)(),
|
34 |
+
])
|
35 |
+
)
|
36 |
+
),
|
37 |
+
val=dict(
|
38 |
+
dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
|
39 |
+
loader=L(get_dataloader)(
|
40 |
+
shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
|
41 |
+
batch_size="${train.batch_size}",
|
42 |
+
dataset="${..dataset}",
|
43 |
+
infinite=False,
|
44 |
+
gpu_transform=L(torch.nn.Sequential)(*[
|
45 |
+
L(ToFloat)(),
|
46 |
+
L(Resize)(size="${data.imsize}"),
|
47 |
+
L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
|
48 |
+
L(CreateCondition)(),
|
49 |
+
])
|
50 |
+
)
|
51 |
+
),
|
52 |
+
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
53 |
+
train_evaluation_fn=functools.partial(train_eval_fn, cache_directory=Path(metrics_cache, "fdf_val_train")),
|
54 |
+
evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
|
55 |
+
)
|
deep_privacy/configs/datasets/fdh.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from tops.config import LazyCall as L
|
4 |
+
import torch
|
5 |
+
import functools
|
6 |
+
from dp2.data.datasets.fdh import get_dataloader_fdh_wds
|
7 |
+
from dp2.data.utils import get_coco_flipmap
|
8 |
+
from dp2.data.transforms.transforms import (
|
9 |
+
Normalize,
|
10 |
+
ToFloat,
|
11 |
+
CreateCondition,
|
12 |
+
RandomHorizontalFlip,
|
13 |
+
CreateEmbedding,
|
14 |
+
)
|
15 |
+
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
16 |
+
from dp2.metrics.fid_clip import compute_fid_clip
|
17 |
+
from dp2.metrics.ppl import calculate_ppl
|
18 |
+
from .utils import train_eval_fn
|
19 |
+
|
20 |
+
|
21 |
+
def final_eval_fn(*args, **kwargs):
|
22 |
+
result = compute_metrics_iteratively(*args, **kwargs)
|
23 |
+
result2 = calculate_ppl(*args, **kwargs, upsample_size=(288, 160))
|
24 |
+
result3 = compute_fid_clip(*args, **kwargs)
|
25 |
+
assert all(key not in result for key in result2)
|
26 |
+
result.update(result2)
|
27 |
+
result.update(result3)
|
28 |
+
return result
|
29 |
+
|
30 |
+
|
31 |
+
def get_cache_directory(imsize, subset):
|
32 |
+
return Path(metrics_cache, f"{subset}{imsize[0]}")
|
33 |
+
|
34 |
+
dataset_base_dir = (
|
35 |
+
os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
|
36 |
+
)
|
37 |
+
metrics_cache = (
|
38 |
+
os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
|
39 |
+
)
|
40 |
+
data_dir = Path(dataset_base_dir, "fdh")
|
41 |
+
data = dict(
|
42 |
+
imsize=(288, 160),
|
43 |
+
im_channels=3,
|
44 |
+
cse_nc=16,
|
45 |
+
n_keypoints=17,
|
46 |
+
train=dict(
|
47 |
+
loader=L(get_dataloader_fdh_wds)(
|
48 |
+
path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
|
49 |
+
batch_size="${train.batch_size}",
|
50 |
+
num_workers=6,
|
51 |
+
transform=L(torch.nn.Sequential)(
|
52 |
+
L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
|
53 |
+
),
|
54 |
+
gpu_transform=L(torch.nn.Sequential)(
|
55 |
+
L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
|
56 |
+
L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
|
57 |
+
L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
|
58 |
+
L(CreateCondition)(),
|
59 |
+
),
|
60 |
+
infinite=True,
|
61 |
+
shuffle=True,
|
62 |
+
partial_batches=False,
|
63 |
+
load_embedding=True,
|
64 |
+
keypoints_split="train",
|
65 |
+
load_new_keypoints=False
|
66 |
+
)
|
67 |
+
),
|
68 |
+
val=dict(
|
69 |
+
loader=L(get_dataloader_fdh_wds)(
|
70 |
+
path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
|
71 |
+
batch_size="${train.batch_size}",
|
72 |
+
num_workers=6,
|
73 |
+
transform=None,
|
74 |
+
gpu_transform="${data.train.loader.gpu_transform}",
|
75 |
+
infinite=False,
|
76 |
+
shuffle=False,
|
77 |
+
partial_batches=True,
|
78 |
+
load_embedding=True,
|
79 |
+
keypoints_split="val",
|
80 |
+
load_new_keypoints="${data.train.loader.load_new_keypoints}"
|
81 |
+
)
|
82 |
+
),
|
83 |
+
# Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
|
84 |
+
train_evaluation_fn=L(functools.partial)(
|
85 |
+
train_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh"),
|
86 |
+
data_len=30_000),
|
87 |
+
evaluation_fn=L(functools.partial)(
|
88 |
+
final_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh_eval"),
|
89 |
+
data_len=30_000)
|
90 |
+
)
|
deep_privacy/configs/datasets/utils.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.metrics.ppl import calculate_ppl
|
2 |
+
from dp2.metrics.torch_metrics import compute_metrics_iteratively
|
3 |
+
from dp2.metrics.fid_clip import compute_fid_clip
|
4 |
+
|
5 |
+
|
6 |
+
def final_eval_fn(*args, **kwargs):
|
7 |
+
result = compute_metrics_iteratively(*args, **kwargs)
|
8 |
+
result2 = calculate_ppl(*args, **kwargs,)
|
9 |
+
result3 = compute_fid_clip(*args, **kwargs)
|
10 |
+
assert all(key not in result for key in result2)
|
11 |
+
result.update(result2)
|
12 |
+
result.update(result3)
|
13 |
+
return result
|
14 |
+
|
15 |
+
|
16 |
+
def train_eval_fn(*args, **kwargs):
|
17 |
+
result = compute_metrics_iteratively(*args, **kwargs)
|
18 |
+
result2 = compute_fid_clip(*args, **kwargs)
|
19 |
+
assert all(key not in result for key in result2)
|
20 |
+
result.update(result2)
|
21 |
+
return result
|
deep_privacy/configs/defaults.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from tops.config import LazyCall as L
|
5 |
+
|
6 |
+
if "PRETRAINED_CHECKPOINTS_PATH" in os.environ:
|
7 |
+
PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"])
|
8 |
+
else:
|
9 |
+
PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints")
|
10 |
+
if "BASE_OUTPUT_DIR" in os.environ:
|
11 |
+
BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"])
|
12 |
+
else:
|
13 |
+
BASE_OUTPUT_DIR = pathlib.Path("outputs")
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
common = dict(
|
18 |
+
logger_backend=["wandb", "stdout", "json", "image_dumper"],
|
19 |
+
wandb_project="deep_privacy2",
|
20 |
+
output_dir=BASE_OUTPUT_DIR,
|
21 |
+
experiment_name=None, # Optional experiment name to show on wandb
|
22 |
+
)
|
23 |
+
|
24 |
+
train = dict(
|
25 |
+
batch_size=32,
|
26 |
+
seed=0,
|
27 |
+
ims_per_log=1024,
|
28 |
+
ims_per_val=int(200e3),
|
29 |
+
max_images_to_train=int(12e6),
|
30 |
+
amp=dict(
|
31 |
+
enabled=True,
|
32 |
+
scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
|
33 |
+
scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
|
34 |
+
),
|
35 |
+
fp16_ddp_accumulate=False, # All gather gradients in fp16?
|
36 |
+
broadcast_buffers=False,
|
37 |
+
bias_act_plugin_enabled=True,
|
38 |
+
grid_sample_gradfix_enabled=True,
|
39 |
+
conv2d_gradfix_enabled=False,
|
40 |
+
channels_last=False,
|
41 |
+
compile_G=dict(
|
42 |
+
enabled=False,
|
43 |
+
mode="default" # default, reduce-overhead or max-autotune
|
44 |
+
),
|
45 |
+
compile_D=dict(
|
46 |
+
enabled=False,
|
47 |
+
mode="default" # default, reduce-overhead or max-autotune
|
48 |
+
)
|
49 |
+
)
|
50 |
+
|
51 |
+
# exponential moving average
|
52 |
+
EMA = dict(rampup=0.05)
|
53 |
+
|
deep_privacy/configs/discriminators/sg2_discriminator.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from dp2.discriminator import SG2Discriminator
|
3 |
+
import torch
|
4 |
+
from dp2.loss import StyleGAN2Loss
|
5 |
+
|
6 |
+
|
7 |
+
discriminator = L(SG2Discriminator)(
|
8 |
+
imsize="${data.imsize}",
|
9 |
+
im_channels="${data.im_channels}",
|
10 |
+
min_fmap_resolution=4,
|
11 |
+
max_cnum_mul=8,
|
12 |
+
cnum=80,
|
13 |
+
input_condition=True,
|
14 |
+
conv_clamp=256,
|
15 |
+
input_cse=False,
|
16 |
+
cse_nc="${data.cse_nc}",
|
17 |
+
fix_residual=False,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
loss_fnc = L(StyleGAN2Loss)(
|
22 |
+
lazy_regularization=True,
|
23 |
+
lazy_reg_interval=16,
|
24 |
+
r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False),
|
25 |
+
EP_lambd=0.001,
|
26 |
+
pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01)
|
27 |
+
)
|
28 |
+
|
29 |
+
def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs):
|
30 |
+
if lazy_regularization:
|
31 |
+
# From Analyzing and improving the image quality of stylegan, CVPR 2020
|
32 |
+
c = lazy_reg_interval / (lazy_reg_interval + 1)
|
33 |
+
betas = [beta ** c for beta in betas]
|
34 |
+
lr *= c
|
35 |
+
print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}")
|
36 |
+
return type(lr=lr, betas=betas, **kwargs)
|
37 |
+
|
38 |
+
|
39 |
+
D_optim = L(build_D_optim)(
|
40 |
+
type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99),
|
41 |
+
lazy_regularization="${loss_fnc.lazy_regularization}",
|
42 |
+
lazy_reg_interval="${loss_fnc.lazy_reg_interval}")
|
43 |
+
G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99))
|
deep_privacy/configs/fdf/deep_privacy1.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from dp2.generator.deep_privacy1 import MSGGenerator
|
3 |
+
from ..datasets.fdf128 import data
|
4 |
+
from ..defaults import common, train
|
5 |
+
|
6 |
+
generator = L(MSGGenerator)()
|
7 |
+
|
8 |
+
common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/fdf128_model512.ckpt"
|
9 |
+
common.model_md5sum = "6cc8b285bdc1fcdfc64f5db7c521d0a6"
|
deep_privacy/configs/fdf/stylegan.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..generators.stylegan_unet import generator
|
2 |
+
from ..datasets.fdf256 import data
|
3 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
4 |
+
from ..defaults import train, common, EMA
|
5 |
+
|
6 |
+
train.max_images_to_train = int(35e6)
|
7 |
+
G_optim.lr = 0.002
|
8 |
+
D_optim.lr = 0.002
|
9 |
+
generator.input_cse = False
|
10 |
+
loss_fnc.r1_opts.lambd = 1
|
11 |
+
train.ims_per_val = int(2e6)
|
12 |
+
|
13 |
+
common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e"
|
14 |
+
common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07"
|
deep_privacy/configs/fdf/stylegan_fdf128.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
2 |
+
from ..datasets.fdf128 import data
|
3 |
+
from ..generators.stylegan_unet import generator
|
4 |
+
from ..defaults import train, common, EMA
|
5 |
+
from tops.config import LazyCall as L
|
6 |
+
|
7 |
+
G_optim.lr = 0.002
|
8 |
+
D_optim.lr = 0.002
|
9 |
+
generator.update(cnum=128, max_cnum_mul=4, input_cse=False)
|
10 |
+
loss_fnc.r1_opts.lambd = 0.1
|
11 |
+
|
12 |
+
train.update(ims_per_val=int(2e6), batch_size=64, max_images_to_train=int(35e6))
|
13 |
+
|
14 |
+
common.update(
|
15 |
+
model_url="https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/66d803c0-55ce-44c0-9d53-815c2c0e6ba4eb458409-9e91-45d1-bce0-95c8a47a57218b102fdf-bea3-44dc-aac4-0fb1d370ef1c",
|
16 |
+
model_md5sum="bccd4403e7c9bca682566ff3319e8176"
|
17 |
+
)
|
deep_privacy/configs/fdh/styleganL.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from ..generators.stylegan_unet import generator
|
3 |
+
from ..datasets.fdh import data
|
4 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
5 |
+
from ..defaults import train, common, EMA
|
6 |
+
|
7 |
+
train.max_images_to_train = int(50e6)
|
8 |
+
train.batch_size = 64
|
9 |
+
G_optim.lr = 0.002
|
10 |
+
D_optim.lr = 0.002
|
11 |
+
data.train.loader.num_workers = 4
|
12 |
+
train.ims_per_val = int(1e6)
|
13 |
+
loss_fnc.r1_opts.lambd = .1
|
14 |
+
|
15 |
+
common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c"
|
16 |
+
common.model_md5sum = "3411478b5ec600a4219cccf4499732bd"
|
deep_privacy/configs/fdh/styleganL_nocse.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tops.config import LazyCall as L
|
2 |
+
from ..generators.stylegan_unet import generator
|
3 |
+
from ..datasets.fdh import data
|
4 |
+
from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
|
5 |
+
from ..defaults import train, common, EMA
|
6 |
+
|
7 |
+
train.max_images_to_train = int(50e6)
|
8 |
+
G_optim.lr = 0.002
|
9 |
+
D_optim.lr = 0.002
|
10 |
+
generator.input_cse = False
|
11 |
+
data.load_embeddings = False
|
12 |
+
common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt"
|
13 |
+
common.model_md5sum = "fda0d809741bc67487abada793975c37"
|
14 |
+
generator.fix_errors = False
|
deep_privacy/configs/generators/stylegan_unet.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dp2.generator.stylegan_unet import StyleGANUnet
|
2 |
+
from tops.config import LazyCall as L
|
3 |
+
|
4 |
+
generator = L(StyleGANUnet)(
|
5 |
+
imsize="${data.imsize}",
|
6 |
+
im_channels="${data.im_channels}",
|
7 |
+
min_fmap_resolution=8,
|
8 |
+
cnum=64,
|
9 |
+
max_cnum_mul=8,
|
10 |
+
n_middle_blocks=0,
|
11 |
+
z_channels=512,
|
12 |
+
mask_output=True,
|
13 |
+
conv_clamp=256,
|
14 |
+
input_cse=True,
|
15 |
+
scale_grad=True,
|
16 |
+
cse_nc="${data.cse_nc}",
|
17 |
+
w_dim=512,
|
18 |
+
n_keypoints="${data.n_keypoints}",
|
19 |
+
input_keypoints=False,
|
20 |
+
input_keypoint_indices=[],
|
21 |
+
fix_errors=True
|
22 |
+
)
|
deep_privacy/dp2/__init__.py
ADDED
File without changes
|
deep_privacy/dp2/anonymizer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .anonymizer import Anonymizer
|
deep_privacy/dp2/anonymizer/anonymizer.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Union, Optional
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import tops
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
from motpy import Detection, MultiObjectTracker
|
8 |
+
from dp2.utils import load_config
|
9 |
+
from dp2.infer import build_trained_generator
|
10 |
+
from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
|
11 |
+
|
12 |
+
|
13 |
+
def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
|
14 |
+
cfg = load_config(cfg_path)
|
15 |
+
G = build_trained_generator(cfg)
|
16 |
+
tops.logger.log(f"Loaded generator from: {cfg_path}")
|
17 |
+
return G
|
18 |
+
|
19 |
+
|
20 |
+
class Anonymizer:
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
detector,
|
25 |
+
load_cache: bool = False,
|
26 |
+
person_G_cfg: Optional[Union[str, Path]] = None,
|
27 |
+
cse_person_G_cfg: Optional[Union[str, Path]] = None,
|
28 |
+
face_G_cfg: Optional[Union[str, Path]] = None,
|
29 |
+
car_G_cfg: Optional[Union[str, Path]] = None,
|
30 |
+
) -> None:
|
31 |
+
self.detector = detector
|
32 |
+
self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
|
33 |
+
self.load_cache = load_cache
|
34 |
+
if cse_person_G_cfg is not None:
|
35 |
+
self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
|
36 |
+
if person_G_cfg is not None:
|
37 |
+
self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
|
38 |
+
if face_G_cfg is not None:
|
39 |
+
self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
|
40 |
+
if car_G_cfg is not None:
|
41 |
+
self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
|
42 |
+
|
43 |
+
def initialize_tracker(self, fps: float):
|
44 |
+
self.tracker = MultiObjectTracker(dt=1/fps)
|
45 |
+
self.track_to_z_idx = dict()
|
46 |
+
|
47 |
+
def reset_tracker(self):
|
48 |
+
self.track_to_z_idx = dict()
|
49 |
+
|
50 |
+
def forward_G(self,
|
51 |
+
G,
|
52 |
+
batch,
|
53 |
+
multi_modal_truncation: bool,
|
54 |
+
amp: bool,
|
55 |
+
z_idx: int,
|
56 |
+
truncation_value: float,
|
57 |
+
idx: int,
|
58 |
+
all_styles=None):
|
59 |
+
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
|
60 |
+
batch["img"] = batch["img"].float()
|
61 |
+
batch["condition"] = batch["mask"].float() * batch["img"]
|
62 |
+
|
63 |
+
with torch.cuda.amp.autocast(amp):
|
64 |
+
z = None
|
65 |
+
if z_idx is not None:
|
66 |
+
state = np.random.RandomState(seed=z_idx[idx])
|
67 |
+
z = state.normal(size=(1, G.z_channels)).astype(np.float32)
|
68 |
+
z = tops.to_cuda(torch.from_numpy(z))
|
69 |
+
|
70 |
+
if all_styles is not None:
|
71 |
+
anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
|
72 |
+
elif multi_modal_truncation:
|
73 |
+
w_indices = None
|
74 |
+
if z_idx is not None:
|
75 |
+
w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
|
76 |
+
anonymized_im = G.multi_modal_truncate(
|
77 |
+
**batch, truncation_value=truncation_value,
|
78 |
+
w_indices=w_indices,
|
79 |
+
z=z
|
80 |
+
)["img"]
|
81 |
+
else:
|
82 |
+
anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
|
83 |
+
anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
|
84 |
+
return anonymized_im
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def anonymize_detections(self,
|
88 |
+
im, detection,
|
89 |
+
update_identity=None,
|
90 |
+
**synthesis_kwargs
|
91 |
+
):
|
92 |
+
G = self.generators[type(detection)]
|
93 |
+
if G is None:
|
94 |
+
return im
|
95 |
+
C, H, W = im.shape
|
96 |
+
if update_identity is None:
|
97 |
+
update_identity = [True for i in range(len(detection))]
|
98 |
+
for idx in range(len(detection)):
|
99 |
+
if not update_identity[idx]:
|
100 |
+
continue
|
101 |
+
batch = detection.get_crop(idx, im)
|
102 |
+
x0, y0, x1, y1 = batch.pop("boxes")[0]
|
103 |
+
batch = {k: tops.to_cuda(v) for k, v in batch.items()}
|
104 |
+
anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx)
|
105 |
+
|
106 |
+
gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True)
|
107 |
+
mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
|
108 |
+
# Remove padding
|
109 |
+
pad = [max(-x0, 0), max(-y0, 0)]
|
110 |
+
pad = [*pad, max(x1-W, 0), max(y1-H, 0)]
|
111 |
+
def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
|
112 |
+
|
113 |
+
gim = remove_pad(gim)
|
114 |
+
mask = remove_pad(mask) > 0.5
|
115 |
+
x0, y0 = max(x0, 0), max(y0, 0)
|
116 |
+
x1, y1 = min(x1, W), min(y1, H)
|
117 |
+
mask = mask.logical_not()[None].repeat(3, 1, 1)
|
118 |
+
|
119 |
+
im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte()
|
120 |
+
return im
|
121 |
+
|
122 |
+
def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
|
123 |
+
all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
|
124 |
+
im = im.cpu()
|
125 |
+
for det in all_detections:
|
126 |
+
im = det.visualize(im)
|
127 |
+
return im
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor:
|
131 |
+
assert im.dtype == torch.uint8
|
132 |
+
im = tops.to_cuda(im)
|
133 |
+
all_detections = detections
|
134 |
+
if detections is None:
|
135 |
+
if self.load_cache:
|
136 |
+
all_detections = self.detector.forward_and_cache(im, cache_id)
|
137 |
+
else:
|
138 |
+
all_detections = self.detector(im)
|
139 |
+
if hasattr(self, "tracker") and track:
|
140 |
+
[_.pre_process() for _ in all_detections]
|
141 |
+
boxes = np.concatenate([_.boxes for _ in all_detections])
|
142 |
+
boxes = [Detection(box) for box in boxes]
|
143 |
+
self.tracker.step(boxes)
|
144 |
+
track_ids = self.tracker.detections_matched_ids
|
145 |
+
z_idx = []
|
146 |
+
for track_id in track_ids:
|
147 |
+
if track_id not in self.track_to_z_idx:
|
148 |
+
self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1)
|
149 |
+
z_idx.append(self.track_to_z_idx[track_id])
|
150 |
+
z_idx = np.array(z_idx)
|
151 |
+
idx_offset = 0
|
152 |
+
|
153 |
+
for detection in all_detections:
|
154 |
+
zs = None
|
155 |
+
if hasattr(self, "tracker") and track:
|
156 |
+
zs = z_idx[idx_offset:idx_offset+len(detection)]
|
157 |
+
idx_offset += len(detection)
|
158 |
+
im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
|
159 |
+
|
160 |
+
return im.cpu()
|
161 |
+
|
162 |
+
def __call__(self, *args, **kwargs):
|
163 |
+
return self.forward(*args, **kwargs)
|
deep_privacy/dp2/anonymizer/histogram_match_anonymizers.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import tops
|
4 |
+
import numpy as np
|
5 |
+
from kornia.color import rgb_to_hsv
|
6 |
+
from dp2 import utils
|
7 |
+
from kornia.enhance import histogram
|
8 |
+
from .anonymizer import Anonymizer
|
9 |
+
import torchvision.transforms.functional as F
|
10 |
+
from skimage.exposure import match_histograms
|
11 |
+
from kornia.filters import gaussian_blur2d
|
12 |
+
|
13 |
+
|
14 |
+
class LatentHistogramMatchAnonymizer(Anonymizer):
|
15 |
+
|
16 |
+
def forward_G(
|
17 |
+
self,
|
18 |
+
G,
|
19 |
+
batch,
|
20 |
+
multi_modal_truncation: bool,
|
21 |
+
amp: bool,
|
22 |
+
z_idx: int,
|
23 |
+
truncation_value: float,
|
24 |
+
idx: int,
|
25 |
+
n_sampling_steps: int = 1,
|
26 |
+
all_styles=None,
|
27 |
+
):
|
28 |
+
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
|
29 |
+
batch["img"] = batch["img"].float()
|
30 |
+
batch["condition"] = batch["mask"].float() * batch["img"]
|
31 |
+
|
32 |
+
assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
|
33 |
+
real_hls = rgb_to_hsv(utils.denormalize_img(batch["img"]))
|
34 |
+
real_hls[:, 0] /= 2 * torch.pi
|
35 |
+
indices = [1, 2]
|
36 |
+
hist_kwargs = dict(
|
37 |
+
bins=torch.linspace(0, 1, 256, dtype=torch.float32, device=tops.get_device()),
|
38 |
+
bandwidth=torch.tensor(1., device=tops.get_device()))
|
39 |
+
real_hist = [histogram(real_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
|
40 |
+
for j in range(n_sampling_steps):
|
41 |
+
if j == 0:
|
42 |
+
if multi_modal_truncation:
|
43 |
+
w = G.style_net.multi_modal_truncate(
|
44 |
+
truncation_value=truncation_value, **batch, w_indices=None).detach()
|
45 |
+
else:
|
46 |
+
w = G.style_net.get_truncated(truncation_value, **batch).detach()
|
47 |
+
assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
|
48 |
+
w.requires_grad = True
|
49 |
+
optim = torch.optim.Adam([w])
|
50 |
+
with torch.set_grad_enabled(True):
|
51 |
+
with torch.cuda.amp.autocast(amp):
|
52 |
+
anonymized_im = G(**batch, truncation_value=None, w=w)["img"]
|
53 |
+
fake_hls = rgb_to_hsv(anonymized_im*0.5 + 0.5)
|
54 |
+
fake_hls[:, 0] /= 2 * torch.pi
|
55 |
+
fake_hist = [histogram(fake_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
|
56 |
+
dist = sum([utils.torch_wasserstein_loss(r, f) for r, f in zip(real_hist, fake_hist)])
|
57 |
+
dist.backward()
|
58 |
+
if w.grad.sum() == 0:
|
59 |
+
break
|
60 |
+
assert w.grad.sum() != 0
|
61 |
+
optim.step()
|
62 |
+
optim.zero_grad()
|
63 |
+
if dist < 0.02:
|
64 |
+
break
|
65 |
+
anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
|
66 |
+
return anonymized_im
|
67 |
+
|
68 |
+
|
69 |
+
class HistogramMatchAnonymizer(Anonymizer):
|
70 |
+
|
71 |
+
def forward_G(self, batch, *args, **kwargs):
|
72 |
+
rimg = batch["img"]
|
73 |
+
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
|
74 |
+
batch["img"] = batch["img"].float()
|
75 |
+
batch["condition"] = batch["mask"].float() * batch["img"]
|
76 |
+
|
77 |
+
anonymized_im = super().forward_G(batch, *args, **kwargs)
|
78 |
+
|
79 |
+
equalized_gim = match_histograms(tops.im2numpy(anonymized_im.round().clamp(0, 255).byte()), tops.im2numpy(rimg))
|
80 |
+
if equalized_gim.dtype != np.uint8:
|
81 |
+
equalized_gim = equalized_gim.astype(np.float32)
|
82 |
+
assert equalized_gim.dtype == np.float32, equalized_gim.dtype
|
83 |
+
equalized_gim = tops.im2torch(equalized_gim, to_float=False)[0]
|
84 |
+
else:
|
85 |
+
equalized_gim = tops.im2torch(equalized_gim, to_float=False).float()[0]
|
86 |
+
equalized_gim = equalized_gim.to(device=rimg.device)
|
87 |
+
assert equalized_gim.dtype == torch.float32
|
88 |
+
gaussian_mask = 1 - (batch["maskrcnn_mask"][0].repeat(3, 1, 1) > 0.5).float()
|
89 |
+
|
90 |
+
gaussian_mask = gaussian_blur2d(gaussian_mask[None], kernel_size=[19, 19], sigma=[10, 10])[0]
|
91 |
+
gaussian_mask = gaussian_mask / gaussian_mask.max()
|
92 |
+
anonymized_im = gaussian_mask * equalized_gim + (1-gaussian_mask) * anonymized_im
|
93 |
+
return anonymized_im
|
deep_privacy/dp2/data/__init__.py
ADDED
File without changes
|
deep_privacy/dp2/data/build.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
from .utils import collate_fn
|
4 |
+
|
5 |
+
|
6 |
+
def get_dataloader(
|
7 |
+
dataset, gpu_transform: torch.nn.Module,
|
8 |
+
num_workers,
|
9 |
+
batch_size,
|
10 |
+
infinite: bool,
|
11 |
+
drop_last: bool,
|
12 |
+
prefetch_factor: int,
|
13 |
+
shuffle,
|
14 |
+
channels_last=False
|
15 |
+
):
|
16 |
+
sampler = None
|
17 |
+
dl_kwargs = dict(
|
18 |
+
pin_memory=True,
|
19 |
+
)
|
20 |
+
if infinite:
|
21 |
+
sampler = tops.InfiniteSampler(
|
22 |
+
dataset, rank=tops.rank(),
|
23 |
+
num_replicas=tops.world_size(),
|
24 |
+
shuffle=shuffle
|
25 |
+
)
|
26 |
+
elif tops.world_size() > 1:
|
27 |
+
sampler = torch.utils.data.DistributedSampler(
|
28 |
+
dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
|
29 |
+
dl_kwargs["drop_last"] = drop_last
|
30 |
+
else:
|
31 |
+
dl_kwargs["shuffle"] = shuffle
|
32 |
+
dl_kwargs["drop_last"] = drop_last
|
33 |
+
dataloader = torch.utils.data.DataLoader(
|
34 |
+
dataset, sampler=sampler, collate_fn=collate_fn,
|
35 |
+
batch_size=batch_size,
|
36 |
+
num_workers=num_workers, prefetch_factor=prefetch_factor,
|
37 |
+
**dl_kwargs
|
38 |
+
)
|
39 |
+
dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
|
40 |
+
return dataloader
|
deep_privacy/dp2/data/datasets/__init__.py
ADDED
File without changes
|
deep_privacy/dp2/data/datasets/coco_cse.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import torchvision
|
3 |
+
import torch
|
4 |
+
import pathlib
|
5 |
+
import numpy as np
|
6 |
+
from typing import Callable, Optional, Union
|
7 |
+
from torch.hub import get_dir as get_hub_dir
|
8 |
+
|
9 |
+
|
10 |
+
def cache_embed_stats(embed_map: torch.Tensor):
|
11 |
+
mean = embed_map.mean(dim=0, keepdim=True)
|
12 |
+
rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
|
13 |
+
|
14 |
+
cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
|
15 |
+
path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
|
16 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
17 |
+
torch.save(cache, path)
|
18 |
+
|
19 |
+
|
20 |
+
class CocoCSE(torch.utils.data.Dataset):
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
dirpath: Union[str, pathlib.Path],
|
24 |
+
transform: Optional[Callable],
|
25 |
+
normalize_E: bool,):
|
26 |
+
dirpath = pathlib.Path(dirpath)
|
27 |
+
self.dirpath = dirpath
|
28 |
+
|
29 |
+
self.transform = transform
|
30 |
+
assert self.dirpath.is_dir(),\
|
31 |
+
f"Did not find dataset at: {dirpath}"
|
32 |
+
self.image_paths, self.embedding_paths = self._load_impaths()
|
33 |
+
self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
|
34 |
+
mean = self.embed_map.mean(dim=0, keepdim=True)
|
35 |
+
rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
|
36 |
+
self.embed_map = (self.embed_map - mean) * rstd
|
37 |
+
cache_embed_stats(self.embed_map)
|
38 |
+
|
39 |
+
def _load_impaths(self):
|
40 |
+
image_dir = self.dirpath.joinpath("images")
|
41 |
+
image_paths = list(image_dir.glob("*.png"))
|
42 |
+
image_paths.sort()
|
43 |
+
embedding_paths = [
|
44 |
+
self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
|
45 |
+
]
|
46 |
+
return image_paths, embedding_paths
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.image_paths)
|
50 |
+
|
51 |
+
def __getitem__(self, idx):
|
52 |
+
im = torchvision.io.read_image(str(self.image_paths[idx]))
|
53 |
+
vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
|
54 |
+
vertices = torch.from_numpy(vertices.squeeze()).long()
|
55 |
+
mask = torch.from_numpy(mask.squeeze()).float()
|
56 |
+
border = torch.from_numpy(border.squeeze()).float()
|
57 |
+
E_mask = 1 - mask - border
|
58 |
+
batch = {
|
59 |
+
"img": im,
|
60 |
+
"vertices": vertices[None],
|
61 |
+
"mask": mask[None],
|
62 |
+
"embed_map": self.embed_map,
|
63 |
+
"border": border[None],
|
64 |
+
"E_mask": E_mask[None]
|
65 |
+
}
|
66 |
+
if self.transform is None:
|
67 |
+
return batch
|
68 |
+
return self.transform(batch)
|
deep_privacy/dp2/data/datasets/fdf.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from typing import Tuple
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import pathlib
|
6 |
+
try:
|
7 |
+
import pyspng
|
8 |
+
PYSPNG_IMPORTED = True
|
9 |
+
except ImportError:
|
10 |
+
PYSPNG_IMPORTED = False
|
11 |
+
print("Could not load pyspng. Defaulting to pillow image backend.")
|
12 |
+
from PIL import Image
|
13 |
+
from tops import logger
|
14 |
+
|
15 |
+
|
16 |
+
class FDFDataset:
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
dirpath,
|
20 |
+
imsize: Tuple[int],
|
21 |
+
load_keypoints: bool,
|
22 |
+
transform):
|
23 |
+
dirpath = pathlib.Path(dirpath)
|
24 |
+
self.dirpath = dirpath
|
25 |
+
self.transform = transform
|
26 |
+
self.imsize = imsize[0]
|
27 |
+
self.load_keypoints = load_keypoints
|
28 |
+
assert self.dirpath.is_dir(),\
|
29 |
+
f"Did not find dataset at: {dirpath}"
|
30 |
+
image_dir = self.dirpath.joinpath("images", str(self.imsize))
|
31 |
+
self.image_paths = list(image_dir.glob("*.png"))
|
32 |
+
assert len(self.image_paths) > 0,\
|
33 |
+
f"Did not find images in: {image_dir}"
|
34 |
+
self.image_paths.sort(key=lambda x: int(x.stem))
|
35 |
+
self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
|
36 |
+
|
37 |
+
self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch"))
|
38 |
+
assert len(self.image_paths) == len(self.bounding_boxes)
|
39 |
+
assert len(self.image_paths) == len(self.landmarks)
|
40 |
+
logger.log(
|
41 |
+
f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}")
|
42 |
+
|
43 |
+
def get_mask(self, idx):
|
44 |
+
mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool)
|
45 |
+
bounding_box = self.bounding_boxes[idx]
|
46 |
+
x0, y0, x1, y1 = bounding_box
|
47 |
+
mask[:, y0:y1, x0:x1] = 0
|
48 |
+
return mask
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.image_paths)
|
52 |
+
|
53 |
+
def __getitem__(self, index):
|
54 |
+
impath = self.image_paths[index]
|
55 |
+
if PYSPNG_IMPORTED:
|
56 |
+
with open(impath, "rb") as fp:
|
57 |
+
im = pyspng.load(fp.read())
|
58 |
+
else:
|
59 |
+
with Image.open(impath) as fp:
|
60 |
+
im = np.array(fp)
|
61 |
+
im = torch.from_numpy(np.rollaxis(im, -1, 0))
|
62 |
+
masks = self.get_mask(index)
|
63 |
+
landmark = self.landmarks[index]
|
64 |
+
batch = {
|
65 |
+
"img": im,
|
66 |
+
"mask": masks,
|
67 |
+
}
|
68 |
+
if self.load_keypoints:
|
69 |
+
batch["keypoints"] = landmark
|
70 |
+
if self.transform is None:
|
71 |
+
return batch
|
72 |
+
return self.transform(batch)
|
73 |
+
|
74 |
+
|
75 |
+
class FDF256Dataset:
|
76 |
+
|
77 |
+
def __init__(self,
|
78 |
+
dirpath,
|
79 |
+
load_keypoints: bool,
|
80 |
+
transform):
|
81 |
+
dirpath = pathlib.Path(dirpath)
|
82 |
+
self.dirpath = dirpath
|
83 |
+
self.transform = transform
|
84 |
+
self.load_keypoints = load_keypoints
|
85 |
+
assert self.dirpath.is_dir(),\
|
86 |
+
f"Did not find dataset at: {dirpath}"
|
87 |
+
image_dir = self.dirpath.joinpath("images")
|
88 |
+
self.image_paths = list(image_dir.glob("*.png"))
|
89 |
+
assert len(self.image_paths) > 0,\
|
90 |
+
f"Did not find images in: {image_dir}"
|
91 |
+
self.image_paths.sort(key=lambda x: int(x.stem))
|
92 |
+
self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
|
93 |
+
self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy")))
|
94 |
+
assert len(self.image_paths) == len(self.bounding_boxes)
|
95 |
+
assert len(self.image_paths) == len(self.landmarks)
|
96 |
+
logger.log(
|
97 |
+
f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}")
|
98 |
+
|
99 |
+
def get_mask(self, idx):
|
100 |
+
mask = torch.ones((1, 256, 256), dtype=torch.bool)
|
101 |
+
bounding_box = self.bounding_boxes[idx]
|
102 |
+
x0, y0, x1, y1 = bounding_box
|
103 |
+
mask[:, y0:y1, x0:x1] = 0
|
104 |
+
return mask
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.image_paths)
|
108 |
+
|
109 |
+
def __getitem__(self, index):
|
110 |
+
impath = self.image_paths[index]
|
111 |
+
if PYSPNG_IMPORTED:
|
112 |
+
with open(impath, "rb") as fp:
|
113 |
+
im = pyspng.load(fp.read())
|
114 |
+
else:
|
115 |
+
with Image.open(impath) as fp:
|
116 |
+
im = np.array(fp)
|
117 |
+
im = torch.from_numpy(np.rollaxis(im, -1, 0))
|
118 |
+
masks = self.get_mask(index)
|
119 |
+
landmark = self.landmarks[index]
|
120 |
+
batch = {
|
121 |
+
"img": im,
|
122 |
+
"mask": masks,
|
123 |
+
}
|
124 |
+
if self.load_keypoints:
|
125 |
+
batch["keypoints"] = landmark
|
126 |
+
if self.transform is None:
|
127 |
+
return batch
|
128 |
+
return self.transform(batch)
|
deep_privacy/dp2/data/datasets/fdf128_wds.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
import numpy as np
|
4 |
+
import io
|
5 |
+
import webdataset as wds
|
6 |
+
import os
|
7 |
+
from ..utils import png_decoder, get_num_workers, collate_fn
|
8 |
+
|
9 |
+
|
10 |
+
def kp_decoder(x):
|
11 |
+
# Keypoints are between [0, 1] for webdataset
|
12 |
+
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float().view(7, 2).clamp(0, 1)
|
13 |
+
keypoints = torch.cat((keypoints, torch.ones((7, 1))), dim=-1)
|
14 |
+
return keypoints
|
15 |
+
|
16 |
+
|
17 |
+
def bbox_decoder(x):
|
18 |
+
return torch.from_numpy(np.load(io.BytesIO(x))).float().view(4)
|
19 |
+
|
20 |
+
|
21 |
+
class BBoxToMask:
|
22 |
+
|
23 |
+
def __call__(self, sample):
|
24 |
+
imsize = sample["image.png"].shape[-1]
|
25 |
+
bbox = sample["bounding_box.npy"] * imsize
|
26 |
+
x0, y0, x1, y1 = np.round(bbox).astype(np.int64)
|
27 |
+
mask = torch.ones((1, imsize, imsize), dtype=torch.bool)
|
28 |
+
mask[:, y0:y1, x0:x1] = 0
|
29 |
+
sample["mask"] = mask
|
30 |
+
return sample
|
31 |
+
|
32 |
+
|
33 |
+
def get_dataloader_fdf_wds(
|
34 |
+
path,
|
35 |
+
batch_size: int,
|
36 |
+
num_workers: int,
|
37 |
+
transform: torch.nn.Module,
|
38 |
+
gpu_transform: torch.nn.Module,
|
39 |
+
infinite: bool,
|
40 |
+
shuffle: bool,
|
41 |
+
partial_batches: bool,
|
42 |
+
sample_shuffle=10_000,
|
43 |
+
tar_shuffle=100,
|
44 |
+
channels_last=False,
|
45 |
+
):
|
46 |
+
# Need to set this for split_by_node to work.
|
47 |
+
os.environ["RANK"] = str(tops.rank())
|
48 |
+
os.environ["WORLD_SIZE"] = str(tops.world_size())
|
49 |
+
if infinite:
|
50 |
+
pipeline = [wds.ResampledShards(str(path))]
|
51 |
+
else:
|
52 |
+
pipeline = [wds.SimpleShardList(str(path))]
|
53 |
+
if shuffle:
|
54 |
+
pipeline.append(wds.shuffle(tar_shuffle))
|
55 |
+
pipeline.extend([
|
56 |
+
wds.split_by_node,
|
57 |
+
wds.split_by_worker,
|
58 |
+
])
|
59 |
+
if shuffle:
|
60 |
+
pipeline.append(wds.shuffle(sample_shuffle))
|
61 |
+
|
62 |
+
decoder = [
|
63 |
+
wds.handle_extension("image.png", png_decoder),
|
64 |
+
wds.handle_extension("keypoints.npy", kp_decoder),
|
65 |
+
]
|
66 |
+
|
67 |
+
rename_keys = [
|
68 |
+
["img", "image.png"],
|
69 |
+
["keypoints", "keypoints.npy"],
|
70 |
+
["__key__", "__key__"],
|
71 |
+
["mask", "mask"]
|
72 |
+
]
|
73 |
+
|
74 |
+
pipeline.extend([
|
75 |
+
wds.tarfile_to_samples(),
|
76 |
+
wds.decode(*decoder),
|
77 |
+
])
|
78 |
+
pipeline.append(wds.map(BBoxToMask()))
|
79 |
+
pipeline.extend([
|
80 |
+
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
|
81 |
+
wds.rename_keys(*rename_keys),
|
82 |
+
])
|
83 |
+
|
84 |
+
if transform is not None:
|
85 |
+
pipeline.append(wds.map(transform))
|
86 |
+
pipeline = wds.DataPipeline(*pipeline)
|
87 |
+
if infinite:
|
88 |
+
pipeline = pipeline.repeat(nepochs=1000000)
|
89 |
+
|
90 |
+
loader = wds.WebLoader(
|
91 |
+
pipeline, batch_size=None, shuffle=False,
|
92 |
+
num_workers=get_num_workers(num_workers),
|
93 |
+
persistent_workers=True,
|
94 |
+
)
|
95 |
+
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
|
96 |
+
return loader
|
deep_privacy/dp2/data/datasets/fdh.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
import numpy as np
|
4 |
+
import io
|
5 |
+
import webdataset as wds
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
from pathlib import Path
|
9 |
+
from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
|
10 |
+
|
11 |
+
|
12 |
+
def kp_decoder(x):
|
13 |
+
# Keypoints are between [0, 1] for webdataset
|
14 |
+
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
|
15 |
+
def check_outside(x): return (x < 0).logical_or(x > 1)
|
16 |
+
is_outside = check_outside(keypoints[:, 0]).logical_or(
|
17 |
+
check_outside(keypoints[:, 1])
|
18 |
+
)
|
19 |
+
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
|
20 |
+
return keypoints
|
21 |
+
|
22 |
+
|
23 |
+
def vertices_decoder(x):
|
24 |
+
vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
|
25 |
+
return vertices.squeeze()[None]
|
26 |
+
|
27 |
+
|
28 |
+
class InsertNewKeypoints:
|
29 |
+
|
30 |
+
def __init__(self, keypoints_path: Path) -> None:
|
31 |
+
with open(keypoints_path, "r") as fp:
|
32 |
+
self.keypoints = json.load(fp)
|
33 |
+
|
34 |
+
def __call__(self, sample):
|
35 |
+
key = sample["__key__"]
|
36 |
+
keypoints = torch.tensor(self.keypoints[key], dtype=torch.float32)
|
37 |
+
def check_outside(x): return (x < 0).logical_or(x > 1)
|
38 |
+
is_outside = check_outside(keypoints[:, 0]).logical_or(
|
39 |
+
check_outside(keypoints[:, 1])
|
40 |
+
)
|
41 |
+
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
|
42 |
+
|
43 |
+
sample["keypoints.npy"] = keypoints
|
44 |
+
return sample
|
45 |
+
|
46 |
+
|
47 |
+
def get_dataloader_fdh_wds(
|
48 |
+
path,
|
49 |
+
batch_size: int,
|
50 |
+
num_workers: int,
|
51 |
+
transform: torch.nn.Module,
|
52 |
+
gpu_transform: torch.nn.Module,
|
53 |
+
infinite: bool,
|
54 |
+
shuffle: bool,
|
55 |
+
partial_batches: bool,
|
56 |
+
load_embedding: bool,
|
57 |
+
sample_shuffle=10_000,
|
58 |
+
tar_shuffle=100,
|
59 |
+
read_condition=False,
|
60 |
+
channels_last=False,
|
61 |
+
load_new_keypoints=False,
|
62 |
+
keypoints_split=None,
|
63 |
+
):
|
64 |
+
# Need to set this for split_by_node to work.
|
65 |
+
os.environ["RANK"] = str(tops.rank())
|
66 |
+
os.environ["WORLD_SIZE"] = str(tops.world_size())
|
67 |
+
if infinite:
|
68 |
+
pipeline = [wds.ResampledShards(str(path))]
|
69 |
+
else:
|
70 |
+
pipeline = [wds.SimpleShardList(str(path))]
|
71 |
+
if shuffle:
|
72 |
+
pipeline.append(wds.shuffle(tar_shuffle))
|
73 |
+
pipeline.extend([
|
74 |
+
wds.split_by_node,
|
75 |
+
wds.split_by_worker,
|
76 |
+
])
|
77 |
+
if shuffle:
|
78 |
+
pipeline.append(wds.shuffle(sample_shuffle))
|
79 |
+
|
80 |
+
decoder = [
|
81 |
+
wds.handle_extension("image.png", png_decoder),
|
82 |
+
wds.handle_extension("mask.png", mask_decoder),
|
83 |
+
wds.handle_extension("maskrcnn_mask.png", mask_decoder),
|
84 |
+
wds.handle_extension("keypoints.npy", kp_decoder),
|
85 |
+
]
|
86 |
+
|
87 |
+
rename_keys = [
|
88 |
+
["img", "image.png"], ["mask", "mask.png"],
|
89 |
+
["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"],
|
90 |
+
["__key__", "__key__"]
|
91 |
+
]
|
92 |
+
if load_embedding:
|
93 |
+
decoder.extend([
|
94 |
+
wds.handle_extension("vertices.npy", vertices_decoder),
|
95 |
+
wds.handle_extension("E_mask.png", mask_decoder)
|
96 |
+
])
|
97 |
+
rename_keys.extend([
|
98 |
+
["vertices", "vertices.npy"],
|
99 |
+
["E_mask", "e_mask.png"]
|
100 |
+
])
|
101 |
+
|
102 |
+
if read_condition:
|
103 |
+
decoder.append(
|
104 |
+
wds.handle_extension("condition.png", png_decoder)
|
105 |
+
)
|
106 |
+
rename_keys.append(["condition", "condition.png"])
|
107 |
+
|
108 |
+
pipeline.extend([
|
109 |
+
wds.tarfile_to_samples(),
|
110 |
+
wds.decode(*decoder),
|
111 |
+
|
112 |
+
])
|
113 |
+
if load_new_keypoints:
|
114 |
+
assert keypoints_split in ["train", "val"]
|
115 |
+
keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/1eb88522-8b91-49c7-b56a-ed98a9c7888cef9c0429-a385-4248-abe3-8682de26d041f268aed1-7c88-4677-baad-7623c2ee330f"
|
116 |
+
file_name = "fdh_keypoints_val-050133b34d.json"
|
117 |
+
if keypoints_split == "train":
|
118 |
+
keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/3e828b1c-d6c0-4622-90bc-1b2cce48ccfff14ab45d-0a5c-431d-be13-7e60580765bd7938601c-e72e-41d9-8836-fffc49e76f58"
|
119 |
+
file_name = "fdh_keypoints_train-2cff11f69a.json"
|
120 |
+
# Set check_hash=True if you suspect download is incorrect.
|
121 |
+
filepath = tops.download_file(keypoint_url, file_name=file_name, check_hash=False)
|
122 |
+
pipeline.append(
|
123 |
+
wds.map(InsertNewKeypoints(filepath))
|
124 |
+
)
|
125 |
+
pipeline.extend([
|
126 |
+
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
|
127 |
+
wds.rename_keys(*rename_keys),
|
128 |
+
])
|
129 |
+
|
130 |
+
if transform is not None:
|
131 |
+
pipeline.append(wds.map(transform))
|
132 |
+
pipeline = wds.DataPipeline(*pipeline)
|
133 |
+
if infinite:
|
134 |
+
pipeline = pipeline.repeat(nepochs=1000000)
|
135 |
+
|
136 |
+
loader = wds.WebLoader(
|
137 |
+
pipeline, batch_size=None, shuffle=False,
|
138 |
+
num_workers=get_num_workers(num_workers),
|
139 |
+
persistent_workers=True,
|
140 |
+
)
|
141 |
+
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
|
142 |
+
return loader
|
deep_privacy/dp2/data/transforms/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize
|
2 |
+
from .stylegan2_transform import StyleGANAugmentPipe
|
deep_privacy/dp2/data/transforms/functional.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.transforms.functional as F
|
2 |
+
import torch
|
3 |
+
import pickle
|
4 |
+
from tops import download_file, assert_shape
|
5 |
+
from typing import Dict
|
6 |
+
from functools import lru_cache
|
7 |
+
|
8 |
+
global symmetry_transform
|
9 |
+
|
10 |
+
|
11 |
+
@lru_cache(maxsize=1)
|
12 |
+
def get_symmetry_transform(symmetry_url):
|
13 |
+
file_name = download_file(symmetry_url)
|
14 |
+
with open(file_name, "rb") as fp:
|
15 |
+
symmetry = pickle.load(fp)
|
16 |
+
return torch.from_numpy(symmetry["vertex_transforms"]).long()
|
17 |
+
|
18 |
+
|
19 |
+
hflip_handled_cases = set([
|
20 |
+
"keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition",
|
21 |
+
"embedding", "vertx2cat", "maskrcnn_mask", "__key__"])
|
22 |
+
|
23 |
+
|
24 |
+
def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]:
|
25 |
+
container["img"] = F.hflip(container["img"])
|
26 |
+
if "condition" in container:
|
27 |
+
container["condition"] = F.hflip(container["condition"])
|
28 |
+
if "embedding" in container:
|
29 |
+
container["embedding"] = F.hflip(container["embedding"])
|
30 |
+
assert all([key in hflip_handled_cases for key in container]), container.keys()
|
31 |
+
if "keypoints" in container:
|
32 |
+
assert flip_map is not None
|
33 |
+
if container["keypoints"].ndim == 3:
|
34 |
+
keypoints = container["keypoints"][:, flip_map, :]
|
35 |
+
keypoints[:, :, 0] = 1 - keypoints[:, :, 0]
|
36 |
+
else:
|
37 |
+
assert_shape(container["keypoints"], (None, 3))
|
38 |
+
keypoints = container["keypoints"][flip_map, :]
|
39 |
+
keypoints[:, 0] = 1 - keypoints[:, 0]
|
40 |
+
container["keypoints"] = keypoints
|
41 |
+
if "mask" in container:
|
42 |
+
container["mask"] = F.hflip(container["mask"])
|
43 |
+
if "border" in container:
|
44 |
+
container["border"] = F.hflip(container["border"])
|
45 |
+
if "semantic_mask" in container:
|
46 |
+
container["semantic_mask"] = F.hflip(container["semantic_mask"])
|
47 |
+
if "vertices" in container:
|
48 |
+
symmetry_transform = get_symmetry_transform(
|
49 |
+
"https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl")
|
50 |
+
container["vertices"] = F.hflip(container["vertices"])
|
51 |
+
symmetry_transform_ = symmetry_transform.to(container["vertices"].device)
|
52 |
+
container["vertices"] = symmetry_transform_[container["vertices"].long()]
|
53 |
+
if "E_mask" in container:
|
54 |
+
container["E_mask"] = F.hflip(container["E_mask"])
|
55 |
+
if "maskrcnn_mask" in container:
|
56 |
+
container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"])
|
57 |
+
return container
|
deep_privacy/dp2/data/transforms/stylegan2_transform.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy.signal
|
3 |
+
import torch
|
4 |
+
try:
|
5 |
+
from sg3_torch_utils import misc
|
6 |
+
from sg3_torch_utils.ops import upfirdn2d
|
7 |
+
from sg3_torch_utils.ops import grid_sample_gradfix
|
8 |
+
from sg3_torch_utils.ops import conv2d_gradfix
|
9 |
+
except:
|
10 |
+
pass
|
11 |
+
#----------------------------------------------------------------------------
|
12 |
+
# Coefficients of various wavelet decomposition low-pass filters.
|
13 |
+
|
14 |
+
wavelets = {
|
15 |
+
'haar': [0.7071067811865476, 0.7071067811865476],
|
16 |
+
'db1': [0.7071067811865476, 0.7071067811865476],
|
17 |
+
'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
18 |
+
'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
19 |
+
'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
|
20 |
+
'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
|
21 |
+
'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
|
22 |
+
'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
|
23 |
+
'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
|
24 |
+
'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
25 |
+
'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
26 |
+
'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
|
27 |
+
'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
|
28 |
+
'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
|
29 |
+
'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
|
30 |
+
'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
|
31 |
+
}
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
# Helpers for constructing transformation matrices.
|
35 |
+
|
36 |
+
|
37 |
+
def matrix(*rows, device=None):
|
38 |
+
assert all(len(row) == len(rows[0]) for row in rows)
|
39 |
+
elems = [x for row in rows for x in row]
|
40 |
+
ref = [x for x in elems if isinstance(x, torch.Tensor)]
|
41 |
+
if len(ref) == 0:
|
42 |
+
return misc.constant(np.asarray(rows), device=device)
|
43 |
+
assert device is None or device == ref[0].device
|
44 |
+
elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
|
45 |
+
return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
|
46 |
+
|
47 |
+
|
48 |
+
def translate2d(tx, ty, **kwargs):
|
49 |
+
return matrix(
|
50 |
+
[1, 0, tx],
|
51 |
+
[0, 1, ty],
|
52 |
+
[0, 0, 1],
|
53 |
+
**kwargs)
|
54 |
+
|
55 |
+
|
56 |
+
def translate3d(tx, ty, tz, **kwargs):
|
57 |
+
return matrix(
|
58 |
+
[1, 0, 0, tx],
|
59 |
+
[0, 1, 0, ty],
|
60 |
+
[0, 0, 1, tz],
|
61 |
+
[0, 0, 0, 1],
|
62 |
+
**kwargs)
|
63 |
+
|
64 |
+
|
65 |
+
def scale2d(sx, sy, **kwargs):
|
66 |
+
return matrix(
|
67 |
+
[sx, 0, 0],
|
68 |
+
[0, sy, 0],
|
69 |
+
[0, 0, 1],
|
70 |
+
**kwargs)
|
71 |
+
|
72 |
+
|
73 |
+
def scale3d(sx, sy, sz, **kwargs):
|
74 |
+
return matrix(
|
75 |
+
[sx, 0, 0, 0],
|
76 |
+
[0, sy, 0, 0],
|
77 |
+
[0, 0, sz, 0],
|
78 |
+
[0, 0, 0, 1],
|
79 |
+
**kwargs)
|
80 |
+
|
81 |
+
|
82 |
+
def rotate2d(theta, **kwargs):
|
83 |
+
return matrix(
|
84 |
+
[torch.cos(theta), torch.sin(-theta), 0],
|
85 |
+
[torch.sin(theta), torch.cos(theta), 0],
|
86 |
+
[0, 0, 1],
|
87 |
+
**kwargs)
|
88 |
+
|
89 |
+
|
90 |
+
def rotate3d(v, theta, **kwargs):
|
91 |
+
vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
|
92 |
+
s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
|
93 |
+
return matrix(
|
94 |
+
[vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
|
95 |
+
[vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
|
96 |
+
[vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
|
97 |
+
[0, 0, 0, 1],
|
98 |
+
**kwargs)
|
99 |
+
|
100 |
+
|
101 |
+
def translate2d_inv(tx, ty, **kwargs):
|
102 |
+
return translate2d(-tx, -ty, **kwargs)
|
103 |
+
|
104 |
+
|
105 |
+
def scale2d_inv(sx, sy, **kwargs):
|
106 |
+
return scale2d(1 / sx, 1 / sy, **kwargs)
|
107 |
+
|
108 |
+
|
109 |
+
def rotate2d_inv(theta, **kwargs):
|
110 |
+
return rotate2d(-theta, **kwargs)
|
111 |
+
|
112 |
+
|
113 |
+
class StyleGANAugmentPipe(torch.nn.Module):
|
114 |
+
def __init__(self,
|
115 |
+
rotate90=0, xint=0, xint_max=0.125,
|
116 |
+
scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
|
117 |
+
brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5,
|
118 |
+
hue_max=1, saturation_std=1,
|
119 |
+
imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
|
123 |
+
|
124 |
+
# Pixel blitting.
|
125 |
+
self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
|
126 |
+
self.xint = float(xint) # Probability multiplier for integer translation.
|
127 |
+
self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
|
128 |
+
|
129 |
+
# General geometric transformations.
|
130 |
+
self.scale = float(scale) # Probability multiplier for isotropic scaling.
|
131 |
+
self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
|
132 |
+
self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
|
133 |
+
self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
|
134 |
+
self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
|
135 |
+
self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
|
136 |
+
self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
|
137 |
+
self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
|
138 |
+
|
139 |
+
# Color transformations.
|
140 |
+
self.brightness = float(brightness) # Probability multiplier for brightness.
|
141 |
+
self.contrast = float(contrast) # Probability multiplier for contrast.
|
142 |
+
self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
|
143 |
+
self.hue = float(hue) # Probability multiplier for hue rotation.
|
144 |
+
self.saturation = float(saturation) # Probability multiplier for saturation.
|
145 |
+
self.brightness_std = float(brightness_std) # Standard deviation of brightness.
|
146 |
+
self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
|
147 |
+
self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
|
148 |
+
self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
|
149 |
+
|
150 |
+
# Image-space filtering.
|
151 |
+
self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
|
152 |
+
self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
|
153 |
+
self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
|
154 |
+
|
155 |
+
# Setup orthogonal lowpass filter for geometric augmentations.
|
156 |
+
self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
|
157 |
+
|
158 |
+
# Construct filter bank for image-space filtering.
|
159 |
+
Hz_lo = np.asarray(wavelets['sym2']) # H(z)
|
160 |
+
Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
|
161 |
+
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
|
162 |
+
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
|
163 |
+
Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
|
164 |
+
for i in range(1, Hz_fbank.shape[0]):
|
165 |
+
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
|
166 |
+
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
|
167 |
+
Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
|
168 |
+
self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
|
169 |
+
|
170 |
+
def forward(self, batch, debug_percentile=None):
|
171 |
+
images = batch["img"]
|
172 |
+
batch["vertices"] = batch["vertices"].float()
|
173 |
+
assert isinstance(images, torch.Tensor) and images.ndim == 4
|
174 |
+
batch_size, num_channels, height, width = images.shape
|
175 |
+
device = images.device
|
176 |
+
self.Hz_fbank = self.Hz_fbank.to(device)
|
177 |
+
self.Hz_geom = self.Hz_geom.to(device)
|
178 |
+
if debug_percentile is not None:
|
179 |
+
debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
|
180 |
+
|
181 |
+
# -------------------------------------
|
182 |
+
# Select parameters for pixel blitting.
|
183 |
+
# -------------------------------------
|
184 |
+
|
185 |
+
# Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
|
186 |
+
I_3 = torch.eye(3, device=device)
|
187 |
+
G_inv = I_3
|
188 |
+
|
189 |
+
# Apply integer translation with probability (xint * strength).
|
190 |
+
if self.xint > 0:
|
191 |
+
t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
|
192 |
+
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
|
193 |
+
if debug_percentile is not None:
|
194 |
+
t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
|
195 |
+
G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
|
196 |
+
|
197 |
+
# --------------------------------------------------------
|
198 |
+
# Select parameters for general geometric transformations.
|
199 |
+
# --------------------------------------------------------
|
200 |
+
|
201 |
+
# Apply isotropic scaling with probability (scale * strength).
|
202 |
+
if self.scale > 0:
|
203 |
+
s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
|
204 |
+
s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
|
205 |
+
if debug_percentile is not None:
|
206 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
|
207 |
+
G_inv = G_inv @ scale2d_inv(s, s)
|
208 |
+
|
209 |
+
# Apply pre-rotation with probability p_rot.
|
210 |
+
p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
|
211 |
+
if self.rotate > 0:
|
212 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
213 |
+
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
214 |
+
if debug_percentile is not None:
|
215 |
+
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
|
216 |
+
G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
|
217 |
+
|
218 |
+
# Apply anisotropic scaling with probability (aniso * strength).
|
219 |
+
if self.aniso > 0:
|
220 |
+
s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
|
221 |
+
s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
|
222 |
+
if debug_percentile is not None:
|
223 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
|
224 |
+
G_inv = G_inv @ scale2d_inv(s, 1 / s)
|
225 |
+
|
226 |
+
# Apply post-rotation with probability p_rot.
|
227 |
+
if self.rotate > 0:
|
228 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
229 |
+
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
230 |
+
if debug_percentile is not None:
|
231 |
+
theta = torch.zeros_like(theta)
|
232 |
+
G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
|
233 |
+
|
234 |
+
# Apply fractional translation with probability (xfrac * strength).
|
235 |
+
if self.xfrac > 0:
|
236 |
+
t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
|
237 |
+
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
|
238 |
+
if debug_percentile is not None:
|
239 |
+
t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
|
240 |
+
G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
|
241 |
+
|
242 |
+
# ----------------------------------
|
243 |
+
# Execute geometric transformations.
|
244 |
+
# ----------------------------------
|
245 |
+
|
246 |
+
# Execute if the transform is not identity.
|
247 |
+
if G_inv is not I_3:
|
248 |
+
# Calculate padding.
|
249 |
+
cx = (width - 1) / 2
|
250 |
+
cy = (height - 1) / 2
|
251 |
+
cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
|
252 |
+
cp = G_inv @ cp.t() # [batch, xyz, idx]
|
253 |
+
Hz_pad = self.Hz_geom.shape[0] // 4
|
254 |
+
margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
|
255 |
+
margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
|
256 |
+
margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
|
257 |
+
margin = margin.max(misc.constant([0, 0] * 2, device=device))
|
258 |
+
margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
|
259 |
+
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
|
260 |
+
|
261 |
+
# Pad image and adjust origin.
|
262 |
+
images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
|
263 |
+
batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0)
|
264 |
+
batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
|
265 |
+
batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
|
266 |
+
G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
|
267 |
+
|
268 |
+
# Upsample.
|
269 |
+
images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
|
270 |
+
batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest")
|
271 |
+
batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest")
|
272 |
+
batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest")
|
273 |
+
G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
|
274 |
+
G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
|
275 |
+
|
276 |
+
# Execute transformation.
|
277 |
+
shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
|
278 |
+
G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
|
279 |
+
grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
|
280 |
+
images = grid_sample_gradfix.grid_sample(images, grid)
|
281 |
+
|
282 |
+
batch["mask"] = torch.nn.functional.grid_sample(
|
283 |
+
input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
284 |
+
batch["E_mask"] = torch.nn.functional.grid_sample(
|
285 |
+
input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
286 |
+
batch["vertices"] = torch.nn.functional.grid_sample(
|
287 |
+
input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
288 |
+
|
289 |
+
|
290 |
+
# Downsample and crop.
|
291 |
+
images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
|
292 |
+
batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
293 |
+
batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
294 |
+
batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
295 |
+
# --------------------------------------------
|
296 |
+
# Select parameters for color transformations.
|
297 |
+
# --------------------------------------------
|
298 |
+
|
299 |
+
# Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
|
300 |
+
I_4 = torch.eye(4, device=device)
|
301 |
+
C = I_4
|
302 |
+
|
303 |
+
# Apply brightness with probability (brightness * strength).
|
304 |
+
if self.brightness > 0:
|
305 |
+
b = torch.randn([batch_size], device=device) * self.brightness_std
|
306 |
+
b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
|
307 |
+
if debug_percentile is not None:
|
308 |
+
b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
|
309 |
+
C = translate3d(b, b, b) @ C
|
310 |
+
|
311 |
+
# Apply contrast with probability (contrast * strength).
|
312 |
+
if self.contrast > 0:
|
313 |
+
c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
|
314 |
+
c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
|
315 |
+
if debug_percentile is not None:
|
316 |
+
c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
|
317 |
+
C = scale3d(c, c, c) @ C
|
318 |
+
|
319 |
+
# Apply luma flip with probability (lumaflip * strength).
|
320 |
+
v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
|
321 |
+
|
322 |
+
# Apply hue rotation with probability (hue * strength).
|
323 |
+
if self.hue > 0 and num_channels > 1:
|
324 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
|
325 |
+
theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
|
326 |
+
if debug_percentile is not None:
|
327 |
+
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
|
328 |
+
C = rotate3d(v, theta) @ C # Rotate around v.
|
329 |
+
|
330 |
+
# Apply saturation with probability (saturation * strength).
|
331 |
+
if self.saturation > 0 and num_channels > 1:
|
332 |
+
s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
|
333 |
+
s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
|
334 |
+
if debug_percentile is not None:
|
335 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
|
336 |
+
C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
|
337 |
+
|
338 |
+
# ------------------------------
|
339 |
+
# Execute color transformations.
|
340 |
+
# ------------------------------
|
341 |
+
|
342 |
+
# Execute if the transform is not identity.
|
343 |
+
if C is not I_4:
|
344 |
+
images = images.reshape([batch_size, num_channels, height * width])
|
345 |
+
if num_channels == 3:
|
346 |
+
images = C[:, :3, :3] @ images + C[:, :3, 3:]
|
347 |
+
elif num_channels == 1:
|
348 |
+
C = C[:, :3, :].mean(dim=1, keepdims=True)
|
349 |
+
images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
|
350 |
+
else:
|
351 |
+
raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
|
352 |
+
images = images.reshape([batch_size, num_channels, height, width])
|
353 |
+
|
354 |
+
# ----------------------
|
355 |
+
# Image-space filtering.
|
356 |
+
# ----------------------
|
357 |
+
|
358 |
+
if self.imgfilter > 0:
|
359 |
+
num_bands = self.Hz_fbank.shape[0]
|
360 |
+
assert len(self.imgfilter_bands) == num_bands
|
361 |
+
expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
|
362 |
+
|
363 |
+
# Apply amplification for each band with probability (imgfilter * strength * band_strength).
|
364 |
+
g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
|
365 |
+
for i, band_strength in enumerate(self.imgfilter_bands):
|
366 |
+
t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
|
367 |
+
t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
|
368 |
+
if debug_percentile is not None:
|
369 |
+
t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
|
370 |
+
t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
|
371 |
+
t[:, i] = t_i # Replace i'th element.
|
372 |
+
t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
|
373 |
+
g = g * t # Accumulate into global gain.
|
374 |
+
|
375 |
+
# Construct combined amplification filter.
|
376 |
+
Hz_prime = g @ self.Hz_fbank # [batch, tap]
|
377 |
+
Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
|
378 |
+
Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
|
379 |
+
|
380 |
+
# Apply filter.
|
381 |
+
p = self.Hz_fbank.shape[1] // 2
|
382 |
+
images = images.reshape([1, batch_size * num_channels, height, width])
|
383 |
+
images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
|
384 |
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
|
385 |
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
|
386 |
+
images = images.reshape([batch_size, num_channels, height, width])
|
387 |
+
|
388 |
+
# ------------------------
|
389 |
+
# Image-space corruptions.
|
390 |
+
# ------------------------
|
391 |
+
batch["img"] = images
|
392 |
+
batch["vertices"] = batch["vertices"].long()
|
393 |
+
batch["border"] = 1 - batch["E_mask"] - batch["mask"]
|
394 |
+
return batch
|
deep_privacy/dp2/data/transforms/transforms.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Dict, List
|
3 |
+
import torchvision
|
4 |
+
import torch
|
5 |
+
import tops
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
from .functional import hflip
|
8 |
+
import numpy as np
|
9 |
+
from dp2.utils.vis_utils import get_coco_keypoints
|
10 |
+
from PIL import Image, ImageDraw
|
11 |
+
from typing import Tuple
|
12 |
+
|
13 |
+
|
14 |
+
class RandomHorizontalFlip(torch.nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, p: float, flip_map=None, **kwargs):
|
17 |
+
super().__init__()
|
18 |
+
self.flip_ratio = p
|
19 |
+
self.flip_map = flip_map
|
20 |
+
if self.flip_ratio is None:
|
21 |
+
self.flip_ratio = 0.5
|
22 |
+
assert 0 <= self.flip_ratio <= 1
|
23 |
+
|
24 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
25 |
+
if torch.rand(1) > self.flip_ratio:
|
26 |
+
return container
|
27 |
+
return hflip(container, self.flip_map)
|
28 |
+
|
29 |
+
|
30 |
+
class CenterCrop(torch.nn.Module):
|
31 |
+
"""
|
32 |
+
Performs the transform on the image.
|
33 |
+
NOTE: Does not transform the mask to improve runtime.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, size: List[int]):
|
37 |
+
super().__init__()
|
38 |
+
self.size = tuple(size)
|
39 |
+
|
40 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
41 |
+
min_size = min(container["img"].shape[1], container["img"].shape[2])
|
42 |
+
if min_size < self.size[0]:
|
43 |
+
container["img"] = F.center_crop(container["img"], min_size)
|
44 |
+
container["img"] = F.resize(container["img"], self.size)
|
45 |
+
return container
|
46 |
+
container["img"] = F.center_crop(container["img"], self.size)
|
47 |
+
return container
|
48 |
+
|
49 |
+
|
50 |
+
class Resize(torch.nn.Module):
|
51 |
+
"""
|
52 |
+
Performs the transform on the image.
|
53 |
+
NOTE: Does not transform the mask to improve runtime.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
|
57 |
+
super().__init__()
|
58 |
+
self.size = tuple(size)
|
59 |
+
self.interpolation = interpolation
|
60 |
+
|
61 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
62 |
+
container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
|
63 |
+
if "semantic_mask" in container:
|
64 |
+
container["semantic_mask"] = F.resize(
|
65 |
+
container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
|
66 |
+
if "embedding" in container:
|
67 |
+
container["embedding"] = F.resize(
|
68 |
+
container["embedding"], self.size, self.interpolation)
|
69 |
+
if "mask" in container:
|
70 |
+
container["mask"] = F.resize(
|
71 |
+
container["mask"], self.size, F.InterpolationMode.NEAREST)
|
72 |
+
if "E_mask" in container:
|
73 |
+
container["E_mask"] = F.resize(
|
74 |
+
container["E_mask"], self.size, F.InterpolationMode.NEAREST)
|
75 |
+
if "maskrcnn_mask" in container:
|
76 |
+
container["maskrcnn_mask"] = F.resize(
|
77 |
+
container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
|
78 |
+
if "vertices" in container:
|
79 |
+
container["vertices"] = F.resize(
|
80 |
+
container["vertices"], self.size, F.InterpolationMode.NEAREST)
|
81 |
+
return container
|
82 |
+
|
83 |
+
def __repr__(self):
|
84 |
+
repr = super().__repr__()
|
85 |
+
vars_ = dict(size=self.size, interpolation=self.interpolation)
|
86 |
+
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
|
87 |
+
|
88 |
+
|
89 |
+
class Normalize(torch.nn.Module):
|
90 |
+
"""
|
91 |
+
Performs the transform on the image.
|
92 |
+
NOTE: Does not transform the mask to improve runtime.
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, mean, std, inplace, keys=["img"]):
|
96 |
+
super().__init__()
|
97 |
+
self.mean = mean
|
98 |
+
self.std = std
|
99 |
+
self.inplace = inplace
|
100 |
+
self.keys = keys
|
101 |
+
|
102 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
103 |
+
for key in self.keys:
|
104 |
+
container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
|
105 |
+
return container
|
106 |
+
|
107 |
+
def __repr__(self):
|
108 |
+
repr = super().__repr__()
|
109 |
+
vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
|
110 |
+
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
|
111 |
+
|
112 |
+
|
113 |
+
class ToFloat(torch.nn.Module):
|
114 |
+
|
115 |
+
def __init__(self, keys=["img"], norm=True) -> None:
|
116 |
+
super().__init__()
|
117 |
+
self.keys = keys
|
118 |
+
self.gain = 255 if norm else 1
|
119 |
+
|
120 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
121 |
+
for key in self.keys:
|
122 |
+
container[key] = container[key].float() / self.gain
|
123 |
+
return container
|
124 |
+
|
125 |
+
|
126 |
+
class RandomCrop(torchvision.transforms.RandomCrop):
|
127 |
+
"""
|
128 |
+
Performs the transform on the image.
|
129 |
+
NOTE: Does not transform the mask to improve runtime.
|
130 |
+
"""
|
131 |
+
|
132 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
133 |
+
container["img"] = super().forward(container["img"])
|
134 |
+
return container
|
135 |
+
|
136 |
+
|
137 |
+
class CreateCondition(torch.nn.Module):
|
138 |
+
|
139 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
140 |
+
if container["img"].dtype == torch.uint8:
|
141 |
+
container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
|
142 |
+
return container
|
143 |
+
container["condition"] = container["img"] * container["mask"]
|
144 |
+
return container
|
145 |
+
|
146 |
+
|
147 |
+
class CreateEmbedding(torch.nn.Module):
|
148 |
+
|
149 |
+
def __init__(self, embed_path: Path, cuda=True) -> None:
|
150 |
+
super().__init__()
|
151 |
+
self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
|
152 |
+
if cuda:
|
153 |
+
self.embed_map = tops.to_cuda(self.embed_map)
|
154 |
+
|
155 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
156 |
+
vertices = container["vertices"]
|
157 |
+
if vertices.ndim == 3:
|
158 |
+
embedding = self.embed_map[vertices.long()].squeeze(dim=0)
|
159 |
+
embedding = embedding.permute(2, 0, 1) * container["E_mask"]
|
160 |
+
pass
|
161 |
+
else:
|
162 |
+
assert vertices.ndim == 4
|
163 |
+
embedding = self.embed_map[vertices.long()].squeeze(dim=1)
|
164 |
+
embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
|
165 |
+
container["embedding"] = embedding
|
166 |
+
container["embed_map"] = self.embed_map.clone()
|
167 |
+
return container
|
168 |
+
|
169 |
+
|
170 |
+
class InsertJointMap(torch.nn.Module):
|
171 |
+
|
172 |
+
def __init__(self, imsize: Tuple) -> None:
|
173 |
+
super().__init__()
|
174 |
+
self.imsize = imsize
|
175 |
+
knames = get_coco_keypoints()[0]
|
176 |
+
knames = knames + ["neck", "mid_hip"]
|
177 |
+
connectivity = {
|
178 |
+
"nose": ["left_eye", "right_eye", "neck"],
|
179 |
+
"left_eye": ["right_eye", "left_ear"],
|
180 |
+
"right_eye": ["right_ear"],
|
181 |
+
"left_shoulder": ["right_shoulder", "left_elbow", "left_hip"],
|
182 |
+
"right_shoulder": ["right_elbow", "right_hip"],
|
183 |
+
"left_elbow": ["left_wrist"],
|
184 |
+
"right_elbow": ["right_wrist"],
|
185 |
+
"left_hip": ["right_hip", "left_knee"],
|
186 |
+
"right_hip": ["right_knee"],
|
187 |
+
"left_knee": ["left_ankle"],
|
188 |
+
"right_knee": ["right_ankle"],
|
189 |
+
"neck": ["mid_hip", "nose"],
|
190 |
+
}
|
191 |
+
category = {
|
192 |
+
("nose", "left_eye"): 0, # head
|
193 |
+
("nose", "right_eye"): 0, # head
|
194 |
+
("nose", "neck"): 0, # head
|
195 |
+
("left_eye", "right_eye"): 0, # head
|
196 |
+
("left_eye", "left_ear"): 0, # head
|
197 |
+
("right_eye", "right_ear"): 0, # head
|
198 |
+
("left_shoulder", "left_elbow"): 1, # left arm
|
199 |
+
("left_elbow", "left_wrist"): 1, # left arm
|
200 |
+
("right_shoulder", "right_elbow"): 2, # right arm
|
201 |
+
("right_elbow", "right_wrist"): 2, # right arm
|
202 |
+
("left_shoulder", "right_shoulder"): 3, # body
|
203 |
+
("left_shoulder", "left_hip"): 3, # body
|
204 |
+
("right_shoulder", "right_hip"): 3, # body
|
205 |
+
("left_hip", "right_hip"): 3, # body
|
206 |
+
("left_hip", "left_knee"): 4, # left leg
|
207 |
+
("left_knee", "left_ankle"): 4, # left leg
|
208 |
+
("right_hip", "right_knee"): 5, # right leg
|
209 |
+
("right_knee", "right_ankle"): 5, # right leg
|
210 |
+
("neck", "mid_hip"): 3, # body
|
211 |
+
("neck", "nose"): 0, # head
|
212 |
+
}
|
213 |
+
self.indices2category = {
|
214 |
+
tuple([knames.index(n) for n in k]): v for k, v in category.items()
|
215 |
+
}
|
216 |
+
self.connectivity_indices = {
|
217 |
+
knames.index(k): [knames.index(v_) for v_ in v]
|
218 |
+
for k, v in connectivity.items()
|
219 |
+
}
|
220 |
+
self.l_shoulder = knames.index("left_shoulder")
|
221 |
+
self.r_shoulder = knames.index("right_shoulder")
|
222 |
+
self.l_hip = knames.index("left_hip")
|
223 |
+
self.r_hip = knames.index("right_hip")
|
224 |
+
self.l_eye = knames.index("left_eye")
|
225 |
+
self.r_eye = knames.index("right_eye")
|
226 |
+
self.nose = knames.index("nose")
|
227 |
+
self.neck = knames.index("neck")
|
228 |
+
|
229 |
+
def create_joint_map(self, N, H, W, keypoints):
|
230 |
+
joint_maps = np.zeros((N, H, W), dtype=np.uint8)
|
231 |
+
for bidx, keypoints in enumerate(keypoints):
|
232 |
+
assert keypoints.shape == (17, 3), keypoints.shape
|
233 |
+
keypoints = torch.cat((keypoints, torch.zeros(2, 3)))
|
234 |
+
visible = keypoints[:, -1] > 0
|
235 |
+
|
236 |
+
if visible[self.l_shoulder] and visible[self.r_shoulder]:
|
237 |
+
neck = (keypoints[self.l_shoulder]
|
238 |
+
+ (keypoints[self.r_shoulder] - keypoints[self.l_shoulder]) / 2)
|
239 |
+
keypoints[-2] = neck
|
240 |
+
visible[-2] = 1
|
241 |
+
if visible[self.l_hip] and visible[self.r_hip]:
|
242 |
+
mhip = (keypoints[self.l_hip]
|
243 |
+
+ (keypoints[self.r_hip] - keypoints[self.l_hip]) / 2
|
244 |
+
)
|
245 |
+
keypoints[-1] = mhip
|
246 |
+
visible[-1] = 1
|
247 |
+
|
248 |
+
keypoints[:, 0] *= W
|
249 |
+
keypoints[:, 1] *= H
|
250 |
+
joint_map = Image.fromarray(np.zeros((H, W), dtype=np.uint8))
|
251 |
+
draw = ImageDraw.Draw(joint_map)
|
252 |
+
for fidx in self.connectivity_indices.keys():
|
253 |
+
for tidx in self.connectivity_indices[fidx]:
|
254 |
+
if visible[fidx] == 0 or visible[tidx] == 0:
|
255 |
+
continue
|
256 |
+
c = self.indices2category[(fidx, tidx)]
|
257 |
+
s = tuple(keypoints[fidx, :2].round().long().numpy().tolist())
|
258 |
+
e = tuple(keypoints[tidx, :2].round().long().numpy().tolist())
|
259 |
+
draw.line((s, e), width=1, fill=c + 1)
|
260 |
+
if visible[self.nose] == 0 and visible[self.neck] == 1:
|
261 |
+
m_eye = (
|
262 |
+
keypoints[self.l_eye]
|
263 |
+
+ (keypoints[self.r_eye] - keypoints[self.l_eye]) / 2
|
264 |
+
)
|
265 |
+
s = tuple(m_eye[:2].round().long().numpy().tolist())
|
266 |
+
e = tuple(keypoints[self.neck, :2].round().long().numpy().tolist())
|
267 |
+
c = self.indices2category[(self.nose, self.neck)]
|
268 |
+
draw.line((s, e), width=1, fill=c + 1)
|
269 |
+
joint_map = np.array(joint_map)
|
270 |
+
|
271 |
+
joint_maps[bidx] = np.array(joint_map)
|
272 |
+
return joint_maps[:, None]
|
273 |
+
|
274 |
+
def forward(self, batch):
|
275 |
+
batch["joint_map"] = torch.from_numpy(self.create_joint_map(
|
276 |
+
batch["img"].shape[0], *self.imsize, batch["keypoints"]))
|
277 |
+
return batch
|
deep_privacy/dp2/data/utils.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import multiprocessing
|
5 |
+
import io
|
6 |
+
from tops import logger
|
7 |
+
from torch.utils.data._utils.collate import default_collate
|
8 |
+
|
9 |
+
try:
|
10 |
+
import pyspng
|
11 |
+
|
12 |
+
PYSPNG_IMPORTED = True
|
13 |
+
except ImportError:
|
14 |
+
PYSPNG_IMPORTED = False
|
15 |
+
print("Could not load pyspng. Defaulting to pillow image backend.")
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
def get_fdf_keypoints():
|
20 |
+
return get_coco_keypoints()[:7]
|
21 |
+
|
22 |
+
|
23 |
+
def get_fdf_flipmap():
|
24 |
+
keypoints = get_fdf_keypoints()
|
25 |
+
keypoint_flip_map = {
|
26 |
+
"left_eye": "right_eye",
|
27 |
+
"left_ear": "right_ear",
|
28 |
+
"left_shoulder": "right_shoulder",
|
29 |
+
}
|
30 |
+
for key, value in list(keypoint_flip_map.items()):
|
31 |
+
keypoint_flip_map[value] = key
|
32 |
+
keypoint_flip_map["nose"] = "nose"
|
33 |
+
keypoint_flip_map_idx = []
|
34 |
+
for source in keypoints:
|
35 |
+
keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
|
36 |
+
return keypoint_flip_map_idx
|
37 |
+
|
38 |
+
|
39 |
+
def get_coco_keypoints():
|
40 |
+
return [
|
41 |
+
"nose",
|
42 |
+
"left_eye",
|
43 |
+
"right_eye", # 2
|
44 |
+
"left_ear",
|
45 |
+
"right_ear", # 4
|
46 |
+
"left_shoulder",
|
47 |
+
"right_shoulder", # 6
|
48 |
+
"left_elbow",
|
49 |
+
"right_elbow", # 8
|
50 |
+
"left_wrist",
|
51 |
+
"right_wrist", # 10
|
52 |
+
"left_hip",
|
53 |
+
"right_hip", # 12
|
54 |
+
"left_knee",
|
55 |
+
"right_knee", # 14
|
56 |
+
"left_ankle",
|
57 |
+
"right_ankle", # 16
|
58 |
+
]
|
59 |
+
|
60 |
+
|
61 |
+
def get_coco_flipmap():
|
62 |
+
keypoints = get_coco_keypoints()
|
63 |
+
keypoint_flip_map = {
|
64 |
+
"left_eye": "right_eye",
|
65 |
+
"left_ear": "right_ear",
|
66 |
+
"left_shoulder": "right_shoulder",
|
67 |
+
"left_elbow": "right_elbow",
|
68 |
+
"left_wrist": "right_wrist",
|
69 |
+
"left_hip": "right_hip",
|
70 |
+
"left_knee": "right_knee",
|
71 |
+
"left_ankle": "right_ankle",
|
72 |
+
}
|
73 |
+
for key, value in list(keypoint_flip_map.items()):
|
74 |
+
keypoint_flip_map[value] = key
|
75 |
+
keypoint_flip_map["nose"] = "nose"
|
76 |
+
keypoint_flip_map_idx = []
|
77 |
+
for source in keypoints:
|
78 |
+
keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
|
79 |
+
return keypoint_flip_map_idx
|
80 |
+
|
81 |
+
|
82 |
+
def mask_decoder(x):
|
83 |
+
mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
|
84 |
+
mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255.
|
85 |
+
return mask
|
86 |
+
|
87 |
+
|
88 |
+
def png_decoder(x):
|
89 |
+
if PYSPNG_IMPORTED:
|
90 |
+
return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
|
91 |
+
with Image.open(io.BytesIO(x)) as im:
|
92 |
+
im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
|
93 |
+
return im
|
94 |
+
|
95 |
+
|
96 |
+
def jpg_decoder(x):
|
97 |
+
with Image.open(io.BytesIO(x)) as im:
|
98 |
+
im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
|
99 |
+
return im
|
100 |
+
|
101 |
+
|
102 |
+
def get_num_workers(num_workers: int):
|
103 |
+
n_cpus = multiprocessing.cpu_count()
|
104 |
+
if num_workers > n_cpus:
|
105 |
+
logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
|
106 |
+
return n_cpus
|
107 |
+
return num_workers
|
108 |
+
|
109 |
+
|
110 |
+
def collate_fn(batch):
|
111 |
+
elem = batch[0]
|
112 |
+
ignore_keys = set(["embed_map", "vertx2cat"])
|
113 |
+
batch_ = {
|
114 |
+
key: default_collate([d[key] for d in batch])
|
115 |
+
for key in elem
|
116 |
+
if key not in ignore_keys
|
117 |
+
}
|
118 |
+
if "embed_map" in elem:
|
119 |
+
batch_["embed_map"] = elem["embed_map"]
|
120 |
+
if "vertx2cat" in elem:
|
121 |
+
batch_["vertx2cat"] = elem["vertx2cat"]
|
122 |
+
return batch_
|
deep_privacy/dp2/detection/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .cse_mask_face_detector import CSeMaskFaceDetector
|
2 |
+
from .person_detector import CSEPersonDetector
|
3 |
+
from .structures import PersonDetection, VehicleDetection, FaceDetection
|
deep_privacy/dp2/detection/base.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import torch
|
3 |
+
import lzma
|
4 |
+
from pathlib import Path
|
5 |
+
from tops import logger
|
6 |
+
|
7 |
+
|
8 |
+
class BaseDetector:
|
9 |
+
|
10 |
+
def __init__(self, cache_directory: str) -> None:
|
11 |
+
if cache_directory is not None:
|
12 |
+
self.cache_directory = Path(cache_directory, str(self.__class__.__name__))
|
13 |
+
self.cache_directory.mkdir(exist_ok=True, parents=True)
|
14 |
+
|
15 |
+
def save_to_cache(self, detection, cache_path: Path, after_preprocess=True):
|
16 |
+
logger.log(f"Caching detection to: {cache_path}")
|
17 |
+
with lzma.open(cache_path, "wb") as fp:
|
18 |
+
torch.save(
|
19 |
+
[det.state_dict(after_preprocess=after_preprocess) for det in detection], fp,
|
20 |
+
pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
21 |
+
|
22 |
+
def load_from_cache(self, cache_path: Path):
|
23 |
+
logger.log(f"Loading detection from cache path: {cache_path}")
|
24 |
+
with lzma.open(cache_path, "rb") as fp:
|
25 |
+
state_dict = torch.load(fp)
|
26 |
+
return [
|
27 |
+
state["cls"].from_state_dict(state_dict=state) for state in state_dict
|
28 |
+
]
|
29 |
+
|
30 |
+
def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool):
|
31 |
+
if cache_id is None:
|
32 |
+
return self.forward(im)
|
33 |
+
cache_path = self.cache_directory.joinpath(cache_id + ".torch")
|
34 |
+
if cache_path.is_file() and load_cache:
|
35 |
+
try:
|
36 |
+
return self.load_from_cache(cache_path)
|
37 |
+
except Exception as e:
|
38 |
+
logger.warn(f"The cache file was corrupted: {cache_path}")
|
39 |
+
exit()
|
40 |
+
detections = self.forward(im)
|
41 |
+
self.save_to_cache(detections, cache_path)
|
42 |
+
return detections
|
deep_privacy/dp2/detection/box_utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio):
|
5 |
+
x0, y0, x1, y1 = [int(_) for _ in bbox]
|
6 |
+
h, w = y1 - y0, x1 - x0
|
7 |
+
cur_ratio = h / w
|
8 |
+
|
9 |
+
if cur_ratio == target_aspect_ratio:
|
10 |
+
return [x0, y0, x1, y1]
|
11 |
+
if cur_ratio < target_aspect_ratio:
|
12 |
+
target_height = int(w*target_aspect_ratio)
|
13 |
+
y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
|
14 |
+
else:
|
15 |
+
target_width = int(h/target_aspect_ratio)
|
16 |
+
x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
|
17 |
+
return x0, y0, x1, y1
|
18 |
+
|
19 |
+
|
20 |
+
def expand_axis(start, end, target_width, limit):
|
21 |
+
# Can return a bbox outside of limit
|
22 |
+
cur_width = end - start
|
23 |
+
start = start - (target_width-cur_width)//2
|
24 |
+
end = end + (target_width-cur_width)//2
|
25 |
+
if end - start != target_width:
|
26 |
+
end += 1
|
27 |
+
assert end - start == target_width
|
28 |
+
if start < 0 and end > limit:
|
29 |
+
return start, end
|
30 |
+
if start < 0 and end < limit:
|
31 |
+
to_shift = min(0 - start, limit - end)
|
32 |
+
start += to_shift
|
33 |
+
end += to_shift
|
34 |
+
if end > limit and start > 0:
|
35 |
+
to_shift = min(end - limit, start)
|
36 |
+
end -= to_shift
|
37 |
+
start -= to_shift
|
38 |
+
assert end - start == target_width
|
39 |
+
return start, end
|
40 |
+
|
41 |
+
|
42 |
+
def expand_box(bbox, imshape, mask, percentage_background: float):
|
43 |
+
assert isinstance(bbox[0], int)
|
44 |
+
assert 0 < percentage_background < 1
|
45 |
+
# Percentage in S
|
46 |
+
mask_pixels = mask.long().sum().cpu()
|
47 |
+
total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
48 |
+
percentage_mask = mask_pixels / total_pixels
|
49 |
+
if (1 - percentage_mask) > percentage_background:
|
50 |
+
return bbox
|
51 |
+
target_pixels = mask_pixels / (1 - percentage_background)
|
52 |
+
x0, y0, x1, y1 = bbox
|
53 |
+
H = y1 - y0
|
54 |
+
W = x1 - x0
|
55 |
+
p = np.sqrt(target_pixels/(H*W))
|
56 |
+
target_width = int(np.ceil(p * W))
|
57 |
+
target_height = int(np.ceil(p * H))
|
58 |
+
x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
|
59 |
+
y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
|
60 |
+
return [x0, y0, x1, y1]
|
61 |
+
|
62 |
+
|
63 |
+
def expand_axises_by_percentage(bbox_XYXY, imshape, percentage):
|
64 |
+
x0, y0, x1, y1 = bbox_XYXY
|
65 |
+
H = y1 - y0
|
66 |
+
W = x1 - x0
|
67 |
+
expansion = int(((H*W)**0.5) * percentage)
|
68 |
+
new_width = W + expansion
|
69 |
+
new_height = H + expansion
|
70 |
+
x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1])
|
71 |
+
y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0])
|
72 |
+
return [x0, y0, x1, y1]
|
73 |
+
|
74 |
+
|
75 |
+
def get_expanded_bbox(
|
76 |
+
bbox_XYXY,
|
77 |
+
imshape,
|
78 |
+
mask,
|
79 |
+
percentage_background: float,
|
80 |
+
axis_minimum_expansion: float,
|
81 |
+
target_aspect_ratio: float):
|
82 |
+
bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist()
|
83 |
+
# Expand each axis of the bounding box by a minimum percentage
|
84 |
+
bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion)
|
85 |
+
# Find the minimum bbox with the aspect ratio. Can be outside of imshape
|
86 |
+
bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio)
|
87 |
+
# Expands square box such that X% of the bbox is background
|
88 |
+
bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background)
|
89 |
+
assert isinstance(bbox_XYXY[0], (int, np.int64))
|
90 |
+
return bbox_XYXY
|
91 |
+
|
92 |
+
|
93 |
+
def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape):
|
94 |
+
def area_inside_ratio(bbox, imshape):
|
95 |
+
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
96 |
+
area_inside = (min(bbox[2], imshape[1]) - max(0, bbox[0])) * (min(imshape[0], bbox[3]) - max(0, bbox[1]))
|
97 |
+
return area_inside / area
|
98 |
+
ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0])
|
99 |
+
area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
100 |
+
if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside:
|
101 |
+
return False
|
102 |
+
if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area:
|
103 |
+
return False
|
104 |
+
return True
|
deep_privacy/dp2/detection/box_utils_fdf.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The FDF dataset expands bound boxes differently from what is used for CSE.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def quadratic_bounding_box(x0, y0, width, height, imshape):
|
9 |
+
# We assume that we can create a image that is quadratic without
|
10 |
+
# minimizing any of the sides
|
11 |
+
assert width <= min(imshape[:2])
|
12 |
+
assert height <= min(imshape[:2])
|
13 |
+
min_side = min(height, width)
|
14 |
+
if height != width:
|
15 |
+
side_diff = abs(height - width)
|
16 |
+
# Want to extend the shortest side
|
17 |
+
if min_side == height:
|
18 |
+
# Vertical side
|
19 |
+
height += side_diff
|
20 |
+
if height > imshape[0]:
|
21 |
+
# Take full frame, and shrink width
|
22 |
+
y0 = 0
|
23 |
+
height = imshape[0]
|
24 |
+
|
25 |
+
side_diff = abs(height - width)
|
26 |
+
width -= side_diff
|
27 |
+
x0 += side_diff // 2
|
28 |
+
else:
|
29 |
+
y0 -= side_diff // 2
|
30 |
+
y0 = max(0, y0)
|
31 |
+
else:
|
32 |
+
# Horizontal side
|
33 |
+
width += side_diff
|
34 |
+
if width > imshape[1]:
|
35 |
+
# Take full frame width, and shrink height
|
36 |
+
x0 = 0
|
37 |
+
width = imshape[1]
|
38 |
+
|
39 |
+
side_diff = abs(height - width)
|
40 |
+
height -= side_diff
|
41 |
+
y0 += side_diff // 2
|
42 |
+
else:
|
43 |
+
x0 -= side_diff // 2
|
44 |
+
x0 = max(0, x0)
|
45 |
+
# Check that bbox goes outside image
|
46 |
+
x1 = x0 + width
|
47 |
+
y1 = y0 + height
|
48 |
+
if imshape[1] < x1:
|
49 |
+
diff = x1 - imshape[1]
|
50 |
+
x0 -= diff
|
51 |
+
if imshape[0] < y1:
|
52 |
+
diff = y1 - imshape[0]
|
53 |
+
y0 -= diff
|
54 |
+
assert x0 >= 0, "Bounding box outside image."
|
55 |
+
assert y0 >= 0, "Bounding box outside image."
|
56 |
+
assert x0 + width <= imshape[1], "Bounding box outside image."
|
57 |
+
assert y0 + height <= imshape[0], "Bounding box outside image."
|
58 |
+
return x0, y0, width, height
|
59 |
+
|
60 |
+
|
61 |
+
def expand_bounding_box(bbox, percentage, imshape):
|
62 |
+
orig_bbox = bbox.copy()
|
63 |
+
x0, y0, x1, y1 = bbox
|
64 |
+
width = x1 - x0
|
65 |
+
height = y1 - y0
|
66 |
+
x0, y0, width, height = quadratic_bounding_box(
|
67 |
+
x0, y0, width, height, imshape)
|
68 |
+
expanding_factor = int(max(height, width) * percentage)
|
69 |
+
|
70 |
+
possible_max_expansion = [(imshape[0] - width) // 2,
|
71 |
+
(imshape[1] - height) // 2,
|
72 |
+
expanding_factor]
|
73 |
+
|
74 |
+
expanding_factor = min(possible_max_expansion)
|
75 |
+
# Expand height
|
76 |
+
|
77 |
+
if expanding_factor > 0:
|
78 |
+
|
79 |
+
y0 = y0 - expanding_factor
|
80 |
+
y0 = max(0, y0)
|
81 |
+
|
82 |
+
height += expanding_factor * 2
|
83 |
+
if height > imshape[0]:
|
84 |
+
y0 -= (imshape[0] - height)
|
85 |
+
height = imshape[0]
|
86 |
+
|
87 |
+
if height + y0 > imshape[0]:
|
88 |
+
y0 -= (height + y0 - imshape[0])
|
89 |
+
|
90 |
+
# Expand width
|
91 |
+
x0 = x0 - expanding_factor
|
92 |
+
x0 = max(0, x0)
|
93 |
+
|
94 |
+
width += expanding_factor * 2
|
95 |
+
if width > imshape[1]:
|
96 |
+
x0 -= (imshape[1] - width)
|
97 |
+
width = imshape[1]
|
98 |
+
|
99 |
+
if width + x0 > imshape[1]:
|
100 |
+
x0 -= (width + x0 - imshape[1])
|
101 |
+
y1 = y0 + height
|
102 |
+
x1 = x0 + width
|
103 |
+
assert y0 >= 0, "Y0 is minus"
|
104 |
+
assert height <= imshape[0], "Height is larger than image."
|
105 |
+
assert x0 + width <= imshape[1]
|
106 |
+
assert y0 + height <= imshape[0]
|
107 |
+
assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!"
|
108 |
+
assert x0 >= 0, "Y0 is minus"
|
109 |
+
assert width <= imshape[1], "Height is larger than image."
|
110 |
+
# Check that original bbox is within new
|
111 |
+
x0_o, y0_o, x1_o, y1_o = orig_bbox
|
112 |
+
assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}"
|
113 |
+
assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}"
|
114 |
+
assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}"
|
115 |
+
assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}"
|
116 |
+
|
117 |
+
x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]]
|
118 |
+
x1 = x0 + width
|
119 |
+
y1 = y0 + height
|
120 |
+
return np.array([x0, y0, x1, y1])
|
121 |
+
|
122 |
+
|
123 |
+
def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
|
124 |
+
keypoint = keypoint[:, :3] # only nose + eyes are relevant
|
125 |
+
kp_X = keypoint[0, :]
|
126 |
+
kp_Y = keypoint[1, :]
|
127 |
+
within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1)
|
128 |
+
within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1)
|
129 |
+
return within_X and within_Y
|
130 |
+
|
131 |
+
|
132 |
+
def expand_bbox_simple(bbox, percentage):
|
133 |
+
x0, y0, x1, y1 = bbox.astype(float)
|
134 |
+
width = x1 - x0
|
135 |
+
height = y1 - y0
|
136 |
+
x_c = int(x0) + width // 2
|
137 |
+
y_c = int(y0) + height // 2
|
138 |
+
avg_size = max(width, height)
|
139 |
+
new_width = avg_size * (1 + percentage)
|
140 |
+
x0 = x_c - new_width // 2
|
141 |
+
y0 = y_c - new_width // 2
|
142 |
+
x1 = x_c + new_width // 2
|
143 |
+
y1 = y_c + new_width // 2
|
144 |
+
return np.array([x0, y0, x1, y1]).astype(int)
|
145 |
+
|
146 |
+
|
147 |
+
def pad_image(im, bbox, pad_value):
|
148 |
+
x0, y0, x1, y1 = bbox
|
149 |
+
if x0 < 0:
|
150 |
+
pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]),
|
151 |
+
dtype=np.uint8) + pad_value
|
152 |
+
im = np.concatenate((pad_im, im), axis=1)
|
153 |
+
x1 += abs(x0)
|
154 |
+
x0 = 0
|
155 |
+
if y0 < 0:
|
156 |
+
pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]),
|
157 |
+
dtype=np.uint8) + pad_value
|
158 |
+
im = np.concatenate((pad_im, im), axis=0)
|
159 |
+
y1 += abs(y0)
|
160 |
+
y0 = 0
|
161 |
+
if x1 >= im.shape[1]:
|
162 |
+
pad_im = np.zeros(
|
163 |
+
(im.shape[0], x1 - im.shape[1] + 1, im.shape[2]),
|
164 |
+
dtype=np.uint8) + pad_value
|
165 |
+
im = np.concatenate((im, pad_im), axis=1)
|
166 |
+
if y1 >= im.shape[0]:
|
167 |
+
pad_im = np.zeros(
|
168 |
+
(y1 - im.shape[0] + 1, im.shape[1], im.shape[2]),
|
169 |
+
dtype=np.uint8) + pad_value
|
170 |
+
im = np.concatenate((im, pad_im), axis=0)
|
171 |
+
return im[y0:y1, x0:x1]
|
172 |
+
|
173 |
+
|
174 |
+
def clip_box(bbox, im):
|
175 |
+
bbox[0] = max(0, bbox[0])
|
176 |
+
bbox[1] = max(0, bbox[1])
|
177 |
+
bbox[2] = min(im.shape[1] - 1, bbox[2])
|
178 |
+
bbox[3] = min(im.shape[0] - 1, bbox[3])
|
179 |
+
return bbox
|
180 |
+
|
181 |
+
|
182 |
+
def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True):
|
183 |
+
outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0]
|
184 |
+
if simple_expand or (outside_im and pad_im):
|
185 |
+
return pad_image(im, bbox, pad_value)
|
186 |
+
bbox = clip_box(bbox, im)
|
187 |
+
x0, y0, x1, y1 = bbox
|
188 |
+
return im[y0:y1, x0:x1]
|
189 |
+
|
190 |
+
|
191 |
+
def expand_bbox(
|
192 |
+
bbox_ltrb, imshape, simple_expand, default_to_simple=False,
|
193 |
+
expansion_factor=0.35):
|
194 |
+
assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox_ltrb.shape}"
|
195 |
+
bbox = bbox_ltrb.astype(float)
|
196 |
+
# FDF256 uses simple expand with ratio 0.4
|
197 |
+
if simple_expand:
|
198 |
+
return expand_bbox_simple(bbox, 0.4)
|
199 |
+
try:
|
200 |
+
return expand_bounding_box(bbox, expansion_factor, imshape)
|
201 |
+
except AssertionError:
|
202 |
+
return expand_bbox_simple(bbox, expansion_factor * 2)
|