haakohu commited on
Commit
5d756f1
1 Parent(s): 24ca44a
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +12 -0
  2. .gitignore +51 -0
  3. README.md +4 -4
  4. app.py +31 -0
  5. configs/anonymizers/FB_cse.py +28 -0
  6. configs/anonymizers/FB_cse_mask.py +29 -0
  7. configs/anonymizers/FB_cse_mask_face.py +29 -0
  8. configs/anonymizers/deep_privacy1.py +15 -0
  9. configs/anonymizers/face.py +17 -0
  10. configs/anonymizers/face_fdf128.py +18 -0
  11. configs/anonymizers/market1501/blackout.py +8 -0
  12. configs/anonymizers/market1501/person.py +6 -0
  13. configs/anonymizers/market1501/pixelation16.py +8 -0
  14. configs/anonymizers/market1501/pixelation8.py +8 -0
  15. configs/datasets/coco_cse.py +69 -0
  16. configs/datasets/fdf128.py +24 -0
  17. configs/datasets/fdf256.py +55 -0
  18. configs/datasets/fdh.py +90 -0
  19. configs/datasets/utils.py +21 -0
  20. configs/defaults.py +53 -0
  21. configs/discriminators/sg2_discriminator.py +43 -0
  22. configs/fdf/deep_privacy1.py +9 -0
  23. configs/fdf/stylegan.py +14 -0
  24. configs/fdf/stylegan_fdf128.py +17 -0
  25. configs/fdh/styleganL.py +16 -0
  26. configs/fdh/styleganL_nocse.py +14 -0
  27. configs/generators/stylegan_unet.py +22 -0
  28. dp2/__init__.py +0 -0
  29. dp2/anonymizer/__init__.py +1 -0
  30. dp2/anonymizer/anonymizer.py +163 -0
  31. dp2/anonymizer/histogram_match_anonymizers.py +93 -0
  32. dp2/data/__init__.py +0 -0
  33. dp2/data/build.py +40 -0
  34. dp2/data/datasets/__init__.py +0 -0
  35. dp2/data/datasets/coco_cse.py +68 -0
  36. dp2/data/datasets/fdf.py +128 -0
  37. dp2/data/datasets/fdf128_wds.py +96 -0
  38. dp2/data/datasets/fdh.py +142 -0
  39. dp2/data/transforms/__init__.py +2 -0
  40. dp2/data/transforms/functional.py +57 -0
  41. dp2/data/transforms/stylegan2_transform.py +394 -0
  42. dp2/data/transforms/transforms.py +277 -0
  43. dp2/data/utils.py +122 -0
  44. dp2/detection/__init__.py +3 -0
  45. dp2/detection/base.py +42 -0
  46. dp2/detection/box_utils.py +104 -0
  47. dp2/detection/box_utils_fdf.py +202 -0
  48. dp2/detection/cse_mask_face_detector.py +116 -0
  49. dp2/detection/deep_privacy1_detector.py +106 -0
  50. dp2/detection/face_detector.py +62 -0
.gitattributes CHANGED
@@ -32,3 +32,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ torch_home/hub/checkpoints/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c filter=lfs diff=lfs merge=lfs -text
36
+ torch_home/hub/checkpoints/Base-DensePose-RCNN-FPN-Human.yaml filter=lfs diff=lfs merge=lfs -text
37
+ torch_home/hub/checkpoints/Base-DensePose-RCNN-FPN.yaml filter=lfs diff=lfs merge=lfs -text
38
+ torch_home/hub/checkpoints/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml filter=lfs diff=lfs merge=lfs -text
39
+ torch_home/hub/checkpoints/model_final_1d3314.pkl filter=lfs diff=lfs merge=lfs -text
40
+ torch_home/hub/checkpoints/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e filter=lfs diff=lfs merge=lfs -text
41
+ torch_home/hub/checkpoints/WIDERFace_DSFD_RES152.pth filter=lfs diff=lfs merge=lfs -text
42
+ media2/stylemc_example.jpg filter=lfs diff=lfs merge=lfs -text
43
+ media2/erling.jpg filter=lfs diff=lfs merge=lfs -text
44
+ media2/g7_leaders.jpg filter=lfs diff=lfs merge=lfs -text
45
+ media2/regjeringen.jpg filter=lfs diff=lfs merge=lfs -text
46
+ media/ filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FILES
2
+ *.flist
3
+ *.zip
4
+ *.out
5
+ *.npy
6
+ *.gz
7
+ *.ckpt
8
+ *.log
9
+ *.pyc
10
+ *.csv
11
+ *.yml
12
+ *.ods
13
+ *.ods#
14
+ *.json
15
+ build_docker.sh
16
+
17
+ # Images / Videos
18
+ #*.png
19
+ #*.jpg
20
+ *.jpeg
21
+ *.m4a
22
+ *.mkv
23
+ *.mp4
24
+
25
+ # Directories created by inpaintron
26
+ .cache/
27
+ test_examples/
28
+ .vscode
29
+ __pycache__
30
+ .debug/
31
+ **/.ipynb_checkpoints/**
32
+ outputs/
33
+
34
+
35
+ # From pip setup
36
+ build/
37
+ *.egg-info
38
+ *.egg
39
+ .npm/
40
+
41
+ # From dockerfile
42
+ .bash_history
43
+ .viminfo
44
+ .local/
45
+ *.pickle
46
+ *.onnx
47
+
48
+
49
+ sbatch_files/
50
+ figures/
51
+ image_dump/
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Deep Privacy2 Face
3
- emoji: 👀
4
- colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.23.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Deep Privacy2
3
+ emoji: 📈
4
+ colorFrom: gray
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.9.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ import os
3
+ from tops.config import instantiate
4
+ import gradio.inputs
5
+ os.system("pip install --upgrade pip")
6
+ os.system("pip install ftfy regex tqdm")
7
+ os.system("pip install --no-deps git+https://github.com/openai/CLIP.git")
8
+ os.system("pip install git+https://github.com/facebookresearch/detectron2@96c752ce821a3340e27edd51c28a00665dd32a30#subdirectory=projects/DensePose")
9
+ os.system("pip install --no-deps git+https://github.com/hukkelas/DSFD-Pytorch-Inference")
10
+ os.environ["TORCH_HOME"] = "torch_home"
11
+ from dp2 import utils
12
+ from gradio_demos.modules import ExampleDemo, WebcamDemo
13
+
14
+ cfg_face = utils.load_config("configs/anonymizers/face.py")
15
+
16
+ anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
17
+
18
+ anonymizer_face.initialize_tracker(fps=1)
19
+
20
+
21
+ with gradio.Blocks() as demo:
22
+ gradio.Markdown("# <center> DeepPrivacy2 - Realistic Image Anonymization </center>")
23
+ gradio.Markdown("### <center> Håkon Hukkelås, Rudolf Mester, Frank Lindseth </center>")
24
+ gradio.Markdown("<center> See more information at: <a href='https://github.com/hukkelas/deep_privacy2'> https://github.com/hukkelas/deep_privacy2 </a> </center>")
25
+ gradio.Markdown("<center> For a demo of face anonymization, see: <a href='https://huggingface.co/spaces/haakohu/deep_privacy2_face'> https://huggingface.co/spaces/haakohu/deep_privacy2_face </a> </center>")
26
+ with gradio.Tab("Face Anonymization"):
27
+ ExampleDemo(anonymizer_face)
28
+ with gradio.Tab("Live Webcam"):
29
+ WebcamDemo(anonymizer_face)
30
+
31
+ demo.launch()
configs/anonymizers/FB_cse.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.person_detector import CSEPersonDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+ from dp2.generator.dummy_generators import MaskOutGenerator
6
+
7
+
8
+ maskout_G = L(MaskOutGenerator)(noise="constant")
9
+
10
+ detector = L(CSEPersonDetector)(
11
+ mask_rcnn_cfg=dict(),
12
+ cse_cfg=dict(),
13
+ cse_post_process_cfg=dict(
14
+ target_imsize=(288, 160),
15
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
16
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
17
+ iou_combine_threshold=0.4,
18
+ dilation_percentage=0.02,
19
+ normalize_embedding=False
20
+ ),
21
+ score_threshold=0.3,
22
+ cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
23
+ )
24
+
25
+ anonymizer = L(Anonymizer)(
26
+ detector="${detector}",
27
+ cse_person_G_cfg="configs/fdh/styleganL.py",
28
+ )
configs/anonymizers/FB_cse_mask.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.person_detector import CSEPersonDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+ from dp2.generator.dummy_generators import MaskOutGenerator
6
+
7
+
8
+ maskout_G = L(MaskOutGenerator)(noise="constant")
9
+
10
+ detector = L(CSEPersonDetector)(
11
+ mask_rcnn_cfg=dict(),
12
+ cse_cfg=dict(),
13
+ cse_post_process_cfg=dict(
14
+ target_imsize=(288, 160),
15
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
16
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
17
+ iou_combine_threshold=0.4,
18
+ dilation_percentage=0.02,
19
+ normalize_embedding=False
20
+ ),
21
+ score_threshold=0.3,
22
+ cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
23
+ )
24
+
25
+ anonymizer = L(Anonymizer)(
26
+ detector="${detector}",
27
+ person_G_cfg="configs/fdh/styleganL_nocse.py",
28
+ cse_person_G_cfg="configs/fdh/styleganL.py",
29
+ )
configs/anonymizers/FB_cse_mask_face.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+
6
+ detector = L(CSeMaskFaceDetector)(
7
+ mask_rcnn_cfg=dict(),
8
+ face_detector_cfg=dict(),
9
+ face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
10
+ cse_cfg=dict(),
11
+ cse_post_process_cfg=dict(
12
+ target_imsize=(288, 160),
13
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
14
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
15
+ iou_combine_threshold=0.4,
16
+ dilation_percentage=0.02,
17
+ normalize_embedding=False
18
+ ),
19
+ score_threshold=0.3,
20
+ cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache")
21
+ )
22
+
23
+ anonymizer = L(Anonymizer)(
24
+ detector="${detector}",
25
+ face_G_cfg="configs/fdf/stylegan.py",
26
+ person_G_cfg="configs/fdh/styleganL_nocse.py",
27
+ cse_person_G_cfg="configs/fdh/styleganL.py",
28
+ car_G_cfg="configs/generators/dummy/pixelation8.py"
29
+ )
configs/anonymizers/deep_privacy1.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .face_fdf128 import anonymizer, common, detector
2
+ from dp2.detection.deep_privacy1_detector import DeepPrivacy1Detector
3
+ from tops.config import LazyCall as L
4
+
5
+ anonymizer.update(
6
+ face_G_cfg="configs/fdf/deep_privacy1.py",
7
+ )
8
+
9
+ anonymizer.detector = L(DeepPrivacy1Detector)(
10
+ face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
11
+ face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True),
12
+ score_threshold=0.3,
13
+ keypoint_threshold=0.3,
14
+ cache_directory=common.output_dir.joinpath("deep_privacy1_cache")
15
+ )
configs/anonymizers/face.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.face_detector import FaceDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+
6
+
7
+ detector = L(FaceDetector)(
8
+ face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
9
+ face_post_process_cfg=dict(target_imsize=(256, 256), fdf128_expand=False),
10
+ score_threshold=0.3,
11
+ cache_directory=common.output_dir.joinpath("face_detection_cache"),
12
+ )
13
+
14
+ anonymizer = L(Anonymizer)(
15
+ detector="${detector}",
16
+ face_G_cfg="configs/fdf/stylegan.py",
17
+ )
configs/anonymizers/face_fdf128.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.face_detector import FaceDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+
6
+
7
+ detector = L(FaceDetector)(
8
+ face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
9
+ face_post_process_cfg=dict(target_imsize=(128, 128), fdf128_expand=True),
10
+ score_threshold=0.3,
11
+ cache_directory=common.output_dir.joinpath("face_detection_cache")
12
+ )
13
+
14
+
15
+ anonymizer = L(Anonymizer)(
16
+ detector="${detector}",
17
+ face_G_cfg="configs/fdf/stylegan_fdf128.py",
18
+ )
configs/anonymizers/market1501/blackout.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
7
+ anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py"
8
+ anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py"
configs/anonymizers/market1501/person.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
configs/anonymizers/market1501/pixelation16.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
7
+ anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py"
8
+ anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py"
configs/anonymizers/market1501/pixelation8.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
7
+ anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py"
8
+ anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py"
configs/datasets/coco_cse.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tops.config import LazyCall as L
4
+ import torch
5
+ import functools
6
+ from dp2.data.datasets.coco_cse import CocoCSE
7
+ from dp2.data.build import get_dataloader
8
+ from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
9
+ from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
10
+ from dp2.metrics.torch_metrics import compute_metrics_iteratively
11
+ from .utils import final_eval_fn
12
+
13
+
14
+ dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
15
+ metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
16
+ data_dir = Path(dataset_base_dir, "coco_cse")
17
+ data = dict(
18
+ imsize=(288, 160),
19
+ im_channels=3,
20
+ semantic_nc=26,
21
+ cse_nc=16,
22
+ train=dict(
23
+ dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
24
+ loader=L(get_dataloader)(
25
+ shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
26
+ batch_size="${train.batch_size}",
27
+ dataset="${..dataset}",
28
+ infinite=True,
29
+ gpu_transform=L(torch.nn.Sequential)(*[
30
+ L(ToFloat)(),
31
+ L(StyleGANAugmentPipe)(
32
+ rotate=0.5, rotate_max=.05,
33
+ xint=.5, xint_max=0.05,
34
+ scale=.5, scale_std=.05,
35
+ aniso=0.5, aniso_std=.05,
36
+ xfrac=.5, xfrac_std=.05,
37
+ brightness=.5, brightness_std=.05,
38
+ contrast=.5, contrast_std=.1,
39
+ hue=.5, hue_max=.05,
40
+ saturation=.5, saturation_std=.5,
41
+ imgfilter=.5, imgfilter_std=.1),
42
+ L(RandomHorizontalFlip)(p=0.5),
43
+ L(CreateEmbedding)(),
44
+ L(Resize)(size="${data.imsize}"),
45
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
46
+ L(CreateCondition)(),
47
+ ])
48
+ )
49
+ ),
50
+ val=dict(
51
+ dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
52
+ loader=L(get_dataloader)(
53
+ shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
54
+ batch_size="${train.batch_size}",
55
+ dataset="${..dataset}",
56
+ infinite=False,
57
+ gpu_transform=L(torch.nn.Sequential)(*[
58
+ L(ToFloat)(),
59
+ L(CreateEmbedding)(),
60
+ L(Resize)(size="${data.imsize}"),
61
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
62
+ L(CreateCondition)(),
63
+ ])
64
+ )
65
+ ),
66
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
67
+ train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
68
+ evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
69
+ )
configs/datasets/fdf128.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from functools import partial
3
+ from dp2.data.datasets.fdf import FDFDataset
4
+ from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn, train_eval_fn
5
+
6
+ data_dir = Path(dataset_base_dir, "fdf")
7
+ data.train.dataset.dirpath = data_dir.joinpath("train")
8
+ data.val.dataset.dirpath = data_dir.joinpath("val")
9
+ data.imsize = (128, 128)
10
+
11
+
12
+ data.train_evaluation_fn = partial(
13
+ train_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train"))
14
+ data.evaluation_fn = partial(
15
+ final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final"))
16
+
17
+ data.train.dataset.update(
18
+ _target_ = FDFDataset,
19
+ imsize="${data.imsize}"
20
+ )
21
+ data.val.dataset.update(
22
+ _target_ = FDFDataset,
23
+ imsize="${data.imsize}"
24
+ )
configs/datasets/fdf256.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tops.config import LazyCall as L
4
+ import torch
5
+ import functools
6
+ from dp2.data.datasets.fdf import FDF256Dataset
7
+ from dp2.data.build import get_dataloader
8
+ from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
9
+ from .utils import final_eval_fn, train_eval_fn
10
+
11
+
12
+ dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
13
+ metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
14
+ data_dir = Path(dataset_base_dir, "fdf256")
15
+ data = dict(
16
+ imsize=(256, 256),
17
+ im_channels=3,
18
+ semantic_nc=None,
19
+ cse_nc=None,
20
+ n_keypoints=None,
21
+ train=dict(
22
+ dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
23
+ loader=L(get_dataloader)(
24
+ shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
25
+ batch_size="${train.batch_size}",
26
+ dataset="${..dataset}",
27
+ infinite=True,
28
+ gpu_transform=L(torch.nn.Sequential)(*[
29
+ L(ToFloat)(),
30
+ L(RandomHorizontalFlip)(p=0.5),
31
+ L(Resize)(size="${data.imsize}"),
32
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
33
+ L(CreateCondition)(),
34
+ ])
35
+ )
36
+ ),
37
+ val=dict(
38
+ dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
39
+ loader=L(get_dataloader)(
40
+ shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
41
+ batch_size="${train.batch_size}",
42
+ dataset="${..dataset}",
43
+ infinite=False,
44
+ gpu_transform=L(torch.nn.Sequential)(*[
45
+ L(ToFloat)(),
46
+ L(Resize)(size="${data.imsize}"),
47
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
48
+ L(CreateCondition)(),
49
+ ])
50
+ )
51
+ ),
52
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
53
+ train_evaluation_fn=functools.partial(train_eval_fn, cache_directory=Path(metrics_cache, "fdf_val_train")),
54
+ evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
55
+ )
configs/datasets/fdh.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tops.config import LazyCall as L
4
+ import torch
5
+ import functools
6
+ from dp2.data.datasets.fdh import get_dataloader_fdh_wds
7
+ from dp2.data.utils import get_coco_flipmap
8
+ from dp2.data.transforms.transforms import (
9
+ Normalize,
10
+ ToFloat,
11
+ CreateCondition,
12
+ RandomHorizontalFlip,
13
+ CreateEmbedding,
14
+ )
15
+ from dp2.metrics.torch_metrics import compute_metrics_iteratively
16
+ from dp2.metrics.fid_clip import compute_fid_clip
17
+ from dp2.metrics.ppl import calculate_ppl
18
+ from .utils import train_eval_fn
19
+
20
+
21
+ def final_eval_fn(*args, **kwargs):
22
+ result = compute_metrics_iteratively(*args, **kwargs)
23
+ result2 = calculate_ppl(*args, **kwargs, upsample_size=(288, 160))
24
+ result3 = compute_fid_clip(*args, **kwargs)
25
+ assert all(key not in result for key in result2)
26
+ result.update(result2)
27
+ result.update(result3)
28
+ return result
29
+
30
+
31
+ def get_cache_directory(imsize, subset):
32
+ return Path(metrics_cache, f"{subset}{imsize[0]}")
33
+
34
+ dataset_base_dir = (
35
+ os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
36
+ )
37
+ metrics_cache = (
38
+ os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
39
+ )
40
+ data_dir = Path(dataset_base_dir, "fdh")
41
+ data = dict(
42
+ imsize=(288, 160),
43
+ im_channels=3,
44
+ cse_nc=16,
45
+ n_keypoints=17,
46
+ train=dict(
47
+ loader=L(get_dataloader_fdh_wds)(
48
+ path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
49
+ batch_size="${train.batch_size}",
50
+ num_workers=6,
51
+ transform=L(torch.nn.Sequential)(
52
+ L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
53
+ ),
54
+ gpu_transform=L(torch.nn.Sequential)(
55
+ L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
56
+ L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
57
+ L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
58
+ L(CreateCondition)(),
59
+ ),
60
+ infinite=True,
61
+ shuffle=True,
62
+ partial_batches=False,
63
+ load_embedding=True,
64
+ keypoints_split="train",
65
+ load_new_keypoints=False
66
+ )
67
+ ),
68
+ val=dict(
69
+ loader=L(get_dataloader_fdh_wds)(
70
+ path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
71
+ batch_size="${train.batch_size}",
72
+ num_workers=6,
73
+ transform=None,
74
+ gpu_transform="${data.train.loader.gpu_transform}",
75
+ infinite=False,
76
+ shuffle=False,
77
+ partial_batches=True,
78
+ load_embedding=True,
79
+ keypoints_split="val",
80
+ load_new_keypoints="${data.train.loader.load_new_keypoints}"
81
+ )
82
+ ),
83
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
84
+ train_evaluation_fn=L(functools.partial)(
85
+ train_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh"),
86
+ data_len=30_000),
87
+ evaluation_fn=L(functools.partial)(
88
+ final_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh_eval"),
89
+ data_len=30_000)
90
+ )
configs/datasets/utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.metrics.ppl import calculate_ppl
2
+ from dp2.metrics.torch_metrics import compute_metrics_iteratively
3
+ from dp2.metrics.fid_clip import compute_fid_clip
4
+
5
+
6
+ def final_eval_fn(*args, **kwargs):
7
+ result = compute_metrics_iteratively(*args, **kwargs)
8
+ result2 = calculate_ppl(*args, **kwargs,)
9
+ result3 = compute_fid_clip(*args, **kwargs)
10
+ assert all(key not in result for key in result2)
11
+ result.update(result2)
12
+ result.update(result3)
13
+ return result
14
+
15
+
16
+ def train_eval_fn(*args, **kwargs):
17
+ result = compute_metrics_iteratively(*args, **kwargs)
18
+ result2 = compute_fid_clip(*args, **kwargs)
19
+ assert all(key not in result for key in result2)
20
+ result.update(result2)
21
+ return result
configs/defaults.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import os
3
+ import torch
4
+ from tops.config import LazyCall as L
5
+
6
+ if "PRETRAINED_CHECKPOINTS_PATH" in os.environ:
7
+ PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"])
8
+ else:
9
+ PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints")
10
+ if "BASE_OUTPUT_DIR" in os.environ:
11
+ BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"])
12
+ else:
13
+ BASE_OUTPUT_DIR = pathlib.Path("outputs")
14
+
15
+
16
+
17
+ common = dict(
18
+ logger_backend=["wandb", "stdout", "json", "image_dumper"],
19
+ wandb_project="deep_privacy2",
20
+ output_dir=BASE_OUTPUT_DIR,
21
+ experiment_name=None, # Optional experiment name to show on wandb
22
+ )
23
+
24
+ train = dict(
25
+ batch_size=32,
26
+ seed=0,
27
+ ims_per_log=1024,
28
+ ims_per_val=int(200e3),
29
+ max_images_to_train=int(12e6),
30
+ amp=dict(
31
+ enabled=True,
32
+ scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
33
+ scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
34
+ ),
35
+ fp16_ddp_accumulate=False, # All gather gradients in fp16?
36
+ broadcast_buffers=False,
37
+ bias_act_plugin_enabled=True,
38
+ grid_sample_gradfix_enabled=True,
39
+ conv2d_gradfix_enabled=False,
40
+ channels_last=False,
41
+ compile_G=dict(
42
+ enabled=False,
43
+ mode="default" # default, reduce-overhead or max-autotune
44
+ ),
45
+ compile_D=dict(
46
+ enabled=False,
47
+ mode="default" # default, reduce-overhead or max-autotune
48
+ )
49
+ )
50
+
51
+ # exponential moving average
52
+ EMA = dict(rampup=0.05)
53
+
configs/discriminators/sg2_discriminator.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from dp2.discriminator import SG2Discriminator
3
+ import torch
4
+ from dp2.loss import StyleGAN2Loss
5
+
6
+
7
+ discriminator = L(SG2Discriminator)(
8
+ imsize="${data.imsize}",
9
+ im_channels="${data.im_channels}",
10
+ min_fmap_resolution=4,
11
+ max_cnum_mul=8,
12
+ cnum=80,
13
+ input_condition=True,
14
+ conv_clamp=256,
15
+ input_cse=False,
16
+ cse_nc="${data.cse_nc}",
17
+ fix_residual=False,
18
+ )
19
+
20
+
21
+ loss_fnc = L(StyleGAN2Loss)(
22
+ lazy_regularization=True,
23
+ lazy_reg_interval=16,
24
+ r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False),
25
+ EP_lambd=0.001,
26
+ pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01)
27
+ )
28
+
29
+ def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs):
30
+ if lazy_regularization:
31
+ # From Analyzing and improving the image quality of stylegan, CVPR 2020
32
+ c = lazy_reg_interval / (lazy_reg_interval + 1)
33
+ betas = [beta ** c for beta in betas]
34
+ lr *= c
35
+ print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}")
36
+ return type(lr=lr, betas=betas, **kwargs)
37
+
38
+
39
+ D_optim = L(build_D_optim)(
40
+ type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99),
41
+ lazy_regularization="${loss_fnc.lazy_regularization}",
42
+ lazy_reg_interval="${loss_fnc.lazy_reg_interval}")
43
+ G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99))
configs/fdf/deep_privacy1.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from dp2.generator.deep_privacy1 import MSGGenerator
3
+ from ..datasets.fdf128 import data
4
+ from ..defaults import common, train
5
+
6
+ generator = L(MSGGenerator)()
7
+
8
+ common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/fdf128_model512.ckpt"
9
+ common.model_md5sum = "6cc8b285bdc1fcdfc64f5db7c521d0a6"
configs/fdf/stylegan.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..generators.stylegan_unet import generator
2
+ from ..datasets.fdf256 import data
3
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
4
+ from ..defaults import train, common, EMA
5
+
6
+ train.max_images_to_train = int(35e6)
7
+ G_optim.lr = 0.002
8
+ D_optim.lr = 0.002
9
+ generator.input_cse = False
10
+ loss_fnc.r1_opts.lambd = 1
11
+ train.ims_per_val = int(2e6)
12
+
13
+ common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e"
14
+ common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07"
configs/fdf/stylegan_fdf128.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
2
+ from ..datasets.fdf128 import data
3
+ from ..generators.stylegan_unet import generator
4
+ from ..defaults import train, common, EMA
5
+ from tops.config import LazyCall as L
6
+
7
+ G_optim.lr = 0.002
8
+ D_optim.lr = 0.002
9
+ generator.update(cnum=128, max_cnum_mul=4, input_cse=False)
10
+ loss_fnc.r1_opts.lambd = 0.1
11
+
12
+ train.update(ims_per_val=int(2e6), batch_size=64, max_images_to_train=int(35e6))
13
+
14
+ common.update(
15
+ model_url="https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/66d803c0-55ce-44c0-9d53-815c2c0e6ba4eb458409-9e91-45d1-bce0-95c8a47a57218b102fdf-bea3-44dc-aac4-0fb1d370ef1c",
16
+ model_md5sum="bccd4403e7c9bca682566ff3319e8176"
17
+ )
configs/fdh/styleganL.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from ..generators.stylegan_unet import generator
3
+ from ..datasets.fdh import data
4
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
5
+ from ..defaults import train, common, EMA
6
+
7
+ train.max_images_to_train = int(50e6)
8
+ train.batch_size = 64
9
+ G_optim.lr = 0.002
10
+ D_optim.lr = 0.002
11
+ data.train.loader.num_workers = 4
12
+ train.ims_per_val = int(1e6)
13
+ loss_fnc.r1_opts.lambd = .1
14
+
15
+ common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c"
16
+ common.model_md5sum = "3411478b5ec600a4219cccf4499732bd"
configs/fdh/styleganL_nocse.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from ..generators.stylegan_unet import generator
3
+ from ..datasets.fdh import data
4
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
5
+ from ..defaults import train, common, EMA
6
+
7
+ train.max_images_to_train = int(50e6)
8
+ G_optim.lr = 0.002
9
+ D_optim.lr = 0.002
10
+ generator.input_cse = False
11
+ data.load_embeddings = False
12
+ common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt"
13
+ common.model_md5sum = "fda0d809741bc67487abada793975c37"
14
+ generator.fix_errors = False
configs/generators/stylegan_unet.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.generator.stylegan_unet import StyleGANUnet
2
+ from tops.config import LazyCall as L
3
+
4
+ generator = L(StyleGANUnet)(
5
+ imsize="${data.imsize}",
6
+ im_channels="${data.im_channels}",
7
+ min_fmap_resolution=8,
8
+ cnum=64,
9
+ max_cnum_mul=8,
10
+ n_middle_blocks=0,
11
+ z_channels=512,
12
+ mask_output=True,
13
+ conv_clamp=256,
14
+ input_cse=True,
15
+ scale_grad=True,
16
+ cse_nc="${data.cse_nc}",
17
+ w_dim=512,
18
+ n_keypoints="${data.n_keypoints}",
19
+ input_keypoints=False,
20
+ input_keypoint_indices=[],
21
+ fix_errors=True
22
+ )
dp2/__init__.py ADDED
File without changes
dp2/anonymizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .anonymizer import Anonymizer
dp2/anonymizer/anonymizer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Union, Optional
3
+ import numpy as np
4
+ import torch
5
+ import tops
6
+ import torchvision.transforms.functional as F
7
+ from motpy import Detection, MultiObjectTracker
8
+ from dp2.utils import load_config
9
+ from dp2.infer import build_trained_generator
10
+ from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
11
+
12
+
13
+ def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
14
+ cfg = load_config(cfg_path)
15
+ G = build_trained_generator(cfg)
16
+ tops.logger.log(f"Loaded generator from: {cfg_path}")
17
+ return G
18
+
19
+
20
+ class Anonymizer:
21
+
22
+ def __init__(
23
+ self,
24
+ detector,
25
+ load_cache: bool = False,
26
+ person_G_cfg: Optional[Union[str, Path]] = None,
27
+ cse_person_G_cfg: Optional[Union[str, Path]] = None,
28
+ face_G_cfg: Optional[Union[str, Path]] = None,
29
+ car_G_cfg: Optional[Union[str, Path]] = None,
30
+ ) -> None:
31
+ self.detector = detector
32
+ self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
33
+ self.load_cache = load_cache
34
+ if cse_person_G_cfg is not None:
35
+ self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
36
+ if person_G_cfg is not None:
37
+ self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
38
+ if face_G_cfg is not None:
39
+ self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
40
+ if car_G_cfg is not None:
41
+ self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
42
+
43
+ def initialize_tracker(self, fps: float):
44
+ self.tracker = MultiObjectTracker(dt=1/fps)
45
+ self.track_to_z_idx = dict()
46
+
47
+ def reset_tracker(self):
48
+ self.track_to_z_idx = dict()
49
+
50
+ def forward_G(self,
51
+ G,
52
+ batch,
53
+ multi_modal_truncation: bool,
54
+ amp: bool,
55
+ z_idx: int,
56
+ truncation_value: float,
57
+ idx: int,
58
+ all_styles=None):
59
+ batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
60
+ batch["img"] = batch["img"].float()
61
+ batch["condition"] = batch["mask"].float() * batch["img"]
62
+
63
+ with torch.cuda.amp.autocast(amp):
64
+ z = None
65
+ if z_idx is not None:
66
+ state = np.random.RandomState(seed=z_idx[idx])
67
+ z = state.normal(size=(1, G.z_channels)).astype(np.float32)
68
+ z = tops.to_cuda(torch.from_numpy(z))
69
+
70
+ if all_styles is not None:
71
+ anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
72
+ elif multi_modal_truncation:
73
+ w_indices = None
74
+ if z_idx is not None:
75
+ w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
76
+ anonymized_im = G.multi_modal_truncate(
77
+ **batch, truncation_value=truncation_value,
78
+ w_indices=w_indices,
79
+ z=z
80
+ )["img"]
81
+ else:
82
+ anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
83
+ anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
84
+ return anonymized_im
85
+
86
+ @torch.no_grad()
87
+ def anonymize_detections(self,
88
+ im, detection,
89
+ update_identity=None,
90
+ **synthesis_kwargs
91
+ ):
92
+ G = self.generators[type(detection)]
93
+ if G is None:
94
+ return im
95
+ C, H, W = im.shape
96
+ if update_identity is None:
97
+ update_identity = [True for i in range(len(detection))]
98
+ for idx in range(len(detection)):
99
+ if not update_identity[idx]:
100
+ continue
101
+ batch = detection.get_crop(idx, im)
102
+ x0, y0, x1, y1 = batch.pop("boxes")[0]
103
+ batch = {k: tops.to_cuda(v) for k, v in batch.items()}
104
+ anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx)
105
+
106
+ gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True)
107
+ mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
108
+ # Remove padding
109
+ pad = [max(-x0, 0), max(-y0, 0)]
110
+ pad = [*pad, max(x1-W, 0), max(y1-H, 0)]
111
+ def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
112
+
113
+ gim = remove_pad(gim)
114
+ mask = remove_pad(mask) > 0.5
115
+ x0, y0 = max(x0, 0), max(y0, 0)
116
+ x1, y1 = min(x1, W), min(y1, H)
117
+ mask = mask.logical_not()[None].repeat(3, 1, 1)
118
+
119
+ im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte()
120
+ return im
121
+
122
+ def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
123
+ all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
124
+ im = im.cpu()
125
+ for det in all_detections:
126
+ im = det.visualize(im)
127
+ return im
128
+
129
+ @torch.no_grad()
130
+ def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor:
131
+ assert im.dtype == torch.uint8
132
+ im = tops.to_cuda(im)
133
+ all_detections = detections
134
+ if detections is None:
135
+ if self.load_cache:
136
+ all_detections = self.detector.forward_and_cache(im, cache_id)
137
+ else:
138
+ all_detections = self.detector(im)
139
+ if hasattr(self, "tracker") and track:
140
+ [_.pre_process() for _ in all_detections]
141
+ boxes = np.concatenate([_.boxes for _ in all_detections])
142
+ boxes = [Detection(box) for box in boxes]
143
+ self.tracker.step(boxes)
144
+ track_ids = self.tracker.detections_matched_ids
145
+ z_idx = []
146
+ for track_id in track_ids:
147
+ if track_id not in self.track_to_z_idx:
148
+ self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1)
149
+ z_idx.append(self.track_to_z_idx[track_id])
150
+ z_idx = np.array(z_idx)
151
+ idx_offset = 0
152
+
153
+ for detection in all_detections:
154
+ zs = None
155
+ if hasattr(self, "tracker") and track:
156
+ zs = z_idx[idx_offset:idx_offset+len(detection)]
157
+ idx_offset += len(detection)
158
+ im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
159
+
160
+ return im.cpu()
161
+
162
+ def __call__(self, *args, **kwargs):
163
+ return self.forward(*args, **kwargs)
dp2/anonymizer/histogram_match_anonymizers.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import tops
4
+ import numpy as np
5
+ from kornia.color import rgb_to_hsv
6
+ from dp2 import utils
7
+ from kornia.enhance import histogram
8
+ from .anonymizer import Anonymizer
9
+ import torchvision.transforms.functional as F
10
+ from skimage.exposure import match_histograms
11
+ from kornia.filters import gaussian_blur2d
12
+
13
+
14
+ class LatentHistogramMatchAnonymizer(Anonymizer):
15
+
16
+ def forward_G(
17
+ self,
18
+ G,
19
+ batch,
20
+ multi_modal_truncation: bool,
21
+ amp: bool,
22
+ z_idx: int,
23
+ truncation_value: float,
24
+ idx: int,
25
+ n_sampling_steps: int = 1,
26
+ all_styles=None,
27
+ ):
28
+ batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
29
+ batch["img"] = batch["img"].float()
30
+ batch["condition"] = batch["mask"].float() * batch["img"]
31
+
32
+ assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
33
+ real_hls = rgb_to_hsv(utils.denormalize_img(batch["img"]))
34
+ real_hls[:, 0] /= 2 * torch.pi
35
+ indices = [1, 2]
36
+ hist_kwargs = dict(
37
+ bins=torch.linspace(0, 1, 256, dtype=torch.float32, device=tops.get_device()),
38
+ bandwidth=torch.tensor(1., device=tops.get_device()))
39
+ real_hist = [histogram(real_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
40
+ for j in range(n_sampling_steps):
41
+ if j == 0:
42
+ if multi_modal_truncation:
43
+ w = G.style_net.multi_modal_truncate(
44
+ truncation_value=truncation_value, **batch, w_indices=None).detach()
45
+ else:
46
+ w = G.style_net.get_truncated(truncation_value, **batch).detach()
47
+ assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
48
+ w.requires_grad = True
49
+ optim = torch.optim.Adam([w])
50
+ with torch.set_grad_enabled(True):
51
+ with torch.cuda.amp.autocast(amp):
52
+ anonymized_im = G(**batch, truncation_value=None, w=w)["img"]
53
+ fake_hls = rgb_to_hsv(anonymized_im*0.5 + 0.5)
54
+ fake_hls[:, 0] /= 2 * torch.pi
55
+ fake_hist = [histogram(fake_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
56
+ dist = sum([utils.torch_wasserstein_loss(r, f) for r, f in zip(real_hist, fake_hist)])
57
+ dist.backward()
58
+ if w.grad.sum() == 0:
59
+ break
60
+ assert w.grad.sum() != 0
61
+ optim.step()
62
+ optim.zero_grad()
63
+ if dist < 0.02:
64
+ break
65
+ anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
66
+ return anonymized_im
67
+
68
+
69
+ class HistogramMatchAnonymizer(Anonymizer):
70
+
71
+ def forward_G(self, batch, *args, **kwargs):
72
+ rimg = batch["img"]
73
+ batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
74
+ batch["img"] = batch["img"].float()
75
+ batch["condition"] = batch["mask"].float() * batch["img"]
76
+
77
+ anonymized_im = super().forward_G(batch, *args, **kwargs)
78
+
79
+ equalized_gim = match_histograms(tops.im2numpy(anonymized_im.round().clamp(0, 255).byte()), tops.im2numpy(rimg))
80
+ if equalized_gim.dtype != np.uint8:
81
+ equalized_gim = equalized_gim.astype(np.float32)
82
+ assert equalized_gim.dtype == np.float32, equalized_gim.dtype
83
+ equalized_gim = tops.im2torch(equalized_gim, to_float=False)[0]
84
+ else:
85
+ equalized_gim = tops.im2torch(equalized_gim, to_float=False).float()[0]
86
+ equalized_gim = equalized_gim.to(device=rimg.device)
87
+ assert equalized_gim.dtype == torch.float32
88
+ gaussian_mask = 1 - (batch["maskrcnn_mask"][0].repeat(3, 1, 1) > 0.5).float()
89
+
90
+ gaussian_mask = gaussian_blur2d(gaussian_mask[None], kernel_size=[19, 19], sigma=[10, 10])[0]
91
+ gaussian_mask = gaussian_mask / gaussian_mask.max()
92
+ anonymized_im = gaussian_mask * equalized_gim + (1-gaussian_mask) * anonymized_im
93
+ return anonymized_im
dp2/data/__init__.py ADDED
File without changes
dp2/data/build.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ from .utils import collate_fn
4
+
5
+
6
+ def get_dataloader(
7
+ dataset, gpu_transform: torch.nn.Module,
8
+ num_workers,
9
+ batch_size,
10
+ infinite: bool,
11
+ drop_last: bool,
12
+ prefetch_factor: int,
13
+ shuffle,
14
+ channels_last=False
15
+ ):
16
+ sampler = None
17
+ dl_kwargs = dict(
18
+ pin_memory=True,
19
+ )
20
+ if infinite:
21
+ sampler = tops.InfiniteSampler(
22
+ dataset, rank=tops.rank(),
23
+ num_replicas=tops.world_size(),
24
+ shuffle=shuffle
25
+ )
26
+ elif tops.world_size() > 1:
27
+ sampler = torch.utils.data.DistributedSampler(
28
+ dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
29
+ dl_kwargs["drop_last"] = drop_last
30
+ else:
31
+ dl_kwargs["shuffle"] = shuffle
32
+ dl_kwargs["drop_last"] = drop_last
33
+ dataloader = torch.utils.data.DataLoader(
34
+ dataset, sampler=sampler, collate_fn=collate_fn,
35
+ batch_size=batch_size,
36
+ num_workers=num_workers, prefetch_factor=prefetch_factor,
37
+ **dl_kwargs
38
+ )
39
+ dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
40
+ return dataloader
dp2/data/datasets/__init__.py ADDED
File without changes
dp2/data/datasets/coco_cse.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torchvision
3
+ import torch
4
+ import pathlib
5
+ import numpy as np
6
+ from typing import Callable, Optional, Union
7
+ from torch.hub import get_dir as get_hub_dir
8
+
9
+
10
+ def cache_embed_stats(embed_map: torch.Tensor):
11
+ mean = embed_map.mean(dim=0, keepdim=True)
12
+ rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
13
+
14
+ cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
15
+ path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
16
+ path.parent.mkdir(exist_ok=True, parents=True)
17
+ torch.save(cache, path)
18
+
19
+
20
+ class CocoCSE(torch.utils.data.Dataset):
21
+
22
+ def __init__(self,
23
+ dirpath: Union[str, pathlib.Path],
24
+ transform: Optional[Callable],
25
+ normalize_E: bool,):
26
+ dirpath = pathlib.Path(dirpath)
27
+ self.dirpath = dirpath
28
+
29
+ self.transform = transform
30
+ assert self.dirpath.is_dir(),\
31
+ f"Did not find dataset at: {dirpath}"
32
+ self.image_paths, self.embedding_paths = self._load_impaths()
33
+ self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
34
+ mean = self.embed_map.mean(dim=0, keepdim=True)
35
+ rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
36
+ self.embed_map = (self.embed_map - mean) * rstd
37
+ cache_embed_stats(self.embed_map)
38
+
39
+ def _load_impaths(self):
40
+ image_dir = self.dirpath.joinpath("images")
41
+ image_paths = list(image_dir.glob("*.png"))
42
+ image_paths.sort()
43
+ embedding_paths = [
44
+ self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
45
+ ]
46
+ return image_paths, embedding_paths
47
+
48
+ def __len__(self):
49
+ return len(self.image_paths)
50
+
51
+ def __getitem__(self, idx):
52
+ im = torchvision.io.read_image(str(self.image_paths[idx]))
53
+ vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
54
+ vertices = torch.from_numpy(vertices.squeeze()).long()
55
+ mask = torch.from_numpy(mask.squeeze()).float()
56
+ border = torch.from_numpy(border.squeeze()).float()
57
+ E_mask = 1 - mask - border
58
+ batch = {
59
+ "img": im,
60
+ "vertices": vertices[None],
61
+ "mask": mask[None],
62
+ "embed_map": self.embed_map,
63
+ "border": border[None],
64
+ "E_mask": E_mask[None]
65
+ }
66
+ if self.transform is None:
67
+ return batch
68
+ return self.transform(batch)
dp2/data/datasets/fdf.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import Tuple
3
+ import numpy as np
4
+ import torch
5
+ import pathlib
6
+ try:
7
+ import pyspng
8
+ PYSPNG_IMPORTED = True
9
+ except ImportError:
10
+ PYSPNG_IMPORTED = False
11
+ print("Could not load pyspng. Defaulting to pillow image backend.")
12
+ from PIL import Image
13
+ from tops import logger
14
+
15
+
16
+ class FDFDataset:
17
+
18
+ def __init__(self,
19
+ dirpath,
20
+ imsize: Tuple[int],
21
+ load_keypoints: bool,
22
+ transform):
23
+ dirpath = pathlib.Path(dirpath)
24
+ self.dirpath = dirpath
25
+ self.transform = transform
26
+ self.imsize = imsize[0]
27
+ self.load_keypoints = load_keypoints
28
+ assert self.dirpath.is_dir(),\
29
+ f"Did not find dataset at: {dirpath}"
30
+ image_dir = self.dirpath.joinpath("images", str(self.imsize))
31
+ self.image_paths = list(image_dir.glob("*.png"))
32
+ assert len(self.image_paths) > 0,\
33
+ f"Did not find images in: {image_dir}"
34
+ self.image_paths.sort(key=lambda x: int(x.stem))
35
+ self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
36
+
37
+ self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch"))
38
+ assert len(self.image_paths) == len(self.bounding_boxes)
39
+ assert len(self.image_paths) == len(self.landmarks)
40
+ logger.log(
41
+ f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}")
42
+
43
+ def get_mask(self, idx):
44
+ mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool)
45
+ bounding_box = self.bounding_boxes[idx]
46
+ x0, y0, x1, y1 = bounding_box
47
+ mask[:, y0:y1, x0:x1] = 0
48
+ return mask
49
+
50
+ def __len__(self):
51
+ return len(self.image_paths)
52
+
53
+ def __getitem__(self, index):
54
+ impath = self.image_paths[index]
55
+ if PYSPNG_IMPORTED:
56
+ with open(impath, "rb") as fp:
57
+ im = pyspng.load(fp.read())
58
+ else:
59
+ with Image.open(impath) as fp:
60
+ im = np.array(fp)
61
+ im = torch.from_numpy(np.rollaxis(im, -1, 0))
62
+ masks = self.get_mask(index)
63
+ landmark = self.landmarks[index]
64
+ batch = {
65
+ "img": im,
66
+ "mask": masks,
67
+ }
68
+ if self.load_keypoints:
69
+ batch["keypoints"] = landmark
70
+ if self.transform is None:
71
+ return batch
72
+ return self.transform(batch)
73
+
74
+
75
+ class FDF256Dataset:
76
+
77
+ def __init__(self,
78
+ dirpath,
79
+ load_keypoints: bool,
80
+ transform):
81
+ dirpath = pathlib.Path(dirpath)
82
+ self.dirpath = dirpath
83
+ self.transform = transform
84
+ self.load_keypoints = load_keypoints
85
+ assert self.dirpath.is_dir(),\
86
+ f"Did not find dataset at: {dirpath}"
87
+ image_dir = self.dirpath.joinpath("images")
88
+ self.image_paths = list(image_dir.glob("*.png"))
89
+ assert len(self.image_paths) > 0,\
90
+ f"Did not find images in: {image_dir}"
91
+ self.image_paths.sort(key=lambda x: int(x.stem))
92
+ self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
93
+ self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy")))
94
+ assert len(self.image_paths) == len(self.bounding_boxes)
95
+ assert len(self.image_paths) == len(self.landmarks)
96
+ logger.log(
97
+ f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}")
98
+
99
+ def get_mask(self, idx):
100
+ mask = torch.ones((1, 256, 256), dtype=torch.bool)
101
+ bounding_box = self.bounding_boxes[idx]
102
+ x0, y0, x1, y1 = bounding_box
103
+ mask[:, y0:y1, x0:x1] = 0
104
+ return mask
105
+
106
+ def __len__(self):
107
+ return len(self.image_paths)
108
+
109
+ def __getitem__(self, index):
110
+ impath = self.image_paths[index]
111
+ if PYSPNG_IMPORTED:
112
+ with open(impath, "rb") as fp:
113
+ im = pyspng.load(fp.read())
114
+ else:
115
+ with Image.open(impath) as fp:
116
+ im = np.array(fp)
117
+ im = torch.from_numpy(np.rollaxis(im, -1, 0))
118
+ masks = self.get_mask(index)
119
+ landmark = self.landmarks[index]
120
+ batch = {
121
+ "img": im,
122
+ "mask": masks,
123
+ }
124
+ if self.load_keypoints:
125
+ batch["keypoints"] = landmark
126
+ if self.transform is None:
127
+ return batch
128
+ return self.transform(batch)
dp2/data/datasets/fdf128_wds.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ import numpy as np
4
+ import io
5
+ import webdataset as wds
6
+ import os
7
+ from ..utils import png_decoder, get_num_workers, collate_fn
8
+
9
+
10
+ def kp_decoder(x):
11
+ # Keypoints are between [0, 1] for webdataset
12
+ keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float().view(7, 2).clamp(0, 1)
13
+ keypoints = torch.cat((keypoints, torch.ones((7, 1))), dim=-1)
14
+ return keypoints
15
+
16
+
17
+ def bbox_decoder(x):
18
+ return torch.from_numpy(np.load(io.BytesIO(x))).float().view(4)
19
+
20
+
21
+ class BBoxToMask:
22
+
23
+ def __call__(self, sample):
24
+ imsize = sample["image.png"].shape[-1]
25
+ bbox = sample["bounding_box.npy"] * imsize
26
+ x0, y0, x1, y1 = np.round(bbox).astype(np.int64)
27
+ mask = torch.ones((1, imsize, imsize), dtype=torch.bool)
28
+ mask[:, y0:y1, x0:x1] = 0
29
+ sample["mask"] = mask
30
+ return sample
31
+
32
+
33
+ def get_dataloader_fdf_wds(
34
+ path,
35
+ batch_size: int,
36
+ num_workers: int,
37
+ transform: torch.nn.Module,
38
+ gpu_transform: torch.nn.Module,
39
+ infinite: bool,
40
+ shuffle: bool,
41
+ partial_batches: bool,
42
+ sample_shuffle=10_000,
43
+ tar_shuffle=100,
44
+ channels_last=False,
45
+ ):
46
+ # Need to set this for split_by_node to work.
47
+ os.environ["RANK"] = str(tops.rank())
48
+ os.environ["WORLD_SIZE"] = str(tops.world_size())
49
+ if infinite:
50
+ pipeline = [wds.ResampledShards(str(path))]
51
+ else:
52
+ pipeline = [wds.SimpleShardList(str(path))]
53
+ if shuffle:
54
+ pipeline.append(wds.shuffle(tar_shuffle))
55
+ pipeline.extend([
56
+ wds.split_by_node,
57
+ wds.split_by_worker,
58
+ ])
59
+ if shuffle:
60
+ pipeline.append(wds.shuffle(sample_shuffle))
61
+
62
+ decoder = [
63
+ wds.handle_extension("image.png", png_decoder),
64
+ wds.handle_extension("keypoints.npy", kp_decoder),
65
+ ]
66
+
67
+ rename_keys = [
68
+ ["img", "image.png"],
69
+ ["keypoints", "keypoints.npy"],
70
+ ["__key__", "__key__"],
71
+ ["mask", "mask"]
72
+ ]
73
+
74
+ pipeline.extend([
75
+ wds.tarfile_to_samples(),
76
+ wds.decode(*decoder),
77
+ ])
78
+ pipeline.append(wds.map(BBoxToMask()))
79
+ pipeline.extend([
80
+ wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
81
+ wds.rename_keys(*rename_keys),
82
+ ])
83
+
84
+ if transform is not None:
85
+ pipeline.append(wds.map(transform))
86
+ pipeline = wds.DataPipeline(*pipeline)
87
+ if infinite:
88
+ pipeline = pipeline.repeat(nepochs=1000000)
89
+
90
+ loader = wds.WebLoader(
91
+ pipeline, batch_size=None, shuffle=False,
92
+ num_workers=get_num_workers(num_workers),
93
+ persistent_workers=True,
94
+ )
95
+ loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
96
+ return loader
dp2/data/datasets/fdh.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ import numpy as np
4
+ import io
5
+ import webdataset as wds
6
+ import os
7
+ import json
8
+ from pathlib import Path
9
+ from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
10
+
11
+
12
+ def kp_decoder(x):
13
+ # Keypoints are between [0, 1] for webdataset
14
+ keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
15
+ def check_outside(x): return (x < 0).logical_or(x > 1)
16
+ is_outside = check_outside(keypoints[:, 0]).logical_or(
17
+ check_outside(keypoints[:, 1])
18
+ )
19
+ keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
20
+ return keypoints
21
+
22
+
23
+ def vertices_decoder(x):
24
+ vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
25
+ return vertices.squeeze()[None]
26
+
27
+
28
+ class InsertNewKeypoints:
29
+
30
+ def __init__(self, keypoints_path: Path) -> None:
31
+ with open(keypoints_path, "r") as fp:
32
+ self.keypoints = json.load(fp)
33
+
34
+ def __call__(self, sample):
35
+ key = sample["__key__"]
36
+ keypoints = torch.tensor(self.keypoints[key], dtype=torch.float32)
37
+ def check_outside(x): return (x < 0).logical_or(x > 1)
38
+ is_outside = check_outside(keypoints[:, 0]).logical_or(
39
+ check_outside(keypoints[:, 1])
40
+ )
41
+ keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
42
+
43
+ sample["keypoints.npy"] = keypoints
44
+ return sample
45
+
46
+
47
+ def get_dataloader_fdh_wds(
48
+ path,
49
+ batch_size: int,
50
+ num_workers: int,
51
+ transform: torch.nn.Module,
52
+ gpu_transform: torch.nn.Module,
53
+ infinite: bool,
54
+ shuffle: bool,
55
+ partial_batches: bool,
56
+ load_embedding: bool,
57
+ sample_shuffle=10_000,
58
+ tar_shuffle=100,
59
+ read_condition=False,
60
+ channels_last=False,
61
+ load_new_keypoints=False,
62
+ keypoints_split=None,
63
+ ):
64
+ # Need to set this for split_by_node to work.
65
+ os.environ["RANK"] = str(tops.rank())
66
+ os.environ["WORLD_SIZE"] = str(tops.world_size())
67
+ if infinite:
68
+ pipeline = [wds.ResampledShards(str(path))]
69
+ else:
70
+ pipeline = [wds.SimpleShardList(str(path))]
71
+ if shuffle:
72
+ pipeline.append(wds.shuffle(tar_shuffle))
73
+ pipeline.extend([
74
+ wds.split_by_node,
75
+ wds.split_by_worker,
76
+ ])
77
+ if shuffle:
78
+ pipeline.append(wds.shuffle(sample_shuffle))
79
+
80
+ decoder = [
81
+ wds.handle_extension("image.png", png_decoder),
82
+ wds.handle_extension("mask.png", mask_decoder),
83
+ wds.handle_extension("maskrcnn_mask.png", mask_decoder),
84
+ wds.handle_extension("keypoints.npy", kp_decoder),
85
+ ]
86
+
87
+ rename_keys = [
88
+ ["img", "image.png"], ["mask", "mask.png"],
89
+ ["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"],
90
+ ["__key__", "__key__"]
91
+ ]
92
+ if load_embedding:
93
+ decoder.extend([
94
+ wds.handle_extension("vertices.npy", vertices_decoder),
95
+ wds.handle_extension("E_mask.png", mask_decoder)
96
+ ])
97
+ rename_keys.extend([
98
+ ["vertices", "vertices.npy"],
99
+ ["E_mask", "e_mask.png"]
100
+ ])
101
+
102
+ if read_condition:
103
+ decoder.append(
104
+ wds.handle_extension("condition.png", png_decoder)
105
+ )
106
+ rename_keys.append(["condition", "condition.png"])
107
+
108
+ pipeline.extend([
109
+ wds.tarfile_to_samples(),
110
+ wds.decode(*decoder),
111
+
112
+ ])
113
+ if load_new_keypoints:
114
+ assert keypoints_split in ["train", "val"]
115
+ keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/1eb88522-8b91-49c7-b56a-ed98a9c7888cef9c0429-a385-4248-abe3-8682de26d041f268aed1-7c88-4677-baad-7623c2ee330f"
116
+ file_name = "fdh_keypoints_val-050133b34d.json"
117
+ if keypoints_split == "train":
118
+ keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/3e828b1c-d6c0-4622-90bc-1b2cce48ccfff14ab45d-0a5c-431d-be13-7e60580765bd7938601c-e72e-41d9-8836-fffc49e76f58"
119
+ file_name = "fdh_keypoints_train-2cff11f69a.json"
120
+ # Set check_hash=True if you suspect download is incorrect.
121
+ filepath = tops.download_file(keypoint_url, file_name=file_name, check_hash=False)
122
+ pipeline.append(
123
+ wds.map(InsertNewKeypoints(filepath))
124
+ )
125
+ pipeline.extend([
126
+ wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
127
+ wds.rename_keys(*rename_keys),
128
+ ])
129
+
130
+ if transform is not None:
131
+ pipeline.append(wds.map(transform))
132
+ pipeline = wds.DataPipeline(*pipeline)
133
+ if infinite:
134
+ pipeline = pipeline.repeat(nepochs=1000000)
135
+
136
+ loader = wds.WebLoader(
137
+ pipeline, batch_size=None, shuffle=False,
138
+ num_workers=get_num_workers(num_workers),
139
+ persistent_workers=True,
140
+ )
141
+ loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
142
+ return loader
dp2/data/transforms/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize
2
+ from .stylegan2_transform import StyleGANAugmentPipe
dp2/data/transforms/functional.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms.functional as F
2
+ import torch
3
+ import pickle
4
+ from tops import download_file, assert_shape
5
+ from typing import Dict
6
+ from functools import lru_cache
7
+
8
+ global symmetry_transform
9
+
10
+
11
+ @lru_cache(maxsize=1)
12
+ def get_symmetry_transform(symmetry_url):
13
+ file_name = download_file(symmetry_url)
14
+ with open(file_name, "rb") as fp:
15
+ symmetry = pickle.load(fp)
16
+ return torch.from_numpy(symmetry["vertex_transforms"]).long()
17
+
18
+
19
+ hflip_handled_cases = set([
20
+ "keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition",
21
+ "embedding", "vertx2cat", "maskrcnn_mask", "__key__"])
22
+
23
+
24
+ def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]:
25
+ container["img"] = F.hflip(container["img"])
26
+ if "condition" in container:
27
+ container["condition"] = F.hflip(container["condition"])
28
+ if "embedding" in container:
29
+ container["embedding"] = F.hflip(container["embedding"])
30
+ assert all([key in hflip_handled_cases for key in container]), container.keys()
31
+ if "keypoints" in container:
32
+ assert flip_map is not None
33
+ if container["keypoints"].ndim == 3:
34
+ keypoints = container["keypoints"][:, flip_map, :]
35
+ keypoints[:, :, 0] = 1 - keypoints[:, :, 0]
36
+ else:
37
+ assert_shape(container["keypoints"], (None, 3))
38
+ keypoints = container["keypoints"][flip_map, :]
39
+ keypoints[:, 0] = 1 - keypoints[:, 0]
40
+ container["keypoints"] = keypoints
41
+ if "mask" in container:
42
+ container["mask"] = F.hflip(container["mask"])
43
+ if "border" in container:
44
+ container["border"] = F.hflip(container["border"])
45
+ if "semantic_mask" in container:
46
+ container["semantic_mask"] = F.hflip(container["semantic_mask"])
47
+ if "vertices" in container:
48
+ symmetry_transform = get_symmetry_transform(
49
+ "https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl")
50
+ container["vertices"] = F.hflip(container["vertices"])
51
+ symmetry_transform_ = symmetry_transform.to(container["vertices"].device)
52
+ container["vertices"] = symmetry_transform_[container["vertices"].long()]
53
+ if "E_mask" in container:
54
+ container["E_mask"] = F.hflip(container["E_mask"])
55
+ if "maskrcnn_mask" in container:
56
+ container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"])
57
+ return container
dp2/data/transforms/stylegan2_transform.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.signal
3
+ import torch
4
+ try:
5
+ from sg3_torch_utils import misc
6
+ from sg3_torch_utils.ops import upfirdn2d
7
+ from sg3_torch_utils.ops import grid_sample_gradfix
8
+ from sg3_torch_utils.ops import conv2d_gradfix
9
+ except:
10
+ pass
11
+ #----------------------------------------------------------------------------
12
+ # Coefficients of various wavelet decomposition low-pass filters.
13
+
14
+ wavelets = {
15
+ 'haar': [0.7071067811865476, 0.7071067811865476],
16
+ 'db1': [0.7071067811865476, 0.7071067811865476],
17
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
18
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
19
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
20
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
21
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
22
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
23
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
24
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
25
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
26
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
27
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
28
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
29
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
30
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
31
+ }
32
+
33
+ #----------------------------------------------------------------------------
34
+ # Helpers for constructing transformation matrices.
35
+
36
+
37
+ def matrix(*rows, device=None):
38
+ assert all(len(row) == len(rows[0]) for row in rows)
39
+ elems = [x for row in rows for x in row]
40
+ ref = [x for x in elems if isinstance(x, torch.Tensor)]
41
+ if len(ref) == 0:
42
+ return misc.constant(np.asarray(rows), device=device)
43
+ assert device is None or device == ref[0].device
44
+ elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
45
+ return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
46
+
47
+
48
+ def translate2d(tx, ty, **kwargs):
49
+ return matrix(
50
+ [1, 0, tx],
51
+ [0, 1, ty],
52
+ [0, 0, 1],
53
+ **kwargs)
54
+
55
+
56
+ def translate3d(tx, ty, tz, **kwargs):
57
+ return matrix(
58
+ [1, 0, 0, tx],
59
+ [0, 1, 0, ty],
60
+ [0, 0, 1, tz],
61
+ [0, 0, 0, 1],
62
+ **kwargs)
63
+
64
+
65
+ def scale2d(sx, sy, **kwargs):
66
+ return matrix(
67
+ [sx, 0, 0],
68
+ [0, sy, 0],
69
+ [0, 0, 1],
70
+ **kwargs)
71
+
72
+
73
+ def scale3d(sx, sy, sz, **kwargs):
74
+ return matrix(
75
+ [sx, 0, 0, 0],
76
+ [0, sy, 0, 0],
77
+ [0, 0, sz, 0],
78
+ [0, 0, 0, 1],
79
+ **kwargs)
80
+
81
+
82
+ def rotate2d(theta, **kwargs):
83
+ return matrix(
84
+ [torch.cos(theta), torch.sin(-theta), 0],
85
+ [torch.sin(theta), torch.cos(theta), 0],
86
+ [0, 0, 1],
87
+ **kwargs)
88
+
89
+
90
+ def rotate3d(v, theta, **kwargs):
91
+ vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
92
+ s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
93
+ return matrix(
94
+ [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
95
+ [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
96
+ [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
97
+ [0, 0, 0, 1],
98
+ **kwargs)
99
+
100
+
101
+ def translate2d_inv(tx, ty, **kwargs):
102
+ return translate2d(-tx, -ty, **kwargs)
103
+
104
+
105
+ def scale2d_inv(sx, sy, **kwargs):
106
+ return scale2d(1 / sx, 1 / sy, **kwargs)
107
+
108
+
109
+ def rotate2d_inv(theta, **kwargs):
110
+ return rotate2d(-theta, **kwargs)
111
+
112
+
113
+ class StyleGANAugmentPipe(torch.nn.Module):
114
+ def __init__(self,
115
+ rotate90=0, xint=0, xint_max=0.125,
116
+ scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
117
+ brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5,
118
+ hue_max=1, saturation_std=1,
119
+ imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
120
+ ):
121
+ super().__init__()
122
+ self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
123
+
124
+ # Pixel blitting.
125
+ self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
126
+ self.xint = float(xint) # Probability multiplier for integer translation.
127
+ self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
128
+
129
+ # General geometric transformations.
130
+ self.scale = float(scale) # Probability multiplier for isotropic scaling.
131
+ self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
132
+ self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
133
+ self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
134
+ self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
135
+ self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
136
+ self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
137
+ self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
138
+
139
+ # Color transformations.
140
+ self.brightness = float(brightness) # Probability multiplier for brightness.
141
+ self.contrast = float(contrast) # Probability multiplier for contrast.
142
+ self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
143
+ self.hue = float(hue) # Probability multiplier for hue rotation.
144
+ self.saturation = float(saturation) # Probability multiplier for saturation.
145
+ self.brightness_std = float(brightness_std) # Standard deviation of brightness.
146
+ self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
147
+ self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
148
+ self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
149
+
150
+ # Image-space filtering.
151
+ self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
152
+ self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
153
+ self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
154
+
155
+ # Setup orthogonal lowpass filter for geometric augmentations.
156
+ self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
157
+
158
+ # Construct filter bank for image-space filtering.
159
+ Hz_lo = np.asarray(wavelets['sym2']) # H(z)
160
+ Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
161
+ Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
162
+ Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
163
+ Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
164
+ for i in range(1, Hz_fbank.shape[0]):
165
+ Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
166
+ Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
167
+ Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
168
+ self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
169
+
170
+ def forward(self, batch, debug_percentile=None):
171
+ images = batch["img"]
172
+ batch["vertices"] = batch["vertices"].float()
173
+ assert isinstance(images, torch.Tensor) and images.ndim == 4
174
+ batch_size, num_channels, height, width = images.shape
175
+ device = images.device
176
+ self.Hz_fbank = self.Hz_fbank.to(device)
177
+ self.Hz_geom = self.Hz_geom.to(device)
178
+ if debug_percentile is not None:
179
+ debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
180
+
181
+ # -------------------------------------
182
+ # Select parameters for pixel blitting.
183
+ # -------------------------------------
184
+
185
+ # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
186
+ I_3 = torch.eye(3, device=device)
187
+ G_inv = I_3
188
+
189
+ # Apply integer translation with probability (xint * strength).
190
+ if self.xint > 0:
191
+ t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
192
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
193
+ if debug_percentile is not None:
194
+ t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
195
+ G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
196
+
197
+ # --------------------------------------------------------
198
+ # Select parameters for general geometric transformations.
199
+ # --------------------------------------------------------
200
+
201
+ # Apply isotropic scaling with probability (scale * strength).
202
+ if self.scale > 0:
203
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
204
+ s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
205
+ if debug_percentile is not None:
206
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
207
+ G_inv = G_inv @ scale2d_inv(s, s)
208
+
209
+ # Apply pre-rotation with probability p_rot.
210
+ p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
211
+ if self.rotate > 0:
212
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
213
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
214
+ if debug_percentile is not None:
215
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
216
+ G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
217
+
218
+ # Apply anisotropic scaling with probability (aniso * strength).
219
+ if self.aniso > 0:
220
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
221
+ s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
222
+ if debug_percentile is not None:
223
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
224
+ G_inv = G_inv @ scale2d_inv(s, 1 / s)
225
+
226
+ # Apply post-rotation with probability p_rot.
227
+ if self.rotate > 0:
228
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
229
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
230
+ if debug_percentile is not None:
231
+ theta = torch.zeros_like(theta)
232
+ G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
233
+
234
+ # Apply fractional translation with probability (xfrac * strength).
235
+ if self.xfrac > 0:
236
+ t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
237
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
238
+ if debug_percentile is not None:
239
+ t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
240
+ G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
241
+
242
+ # ----------------------------------
243
+ # Execute geometric transformations.
244
+ # ----------------------------------
245
+
246
+ # Execute if the transform is not identity.
247
+ if G_inv is not I_3:
248
+ # Calculate padding.
249
+ cx = (width - 1) / 2
250
+ cy = (height - 1) / 2
251
+ cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
252
+ cp = G_inv @ cp.t() # [batch, xyz, idx]
253
+ Hz_pad = self.Hz_geom.shape[0] // 4
254
+ margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
255
+ margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
256
+ margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
257
+ margin = margin.max(misc.constant([0, 0] * 2, device=device))
258
+ margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
259
+ mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
260
+
261
+ # Pad image and adjust origin.
262
+ images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
263
+ batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0)
264
+ batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
265
+ batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
266
+ G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
267
+
268
+ # Upsample.
269
+ images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
270
+ batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest")
271
+ batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest")
272
+ batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest")
273
+ G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
274
+ G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
275
+
276
+ # Execute transformation.
277
+ shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
278
+ G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
279
+ grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
280
+ images = grid_sample_gradfix.grid_sample(images, grid)
281
+
282
+ batch["mask"] = torch.nn.functional.grid_sample(
283
+ input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
284
+ batch["E_mask"] = torch.nn.functional.grid_sample(
285
+ input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
286
+ batch["vertices"] = torch.nn.functional.grid_sample(
287
+ input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
288
+
289
+
290
+ # Downsample and crop.
291
+ images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
292
+ batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
293
+ batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
294
+ batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
295
+ # --------------------------------------------
296
+ # Select parameters for color transformations.
297
+ # --------------------------------------------
298
+
299
+ # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
300
+ I_4 = torch.eye(4, device=device)
301
+ C = I_4
302
+
303
+ # Apply brightness with probability (brightness * strength).
304
+ if self.brightness > 0:
305
+ b = torch.randn([batch_size], device=device) * self.brightness_std
306
+ b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
307
+ if debug_percentile is not None:
308
+ b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
309
+ C = translate3d(b, b, b) @ C
310
+
311
+ # Apply contrast with probability (contrast * strength).
312
+ if self.contrast > 0:
313
+ c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
314
+ c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
315
+ if debug_percentile is not None:
316
+ c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
317
+ C = scale3d(c, c, c) @ C
318
+
319
+ # Apply luma flip with probability (lumaflip * strength).
320
+ v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
321
+
322
+ # Apply hue rotation with probability (hue * strength).
323
+ if self.hue > 0 and num_channels > 1:
324
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
325
+ theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
326
+ if debug_percentile is not None:
327
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
328
+ C = rotate3d(v, theta) @ C # Rotate around v.
329
+
330
+ # Apply saturation with probability (saturation * strength).
331
+ if self.saturation > 0 and num_channels > 1:
332
+ s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
333
+ s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
334
+ if debug_percentile is not None:
335
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
336
+ C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
337
+
338
+ # ------------------------------
339
+ # Execute color transformations.
340
+ # ------------------------------
341
+
342
+ # Execute if the transform is not identity.
343
+ if C is not I_4:
344
+ images = images.reshape([batch_size, num_channels, height * width])
345
+ if num_channels == 3:
346
+ images = C[:, :3, :3] @ images + C[:, :3, 3:]
347
+ elif num_channels == 1:
348
+ C = C[:, :3, :].mean(dim=1, keepdims=True)
349
+ images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
350
+ else:
351
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
352
+ images = images.reshape([batch_size, num_channels, height, width])
353
+
354
+ # ----------------------
355
+ # Image-space filtering.
356
+ # ----------------------
357
+
358
+ if self.imgfilter > 0:
359
+ num_bands = self.Hz_fbank.shape[0]
360
+ assert len(self.imgfilter_bands) == num_bands
361
+ expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
362
+
363
+ # Apply amplification for each band with probability (imgfilter * strength * band_strength).
364
+ g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
365
+ for i, band_strength in enumerate(self.imgfilter_bands):
366
+ t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
367
+ t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
368
+ if debug_percentile is not None:
369
+ t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
370
+ t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
371
+ t[:, i] = t_i # Replace i'th element.
372
+ t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
373
+ g = g * t # Accumulate into global gain.
374
+
375
+ # Construct combined amplification filter.
376
+ Hz_prime = g @ self.Hz_fbank # [batch, tap]
377
+ Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
378
+ Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
379
+
380
+ # Apply filter.
381
+ p = self.Hz_fbank.shape[1] // 2
382
+ images = images.reshape([1, batch_size * num_channels, height, width])
383
+ images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
384
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
385
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
386
+ images = images.reshape([batch_size, num_channels, height, width])
387
+
388
+ # ------------------------
389
+ # Image-space corruptions.
390
+ # ------------------------
391
+ batch["img"] = images
392
+ batch["vertices"] = batch["vertices"].long()
393
+ batch["border"] = 1 - batch["E_mask"] - batch["mask"]
394
+ return batch
dp2/data/transforms/transforms.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, List
3
+ import torchvision
4
+ import torch
5
+ import tops
6
+ import torchvision.transforms.functional as F
7
+ from .functional import hflip
8
+ import numpy as np
9
+ from dp2.utils.vis_utils import get_coco_keypoints
10
+ from PIL import Image, ImageDraw
11
+ from typing import Tuple
12
+
13
+
14
+ class RandomHorizontalFlip(torch.nn.Module):
15
+
16
+ def __init__(self, p: float, flip_map=None, **kwargs):
17
+ super().__init__()
18
+ self.flip_ratio = p
19
+ self.flip_map = flip_map
20
+ if self.flip_ratio is None:
21
+ self.flip_ratio = 0.5
22
+ assert 0 <= self.flip_ratio <= 1
23
+
24
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
25
+ if torch.rand(1) > self.flip_ratio:
26
+ return container
27
+ return hflip(container, self.flip_map)
28
+
29
+
30
+ class CenterCrop(torch.nn.Module):
31
+ """
32
+ Performs the transform on the image.
33
+ NOTE: Does not transform the mask to improve runtime.
34
+ """
35
+
36
+ def __init__(self, size: List[int]):
37
+ super().__init__()
38
+ self.size = tuple(size)
39
+
40
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
41
+ min_size = min(container["img"].shape[1], container["img"].shape[2])
42
+ if min_size < self.size[0]:
43
+ container["img"] = F.center_crop(container["img"], min_size)
44
+ container["img"] = F.resize(container["img"], self.size)
45
+ return container
46
+ container["img"] = F.center_crop(container["img"], self.size)
47
+ return container
48
+
49
+
50
+ class Resize(torch.nn.Module):
51
+ """
52
+ Performs the transform on the image.
53
+ NOTE: Does not transform the mask to improve runtime.
54
+ """
55
+
56
+ def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
57
+ super().__init__()
58
+ self.size = tuple(size)
59
+ self.interpolation = interpolation
60
+
61
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
62
+ container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
63
+ if "semantic_mask" in container:
64
+ container["semantic_mask"] = F.resize(
65
+ container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
66
+ if "embedding" in container:
67
+ container["embedding"] = F.resize(
68
+ container["embedding"], self.size, self.interpolation)
69
+ if "mask" in container:
70
+ container["mask"] = F.resize(
71
+ container["mask"], self.size, F.InterpolationMode.NEAREST)
72
+ if "E_mask" in container:
73
+ container["E_mask"] = F.resize(
74
+ container["E_mask"], self.size, F.InterpolationMode.NEAREST)
75
+ if "maskrcnn_mask" in container:
76
+ container["maskrcnn_mask"] = F.resize(
77
+ container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
78
+ if "vertices" in container:
79
+ container["vertices"] = F.resize(
80
+ container["vertices"], self.size, F.InterpolationMode.NEAREST)
81
+ return container
82
+
83
+ def __repr__(self):
84
+ repr = super().__repr__()
85
+ vars_ = dict(size=self.size, interpolation=self.interpolation)
86
+ return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
87
+
88
+
89
+ class Normalize(torch.nn.Module):
90
+ """
91
+ Performs the transform on the image.
92
+ NOTE: Does not transform the mask to improve runtime.
93
+ """
94
+
95
+ def __init__(self, mean, std, inplace, keys=["img"]):
96
+ super().__init__()
97
+ self.mean = mean
98
+ self.std = std
99
+ self.inplace = inplace
100
+ self.keys = keys
101
+
102
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
103
+ for key in self.keys:
104
+ container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
105
+ return container
106
+
107
+ def __repr__(self):
108
+ repr = super().__repr__()
109
+ vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
110
+ return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
111
+
112
+
113
+ class ToFloat(torch.nn.Module):
114
+
115
+ def __init__(self, keys=["img"], norm=True) -> None:
116
+ super().__init__()
117
+ self.keys = keys
118
+ self.gain = 255 if norm else 1
119
+
120
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
121
+ for key in self.keys:
122
+ container[key] = container[key].float() / self.gain
123
+ return container
124
+
125
+
126
+ class RandomCrop(torchvision.transforms.RandomCrop):
127
+ """
128
+ Performs the transform on the image.
129
+ NOTE: Does not transform the mask to improve runtime.
130
+ """
131
+
132
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
133
+ container["img"] = super().forward(container["img"])
134
+ return container
135
+
136
+
137
+ class CreateCondition(torch.nn.Module):
138
+
139
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
140
+ if container["img"].dtype == torch.uint8:
141
+ container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
142
+ return container
143
+ container["condition"] = container["img"] * container["mask"]
144
+ return container
145
+
146
+
147
+ class CreateEmbedding(torch.nn.Module):
148
+
149
+ def __init__(self, embed_path: Path, cuda=True) -> None:
150
+ super().__init__()
151
+ self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
152
+ if cuda:
153
+ self.embed_map = tops.to_cuda(self.embed_map)
154
+
155
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
156
+ vertices = container["vertices"]
157
+ if vertices.ndim == 3:
158
+ embedding = self.embed_map[vertices.long()].squeeze(dim=0)
159
+ embedding = embedding.permute(2, 0, 1) * container["E_mask"]
160
+ pass
161
+ else:
162
+ assert vertices.ndim == 4
163
+ embedding = self.embed_map[vertices.long()].squeeze(dim=1)
164
+ embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
165
+ container["embedding"] = embedding
166
+ container["embed_map"] = self.embed_map.clone()
167
+ return container
168
+
169
+
170
+ class InsertJointMap(torch.nn.Module):
171
+
172
+ def __init__(self, imsize: Tuple) -> None:
173
+ super().__init__()
174
+ self.imsize = imsize
175
+ knames = get_coco_keypoints()[0]
176
+ knames = knames + ["neck", "mid_hip"]
177
+ connectivity = {
178
+ "nose": ["left_eye", "right_eye", "neck"],
179
+ "left_eye": ["right_eye", "left_ear"],
180
+ "right_eye": ["right_ear"],
181
+ "left_shoulder": ["right_shoulder", "left_elbow", "left_hip"],
182
+ "right_shoulder": ["right_elbow", "right_hip"],
183
+ "left_elbow": ["left_wrist"],
184
+ "right_elbow": ["right_wrist"],
185
+ "left_hip": ["right_hip", "left_knee"],
186
+ "right_hip": ["right_knee"],
187
+ "left_knee": ["left_ankle"],
188
+ "right_knee": ["right_ankle"],
189
+ "neck": ["mid_hip", "nose"],
190
+ }
191
+ category = {
192
+ ("nose", "left_eye"): 0, # head
193
+ ("nose", "right_eye"): 0, # head
194
+ ("nose", "neck"): 0, # head
195
+ ("left_eye", "right_eye"): 0, # head
196
+ ("left_eye", "left_ear"): 0, # head
197
+ ("right_eye", "right_ear"): 0, # head
198
+ ("left_shoulder", "left_elbow"): 1, # left arm
199
+ ("left_elbow", "left_wrist"): 1, # left arm
200
+ ("right_shoulder", "right_elbow"): 2, # right arm
201
+ ("right_elbow", "right_wrist"): 2, # right arm
202
+ ("left_shoulder", "right_shoulder"): 3, # body
203
+ ("left_shoulder", "left_hip"): 3, # body
204
+ ("right_shoulder", "right_hip"): 3, # body
205
+ ("left_hip", "right_hip"): 3, # body
206
+ ("left_hip", "left_knee"): 4, # left leg
207
+ ("left_knee", "left_ankle"): 4, # left leg
208
+ ("right_hip", "right_knee"): 5, # right leg
209
+ ("right_knee", "right_ankle"): 5, # right leg
210
+ ("neck", "mid_hip"): 3, # body
211
+ ("neck", "nose"): 0, # head
212
+ }
213
+ self.indices2category = {
214
+ tuple([knames.index(n) for n in k]): v for k, v in category.items()
215
+ }
216
+ self.connectivity_indices = {
217
+ knames.index(k): [knames.index(v_) for v_ in v]
218
+ for k, v in connectivity.items()
219
+ }
220
+ self.l_shoulder = knames.index("left_shoulder")
221
+ self.r_shoulder = knames.index("right_shoulder")
222
+ self.l_hip = knames.index("left_hip")
223
+ self.r_hip = knames.index("right_hip")
224
+ self.l_eye = knames.index("left_eye")
225
+ self.r_eye = knames.index("right_eye")
226
+ self.nose = knames.index("nose")
227
+ self.neck = knames.index("neck")
228
+
229
+ def create_joint_map(self, N, H, W, keypoints):
230
+ joint_maps = np.zeros((N, H, W), dtype=np.uint8)
231
+ for bidx, keypoints in enumerate(keypoints):
232
+ assert keypoints.shape == (17, 3), keypoints.shape
233
+ keypoints = torch.cat((keypoints, torch.zeros(2, 3)))
234
+ visible = keypoints[:, -1] > 0
235
+
236
+ if visible[self.l_shoulder] and visible[self.r_shoulder]:
237
+ neck = (keypoints[self.l_shoulder]
238
+ + (keypoints[self.r_shoulder] - keypoints[self.l_shoulder]) / 2)
239
+ keypoints[-2] = neck
240
+ visible[-2] = 1
241
+ if visible[self.l_hip] and visible[self.r_hip]:
242
+ mhip = (keypoints[self.l_hip]
243
+ + (keypoints[self.r_hip] - keypoints[self.l_hip]) / 2
244
+ )
245
+ keypoints[-1] = mhip
246
+ visible[-1] = 1
247
+
248
+ keypoints[:, 0] *= W
249
+ keypoints[:, 1] *= H
250
+ joint_map = Image.fromarray(np.zeros((H, W), dtype=np.uint8))
251
+ draw = ImageDraw.Draw(joint_map)
252
+ for fidx in self.connectivity_indices.keys():
253
+ for tidx in self.connectivity_indices[fidx]:
254
+ if visible[fidx] == 0 or visible[tidx] == 0:
255
+ continue
256
+ c = self.indices2category[(fidx, tidx)]
257
+ s = tuple(keypoints[fidx, :2].round().long().numpy().tolist())
258
+ e = tuple(keypoints[tidx, :2].round().long().numpy().tolist())
259
+ draw.line((s, e), width=1, fill=c + 1)
260
+ if visible[self.nose] == 0 and visible[self.neck] == 1:
261
+ m_eye = (
262
+ keypoints[self.l_eye]
263
+ + (keypoints[self.r_eye] - keypoints[self.l_eye]) / 2
264
+ )
265
+ s = tuple(m_eye[:2].round().long().numpy().tolist())
266
+ e = tuple(keypoints[self.neck, :2].round().long().numpy().tolist())
267
+ c = self.indices2category[(self.nose, self.neck)]
268
+ draw.line((s, e), width=1, fill=c + 1)
269
+ joint_map = np.array(joint_map)
270
+
271
+ joint_maps[bidx] = np.array(joint_map)
272
+ return joint_maps[:, None]
273
+
274
+ def forward(self, batch):
275
+ batch["joint_map"] = torch.from_numpy(self.create_joint_map(
276
+ batch["img"].shape[0], *self.imsize, batch["keypoints"]))
277
+ return batch
dp2/data/utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ import multiprocessing
5
+ import io
6
+ from tops import logger
7
+ from torch.utils.data._utils.collate import default_collate
8
+
9
+ try:
10
+ import pyspng
11
+
12
+ PYSPNG_IMPORTED = True
13
+ except ImportError:
14
+ PYSPNG_IMPORTED = False
15
+ print("Could not load pyspng. Defaulting to pillow image backend.")
16
+ from PIL import Image
17
+
18
+
19
+ def get_fdf_keypoints():
20
+ return get_coco_keypoints()[:7]
21
+
22
+
23
+ def get_fdf_flipmap():
24
+ keypoints = get_fdf_keypoints()
25
+ keypoint_flip_map = {
26
+ "left_eye": "right_eye",
27
+ "left_ear": "right_ear",
28
+ "left_shoulder": "right_shoulder",
29
+ }
30
+ for key, value in list(keypoint_flip_map.items()):
31
+ keypoint_flip_map[value] = key
32
+ keypoint_flip_map["nose"] = "nose"
33
+ keypoint_flip_map_idx = []
34
+ for source in keypoints:
35
+ keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
36
+ return keypoint_flip_map_idx
37
+
38
+
39
+ def get_coco_keypoints():
40
+ return [
41
+ "nose",
42
+ "left_eye",
43
+ "right_eye", # 2
44
+ "left_ear",
45
+ "right_ear", # 4
46
+ "left_shoulder",
47
+ "right_shoulder", # 6
48
+ "left_elbow",
49
+ "right_elbow", # 8
50
+ "left_wrist",
51
+ "right_wrist", # 10
52
+ "left_hip",
53
+ "right_hip", # 12
54
+ "left_knee",
55
+ "right_knee", # 14
56
+ "left_ankle",
57
+ "right_ankle", # 16
58
+ ]
59
+
60
+
61
+ def get_coco_flipmap():
62
+ keypoints = get_coco_keypoints()
63
+ keypoint_flip_map = {
64
+ "left_eye": "right_eye",
65
+ "left_ear": "right_ear",
66
+ "left_shoulder": "right_shoulder",
67
+ "left_elbow": "right_elbow",
68
+ "left_wrist": "right_wrist",
69
+ "left_hip": "right_hip",
70
+ "left_knee": "right_knee",
71
+ "left_ankle": "right_ankle",
72
+ }
73
+ for key, value in list(keypoint_flip_map.items()):
74
+ keypoint_flip_map[value] = key
75
+ keypoint_flip_map["nose"] = "nose"
76
+ keypoint_flip_map_idx = []
77
+ for source in keypoints:
78
+ keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
79
+ return keypoint_flip_map_idx
80
+
81
+
82
+ def mask_decoder(x):
83
+ mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
84
+ mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255.
85
+ return mask
86
+
87
+
88
+ def png_decoder(x):
89
+ if PYSPNG_IMPORTED:
90
+ return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
91
+ with Image.open(io.BytesIO(x)) as im:
92
+ im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
93
+ return im
94
+
95
+
96
+ def jpg_decoder(x):
97
+ with Image.open(io.BytesIO(x)) as im:
98
+ im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
99
+ return im
100
+
101
+
102
+ def get_num_workers(num_workers: int):
103
+ n_cpus = multiprocessing.cpu_count()
104
+ if num_workers > n_cpus:
105
+ logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
106
+ return n_cpus
107
+ return num_workers
108
+
109
+
110
+ def collate_fn(batch):
111
+ elem = batch[0]
112
+ ignore_keys = set(["embed_map", "vertx2cat"])
113
+ batch_ = {
114
+ key: default_collate([d[key] for d in batch])
115
+ for key in elem
116
+ if key not in ignore_keys
117
+ }
118
+ if "embed_map" in elem:
119
+ batch_["embed_map"] = elem["embed_map"]
120
+ if "vertx2cat" in elem:
121
+ batch_["vertx2cat"] = elem["vertx2cat"]
122
+ return batch_
dp2/detection/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .cse_mask_face_detector import CSeMaskFaceDetector
2
+ from .person_detector import CSEPersonDetector
3
+ from .structures import PersonDetection, VehicleDetection, FaceDetection
dp2/detection/base.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ import lzma
4
+ from pathlib import Path
5
+ from tops import logger
6
+
7
+
8
+ class BaseDetector:
9
+
10
+ def __init__(self, cache_directory: str) -> None:
11
+ if cache_directory is not None:
12
+ self.cache_directory = Path(cache_directory, str(self.__class__.__name__))
13
+ self.cache_directory.mkdir(exist_ok=True, parents=True)
14
+
15
+ def save_to_cache(self, detection, cache_path: Path, after_preprocess=True):
16
+ logger.log(f"Caching detection to: {cache_path}")
17
+ with lzma.open(cache_path, "wb") as fp:
18
+ torch.save(
19
+ [det.state_dict(after_preprocess=after_preprocess) for det in detection], fp,
20
+ pickle_protocol=pickle.HIGHEST_PROTOCOL)
21
+
22
+ def load_from_cache(self, cache_path: Path):
23
+ logger.log(f"Loading detection from cache path: {cache_path}")
24
+ with lzma.open(cache_path, "rb") as fp:
25
+ state_dict = torch.load(fp)
26
+ return [
27
+ state["cls"].from_state_dict(state_dict=state) for state in state_dict
28
+ ]
29
+
30
+ def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool):
31
+ if cache_id is None:
32
+ return self.forward(im)
33
+ cache_path = self.cache_directory.joinpath(cache_id + ".torch")
34
+ if cache_path.is_file() and load_cache:
35
+ try:
36
+ return self.load_from_cache(cache_path)
37
+ except Exception as e:
38
+ logger.warn(f"The cache file was corrupted: {cache_path}")
39
+ exit()
40
+ detections = self.forward(im)
41
+ self.save_to_cache(detections, cache_path)
42
+ return detections
dp2/detection/box_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio):
5
+ x0, y0, x1, y1 = [int(_) for _ in bbox]
6
+ h, w = y1 - y0, x1 - x0
7
+ cur_ratio = h / w
8
+
9
+ if cur_ratio == target_aspect_ratio:
10
+ return [x0, y0, x1, y1]
11
+ if cur_ratio < target_aspect_ratio:
12
+ target_height = int(w*target_aspect_ratio)
13
+ y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
14
+ else:
15
+ target_width = int(h/target_aspect_ratio)
16
+ x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
17
+ return x0, y0, x1, y1
18
+
19
+
20
+ def expand_axis(start, end, target_width, limit):
21
+ # Can return a bbox outside of limit
22
+ cur_width = end - start
23
+ start = start - (target_width-cur_width)//2
24
+ end = end + (target_width-cur_width)//2
25
+ if end - start != target_width:
26
+ end += 1
27
+ assert end - start == target_width
28
+ if start < 0 and end > limit:
29
+ return start, end
30
+ if start < 0 and end < limit:
31
+ to_shift = min(0 - start, limit - end)
32
+ start += to_shift
33
+ end += to_shift
34
+ if end > limit and start > 0:
35
+ to_shift = min(end - limit, start)
36
+ end -= to_shift
37
+ start -= to_shift
38
+ assert end - start == target_width
39
+ return start, end
40
+
41
+
42
+ def expand_box(bbox, imshape, mask, percentage_background: float):
43
+ assert isinstance(bbox[0], int)
44
+ assert 0 < percentage_background < 1
45
+ # Percentage in S
46
+ mask_pixels = mask.long().sum().cpu()
47
+ total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
48
+ percentage_mask = mask_pixels / total_pixels
49
+ if (1 - percentage_mask) > percentage_background:
50
+ return bbox
51
+ target_pixels = mask_pixels / (1 - percentage_background)
52
+ x0, y0, x1, y1 = bbox
53
+ H = y1 - y0
54
+ W = x1 - x0
55
+ p = np.sqrt(target_pixels/(H*W))
56
+ target_width = int(np.ceil(p * W))
57
+ target_height = int(np.ceil(p * H))
58
+ x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
59
+ y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
60
+ return [x0, y0, x1, y1]
61
+
62
+
63
+ def expand_axises_by_percentage(bbox_XYXY, imshape, percentage):
64
+ x0, y0, x1, y1 = bbox_XYXY
65
+ H = y1 - y0
66
+ W = x1 - x0
67
+ expansion = int(((H*W)**0.5) * percentage)
68
+ new_width = W + expansion
69
+ new_height = H + expansion
70
+ x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1])
71
+ y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0])
72
+ return [x0, y0, x1, y1]
73
+
74
+
75
+ def get_expanded_bbox(
76
+ bbox_XYXY,
77
+ imshape,
78
+ mask,
79
+ percentage_background: float,
80
+ axis_minimum_expansion: float,
81
+ target_aspect_ratio: float):
82
+ bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist()
83
+ # Expand each axis of the bounding box by a minimum percentage
84
+ bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion)
85
+ # Find the minimum bbox with the aspect ratio. Can be outside of imshape
86
+ bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio)
87
+ # Expands square box such that X% of the bbox is background
88
+ bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background)
89
+ assert isinstance(bbox_XYXY[0], (int, np.int64))
90
+ return bbox_XYXY
91
+
92
+
93
+ def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape):
94
+ def area_inside_ratio(bbox, imshape):
95
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
96
+ area_inside = (min(bbox[2], imshape[1]) - max(0, bbox[0])) * (min(imshape[0], bbox[3]) - max(0, bbox[1]))
97
+ return area_inside / area
98
+ ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0])
99
+ area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
100
+ if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside:
101
+ return False
102
+ if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area:
103
+ return False
104
+ return True
dp2/detection/box_utils_fdf.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The FDF dataset expands bound boxes differently from what is used for CSE.
3
+ """
4
+
5
+ import numpy as np
6
+
7
+
8
+ def quadratic_bounding_box(x0, y0, width, height, imshape):
9
+ # We assume that we can create a image that is quadratic without
10
+ # minimizing any of the sides
11
+ assert width <= min(imshape[:2])
12
+ assert height <= min(imshape[:2])
13
+ min_side = min(height, width)
14
+ if height != width:
15
+ side_diff = abs(height - width)
16
+ # Want to extend the shortest side
17
+ if min_side == height:
18
+ # Vertical side
19
+ height += side_diff
20
+ if height > imshape[0]:
21
+ # Take full frame, and shrink width
22
+ y0 = 0
23
+ height = imshape[0]
24
+
25
+ side_diff = abs(height - width)
26
+ width -= side_diff
27
+ x0 += side_diff // 2
28
+ else:
29
+ y0 -= side_diff // 2
30
+ y0 = max(0, y0)
31
+ else:
32
+ # Horizontal side
33
+ width += side_diff
34
+ if width > imshape[1]:
35
+ # Take full frame width, and shrink height
36
+ x0 = 0
37
+ width = imshape[1]
38
+
39
+ side_diff = abs(height - width)
40
+ height -= side_diff
41
+ y0 += side_diff // 2
42
+ else:
43
+ x0 -= side_diff // 2
44
+ x0 = max(0, x0)
45
+ # Check that bbox goes outside image
46
+ x1 = x0 + width
47
+ y1 = y0 + height
48
+ if imshape[1] < x1:
49
+ diff = x1 - imshape[1]
50
+ x0 -= diff
51
+ if imshape[0] < y1:
52
+ diff = y1 - imshape[0]
53
+ y0 -= diff
54
+ assert x0 >= 0, "Bounding box outside image."
55
+ assert y0 >= 0, "Bounding box outside image."
56
+ assert x0 + width <= imshape[1], "Bounding box outside image."
57
+ assert y0 + height <= imshape[0], "Bounding box outside image."
58
+ return x0, y0, width, height
59
+
60
+
61
+ def expand_bounding_box(bbox, percentage, imshape):
62
+ orig_bbox = bbox.copy()
63
+ x0, y0, x1, y1 = bbox
64
+ width = x1 - x0
65
+ height = y1 - y0
66
+ x0, y0, width, height = quadratic_bounding_box(
67
+ x0, y0, width, height, imshape)
68
+ expanding_factor = int(max(height, width) * percentage)
69
+
70
+ possible_max_expansion = [(imshape[0] - width) // 2,
71
+ (imshape[1] - height) // 2,
72
+ expanding_factor]
73
+
74
+ expanding_factor = min(possible_max_expansion)
75
+ # Expand height
76
+
77
+ if expanding_factor > 0:
78
+
79
+ y0 = y0 - expanding_factor
80
+ y0 = max(0, y0)
81
+
82
+ height += expanding_factor * 2
83
+ if height > imshape[0]:
84
+ y0 -= (imshape[0] - height)
85
+ height = imshape[0]
86
+
87
+ if height + y0 > imshape[0]:
88
+ y0 -= (height + y0 - imshape[0])
89
+
90
+ # Expand width
91
+ x0 = x0 - expanding_factor
92
+ x0 = max(0, x0)
93
+
94
+ width += expanding_factor * 2
95
+ if width > imshape[1]:
96
+ x0 -= (imshape[1] - width)
97
+ width = imshape[1]
98
+
99
+ if width + x0 > imshape[1]:
100
+ x0 -= (width + x0 - imshape[1])
101
+ y1 = y0 + height
102
+ x1 = x0 + width
103
+ assert y0 >= 0, "Y0 is minus"
104
+ assert height <= imshape[0], "Height is larger than image."
105
+ assert x0 + width <= imshape[1]
106
+ assert y0 + height <= imshape[0]
107
+ assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!"
108
+ assert x0 >= 0, "Y0 is minus"
109
+ assert width <= imshape[1], "Height is larger than image."
110
+ # Check that original bbox is within new
111
+ x0_o, y0_o, x1_o, y1_o = orig_bbox
112
+ assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}"
113
+ assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}"
114
+ assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}"
115
+ assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}"
116
+
117
+ x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]]
118
+ x1 = x0 + width
119
+ y1 = y0 + height
120
+ return np.array([x0, y0, x1, y1])
121
+
122
+
123
+ def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
124
+ keypoint = keypoint[:, :3] # only nose + eyes are relevant
125
+ kp_X = keypoint[0, :]
126
+ kp_Y = keypoint[1, :]
127
+ within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1)
128
+ within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1)
129
+ return within_X and within_Y
130
+
131
+
132
+ def expand_bbox_simple(bbox, percentage):
133
+ x0, y0, x1, y1 = bbox.astype(float)
134
+ width = x1 - x0
135
+ height = y1 - y0
136
+ x_c = int(x0) + width // 2
137
+ y_c = int(y0) + height // 2
138
+ avg_size = max(width, height)
139
+ new_width = avg_size * (1 + percentage)
140
+ x0 = x_c - new_width // 2
141
+ y0 = y_c - new_width // 2
142
+ x1 = x_c + new_width // 2
143
+ y1 = y_c + new_width // 2
144
+ return np.array([x0, y0, x1, y1]).astype(int)
145
+
146
+
147
+ def pad_image(im, bbox, pad_value):
148
+ x0, y0, x1, y1 = bbox
149
+ if x0 < 0:
150
+ pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]),
151
+ dtype=np.uint8) + pad_value
152
+ im = np.concatenate((pad_im, im), axis=1)
153
+ x1 += abs(x0)
154
+ x0 = 0
155
+ if y0 < 0:
156
+ pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]),
157
+ dtype=np.uint8) + pad_value
158
+ im = np.concatenate((pad_im, im), axis=0)
159
+ y1 += abs(y0)
160
+ y0 = 0
161
+ if x1 >= im.shape[1]:
162
+ pad_im = np.zeros(
163
+ (im.shape[0], x1 - im.shape[1] + 1, im.shape[2]),
164
+ dtype=np.uint8) + pad_value
165
+ im = np.concatenate((im, pad_im), axis=1)
166
+ if y1 >= im.shape[0]:
167
+ pad_im = np.zeros(
168
+ (y1 - im.shape[0] + 1, im.shape[1], im.shape[2]),
169
+ dtype=np.uint8) + pad_value
170
+ im = np.concatenate((im, pad_im), axis=0)
171
+ return im[y0:y1, x0:x1]
172
+
173
+
174
+ def clip_box(bbox, im):
175
+ bbox[0] = max(0, bbox[0])
176
+ bbox[1] = max(0, bbox[1])
177
+ bbox[2] = min(im.shape[1] - 1, bbox[2])
178
+ bbox[3] = min(im.shape[0] - 1, bbox[3])
179
+ return bbox
180
+
181
+
182
+ def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True):
183
+ outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0]
184
+ if simple_expand or (outside_im and pad_im):
185
+ return pad_image(im, bbox, pad_value)
186
+ bbox = clip_box(bbox, im)
187
+ x0, y0, x1, y1 = bbox
188
+ return im[y0:y1, x0:x1]
189
+
190
+
191
+ def expand_bbox(
192
+ bbox_ltrb, imshape, simple_expand, default_to_simple=False,
193
+ expansion_factor=0.35):
194
+ assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox_ltrb.shape}"
195
+ bbox = bbox_ltrb.astype(float)
196
+ # FDF256 uses simple expand with ratio 0.4
197
+ if simple_expand:
198
+ return expand_bbox_simple(bbox, 0.4)
199
+ try:
200
+ return expand_bounding_box(bbox, expansion_factor, imshape)
201
+ except AssertionError:
202
+ return expand_bbox_simple(bbox, expansion_factor * 2)
dp2/detection/cse_mask_face_detector.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lzma
3
+ import tops
4
+ from pathlib import Path
5
+ from dp2.detection.base import BaseDetector
6
+ from .utils import combine_cse_maskrcnn_dets
7
+ from face_detection import build_detector as build_face_detector
8
+ from .models.cse import CSEDetector
9
+ from .models.mask_rcnn import MaskRCNNDetector
10
+ from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection
11
+ from tops import logger
12
+
13
+
14
+ def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
15
+ assert len(box1.shape) == 2
16
+ assert len(box2.shape) == 2
17
+ box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
18
+ # This can be batched
19
+ for i, box in enumerate(box1):
20
+ is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
21
+ is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
22
+ is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
23
+ box1_inside[i] = is_outside.logical_not().any()
24
+ return box1_inside
25
+
26
+
27
+ class CSeMaskFaceDetector(BaseDetector):
28
+
29
+ def __init__(
30
+ self,
31
+ mask_rcnn_cfg,
32
+ face_detector_cfg: dict,
33
+ cse_cfg: dict,
34
+ face_post_process_cfg: dict,
35
+ cse_post_process_cfg,
36
+ score_threshold: float,
37
+ **kwargs
38
+ ) -> None:
39
+ super().__init__(**kwargs)
40
+ self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
41
+ if "confidence_threshold" not in face_detector_cfg:
42
+ face_detector_cfg["confidence_threshold"] = score_threshold
43
+ if "score_thres" not in cse_cfg:
44
+ cse_cfg["score_thres"] = score_threshold
45
+ self.cse_detector = CSEDetector(**cse_cfg)
46
+ self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True)
47
+ self.cse_post_process_cfg = cse_post_process_cfg
48
+ self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
49
+ self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold")
50
+ self.face_post_process_cfg = face_post_process_cfg
51
+
52
+ def __call__(self, *args, **kwargs):
53
+ return self.forward(*args, **kwargs)
54
+
55
+ def _detect_faces(self, im: torch.Tensor):
56
+ H, W = im.shape[1:]
57
+ im = im.float() - self.face_mean
58
+ im = self.face_detector.resize(im[None], 1.0)
59
+ boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
60
+ boxes_XYXY[:, [0, 2]] *= W
61
+ boxes_XYXY[:, [1, 3]] *= H
62
+ return boxes_XYXY.round().long()
63
+
64
+ def load_from_cache(self, cache_path: Path):
65
+ logger.log(f"Loading detection from cache path: {cache_path}",)
66
+ with lzma.open(cache_path, "rb") as fp:
67
+ state_dict = torch.load(fp, map_location="cpu")
68
+ kwargs = dict(
69
+ post_process_cfg=self.cse_post_process_cfg,
70
+ embed_map=self.cse_detector.embed_map,
71
+ **self.face_post_process_cfg
72
+ )
73
+ return [
74
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
75
+ for state in state_dict
76
+ ]
77
+
78
+ @torch.no_grad()
79
+ def forward(self, im: torch.Tensor):
80
+ maskrcnn_dets = self.mask_rcnn(im)
81
+ cse_dets = self.cse_detector(im)
82
+ embed_map = self.cse_detector.embed_map
83
+ print("Calling face detector.")
84
+ face_boxes = self._detect_faces(im).cpu()
85
+ maskrcnn_person = {
86
+ k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items()
87
+ }
88
+ maskrcnn_other = {
89
+ k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items()
90
+ }
91
+ maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"])
92
+ combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets(
93
+ maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold)
94
+
95
+ persons_with_cse = CSEPersonDetection(
96
+ combined_segmentation, cse_dets, **self.cse_post_process_cfg,
97
+ embed_map=embed_map, orig_imshape_CHW=im.shape
98
+ )
99
+ persons_with_cse.pre_process()
100
+ not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]]
101
+ persons_without_cse = PersonDetection(
102
+ maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg,
103
+ orig_imshape_CHW=im.shape
104
+ )
105
+ persons_without_cse.pre_process()
106
+
107
+ face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or(
108
+ box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes)
109
+ )
110
+ face_boxes = face_boxes[face_boxes_covered.logical_not()]
111
+ face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
112
+
113
+ # Order matters. The anonymizer will anonymize FIFO.
114
+ # Later detections will overwrite.
115
+ all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse]
116
+ return all_detections
dp2/detection/deep_privacy1_detector.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ import lzma
4
+ from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
5
+ from .base import BaseDetector
6
+ from face_detection import build_detector as build_face_detector
7
+ from .structures import FaceDetection
8
+ from tops import logger
9
+ from pathlib import Path
10
+
11
+ def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
12
+ keypoint = keypoint[:3, :] # only nose + eyes are relevant
13
+ kp_X = keypoint[:, 0]
14
+ kp_Y = keypoint[:, 1]
15
+ within_X = (kp_X >= x0).all() and (kp_X <= x1).all()
16
+ within_Y = (kp_Y >= y0).all() and (kp_Y <= y1).all()
17
+ return within_X and within_Y
18
+
19
+
20
+ def match_bbox_keypoint(bounding_boxes, keypoints):
21
+ """
22
+ bounding_boxes shape: [N, 5]
23
+ keypoints: [N persons, K keypoints, (x, y)]
24
+ """
25
+ if len(bounding_boxes) == 0 or len(keypoints) == 0:
26
+ return torch.empty((0, 5)), torch.empty((0, 7, 2))
27
+ assert bounding_boxes.shape[1] == 4,\
28
+ f"Shape was : {bounding_boxes.shape}"
29
+ assert keypoints.shape[-1] == 2,\
30
+ f"Expected (x,y) in last axis, got: {keypoints.shape}"
31
+ assert keypoints.shape[1] in (5, 7),\
32
+ f"Expeted 5 or 7 keypoints. Keypoint shape was: {keypoints.shape}"
33
+
34
+ matches = []
35
+ for bbox_idx, bbox in enumerate(bounding_boxes):
36
+ keypoint = None
37
+ for kp_idx, keypoint in enumerate(keypoints):
38
+ if kp_idx in (x[1] for x in matches):
39
+ continue
40
+ if is_keypoint_within_bbox(*bbox, keypoint):
41
+ matches.append((bbox_idx, kp_idx))
42
+ break
43
+ keypoint_idx = [x[1] for x in matches]
44
+ bbox_idx = [x[0] for x in matches]
45
+ return bounding_boxes[bbox_idx], keypoints[keypoint_idx]
46
+
47
+
48
+ class DeepPrivacy1Detector(BaseDetector):
49
+
50
+ def __init__(self,
51
+ keypoint_threshold: float,
52
+ face_detector_cfg,
53
+ score_threshold: float,
54
+ face_post_process_cfg,
55
+ **kwargs):
56
+ super().__init__(**kwargs)
57
+ self.keypoint_detector = tops.to_cuda(keypointrcnn_resnet50_fpn(
58
+ weights=KeypointRCNN_ResNet50_FPN_Weights.COCO_V1).eval())
59
+ self.keypoint_threshold = keypoint_threshold
60
+ self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
61
+ self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
62
+ self.face_post_process_cfg = face_post_process_cfg
63
+
64
+ @torch.no_grad()
65
+ def _detect_faces(self, im: torch.Tensor):
66
+ H, W = im.shape[1:]
67
+ im = im.float() - self.face_mean
68
+ im = self.face_detector.resize(im[None], 1.0)
69
+ boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
70
+ boxes_XYXY[:, [0, 2]] *= W
71
+ boxes_XYXY[:, [1, 3]] *= H
72
+ return boxes_XYXY.round().long().cpu()
73
+
74
+ @torch.no_grad()
75
+ def _detect_keypoints(self, img: torch.Tensor):
76
+ img = img.float() / 255
77
+ outputs = self.keypoint_detector([img])
78
+
79
+ # Shape: [N persons, K keypoints, (x,y,visibility)]
80
+ keypoints = outputs[0]["keypoints"]
81
+ scores = outputs[0]["scores"]
82
+ assert list(scores) == sorted(list(scores))[::-1]
83
+ mask = scores >= self.keypoint_threshold
84
+ keypoints = keypoints[mask, :, :2]
85
+ return keypoints[:, :7, :2]
86
+
87
+ def __call__(self, *args, **kwargs):
88
+ return self.forward(*args, **kwargs)
89
+
90
+ @torch.no_grad()
91
+ def forward(self, im: torch.Tensor):
92
+ face_boxes = self._detect_faces(im)
93
+ keypoints = self._detect_keypoints(im)
94
+ face_boxes, keypoints = match_bbox_keypoint(face_boxes, keypoints)
95
+ face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg, keypoints=keypoints)
96
+ return [face_boxes]
97
+
98
+ def load_from_cache(self, cache_path: Path):
99
+ logger.log(f"Loading detection from cache path: {cache_path}",)
100
+ with lzma.open(cache_path, "rb") as fp:
101
+ state_dict = torch.load(fp, map_location="cpu")
102
+ kwargs = self.face_post_process_cfg
103
+ return [
104
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
105
+ for state in state_dict
106
+ ]
dp2/detection/face_detector.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lzma
3
+ import tops
4
+ from pathlib import Path
5
+ from dp2.detection.base import BaseDetector
6
+ from face_detection import build_detector as build_face_detector
7
+ from .structures import FaceDetection
8
+ from tops import logger
9
+
10
+
11
+ def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
12
+ assert len(box1.shape) == 2
13
+ assert len(box2.shape) == 2
14
+ box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
15
+ # This can be batched
16
+ for i, box in enumerate(box1):
17
+ is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
18
+ is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
19
+ is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
20
+ box1_inside[i] = is_outside.logical_not().any()
21
+ return box1_inside
22
+
23
+
24
+ class FaceDetector(BaseDetector):
25
+
26
+ def __init__(
27
+ self,
28
+ face_detector_cfg: dict,
29
+ score_threshold: float,
30
+ face_post_process_cfg: dict,
31
+ **kwargs
32
+ ) -> None:
33
+ super().__init__(**kwargs)
34
+ self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
35
+ self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
36
+ self.face_post_process_cfg = face_post_process_cfg
37
+
38
+ def __call__(self, *args, **kwargs):
39
+ return self.forward(*args, **kwargs)
40
+
41
+ def _detect_faces(self, im: torch.Tensor):
42
+ H, W = im.shape[1:]
43
+ im = im.float() - self.face_mean
44
+ im = self.face_detector.resize(im[None], 1.0)
45
+ boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
46
+ boxes_XYXY[:, [0, 2]] *= W
47
+ boxes_XYXY[:, [1, 3]] *= H
48
+ return boxes_XYXY.round().long().cpu()
49
+
50
+ @torch.no_grad()
51
+ def forward(self, im: torch.Tensor):
52
+ face_boxes = self._detect_faces(im)
53
+ face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
54
+ return [face_boxes]
55
+
56
+ def load_from_cache(self, cache_path: Path):
57
+ logger.log(f"Loading detection from cache path: {cache_path}")
58
+ with lzma.open(cache_path, "rb") as fp:
59
+ state_dict = torch.load(fp)
60
+ return [
61
+ state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict
62
+ ]