haakohu commited on
Commit
415d841
1 Parent(s): d4bcf77

initial commit

Browse files
Files changed (40) hide show
  1. .gitignore +54 -0
  2. configs/anonymizers/FB_cse.py +28 -0
  3. configs/anonymizers/FB_cse_mask.py +29 -0
  4. configs/anonymizers/FB_cse_mask_face.py +29 -0
  5. configs/anonymizers/face.py +18 -0
  6. configs/anonymizers/market1501/blackout.py +8 -0
  7. configs/anonymizers/market1501/person.py +6 -0
  8. configs/anonymizers/market1501/pixelation16.py +8 -0
  9. configs/anonymizers/market1501/pixelation8.py +8 -0
  10. configs/datasets/coco_cse.py +69 -0
  11. configs/datasets/fdf128.py +24 -0
  12. configs/datasets/fdf256.py +69 -0
  13. configs/datasets/fdh.py +89 -0
  14. configs/datasets/utils.py +12 -0
  15. configs/defaults.py +45 -0
  16. configs/discriminators/sg2_discriminator.py +42 -0
  17. configs/fdf/stylegan.py +14 -0
  18. configs/fdf/stylegan_fdf128.py +13 -0
  19. configs/fdh/styleganL.py +16 -0
  20. configs/fdh/styleganL_nocse.py +14 -0
  21. configs/generators/stylegan_unet.py +22 -0
  22. multi_app.py +204 -0
  23. sg3_torch_utils/LICENSE.txt +97 -0
  24. sg3_torch_utils/__init__.py +9 -0
  25. sg3_torch_utils/custom_ops.py +126 -0
  26. sg3_torch_utils/misc.py +172 -0
  27. sg3_torch_utils/ops/__init__.py +9 -0
  28. sg3_torch_utils/ops/bias_act.cpp +99 -0
  29. sg3_torch_utils/ops/bias_act.cu +173 -0
  30. sg3_torch_utils/ops/bias_act.h +38 -0
  31. sg3_torch_utils/ops/bias_act.py +215 -0
  32. sg3_torch_utils/ops/conv2d_gradfix.py +175 -0
  33. sg3_torch_utils/ops/conv2d_resample.py +142 -0
  34. sg3_torch_utils/ops/fma.py +63 -0
  35. sg3_torch_utils/ops/grid_sample_gradfix.py +88 -0
  36. sg3_torch_utils/ops/upfirdn2d.cpp +103 -0
  37. sg3_torch_utils/ops/upfirdn2d.cu +350 -0
  38. sg3_torch_utils/ops/upfirdn2d.h +59 -0
  39. sg3_torch_utils/ops/upfirdn2d.py +388 -0
  40. stylemc.py +295 -0
.gitignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FILES
2
+ *.yaml
3
+ *.pkl
4
+ *.flist
5
+ *.zip
6
+ *.out
7
+ *.npy
8
+ *.gz
9
+ *.ckpt
10
+ *.pth
11
+ *.log
12
+ *.pyc
13
+ *.csv
14
+ *.yml
15
+ *.ods
16
+ *.ods#
17
+ *.json
18
+ build_docker.sh
19
+
20
+ # Images / Videos
21
+ #*.png
22
+ #*.jpg
23
+ *.jpeg
24
+ *.m4a
25
+ *.mkv
26
+ *.mp4
27
+
28
+ # Directories created by inpaintron
29
+ .cache/
30
+ test_examples/
31
+ .vscode
32
+ __pycache__
33
+ .debug/
34
+ **/.ipynb_checkpoints/**
35
+ outputs/
36
+
37
+
38
+ # From pip setup
39
+ build/
40
+ *.egg-info
41
+ *.egg
42
+ .npm/
43
+
44
+ # From dockerfile
45
+ .bash_history
46
+ .viminfo
47
+ .local/
48
+ *.pickle
49
+ *.onnx
50
+
51
+
52
+ sbatch_files/
53
+ figures/
54
+ image_dump/
configs/anonymizers/FB_cse.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.person_detector import CSEPersonDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+ from dp2.generator.dummy_generators import MaskOutGenerator
6
+
7
+
8
+ maskout_G = L(MaskOutGenerator)(noise="constant")
9
+
10
+ detector = L(CSEPersonDetector)(
11
+ mask_rcnn_cfg=dict(),
12
+ cse_cfg=dict(),
13
+ cse_post_process_cfg=dict(
14
+ target_imsize=(288, 160),
15
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
16
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
17
+ iou_combine_threshold=0.4,
18
+ dilation_percentage=0.02,
19
+ normalize_embedding=False
20
+ ),
21
+ score_threshold=0.3,
22
+ cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
23
+ )
24
+
25
+ anonymizer = L(Anonymizer)(
26
+ detector="${detector}",
27
+ cse_person_G_cfg="configs/fdh/styleganL.py",
28
+ )
configs/anonymizers/FB_cse_mask.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.person_detector import CSEPersonDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+ from dp2.generator.dummy_generators import MaskOutGenerator
6
+
7
+
8
+ maskout_G = L(MaskOutGenerator)(noise="constant")
9
+
10
+ detector = L(CSEPersonDetector)(
11
+ mask_rcnn_cfg=dict(),
12
+ cse_cfg=dict(),
13
+ cse_post_process_cfg=dict(
14
+ target_imsize=(288, 160),
15
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
16
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
17
+ iou_combine_threshold=0.4,
18
+ dilation_percentage=0.02,
19
+ normalize_embedding=False
20
+ ),
21
+ score_threshold=0.3,
22
+ cache_directory=common.output_dir.joinpath("cse_person_detection_cache")
23
+ )
24
+
25
+ anonymizer = L(Anonymizer)(
26
+ detector="${detector}",
27
+ person_G_cfg="configs/fdh/styleganL_nocse.py",
28
+ cse_person_G_cfg="configs/fdh/styleganL.py",
29
+ )
configs/anonymizers/FB_cse_mask_face.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.cse_mask_face_detector import CSeMaskFaceDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+
6
+ detector = L(CSeMaskFaceDetector)(
7
+ mask_rcnn_cfg=dict(),
8
+ face_detector_cfg=dict(),
9
+ face_post_process_cfg=dict(target_imsize=(256, 256)),
10
+ cse_cfg=dict(),
11
+ cse_post_process_cfg=dict(
12
+ target_imsize=(288, 160),
13
+ exp_bbox_cfg=dict(percentage_background=0.3, axis_minimum_expansion=.1),
14
+ exp_bbox_filter=dict(minimum_area=32*32, min_bbox_ratio_inside=0, aspect_ratio_range=[0, 99999]),
15
+ iou_combine_threshold=0.4,
16
+ dilation_percentage=0.02,
17
+ normalize_embedding=False
18
+ ),
19
+ score_threshold=0.3,
20
+ cache_directory=common.output_dir.joinpath("cse_mask_face_detection_cache")
21
+ )
22
+
23
+ anonymizer = L(Anonymizer)(
24
+ detector="${detector}",
25
+ face_G_cfg="configs/fdf/stylegan.py",
26
+ person_G_cfg="configs/fdh/styleganL_nocse.py",
27
+ cse_person_G_cfg="configs/fdh/styleganL.py",
28
+ car_G_cfg="configs/generators/dummy/pixelation8.py"
29
+ )
configs/anonymizers/face.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.anonymizer import Anonymizer
2
+ from dp2.detection.face_detector import FaceDetector
3
+ from ..defaults import common
4
+ from tops.config import LazyCall as L
5
+
6
+
7
+ detector = L(FaceDetector)(
8
+ face_detector_cfg=dict(name="DSFDDetector", clip_boxes=True),
9
+ face_post_process_cfg=dict(target_imsize=(256, 256)),
10
+ score_threshold=0.3,
11
+ cache_directory=common.output_dir.joinpath("face_detection_cache")
12
+ )
13
+
14
+
15
+ anonymizer = L(Anonymizer)(
16
+ detector="${detector}",
17
+ face_G_cfg="configs/fdf/stylegan.py",
18
+ )
configs/anonymizers/market1501/blackout.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
7
+ anonymizer.generators.person_G_cfg = "configs/generators/dummy/maskout.py"
8
+ anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/maskout.py"
configs/anonymizers/market1501/person.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
configs/anonymizers/market1501/pixelation16.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
7
+ anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation16.py"
8
+ anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation16.py"
configs/anonymizers/market1501/pixelation8.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ..FB_cse_mask_face import anonymizer, detector, common
2
+
3
+ detector.score_threshold = .1
4
+ detector.face_detector_cfg.confidence_threshold = .5
5
+ detector.cse_cfg.score_thres = 0.3
6
+ anonymizer.generators.face_G_cfg = None
7
+ anonymizer.generators.person_G_cfg = "configs/generators/dummy/pixelation8.py"
8
+ anonymizer.generators.cse_person_G_cfg = "configs/generators/dummy/pixelation8.py"
configs/datasets/coco_cse.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tops.config import LazyCall as L
4
+ import torch
5
+ import functools
6
+ from dp2.data.datasets import CocoCSE
7
+ from dp2.data.build import get_dataloader
8
+ from dp2.data.transforms.transforms import CreateEmbedding, Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
9
+ from dp2.data.transforms.stylegan2_transform import StyleGANAugmentPipe
10
+ from dp2.metrics.torch_metrics import compute_metrics_iteratively
11
+ from .utils import final_eval_fn
12
+
13
+
14
+ dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
15
+ metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
16
+ data_dir = Path(dataset_base_dir, "coco_cse")
17
+ data = dict(
18
+ imsize=(288, 160),
19
+ im_channels=3,
20
+ semantic_nc=26,
21
+ cse_nc=16,
22
+ train=dict(
23
+ dataset=L(CocoCSE)(data_dir.joinpath("train"), transform=None, normalize_E=False),
24
+ loader=L(get_dataloader)(
25
+ shuffle=True, num_workers=6, drop_last=True, prefetch_factor=2,
26
+ batch_size="${train.batch_size}",
27
+ dataset="${..dataset}",
28
+ infinite=True,
29
+ gpu_transform=L(torch.nn.Sequential)(*[
30
+ L(ToFloat)(),
31
+ L(StyleGANAugmentPipe)(
32
+ rotate=0.5, rotate_max=.05,
33
+ xint=.5, xint_max=0.05,
34
+ scale=.5, scale_std=.05,
35
+ aniso=0.5, aniso_std=.05,
36
+ xfrac=.5, xfrac_std=.05,
37
+ brightness=.5, brightness_std=.05,
38
+ contrast=.5, contrast_std=.1,
39
+ hue=.5, hue_max=.05,
40
+ saturation=.5, saturation_std=.5,
41
+ imgfilter=.5, imgfilter_std=.1),
42
+ L(RandomHorizontalFlip)(p=0.5),
43
+ L(CreateEmbedding)(),
44
+ L(Resize)(size="${data.imsize}"),
45
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
46
+ L(CreateCondition)(),
47
+ ])
48
+ )
49
+ ),
50
+ val=dict(
51
+ dataset=L(CocoCSE)(data_dir.joinpath("val"), transform=None, normalize_E=False),
52
+ loader=L(get_dataloader)(
53
+ shuffle=False, num_workers=6, drop_last=True, prefetch_factor=2,
54
+ batch_size="${train.batch_size}",
55
+ dataset="${..dataset}",
56
+ infinite=False,
57
+ gpu_transform=L(torch.nn.Sequential)(*[
58
+ L(ToFloat)(),
59
+ L(CreateEmbedding)(),
60
+ L(Resize)(size="${data.imsize}"),
61
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
62
+ L(CreateCondition)(),
63
+ ])
64
+ )
65
+ ),
66
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
67
+ train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "coco_cse_val"), include_two_fake=False),
68
+ evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "coco_cse_val_final"), include_two_fake=True)
69
+ )
configs/datasets/fdf128.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from functools import partial
3
+ from dp2.data.datasets.fdf import FDFDataset
4
+ from .fdf256 import data, dataset_base_dir, metrics_cache, final_eval_fn
5
+
6
+ data_dir = Path(dataset_base_dir, "fdf")
7
+ data.train.dataset.dirpath = data_dir.joinpath("train")
8
+ data.val.dataset.dirpath = data_dir.joinpath("val")
9
+ data.imsize = (128, 128)
10
+
11
+
12
+ data.train_evaluation_fn = partial(
13
+ final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_train"))
14
+ data.evaluation_fn = partial(
15
+ final_eval_fn, cache_directory=Path(metrics_cache, "fdf128_val_final"))
16
+
17
+ data.train.dataset.update(
18
+ _target_ = FDFDataset,
19
+ imsize="${data.imsize}"
20
+ )
21
+ data.val.dataset.update(
22
+ _target_ = FDFDataset,
23
+ imsize="${data.imsize}"
24
+ )
configs/datasets/fdf256.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tops.config import LazyCall as L
4
+ import torch
5
+ import functools
6
+ from dp2.data.datasets.fdf import FDF256Dataset
7
+ from dp2.data.build import get_dataloader
8
+ from dp2.data.transforms.transforms import Normalize, Resize, ToFloat, CreateCondition, RandomHorizontalFlip
9
+ from dp2.metrics.torch_metrics import compute_metrics_iteratively
10
+ from dp2.metrics.fid_clip import compute_fid_clip
11
+ from dp2.metrics.ppl import calculate_ppl
12
+ from .utils import final_eval_fn
13
+
14
+
15
+ def final_eval_fn(*args, **kwargs):
16
+ result = compute_metrics_iteratively(*args, **kwargs)
17
+ result2 = compute_fid_clip(*args, **kwargs)
18
+ assert all(key not in result for key in result2)
19
+ result.update(result2)
20
+ result3 = calculate_ppl(*args, **kwargs,)
21
+ assert all(key not in result for key in result3)
22
+ result.update(result3)
23
+ return result
24
+
25
+
26
+ dataset_base_dir = os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
27
+ metrics_cache = os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
28
+ data_dir = Path(dataset_base_dir, "fdf256")
29
+ data = dict(
30
+ imsize=(256, 256),
31
+ im_channels=3,
32
+ semantic_nc=None,
33
+ cse_nc=None,
34
+ n_keypoints=None,
35
+ train=dict(
36
+ dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("train"), transform=None, load_keypoints=False),
37
+ loader=L(get_dataloader)(
38
+ shuffle=True, num_workers=3, drop_last=True, prefetch_factor=2,
39
+ batch_size="${train.batch_size}",
40
+ dataset="${..dataset}",
41
+ infinite=True,
42
+ gpu_transform=L(torch.nn.Sequential)(*[
43
+ L(ToFloat)(),
44
+ L(RandomHorizontalFlip)(p=0.5),
45
+ L(Resize)(size="${data.imsize}"),
46
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
47
+ L(CreateCondition)(),
48
+ ])
49
+ )
50
+ ),
51
+ val=dict(
52
+ dataset=L(FDF256Dataset)(dirpath=data_dir.joinpath("val"), transform=None, load_keypoints=False),
53
+ loader=L(get_dataloader)(
54
+ shuffle=False, num_workers=3, drop_last=False, prefetch_factor=2,
55
+ batch_size="${train.batch_size}",
56
+ dataset="${..dataset}",
57
+ infinite=False,
58
+ gpu_transform=L(torch.nn.Sequential)(*[
59
+ L(ToFloat)(),
60
+ L(Resize)(size="${data.imsize}"),
61
+ L(Normalize)(mean=[.5, .5, .5], std=[.5, .5, .5], inplace=True),
62
+ L(CreateCondition)(),
63
+ ])
64
+ )
65
+ ),
66
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
67
+ train_evaluation_fn=functools.partial(compute_metrics_iteratively, cache_directory=Path(metrics_cache, "fdf_val_train")),
68
+ evaluation_fn=functools.partial(final_eval_fn, cache_directory=Path(metrics_cache, "fdf_val"))
69
+ )
configs/datasets/fdh.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from tops.config import LazyCall as L
4
+ import torch
5
+ import functools
6
+ from dp2.data.datasets.fdh import get_dataloader_fdh_wds
7
+ from dp2.data.utils import get_coco_flipmap
8
+ from dp2.data.transforms.transforms import (
9
+ Normalize,
10
+ ToFloat,
11
+ CreateCondition,
12
+ RandomHorizontalFlip,
13
+ CreateEmbedding,
14
+ )
15
+ from dp2.metrics.torch_metrics import compute_metrics_iteratively
16
+ from dp2.metrics.fid_clip import compute_fid_clip
17
+ from .utils import final_eval_fn
18
+
19
+
20
+ def train_eval_fn(*args, **kwargs):
21
+ result = compute_metrics_iteratively(*args, **kwargs)
22
+ result2 = compute_fid_clip(*args, **kwargs)
23
+ assert all(key not in result for key in result2)
24
+ result.update(result2)
25
+ return result
26
+
27
+
28
+ dataset_base_dir = (
29
+ os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data"
30
+ )
31
+ metrics_cache = (
32
+ os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache"
33
+ )
34
+ data_dir = Path(dataset_base_dir, "fdh")
35
+ data = dict(
36
+ imsize=(288, 160),
37
+ im_channels=3,
38
+ cse_nc=16,
39
+ n_keypoints=17,
40
+ train=dict(
41
+ loader=L(get_dataloader_fdh_wds)(
42
+ path=data_dir.joinpath("train", "out-{000000..001423}.tar"),
43
+ batch_size="${train.batch_size}",
44
+ num_workers=6,
45
+ transform=L(torch.nn.Sequential)(
46
+ L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()),
47
+ ),
48
+ gpu_transform=L(torch.nn.Sequential)(
49
+ L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]),
50
+ L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
51
+ L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
52
+ L(CreateCondition)(),
53
+ ),
54
+ infinite=True,
55
+ shuffle=True,
56
+ partial_batches=False,
57
+ load_embedding=True,
58
+ )
59
+ ),
60
+ val=dict(
61
+ loader=L(get_dataloader_fdh_wds)(
62
+ path=data_dir.joinpath("val", "out-{000000..000023}.tar"),
63
+ batch_size="${train.batch_size}",
64
+ num_workers=6,
65
+ transform=None,
66
+ gpu_transform=L(torch.nn.Sequential)(
67
+ L(ToFloat)(keys=["img", "mask", "E_mask", "maskrcnn_mask"], norm=False),
68
+ L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")),
69
+ L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True),
70
+ L(CreateCondition)(),
71
+ ),
72
+ infinite=False,
73
+ shuffle=False,
74
+ partial_batches=True,
75
+ load_embedding=True,
76
+ )
77
+ ),
78
+ # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP.
79
+ train_evaluation_fn=functools.partial(
80
+ train_eval_fn,
81
+ cache_directory=Path(metrics_cache, "fdh_v7_train"),
82
+ data_len=int(30e3),
83
+ ),
84
+ evaluation_fn=functools.partial(
85
+ final_eval_fn,
86
+ cache_directory=Path(metrics_cache, "fdh_v6_val"),
87
+ data_len=int(30e3),
88
+ ),
89
+ )
configs/datasets/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.metrics.ppl import calculate_ppl
2
+ from dp2.metrics.torch_metrics import compute_metrics_iteratively
3
+ from dp2.metrics.fid_clip import compute_fid_clip
4
+
5
+
6
+ def final_eval_fn(*args, **kwargs):
7
+ result = compute_metrics_iteratively(*args, **kwargs)
8
+ result2 = calculate_ppl(*args, **kwargs,)
9
+ result2 = compute_fid_clip(*args, **kwargs)
10
+ assert all(key not in result for key in result2)
11
+ result.update(result2)
12
+ return result
configs/defaults.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import os
3
+ import torch
4
+ from tops.config import LazyCall as L
5
+
6
+ if "PRETRAINED_CHECKPOINTS_PATH" in os.environ:
7
+ PRETRAINED_CHECKPOINTS_PATH = pathlib.Path(os.environ["PRETRAINED_CHECKPOINTS_PATH"])
8
+ else:
9
+ PRETRAINED_CHECKPOINTS_PATH = pathlib.Path("pretrained_checkpoints")
10
+ if "BASE_OUTPUT_DIR" in os.environ:
11
+ BASE_OUTPUT_DIR = pathlib.Path(os.environ["BASE_OUTPUT_DIR"])
12
+ else:
13
+ BASE_OUTPUT_DIR = pathlib.Path("outputs")
14
+
15
+
16
+
17
+ common = dict(
18
+ logger_backend=["wandb", "stdout", "json", "image_dumper"],
19
+ wandb_project="fba_test",
20
+ output_dir=BASE_OUTPUT_DIR,
21
+ experiment_name=None, # Optional experiment name to show on wandb
22
+ )
23
+
24
+ train = dict(
25
+ batch_size=32,
26
+ seed=0,
27
+ ims_per_log=1024,
28
+ ims_per_val=int(200e3),
29
+ max_images_to_train=int(12e6),
30
+ amp=dict(
31
+ enabled=True,
32
+ scaler_D=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
33
+ scaler_G=L(torch.cuda.amp.GradScaler)(init_scale=2**16, growth_factor=4, growth_interval=100, enabled="${..enabled}"),
34
+ ),
35
+ fp16_ddp_accumulate=False, # All gather gradients in fp16?
36
+ broadcast_buffers=False,
37
+ bias_act_plugin_enabled=True,
38
+ grid_sample_gradfix_enabled=True,
39
+ conv2d_gradfix_enabled=False,
40
+ channels_last=False,
41
+ )
42
+
43
+ # exponential moving average
44
+ EMA = dict(rampup=0.05)
45
+
configs/discriminators/sg2_discriminator.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from dp2.discriminator import SG2Discriminator
3
+ import torch
4
+ from dp2.loss import StyleGAN2Loss
5
+
6
+
7
+ discriminator = L(SG2Discriminator)(
8
+ imsize="${data.imsize}",
9
+ im_channels="${data.im_channels}",
10
+ min_fmap_resolution=4,
11
+ max_cnum_mul=8,
12
+ cnum=80,
13
+ input_condition=True,
14
+ conv_clamp=256,
15
+ input_cse=False,
16
+ cse_nc="${data.cse_nc}"
17
+ )
18
+
19
+
20
+ loss_fnc = L(StyleGAN2Loss)(
21
+ lazy_regularization=True,
22
+ lazy_reg_interval=16,
23
+ r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False),
24
+ EP_lambd=0.001,
25
+ pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01)
26
+ )
27
+
28
+ def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs):
29
+ if lazy_regularization:
30
+ # From Analyzing and improving the image quality of stylegan, CVPR 2020
31
+ c = lazy_reg_interval / (lazy_reg_interval + 1)
32
+ betas = [beta ** c for beta in betas]
33
+ lr *= c
34
+ print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}")
35
+ return type(lr=lr, betas=betas, **kwargs)
36
+
37
+
38
+ D_optim = L(build_D_optim)(
39
+ type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99),
40
+ lazy_regularization="${loss_fnc.lazy_regularization}",
41
+ lazy_reg_interval="${loss_fnc.lazy_reg_interval}")
42
+ G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99))
configs/fdf/stylegan.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..generators.stylegan_unet import generator
2
+ from ..datasets.fdf256 import data
3
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
4
+ from ..defaults import train, common, EMA
5
+
6
+ train.max_images_to_train = int(35e6)
7
+ G_optim.lr = 0.002
8
+ D_optim.lr = 0.002
9
+ generator.input_cse = False
10
+ loss_fnc.r1_opts.lambd = 1
11
+ train.ims_per_val = int(2e6)
12
+
13
+ common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/89660f04-5c11-4dbf-adac-cbe2f11b0aeea25cbf78-7558-475a-b3c7-03f5c10b7934646b0720-ca0a-4d53-aded-daddbfa45c9e"
14
+ common.model_md5sum = "e8e32190528af2ed75f0cb792b7f2b07"
configs/fdf/stylegan_fdf128.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
2
+ from ..datasets.fdf128 import data
3
+ from ..generators.stylegan_unet import generator
4
+ from ..defaults import train, common, EMA
5
+ from tops.config import LazyCall as L
6
+
7
+ train.max_images_to_train = int(25e6)
8
+ G_optim.lr = 0.002
9
+ D_optim.lr = 0.002
10
+ generator.cnum = 128
11
+ generator.max_cnum_mul = 4
12
+ generator.input_cse = False
13
+ loss_fnc.r1_opts.lambd = .1
configs/fdh/styleganL.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from ..generators.stylegan_unet import generator
3
+ from ..datasets.fdh import data
4
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
5
+ from ..defaults import train, common, EMA
6
+
7
+ train.max_images_to_train = int(50e6)
8
+ train.batch_size = 64
9
+ G_optim.lr = 0.002
10
+ D_optim.lr = 0.002
11
+ data.train.loader.num_workers = 4
12
+ train.ims_per_val = int(1e6)
13
+ loss_fnc.r1_opts.lambd = .1
14
+
15
+ common.model_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/21841da7-2546-4ce3-8460-909b3a63c58b13aac1a1-c778-4c8d-9b69-3e5ed2cde9de1524e76e-7aa6-4dd8-b643-52abc9f0792c"
16
+ common.model_md5sum = "3411478b5ec600a4219cccf4499732bd"
configs/fdh/styleganL_nocse.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tops.config import LazyCall as L
2
+ from ..generators.stylegan_unet import generator
3
+ from ..datasets.fdh import data
4
+ from ..discriminators.sg2_discriminator import discriminator, G_optim, D_optim, loss_fnc
5
+ from ..defaults import train, common, EMA
6
+
7
+ train.max_images_to_train = int(50e6)
8
+ G_optim.lr = 0.002
9
+ D_optim.lr = 0.002
10
+ generator.input_cse = False
11
+ data.load_embeddings = False
12
+ common.model_url = "https://folk.ntnu.no/haakohu/checkpoints/deep_privacy2/fdh_styleganL_nocse.ckpt"
13
+ common.model_md5sum = "fda0d809741bc67487abada793975c37"
14
+ generator.fix_errors = False
configs/generators/stylegan_unet.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dp2.generator.stylegan_unet import StyleGANUnet
2
+ from tops.config import LazyCall as L
3
+
4
+ generator = L(StyleGANUnet)(
5
+ imsize="${data.imsize}",
6
+ im_channels="${data.im_channels}",
7
+ min_fmap_resolution=8,
8
+ cnum=64,
9
+ max_cnum_mul=8,
10
+ n_middle_blocks=0,
11
+ z_channels=512,
12
+ mask_output=True,
13
+ conv_clamp=256,
14
+ input_cse=True,
15
+ scale_grad=True,
16
+ cse_nc="${data.cse_nc}",
17
+ w_dim=512,
18
+ n_keypoints="${data.n_keypoints}",
19
+ input_keypoints=False,
20
+ input_keypoint_indices=[],
21
+ fix_errors=True
22
+ )
multi_app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("pip install git+https://github.com/hukkelas/deep_privacy2@36c2c843cfd3022ebc100e9f8579fb2b82f8bde6")
3
+ from collections import defaultdict
4
+ import gradio
5
+ import numpy as np
6
+ import torch
7
+ import cv2
8
+ from PIL import Image
9
+ from dp2 import utils
10
+ from tops.config import instantiate
11
+ import tops
12
+ import gradio.inputs
13
+ from stylemc import get_and_cache_direction, get_styles
14
+
15
+
16
+ class GuidedDemo:
17
+ def __init__(self, face_anonymizer, cfg_face) -> None:
18
+ self.anonymizer = face_anonymizer
19
+ assert sum([x is not None for x in list(face_anonymizer.generators.values())]) == 1
20
+ self.generator = [x for x in list(face_anonymizer.generators.values()) if x is not None][0]
21
+ face_G_cfg = utils.load_config(cfg_face.anonymizer.face_G_cfg)
22
+ face_G_cfg.train.batch_size = 1
23
+ self.dl = instantiate(face_G_cfg.data.val.loader)
24
+ self.cache_dir = face_G_cfg.output_dir
25
+ self.precompute_edits()
26
+
27
+ def precompute_edits(self):
28
+ self.precomputed_edits = set()
29
+ for edit in self.precomputed_edits:
30
+ get_and_cache_direction(self.cache_dir, self.dl, self.generator, edit)
31
+ if self.cache_dir.joinpath("stylemc_cache").is_dir():
32
+ for path in self.cache_dir.joinpath("stylemc_cache").iterdir():
33
+ text_prompt = path.stem.replace("_", " ")
34
+ self.precomputed_edits.add(text_prompt)
35
+ print(text_prompt)
36
+ self.edits = defaultdict(defaultdict)
37
+
38
+ def anonymize(self, img, show_boxes: bool, current_box_idx: int, current_styles, current_boxes, update_identity, edits, cache_id=None):
39
+ if not isinstance(img, torch.Tensor):
40
+ img, cache_id = pil2torch(img)
41
+ img = tops.to_cuda(img)
42
+
43
+ current_box_idx = current_box_idx % len(current_boxes)
44
+ edited_styles = [s.clone() for s in current_styles]
45
+ for face_idx, face_edits in edits.items():
46
+ for prompt, strength in face_edits.items():
47
+ direction = get_and_cache_direction(self.cache_dir, self.dl, self.generator, prompt)
48
+ edited_styles[int(face_idx)] += direction * strength
49
+ update_identity[int(face_idx)] = True
50
+ assert img.dtype == torch.uint8
51
+ img = self.anonymizer(
52
+ img, truncation_value=0,
53
+ multi_modal_truncation=True, amp=True,
54
+ cache_id=cache_id,
55
+ all_styles=edited_styles,
56
+ update_identity=update_identity)
57
+ update_identity = [True for i in range(len(update_identity))]
58
+ img = utils.im2numpy(img)
59
+ if show_boxes:
60
+ x0, y0, x1, y1 = [int(_) for _ in current_boxes[int(current_box_idx)]]
61
+ img = cv2.rectangle(img, (x0, y0), (x1, y1), (255, 0, 0), 1)
62
+ return img, update_identity
63
+
64
+ def update_image(self, img, show_boxes):
65
+ img, cache_id = pil2torch(img)
66
+ img = tops.to_cuda(img)
67
+ det = self.anonymizer.detector.forward_and_cache(img, cache_id, load_cache=True)[0]
68
+ current_styles = []
69
+ for i in range(len(det)):
70
+ s = get_styles(
71
+ np.random.randint(0, 999999),self.generator,
72
+ None, truncation_value=0)
73
+ current_styles.append(s)
74
+ update_identity = [True for i in range(len(det))]
75
+ current_boxes = np.array(det.boxes)
76
+ edits = defaultdict(defaultdict)
77
+ cur_face_idx = -1 % len(current_boxes)
78
+ img, update_identity = self.anonymize(img, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits, cache_id=cache_id)
79
+ return img, current_styles, current_boxes, update_identity, edits, cur_face_idx
80
+
81
+ def change_face(self, change, cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits):
82
+ cur_face_idx = (cur_face_idx+change) % len(current_boxes)
83
+ img, update_identity = self.anonymize(input_image, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits)
84
+ return img, update_identity, cur_face_idx
85
+
86
+ def add_style(self, face_idx: int, prompt: str, strength: float, input_image, show_boxes, current_styles, current_boxes, update_identity, edits):
87
+ face_idx = face_idx % len(current_boxes)
88
+ edits[face_idx][prompt] = strength
89
+ img, update_identity = self.anonymize(input_image, show_boxes, face_idx, current_styles, current_boxes, update_identity, edits)
90
+ return img, update_identity, edits
91
+
92
+ def setup_interface(self):
93
+ current_styles = gradio.State()
94
+ current_boxes = gradio.State(None)
95
+ update_identity = gradio.State([])
96
+ edits = gradio.State([])
97
+ with gradio.Row():
98
+ input_image = gradio.Image(
99
+ type="pil", label="Upload your image or try the example below!",source="webcam")
100
+ output_image = gradio.Image(type="numpy", label="Output")
101
+ with gradio.Row():
102
+ update_btn = gradio.Button("Update Anonymization").style(full_width=True)
103
+ with gradio.Row():
104
+ show_boxes = gradio.Checkbox(value=True, label="Show Selected")
105
+ cur_face_idx = gradio.Number(value=-1,label="Current", interactive=False)
106
+ previous = gradio.Button("Previous Person")
107
+ next_ = gradio.Button("Next Person")
108
+ with gradio.Row():
109
+ text_prompt = gradio.Textbox(
110
+ placeholder=" | ".join(list(self.precomputed_edits)),
111
+ label="Text Prompt for Edit")
112
+ edit_strength = gradio.Slider(0, 5, step=.01)
113
+ add_btn = gradio.Button("Add Edit")
114
+ add_btn.click(self.add_style, inputs=[cur_face_idx, text_prompt, edit_strength, input_image, show_boxes, current_styles, current_boxes, update_identity, edits], outputs=[output_image, update_identity, edits])
115
+ update_btn.click(self.update_image, inputs=[input_image, show_boxes], outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx])
116
+ input_image.change(self.update_image, inputs=[input_image, show_boxes], outputs=[output_image, current_styles, current_boxes, update_identity, edits, cur_face_idx])
117
+ previous.click(self.change_face, inputs=[gradio.State(-1), cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits], outputs=[output_image, update_identity, cur_face_idx])
118
+ next_.click(self.change_face, inputs=[gradio.State(1), cur_face_idx, current_boxes, input_image, show_boxes, current_styles, update_identity, edits], outputs=[output_image, update_identity, cur_face_idx])
119
+
120
+ show_boxes.change(self.anonymize, inputs=[input_image, show_boxes, cur_face_idx, current_styles, current_boxes, update_identity, edits], outputs=[output_image, update_identity])
121
+
122
+
123
+ cfg_body = utils.load_config("configs/anonymizers/FB_cse.py")
124
+ anonymizer_body = instantiate(cfg_body.anonymizer, load_cache=False)
125
+ anonymizer_body.initialize_tracker(fps=1)
126
+ cfg_face = utils.load_config("configs/anonymizers/face.py")
127
+ anonymizer_face = instantiate(cfg_face.anonymizer, load_cache=False)
128
+ anonymizer_face.initialize_tracker(fps=1)
129
+
130
+ class WebcamDemo:
131
+
132
+ def __init__(self, anonymizer) -> None:
133
+ self.anonymizer = anonymizer
134
+ with gradio.Row():
135
+ input_image = gradio.Image(type="pil", source="webcam", streaming=True)
136
+ output_image = gradio.Image(type="numpy", label="Output")
137
+ visualize_det = gradio.Checkbox(value=False, label="Show Detections")
138
+ input_image.stream(self.anonymize, [input_image, visualize_det], [output_image])
139
+ self.track = True
140
+
141
+ def anonymize(self, img: Image, visualize_detection: bool):
142
+ img, cache_id = pil2torch(img)
143
+ img = tops.to_cuda(img)
144
+ if visualize_detection:
145
+ img = self.anonymizer.visualize_detection(img, cache_id=cache_id)
146
+ else:
147
+ img = self.anonymizer(
148
+ img, truncation_value=0, multi_modal_truncation=True, amp=True,
149
+ cache_id=cache_id, track=self.track)
150
+ img = utils.im2numpy(img)
151
+ return img
152
+
153
+ class ExampleDemo(WebcamDemo):
154
+
155
+ def __init__(self, anonymizer) -> None:
156
+ self.anonymizer = anonymizer
157
+ with gradio.Row():
158
+ input_image = gradio.Image(type="pil", source="webcam")
159
+ output_image = gradio.Image(type="numpy", label="Output")
160
+ with gradio.Row():
161
+ update_btn = gradio.Button("Update Anonymization").style(full_width=True)
162
+ visualize_det = gradio.Checkbox(value=False, label="Show Detections")
163
+ visualize_det.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
164
+ gradio.Examples(
165
+ ["media2/erling.jpg", "media2/regjeringen.jpg"], inputs=[input_image]
166
+ )
167
+ update_btn.click(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
168
+ input_image.change(self.anonymize, inputs=[input_image, visualize_det], outputs=[output_image])
169
+ self.track = False
170
+
171
+
172
+ class Information:
173
+
174
+ def __init__(self) -> None:
175
+ gradio.Markdown("## <center> Face Anonymization Architecture </center>")
176
+ gradio.Markdown("---")
177
+ gradio.Image(value="media2/overall_architecture.png")
178
+ gradio.Markdown("## <center> Full-Body Anonymization Architecture </center>")
179
+ gradio.Markdown("---")
180
+ gradio.Image(value="media2/full_body.png")
181
+ gradio.Markdown("### <center> Generative Adversarial Networks </center>")
182
+ gradio.Markdown("---")
183
+ gradio.Image(value="media2/gan_architecture.png")
184
+
185
+
186
+ def pil2torch(img: Image.Image):
187
+ img = img.convert("RGB")
188
+ img = np.array(img)
189
+ img = np.rollaxis(img, 2)
190
+ return torch.from_numpy(img), None
191
+
192
+
193
+ with gradio.Blocks() as demo:
194
+ gradio.Markdown("# <center> DeepPrivacy2 - Realistic Image Anonymization </center>")
195
+ gradio.Markdown("### <center> Håkon Hukkelås, Rudolf Mester, Frank Lindseth </center>")
196
+ with gradio.Tab("Text-Guided Anonymization"):
197
+ GuidedDemo(anonymizer_face, cfg_face).setup_interface()
198
+ with gradio.Tab("Live Full-Body"):
199
+ WebcamDemo(anonymizer_body)
200
+ with gradio.Tab("Live Face"):
201
+ WebcamDemo(anonymizer_face)
202
+
203
+
204
+ demo.launch()
sg3_torch_utils/LICENSE.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
2
+
3
+
4
+ NVIDIA Source Code License for StyleGAN3
5
+
6
+
7
+ =======================================================================
8
+
9
+ 1. Definitions
10
+
11
+ "Licensor" means any person or entity that distributes its Work.
12
+
13
+ "Software" means the original work of authorship made available under
14
+ this License.
15
+
16
+ "Work" means the Software and any additions to or derivative works of
17
+ the Software that are made available under this License.
18
+
19
+ The terms "reproduce," "reproduction," "derivative works," and
20
+ "distribution" have the meaning as provided under U.S. copyright law;
21
+ provided, however, that for the purposes of this License, derivative
22
+ works shall not include works that remain separable from, or merely
23
+ link (or bind by name) to the interfaces of, the Work.
24
+
25
+ Works, including the Software, are "made available" under this License
26
+ by including in or with the Work either (a) a copyright notice
27
+ referencing the applicability of this License to the Work, or (b) a
28
+ copy of this License.
29
+
30
+ 2. License Grants
31
+
32
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
33
+ License, each Licensor grants to you a perpetual, worldwide,
34
+ non-exclusive, royalty-free, copyright license to reproduce,
35
+ prepare derivative works of, publicly display, publicly perform,
36
+ sublicense and distribute its Work and any resulting derivative
37
+ works in any form.
38
+
39
+ 3. Limitations
40
+
41
+ 3.1 Redistribution. You may reproduce or distribute the Work only
42
+ if (a) you do so under this License, (b) you include a complete
43
+ copy of this License with your distribution, and (c) you retain
44
+ without modification any copyright, patent, trademark, or
45
+ attribution notices that are present in the Work.
46
+
47
+ 3.2 Derivative Works. You may specify that additional or different
48
+ terms apply to the use, reproduction, and distribution of your
49
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
50
+ provide that the use limitation in Section 3.3 applies to your
51
+ derivative works, and (b) you identify the specific derivative
52
+ works that are subject to Your Terms. Notwithstanding Your Terms,
53
+ this License (including the redistribution requirements in Section
54
+ 3.1) will continue to apply to the Work itself.
55
+
56
+ 3.3 Use Limitation. The Work and any derivative works thereof only
57
+ may be used or intended for use non-commercially. Notwithstanding
58
+ the foregoing, NVIDIA and its affiliates may use the Work and any
59
+ derivative works commercially. As used herein, "non-commercially"
60
+ means for research or evaluation purposes only.
61
+
62
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63
+ against any Licensor (including any claim, cross-claim or
64
+ counterclaim in a lawsuit) to enforce any patents that you allege
65
+ are infringed by any Work, then your rights under this License from
66
+ such Licensor (including the grant in Section 2.1) will terminate
67
+ immediately.
68
+
69
+ 3.5 Trademarks. This License does not grant any rights to use any
70
+ Licensor’s or its affiliates’ names, logos, or trademarks, except
71
+ as necessary to reproduce the notices described in this License.
72
+
73
+ 3.6 Termination. If you violate any term of this License, then your
74
+ rights under this License (including the grant in Section 2.1) will
75
+ terminate immediately.
76
+
77
+ 4. Disclaimer of Warranty.
78
+
79
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83
+ THIS LICENSE.
84
+
85
+ 5. Limitation of Liability.
86
+
87
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95
+ THE POSSIBILITY OF SUCH DAMAGES.
96
+
97
+ =======================================================================
sg3_torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
sg3_torch_utils/custom_ops.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import glob
11
+ import torch
12
+ import torch.utils.cpp_extension
13
+ import importlib
14
+ import hashlib
15
+ import shutil
16
+ from pathlib import Path
17
+
18
+ from torch.utils.file_baton import FileBaton
19
+
20
+ #----------------------------------------------------------------------------
21
+ # Global options.
22
+
23
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24
+
25
+ #----------------------------------------------------------------------------
26
+ # Internal helper funcs.
27
+
28
+ def _find_compiler_bindir():
29
+ patterns = [
30
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34
+ ]
35
+ for pattern in patterns:
36
+ matches = sorted(glob.glob(pattern))
37
+ if len(matches):
38
+ return matches[-1]
39
+ return None
40
+
41
+ #----------------------------------------------------------------------------
42
+ # Main entry point for compiling and loading C++/CUDA plugins.
43
+
44
+ _cached_plugins = dict()
45
+
46
+ def get_plugin(module_name, sources, **build_kwargs):
47
+ assert verbosity in ['none', 'brief', 'full']
48
+
49
+ # Already cached?
50
+ if module_name in _cached_plugins:
51
+ return _cached_plugins[module_name]
52
+
53
+ # Print status.
54
+ if verbosity == 'full':
55
+ print(f'Setting up PyTorch plugin "{module_name}"...')
56
+ elif verbosity == 'brief':
57
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58
+
59
+ try: # pylint: disable=too-many-nested-blocks
60
+ # Make sure we can find the necessary compiler binaries.
61
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62
+ compiler_bindir = _find_compiler_bindir()
63
+ if compiler_bindir is None:
64
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65
+ os.environ['PATH'] += ';' + compiler_bindir
66
+
67
+ # Compile and load.
68
+ verbose_build = (verbosity == 'full')
69
+
70
+ # Incremental build md5sum trickery. Copies all the input source files
71
+ # into a cached build directory under a combined md5 digest of the input
72
+ # source files. Copying is done only if the combined digest has changed.
73
+ # This keeps input file timestamps and filenames the same as in previous
74
+ # extension builds, allowing for fast incremental rebuilds.
75
+ #
76
+ # This optimization is done only in case all the source files reside in
77
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78
+ # environment variable is set (we take this as a signal that the user
79
+ # actually cares about this.)
80
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
81
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83
+
84
+ # Compute a combined hash digest for all source files in the same
85
+ # custom op directory (usually .cu, .cpp, .py and .h files).
86
+ hash_md5 = hashlib.md5()
87
+ for src in all_source_files:
88
+ with open(src, 'rb') as f:
89
+ hash_md5.update(f.read())
90
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92
+
93
+ if not os.path.isdir(digest_build_dir):
94
+ os.makedirs(digest_build_dir, exist_ok=True)
95
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96
+ if baton.try_acquire():
97
+ try:
98
+ for src in all_source_files:
99
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100
+ finally:
101
+ baton.release()
102
+ else:
103
+ # Someone else is copying source files under the digest dir,
104
+ # wait until done and continue.
105
+ baton.wait()
106
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
109
+ else:
110
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111
+ module = importlib.import_module(module_name)
112
+
113
+ except:
114
+ if verbosity == 'brief':
115
+ print('Failed!')
116
+ raise
117
+
118
+ # Print status and add to cache.
119
+ if verbosity == 'full':
120
+ print(f'Done setting up PyTorch plugin "{module_name}".')
121
+ elif verbosity == 'brief':
122
+ print('Done.')
123
+ _cached_plugins[module_name] = module
124
+ return module
125
+
126
+ #----------------------------------------------------------------------------
sg3_torch_utils/misc.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import re
10
+ import contextlib
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+
15
+ #----------------------------------------------------------------------------
16
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
17
+ # same constant is used multiple times.
18
+
19
+ _constant_cache = dict()
20
+
21
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
22
+ value = np.asarray(value)
23
+ if shape is not None:
24
+ shape = tuple(shape)
25
+ if dtype is None:
26
+ dtype = torch.get_default_dtype()
27
+ if device is None:
28
+ device = torch.device('cpu')
29
+ if memory_format is None:
30
+ memory_format = torch.contiguous_format
31
+
32
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
33
+ tensor = _constant_cache.get(key, None)
34
+ if tensor is None:
35
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
36
+ if shape is not None:
37
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
38
+ tensor = tensor.contiguous(memory_format=memory_format)
39
+ _constant_cache[key] = tensor
40
+ return tensor
41
+
42
+ #----------------------------------------------------------------------------
43
+ # Replace NaN/Inf with specified numerical values.
44
+
45
+ try:
46
+ nan_to_num = torch.nan_to_num # 1.8.0a0
47
+ except AttributeError:
48
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
49
+ assert isinstance(input, torch.Tensor)
50
+ if posinf is None:
51
+ posinf = torch.finfo(input.dtype).max
52
+ if neginf is None:
53
+ neginf = torch.finfo(input.dtype).min
54
+ assert nan == 0
55
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
56
+
57
+ #----------------------------------------------------------------------------
58
+ # Symbolic assert.
59
+
60
+ try:
61
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
62
+ except AttributeError:
63
+ symbolic_assert = torch.Assert # 1.7.0
64
+
65
+ #----------------------------------------------------------------------------
66
+ # Context manager to suppress known warnings in torch.jit.trace().
67
+
68
+ class suppress_tracer_warnings(warnings.catch_warnings):
69
+ def __enter__(self):
70
+ super().__enter__()
71
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
72
+ return self
73
+
74
+ #----------------------------------------------------------------------------
75
+ # Assert that the shape of a tensor matches the given list of integers.
76
+ # None indicates that the size of a dimension is allowed to vary.
77
+ # Performs symbolic assertion when used in torch.jit.trace().
78
+
79
+ def assert_shape(tensor, ref_shape):
80
+ if tensor.ndim != len(ref_shape):
81
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
82
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
83
+ if ref_size is None:
84
+ pass
85
+ elif isinstance(ref_size, torch.Tensor):
86
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
87
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
88
+ elif isinstance(size, torch.Tensor):
89
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
90
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
91
+ elif size != ref_size:
92
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
93
+
94
+ #----------------------------------------------------------------------------
95
+ # Function decorator that calls torch.autograd.profiler.record_function().
96
+
97
+ def profiled_function(fn):
98
+ def decorator(*args, **kwargs):
99
+ with torch.autograd.profiler.record_function(fn.__name__):
100
+ return fn(*args, **kwargs)
101
+ decorator.__name__ = fn.__name__
102
+ return decorator
103
+
104
+ #----------------------------------------------------------------------------
105
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
106
+ # indefinitely, shuffling items as it goes.
107
+
108
+ class InfiniteSampler(torch.utils.data.Sampler):
109
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
110
+ assert len(dataset) > 0
111
+ assert num_replicas > 0
112
+ assert 0 <= rank < num_replicas
113
+ assert 0 <= window_size <= 1
114
+ super().__init__(dataset)
115
+ self.dataset = dataset
116
+ self.rank = rank
117
+ self.num_replicas = num_replicas
118
+ self.shuffle = shuffle
119
+ self.seed = seed
120
+ self.window_size = window_size
121
+
122
+ def __iter__(self):
123
+ order = np.arange(len(self.dataset))
124
+ rnd = None
125
+ window = 0
126
+ if self.shuffle:
127
+ rnd = np.random.RandomState(self.seed)
128
+ rnd.shuffle(order)
129
+ window = int(np.rint(order.size * self.window_size))
130
+
131
+ idx = 0
132
+ while True:
133
+ i = idx % order.size
134
+ if idx % self.num_replicas == self.rank:
135
+ yield order[i]
136
+ if window >= 2:
137
+ j = (i - rnd.randint(window)) % order.size
138
+ order[i], order[j] = order[j], order[i]
139
+ idx += 1
140
+
141
+ #----------------------------------------------------------------------------
142
+ # Utilities for operating with torch.nn.Module parameters and buffers.
143
+
144
+ def params_and_buffers(module):
145
+ assert isinstance(module, torch.nn.Module)
146
+ return list(module.parameters()) + list(module.buffers())
147
+
148
+ def named_params_and_buffers(module):
149
+ assert isinstance(module, torch.nn.Module)
150
+ return list(module.named_parameters()) + list(module.named_buffers())
151
+
152
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
153
+ assert isinstance(src_module, torch.nn.Module)
154
+ assert isinstance(dst_module, torch.nn.Module)
155
+ src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
156
+ for name, tensor in named_params_and_buffers(dst_module):
157
+ assert (name in src_tensors) or (not require_all)
158
+ if name in src_tensors:
159
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
160
+
161
+ #----------------------------------------------------------------------------
162
+ # Context manager for easily enabling/disabling DistributedDataParallel
163
+ # synchronization.
164
+
165
+ @contextlib.contextmanager
166
+ def ddp_sync(module, sync):
167
+ assert isinstance(module, torch.nn.Module)
168
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
169
+ yield
170
+ else:
171
+ with module.no_sync():
172
+ yield
sg3_torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
sg3_torch_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
sg3_torch_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "bias_act.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ //------------------------------------------------------------------------
21
+ // CUDA kernel.
22
+
23
+ template <class T, int A>
24
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
25
+ {
26
+ typedef typename InternalType<T>::scalar_t scalar_t;
27
+ int G = p.grad;
28
+ scalar_t alpha = (scalar_t)p.alpha;
29
+ scalar_t gain = (scalar_t)p.gain;
30
+ scalar_t clamp = (scalar_t)p.clamp;
31
+ scalar_t one = (scalar_t)1;
32
+ scalar_t two = (scalar_t)2;
33
+ scalar_t expRange = (scalar_t)80;
34
+ scalar_t halfExpRange = (scalar_t)40;
35
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
+
38
+ // Loop over elements.
39
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
+ {
42
+ // Load.
43
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
49
+ scalar_t y = 0;
50
+
51
+ // Apply bias.
52
+ ((G == 0) ? x : xref) += b;
53
+
54
+ // linear
55
+ if (A == 1)
56
+ {
57
+ if (G == 0) y = x;
58
+ if (G == 1) y = x;
59
+ }
60
+
61
+ // relu
62
+ if (A == 2)
63
+ {
64
+ if (G == 0) y = (x > 0) ? x : 0;
65
+ if (G == 1) y = (yy > 0) ? x : 0;
66
+ }
67
+
68
+ // lrelu
69
+ if (A == 3)
70
+ {
71
+ if (G == 0) y = (x > 0) ? x : x * alpha;
72
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
73
+ }
74
+
75
+ // tanh
76
+ if (A == 4)
77
+ {
78
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
+ if (G == 1) y = x * (one - yy * yy);
80
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
+ }
82
+
83
+ // sigmoid
84
+ if (A == 5)
85
+ {
86
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
+ if (G == 1) y = x * yy * (one - yy);
88
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
+ }
90
+
91
+ // elu
92
+ if (A == 6)
93
+ {
94
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
+ }
98
+
99
+ // selu
100
+ if (A == 7)
101
+ {
102
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
+ }
106
+
107
+ // softplus
108
+ if (A == 8)
109
+ {
110
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
+ if (G == 1) y = x * (one - exp(-yy));
112
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
+ }
114
+
115
+ // swish
116
+ if (A == 9)
117
+ {
118
+ if (G == 0)
119
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
+ else
121
+ {
122
+ scalar_t c = exp(xref);
123
+ scalar_t d = c + one;
124
+ if (G == 1)
125
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
+ else
127
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
+ }
130
+ }
131
+
132
+ // Apply gain.
133
+ y *= gain * dy;
134
+
135
+ // Clamp.
136
+ if (clamp >= 0)
137
+ {
138
+ if (G == 0)
139
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
+ else
141
+ y = (yref > -clamp & yref < clamp) ? y : 0;
142
+ }
143
+
144
+ // Store.
145
+ ((T*)p.y)[xi] = (T)y;
146
+ }
147
+ }
148
+
149
+ //------------------------------------------------------------------------
150
+ // CUDA kernel selection.
151
+
152
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
+ {
154
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
+ return NULL;
164
+ }
165
+
166
+ //------------------------------------------------------------------------
167
+ // Template specializations.
168
+
169
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
+
173
+ //------------------------------------------------------------------------
sg3_torch_utils/ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------
sg3_torch_utils/ops/bias_act.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import warnings
13
+ import numpy as np
14
+ import torch
15
+ import traceback
16
+
17
+ from .. import custom_ops
18
+ from easydict import EasyDict
19
+ from torch.cuda.amp import custom_bwd, custom_fwd
20
+ #----------------------------------------------------------------------------
21
+
22
+ activation_funcs = {
23
+ 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
24
+ 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
25
+ 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
26
+ 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
27
+ 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
28
+ 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
29
+ 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
30
+ 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
31
+ 'swish': EasyDict(func=lambda x, **_: torch.nn.functional.silu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
32
+ }
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ _inited = False
37
+ _plugin = None
38
+ enabled = False
39
+ _null_tensor = torch.empty([0])
40
+
41
+ def _init():
42
+ global _inited, _plugin
43
+ if not _inited:
44
+ _inited = True
45
+ sources = ['bias_act.cpp', 'bias_act.cu']
46
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
47
+ try:
48
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
49
+ except:
50
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
51
+ return _plugin is not None
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
56
+ r"""Fused bias and activation function.
57
+
58
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
59
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
60
+ the fused op is considerably more efficient than performing the same calculation
61
+ using standard PyTorch ops. It supports first and second order gradients,
62
+ but not third order gradients.
63
+
64
+ Args:
65
+ x: Input activation tensor. Can be of any shape.
66
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
67
+ as `x`. The shape must be known, and it must match the dimension of `x`
68
+ corresponding to `dim`.
69
+ dim: The dimension in `x` corresponding to the elements of `b`.
70
+ The value of `dim` is ignored if `b` is not specified.
71
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
72
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
73
+ See `activation_funcs` for a full list. `None` is not allowed.
74
+ alpha: Shape parameter for the activation function, or `None` to use the default.
75
+ gain: Scaling factor for the output tensor, or `None` to use default.
76
+ See `activation_funcs` for the default scaling of each activation function.
77
+ If unsure, consider specifying 1.
78
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
79
+ the clamping (default).
80
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
81
+
82
+ Returns:
83
+ Tensor of the same shape and datatype as `x`.
84
+ """
85
+ assert isinstance(x, torch.Tensor)
86
+ assert impl in ['ref', 'cuda']
87
+ if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init():
88
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
89
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
90
+
91
+ #----------------------------------------------------------------------------
92
+
93
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
94
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
95
+ """
96
+ assert isinstance(x, torch.Tensor)
97
+ assert clamp is None or clamp >= 0
98
+ spec = activation_funcs[act]
99
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
100
+ gain = float(gain if gain is not None else spec.def_gain)
101
+ clamp = float(clamp if clamp is not None else -1)
102
+
103
+ # Add bias.
104
+ if b is not None:
105
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
106
+ assert 0 <= dim < x.ndim
107
+ assert b.shape[0] == x.shape[dim]
108
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
109
+
110
+ # Evaluate activation function.
111
+ alpha = float(alpha)
112
+ x = spec.func(x, alpha=alpha)
113
+
114
+ # Scale by gain.
115
+ gain = float(gain)
116
+ if gain != 1:
117
+ x = x * gain
118
+
119
+ # Clamp.
120
+ if clamp >= 0:
121
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
122
+ return x
123
+
124
+ #----------------------------------------------------------------------------
125
+
126
+ _bias_act_cuda_cache = dict()
127
+
128
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
129
+ """Fast CUDA implementation of `bias_act()` using custom ops.
130
+ """
131
+ # Parse arguments.
132
+ assert clamp is None or clamp >= 0
133
+ spec = activation_funcs[act]
134
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
135
+ gain = float(gain if gain is not None else spec.def_gain)
136
+ clamp = float(clamp if clamp is not None else -1)
137
+
138
+ # Lookup from cache.
139
+ key = (dim, act, alpha, gain, clamp)
140
+ if key in _bias_act_cuda_cache:
141
+ return _bias_act_cuda_cache[key]
142
+
143
+ # Forward op.
144
+ class BiasActCuda(torch.autograd.Function):
145
+ @staticmethod
146
+ @custom_fwd(cast_inputs=torch.float16)
147
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
148
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
149
+ x = x.contiguous(memory_format=ctx.memory_format)
150
+ b = b.contiguous() if b is not None else _null_tensor
151
+ y = x
152
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
153
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
154
+ ctx.save_for_backward(
155
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
156
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
157
+ y if 'y' in spec.ref else _null_tensor)
158
+ return y
159
+
160
+ @staticmethod
161
+ @custom_bwd
162
+ def backward(ctx, dy): # pylint: disable=arguments-differ
163
+ dy = dy.contiguous(memory_format=ctx.memory_format)
164
+ x, b, y = ctx.saved_tensors
165
+ dx = None
166
+ db = None
167
+
168
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
169
+ dx = dy
170
+ if act != 'linear' or gain != 1 or clamp >= 0:
171
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
172
+
173
+ if ctx.needs_input_grad[1]:
174
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
175
+
176
+ return dx, db
177
+
178
+ # Backward op.
179
+ class BiasActCudaGrad(torch.autograd.Function):
180
+ @staticmethod
181
+ @custom_fwd(cast_inputs=torch.float16)
182
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
183
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
184
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
185
+ ctx.save_for_backward(
186
+ dy if spec.has_2nd_grad else _null_tensor,
187
+ x, b, y)
188
+ return dx
189
+
190
+ @staticmethod
191
+ @custom_bwd
192
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
193
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
194
+ dy, x, b, y = ctx.saved_tensors
195
+ d_dy = None
196
+ d_x = None
197
+ d_b = None
198
+ d_y = None
199
+
200
+ if ctx.needs_input_grad[0]:
201
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
202
+
203
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
204
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
205
+
206
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
207
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
208
+
209
+ return d_dy, d_x, d_b, d_y
210
+
211
+ # Add to cache.
212
+ _bias_act_cuda_cache[key] = BiasActCuda
213
+ return BiasActCuda
214
+
215
+ #----------------------------------------------------------------------------
sg3_torch_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
10
+ arbitrarily high order gradients with zero performance penalty."""
11
+
12
+ import warnings
13
+ import contextlib
14
+ import torch
15
+ from torch.cuda.amp import custom_bwd, custom_fwd
16
+
17
+ # pylint: disable=redefined-builtin
18
+ # pylint: disable=arguments-differ
19
+ # pylint: disable=protected-access
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ enabled = False # Enable the custom op by setting this to true.
24
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
25
+
26
+ @contextlib.contextmanager
27
+ def no_weight_gradients():
28
+ global weight_gradients_disabled
29
+ old = weight_gradients_disabled
30
+ weight_gradients_disabled = True
31
+ yield
32
+ weight_gradients_disabled = old
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
37
+ if _should_use_custom_op(input):
38
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
39
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
40
+
41
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
42
+ if _should_use_custom_op(input):
43
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
44
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
45
+
46
+ #----------------------------------------------------------------------------
47
+
48
+ def _should_use_custom_op(input):
49
+ assert isinstance(input, torch.Tensor)
50
+ if (not enabled) or (not torch.backends.cudnn.enabled):
51
+ return False
52
+ if input.device.type != 'cuda':
53
+ return False
54
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.10']):
55
+ return True
56
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
57
+ return False
58
+
59
+ def _tuple_of_ints(xs, ndim):
60
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
61
+ assert len(xs) == ndim
62
+ assert all(isinstance(x, int) for x in xs)
63
+ return xs
64
+
65
+ #----------------------------------------------------------------------------
66
+
67
+ _conv2d_gradfix_cache = dict()
68
+
69
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
70
+ # Parse arguments.
71
+ ndim = 2
72
+ weight_shape = tuple(weight_shape)
73
+ stride = _tuple_of_ints(stride, ndim)
74
+ padding = _tuple_of_ints(padding, ndim)
75
+ output_padding = _tuple_of_ints(output_padding, ndim)
76
+ dilation = _tuple_of_ints(dilation, ndim)
77
+
78
+ # Lookup from cache.
79
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
80
+ if key in _conv2d_gradfix_cache:
81
+ return _conv2d_gradfix_cache[key]
82
+
83
+ # Validate arguments.
84
+ assert groups >= 1
85
+ assert len(weight_shape) == ndim + 2
86
+ assert all(stride[i] >= 1 for i in range(ndim))
87
+ assert all(padding[i] >= 0 for i in range(ndim))
88
+ assert all(dilation[i] >= 0 for i in range(ndim))
89
+ if not transpose:
90
+ assert all(output_padding[i] == 0 for i in range(ndim))
91
+ else: # transpose
92
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
93
+
94
+ # Helpers.
95
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
96
+ def calc_output_padding(input_shape, output_shape):
97
+ if transpose:
98
+ return [0, 0]
99
+ return [
100
+ input_shape[i + 2]
101
+ - (output_shape[i + 2] - 1) * stride[i]
102
+ - (1 - 2 * padding[i])
103
+ - dilation[i] * (weight_shape[i + 2] - 1)
104
+ for i in range(ndim)
105
+ ]
106
+
107
+ # Forward & backward.
108
+ class Conv2d(torch.autograd.Function):
109
+ @staticmethod
110
+ @custom_fwd(cast_inputs=torch.float16)
111
+ def forward(ctx, input, weight, bias):
112
+ assert weight.shape == weight_shape
113
+ if not transpose:
114
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
115
+ else: # transpose
116
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
117
+ ctx.save_for_backward(input, weight)
118
+ return output
119
+
120
+ @staticmethod
121
+ @custom_bwd
122
+ def backward(ctx, grad_output):
123
+ input, weight = ctx.saved_tensors
124
+ grad_input = None
125
+ grad_weight = None
126
+ grad_bias = None
127
+
128
+ if ctx.needs_input_grad[0]:
129
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
130
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output.float(), weight.float(), None)
131
+ assert grad_input.shape == input.shape
132
+
133
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
134
+ grad_weight = Conv2dGradWeight.apply(grad_output.float(), input.float())
135
+ assert grad_weight.shape == weight_shape
136
+
137
+ if ctx.needs_input_grad[2]:
138
+ grad_bias = grad_output.float().sum([0, 2, 3])
139
+
140
+ return grad_input, grad_weight, grad_bias
141
+
142
+ # Gradient with respect to the weights.
143
+ class Conv2dGradWeight(torch.autograd.Function):
144
+ @staticmethod
145
+ @custom_fwd(cast_inputs=torch.float16)
146
+ def forward(ctx, grad_output, input):
147
+ op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
148
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
149
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
150
+ assert grad_weight.shape == weight_shape
151
+ ctx.save_for_backward(grad_output, input)
152
+ return grad_weight
153
+
154
+ @staticmethod
155
+ @custom_bwd
156
+ def backward(ctx, grad2_grad_weight):
157
+ grad_output, input = ctx.saved_tensors
158
+ grad2_grad_output = None
159
+ grad2_input = None
160
+
161
+ if ctx.needs_input_grad[0]:
162
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
163
+ assert grad2_grad_output.shape == grad_output.shape
164
+
165
+ if ctx.needs_input_grad[1]:
166
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
167
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
168
+ assert grad2_input.shape == input.shape
169
+
170
+ return grad2_grad_output, grad2_input
171
+
172
+ _conv2d_gradfix_cache[key] = Conv2d
173
+ return Conv2d
174
+
175
+ #----------------------------------------------------------------------------
sg3_torch_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """2D convolution with optional up/downsampling."""
10
+
11
+ import torch
12
+
13
+ from .. import misc
14
+ from . import conv2d_gradfix
15
+ from . import upfirdn2d
16
+ from .upfirdn2d import _parse_padding
17
+ from .upfirdn2d import _get_filter_size
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def _get_weight_shape(w):
22
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
+ shape = [int(sz) for sz in w.shape]
24
+ misc.assert_shape(w, shape)
25
+ return shape
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
+ """
32
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33
+
34
+ # Flip weight if requested.
35
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
+ w = w.flip([2, 3])
37
+
38
+ # Otherwise => execute using conv2d_gradfix.
39
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
40
+ return op(x, w, stride=stride, padding=padding, groups=groups)
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ @misc.profiled_function
45
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
46
+ r"""2D convolution with optional up/downsampling.
47
+
48
+ Padding is performed only once at the beginning, not between the operations.
49
+
50
+ Args:
51
+ x: Input tensor of shape
52
+ `[batch_size, in_channels, in_height, in_width]`.
53
+ w: Weight tensor of shape
54
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
55
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
56
+ calling upfirdn2d.setup_filter(). None = identity (default).
57
+ up: Integer upsampling factor (default: 1).
58
+ down: Integer downsampling factor (default: 1).
59
+ padding: Padding with respect to the upsampled image. Can be a single number
60
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
61
+ (default: 0).
62
+ groups: Split input channels into N groups (default: 1).
63
+ flip_weight: False = convolution, True = correlation (default: True).
64
+ flip_filter: False = convolution, True = correlation (default: False).
65
+
66
+ Returns:
67
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
68
+ """
69
+ # Validate arguments.
70
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
71
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4)
72
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
73
+ assert isinstance(up, int) and (up >= 1)
74
+ assert isinstance(down, int) and (down >= 1)
75
+ assert isinstance(groups, int) and (groups >= 1)
76
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
77
+ fw, fh = _get_filter_size(f)
78
+ px0, px1, py0, py1 = _parse_padding(padding)
79
+
80
+ # Adjust padding to account for up/downsampling.
81
+ if up > 1:
82
+ px0 += (fw + up - 1) // 2
83
+ px1 += (fw - up) // 2
84
+ py0 += (fh + up - 1) // 2
85
+ py1 += (fh - up) // 2
86
+ if down > 1:
87
+ px0 += (fw - down + 1) // 2
88
+ px1 += (fw - down) // 2
89
+ py0 += (fh - down + 1) // 2
90
+ py1 += (fh - down) // 2
91
+
92
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
93
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
94
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
95
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
96
+ return x
97
+
98
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
99
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
100
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
101
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
102
+ return x
103
+
104
+ # Fast path: downsampling only => use strided convolution.
105
+ if down > 1 and up == 1:
106
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
107
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
108
+ return x
109
+
110
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
111
+ if up > 1:
112
+ if groups == 1:
113
+ w = w.transpose(0, 1)
114
+ else:
115
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
116
+ w = w.transpose(1, 2)
117
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
118
+ px0 -= kw - 1
119
+ px1 -= kw - up
120
+ py0 -= kh - 1
121
+ py1 -= kh - up
122
+ pxt = max(min(-px0, -px1), 0)
123
+ pyt = max(min(-py0, -py1), 0)
124
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
125
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
126
+ if down > 1:
127
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
128
+ return x
129
+
130
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
131
+ if up == 1 and down == 1:
132
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
133
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
134
+
135
+ # Fallback: Generic reference implementation.
136
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
137
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
138
+ if down > 1:
139
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
140
+ return x
141
+
142
+ #----------------------------------------------------------------------------
sg3_torch_utils/ops/fma.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10
+
11
+ import torch
12
+ from torch.cuda.amp import custom_bwd, custom_fwd
13
+
14
+ #----------------------------------------------------------------------------
15
+
16
+ def fma(a, b, c): # => a * b + c
17
+ return _FusedMultiplyAdd.apply(a, b, c)
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
22
+ @staticmethod
23
+ @custom_fwd(cast_inputs=torch.float16)
24
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
25
+ out = torch.addcmul(c, a, b)
26
+ ctx.save_for_backward(a, b)
27
+ ctx.c_shape = c.shape
28
+ return out
29
+
30
+ @staticmethod
31
+ @custom_bwd
32
+ def backward(ctx, dout): # pylint: disable=arguments-differ
33
+ a, b = ctx.saved_tensors
34
+ c_shape = ctx.c_shape
35
+ da = None
36
+ db = None
37
+ dc = None
38
+
39
+ if ctx.needs_input_grad[0]:
40
+ da = _unbroadcast(dout * b, a.shape)
41
+
42
+ if ctx.needs_input_grad[1]:
43
+ db = _unbroadcast(dout * a, b.shape)
44
+
45
+ if ctx.needs_input_grad[2]:
46
+ dc = _unbroadcast(dout, c_shape)
47
+
48
+ return da, db, dc
49
+
50
+ #----------------------------------------------------------------------------
51
+
52
+ def _unbroadcast(x, shape):
53
+ extra_dims = x.ndim - len(shape)
54
+ assert extra_dims >= 0
55
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
56
+ if len(dim):
57
+ x = x.sum(dim=dim, keepdim=True)
58
+ if extra_dims:
59
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
60
+ assert x.shape == shape
61
+ return x
62
+
63
+ #----------------------------------------------------------------------------
sg3_torch_utils/ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.grid_sample` that
10
+ supports arbitrarily high order gradients between the input and output.
11
+ Only works on 2D images and assumes
12
+ `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
+
14
+ import torch
15
+ from torch.cuda.amp import custom_bwd, custom_fwd
16
+ from pkg_resources import parse_version
17
+ # pylint: disable=redefined-builtin
18
+ # pylint: disable=arguments-differ
19
+ # pylint: disable=protected-access
20
+ _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
21
+
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ enabled = False # Enable the custom op by setting this to true.
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def grid_sample(input, grid):
30
+ if _should_use_custom_op():
31
+ return _GridSample2dForward.apply(input, grid)
32
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def _should_use_custom_op():
37
+ return enabled
38
+
39
+ #----------------------------------------------------------------------------
40
+
41
+ class _GridSample2dForward(torch.autograd.Function):
42
+ @staticmethod
43
+ @custom_fwd(cast_inputs=torch.float16)
44
+ def forward(ctx, input, grid):
45
+ assert input.ndim == 4
46
+ assert grid.ndim == 4
47
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
48
+ ctx.save_for_backward(input, grid)
49
+ return output
50
+
51
+ @staticmethod
52
+ @custom_bwd
53
+ def backward(ctx, grad_output):
54
+ input, grid = ctx.saved_tensors
55
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
56
+ return grad_input, grad_grid
57
+
58
+ #----------------------------------------------------------------------------
59
+
60
+ class _GridSample2dBackward(torch.autograd.Function):
61
+ @staticmethod
62
+ @custom_fwd(cast_inputs=torch.float16)
63
+ def forward(ctx, grad_output, input, grid):
64
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
65
+ if _use_pytorch_1_11_api:
66
+ output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
67
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
68
+ else:
69
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
70
+ ctx.save_for_backward(grid)
71
+ return grad_input, grad_grid
72
+
73
+ @staticmethod
74
+ @custom_bwd
75
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
76
+ _ = grad2_grad_grid # unused
77
+ grid, = ctx.saved_tensors
78
+ grad2_grad_output = None
79
+ grad2_input = None
80
+ grad2_grid = None
81
+
82
+ if ctx.needs_input_grad[0]:
83
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
84
+
85
+ assert not ctx.needs_input_grad[2]
86
+ return grad2_grad_output, grad2_input, grad2_grid
87
+
88
+ #----------------------------------------------------------------------------
sg3_torch_utils/ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
+ {
18
+ // Validate arguments.
19
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29
+
30
+ // Create output tensor.
31
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37
+
38
+ // Initialize CUDA kernel parameters.
39
+ upfirdn2d_kernel_params p;
40
+ p.x = x.data_ptr();
41
+ p.f = f.data_ptr<float>();
42
+ p.y = y.data_ptr();
43
+ p.up = make_int2(upx, upy);
44
+ p.down = make_int2(downx, downy);
45
+ p.pad0 = make_int2(padx0, pady0);
46
+ p.flip = (flip) ? 1 : 0;
47
+ p.gain = gain;
48
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56
+
57
+ // Choose CUDA kernel.
58
+ upfirdn2d_kernel_spec spec;
59
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60
+ {
61
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
62
+ });
63
+
64
+ // Set looping options.
65
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66
+ p.loopMinor = spec.loopMinor;
67
+ p.loopX = spec.loopX;
68
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70
+
71
+ // Compute grid size.
72
+ dim3 blockSize, gridSize;
73
+ if (spec.tileOutW < 0) // large
74
+ {
75
+ blockSize = dim3(4, 32, 1);
76
+ gridSize = dim3(
77
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79
+ p.launchMajor);
80
+ }
81
+ else // small
82
+ {
83
+ blockSize = dim3(256, 1, 1);
84
+ gridSize = dim3(
85
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87
+ p.launchMajor);
88
+ }
89
+
90
+ // Launch CUDA kernel.
91
+ void* args[] = {&p};
92
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93
+ return y;
94
+ }
95
+
96
+ //------------------------------------------------------------------------
97
+
98
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99
+ {
100
+ m.def("upfirdn2d", &upfirdn2d);
101
+ }
102
+
103
+ //------------------------------------------------------------------------
sg3_torch_utils/ops/upfirdn2d.cu ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "upfirdn2d.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ static __device__ __forceinline__ int floor_div(int a, int b)
21
+ {
22
+ int t = 1 - a / b;
23
+ return (a + t * b) / b - t;
24
+ }
25
+
26
+ //------------------------------------------------------------------------
27
+ // Generic CUDA implementation for large filters.
28
+
29
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
+ {
31
+ typedef typename InternalType<T>::scalar_t scalar_t;
32
+
33
+ // Calculate thread index.
34
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
+ int outY = minorBase / p.launchMinor;
36
+ minorBase -= outY * p.launchMinor;
37
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
+ int majorBase = blockIdx.z * p.loopMajor;
39
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
+ return;
41
+
42
+ // Setup Y receptive field.
43
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
+ if (p.flip)
48
+ filterY = p.filterSize.y - 1 - filterY;
49
+
50
+ // Loop over major, minor, and X.
51
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
+ {
54
+ int nc = major * p.sizeMinor + minor;
55
+ int n = nc / p.inSize.z;
56
+ int c = nc - n * p.inSize.z;
57
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
+ {
59
+ // Setup X receptive field.
60
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
+ if (p.flip)
65
+ filterX = p.filterSize.x - 1 - filterX;
66
+
67
+ // Initialize pointers.
68
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
+
73
+ // Inner loop.
74
+ scalar_t v = 0;
75
+ for (int y = 0; y < h; y++)
76
+ {
77
+ for (int x = 0; x < w; x++)
78
+ {
79
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
+ xp += p.inStride.x;
81
+ fp += filterStepX;
82
+ }
83
+ xp += p.inStride.y - w * p.inStride.x;
84
+ fp += filterStepY - w * filterStepX;
85
+ }
86
+
87
+ // Store result.
88
+ v *= p.gain;
89
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
+ }
91
+ }
92
+ }
93
+
94
+ //------------------------------------------------------------------------
95
+ // Specialized CUDA implementation for small filters.
96
+
97
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
+ {
100
+ typedef typename InternalType<T>::scalar_t scalar_t;
101
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
+ __shared__ volatile scalar_t sf[filterH][filterW];
104
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
+
106
+ // Calculate tile index.
107
+ int minorBase = blockIdx.x;
108
+ int tileOutY = minorBase / p.launchMinor;
109
+ minorBase -= tileOutY * p.launchMinor;
110
+ minorBase *= loopMinor;
111
+ tileOutY *= tileOutH;
112
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
+ int majorBase = blockIdx.z * p.loopMajor;
114
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
+ return;
116
+
117
+ // Load filter (flipped).
118
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
+ {
120
+ int fy = tapIdx / filterW;
121
+ int fx = tapIdx - fy * filterW;
122
+ scalar_t v = 0;
123
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
124
+ {
125
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
+ }
129
+ sf[fy][fx] = v;
130
+ }
131
+
132
+ // Loop over major and X.
133
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
+ {
135
+ int baseNC = major * p.sizeMinor + minorBase;
136
+ int n = baseNC / p.inSize.z;
137
+ int baseC = baseNC - n * p.inSize.z;
138
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
+ {
140
+ // Load input pixels.
141
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
+ int tileInX = floor_div(tileMidX, upx);
144
+ int tileInY = floor_div(tileMidY, upy);
145
+ __syncthreads();
146
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
+ {
148
+ int relC = inIdx;
149
+ int relInX = relC / loopMinor;
150
+ int relInY = relInX / tileInW;
151
+ relC -= relInX * loopMinor;
152
+ relInX -= relInY * tileInW;
153
+ int c = baseC + relC;
154
+ int inX = tileInX + relInX;
155
+ int inY = tileInY + relInY;
156
+ scalar_t v = 0;
157
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
+ sx[relInY][relInX][relC] = v;
160
+ }
161
+
162
+ // Loop over output pixels.
163
+ __syncthreads();
164
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
+ {
166
+ int relC = outIdx;
167
+ int relOutX = relC / loopMinor;
168
+ int relOutY = relOutX / tileOutW;
169
+ relC -= relOutX * loopMinor;
170
+ relOutX -= relOutY * tileOutW;
171
+ int c = baseC + relC;
172
+ int outX = tileOutX + relOutX;
173
+ int outY = tileOutY + relOutY;
174
+
175
+ // Setup receptive field.
176
+ int midX = tileMidX + relOutX * downx;
177
+ int midY = tileMidY + relOutY * downy;
178
+ int inX = floor_div(midX, upx);
179
+ int inY = floor_div(midY, upy);
180
+ int relInX = inX - tileInX;
181
+ int relInY = inY - tileInY;
182
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
183
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
184
+
185
+ // Inner loop.
186
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
+ {
188
+ scalar_t v = 0;
189
+ #pragma unroll
190
+ for (int y = 0; y < filterH / upy; y++)
191
+ #pragma unroll
192
+ for (int x = 0; x < filterW / upx; x++)
193
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
+ v *= p.gain;
195
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ //------------------------------------------------------------------------
203
+ // CUDA kernel selection.
204
+
205
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
+ {
207
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
+
209
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
210
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
211
+
212
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
213
+ {
214
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
215
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
216
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
217
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
218
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
219
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
220
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
221
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
222
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
223
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
224
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
225
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
226
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
228
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
229
+ }
230
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
231
+ {
232
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
233
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
234
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
236
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
237
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
238
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
239
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
240
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
241
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
242
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
243
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
244
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
245
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
246
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
247
+ }
248
+ if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
249
+ {
250
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
+ }
255
+ if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
256
+ {
257
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
+ }
262
+ if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
263
+ {
264
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
265
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
266
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
268
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
269
+ }
270
+ if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
271
+ {
272
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
273
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
274
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
275
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
276
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
277
+ }
278
+ if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
279
+ {
280
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
281
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
282
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
283
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
284
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
285
+ }
286
+ if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
287
+ {
288
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
289
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
290
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
291
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
292
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
293
+ }
294
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
295
+ {
296
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
297
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
298
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
299
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
300
+ }
301
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
302
+ {
303
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
304
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
305
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
306
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
307
+ }
308
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
309
+ {
310
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
311
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
312
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
313
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
314
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
315
+ }
316
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
317
+ {
318
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
319
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
320
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
321
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
322
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
323
+ }
324
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
325
+ {
326
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
327
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
328
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
329
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
330
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
331
+ }
332
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
333
+ {
334
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
335
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
336
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
337
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
338
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
339
+ }
340
+ return spec;
341
+ }
342
+
343
+ //------------------------------------------------------------------------
344
+ // Template specializations.
345
+
346
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
347
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
348
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
349
+
350
+ //------------------------------------------------------------------------
sg3_torch_utils/ops/upfirdn2d.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct upfirdn2d_kernel_params
15
+ {
16
+ const void* x;
17
+ const float* f;
18
+ void* y;
19
+
20
+ int2 up;
21
+ int2 down;
22
+ int2 pad0;
23
+ int flip;
24
+ float gain;
25
+
26
+ int4 inSize; // [width, height, channel, batch]
27
+ int4 inStride;
28
+ int2 filterSize; // [width, height]
29
+ int2 filterStride;
30
+ int4 outSize; // [width, height, channel, batch]
31
+ int4 outStride;
32
+ int sizeMinor;
33
+ int sizeMajor;
34
+
35
+ int loopMinor;
36
+ int loopMajor;
37
+ int loopX;
38
+ int launchMinor;
39
+ int launchMajor;
40
+ };
41
+
42
+ //------------------------------------------------------------------------
43
+ // CUDA kernel specialization.
44
+
45
+ struct upfirdn2d_kernel_spec
46
+ {
47
+ void* kernel;
48
+ int tileOutW;
49
+ int tileOutH;
50
+ int loopMinor;
51
+ int loopX;
52
+ };
53
+
54
+ //------------------------------------------------------------------------
55
+ // CUDA kernel selection.
56
+
57
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
+
59
+ //------------------------------------------------------------------------
sg3_torch_utils/ops/upfirdn2d.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient resampling of 2D images."""
10
+
11
+ import os
12
+ import warnings
13
+ import numpy as np
14
+ import torch
15
+ import traceback
16
+
17
+ from .. import custom_ops
18
+ from .. import misc
19
+ from . import conv2d_gradfix
20
+ from torch.cuda.amp import custom_bwd, custom_fwd
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ _inited = False
25
+ _plugin = None
26
+ enabled = False
27
+
28
+ def _init():
29
+ global _inited, _plugin
30
+ if not _inited:
31
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
32
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
33
+ try:
34
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
35
+ except:
36
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
37
+ return _plugin is not None
38
+
39
+ def _parse_scaling(scaling):
40
+ if isinstance(scaling, int):
41
+ scaling = [scaling, scaling]
42
+ assert isinstance(scaling, (list, tuple))
43
+ assert all(isinstance(x, int) for x in scaling)
44
+ sx, sy = scaling
45
+ assert sx >= 1 and sy >= 1
46
+ return sx, sy
47
+
48
+ def _parse_padding(padding):
49
+ if isinstance(padding, int):
50
+ padding = [padding, padding]
51
+ assert isinstance(padding, (list, tuple))
52
+ assert all(isinstance(x, int) for x in padding)
53
+ if len(padding) == 2:
54
+ padx, pady = padding
55
+ padding = [padx, padx, pady, pady]
56
+ padx0, padx1, pady0, pady1 = padding
57
+ return padx0, padx1, pady0, pady1
58
+
59
+ def _get_filter_size(f):
60
+ if f is None:
61
+ return 1, 1
62
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
63
+ fw = f.shape[-1]
64
+ fh = f.shape[0]
65
+ with misc.suppress_tracer_warnings():
66
+ fw = int(fw)
67
+ fh = int(fh)
68
+ misc.assert_shape(f, [fh, fw][:f.ndim])
69
+ assert fw >= 1 and fh >= 1
70
+ return fw, fh
71
+
72
+ #----------------------------------------------------------------------------
73
+
74
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
75
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
76
+
77
+ Args:
78
+ f: Torch tensor, numpy array, or python list of the shape
79
+ `[filter_height, filter_width]` (non-separable),
80
+ `[filter_taps]` (separable),
81
+ `[]` (impulse), or
82
+ `None` (identity).
83
+ device: Result device (default: cpu).
84
+ normalize: Normalize the filter so that it retains the magnitude
85
+ for constant input signal (DC)? (default: True).
86
+ flip_filter: Flip the filter? (default: False).
87
+ gain: Overall scaling factor for signal magnitude (default: 1).
88
+ separable: Return a separable filter? (default: select automatically).
89
+
90
+ Returns:
91
+ Float32 tensor of the shape
92
+ `[filter_height, filter_width]` (non-separable) or
93
+ `[filter_taps]` (separable).
94
+ """
95
+ # Validate.
96
+ if f is None:
97
+ f = 1
98
+ f = torch.as_tensor(f, dtype=torch.float32)
99
+ assert f.ndim in [0, 1, 2]
100
+ assert f.numel() > 0
101
+ if f.ndim == 0:
102
+ f = f[np.newaxis]
103
+
104
+ # Separable?
105
+ if separable is None:
106
+ separable = (f.ndim == 1 and f.numel() >= 8)
107
+ if f.ndim == 1 and not separable:
108
+ f = f.ger(f)
109
+ assert f.ndim == (1 if separable else 2)
110
+
111
+ # Apply normalize, flip, gain, and device.
112
+ if normalize:
113
+ f /= f.sum()
114
+ if flip_filter:
115
+ f = f.flip(list(range(f.ndim)))
116
+ f = f * (gain ** (f.ndim / 2))
117
+ f = f.to(device=device)
118
+ return f
119
+
120
+ #----------------------------------------------------------------------------
121
+
122
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
123
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
124
+
125
+ Performs the following sequence of operations for each channel:
126
+
127
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
128
+
129
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
130
+ Negative padding corresponds to cropping the image.
131
+
132
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
133
+ so that the footprint of all output pixels lies within the input image.
134
+
135
+ 4. Downsample the image by keeping every Nth pixel (`down`).
136
+
137
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
138
+ The fused op is considerably more efficient than performing the same calculation
139
+ using standard PyTorch ops. It supports gradients of arbitrary order.
140
+
141
+ Args:
142
+ x: Float32/float64/float16 input tensor of the shape
143
+ `[batch_size, num_channels, in_height, in_width]`.
144
+ f: Float32 FIR filter of the shape
145
+ `[filter_height, filter_width]` (non-separable),
146
+ `[filter_taps]` (separable), or
147
+ `None` (identity).
148
+ up: Integer upsampling factor. Can be a single int or a list/tuple
149
+ `[x, y]` (default: 1).
150
+ down: Integer downsampling factor. Can be a single int or a list/tuple
151
+ `[x, y]` (default: 1).
152
+ padding: Padding with respect to the upsampled image. Can be a single number
153
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
154
+ (default: 0).
155
+ flip_filter: False = convolution, True = correlation (default: False).
156
+ gain: Overall scaling factor for signal magnitude (default: 1).
157
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
158
+
159
+ Returns:
160
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
161
+ """
162
+ assert isinstance(x, torch.Tensor)
163
+ assert impl in ['ref', 'cuda']
164
+ if impl == 'cuda' and x.device.type == 'cuda' and enabled and _init():
165
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
166
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
167
+
168
+ #----------------------------------------------------------------------------
169
+
170
+ @misc.profiled_function
171
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
172
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
173
+ """
174
+ # Validate arguments.
175
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
176
+ if f is None:
177
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
178
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
179
+ assert f.dtype == torch.float32 and not f.requires_grad
180
+ batch_size, num_channels, in_height, in_width = x.shape
181
+ upx, upy = _parse_scaling(up)
182
+ downx, downy = _parse_scaling(down)
183
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
184
+
185
+ # Upsample by inserting zeros.
186
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
187
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
188
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
189
+
190
+ # Pad or crop.
191
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
192
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
193
+
194
+ # Setup filter.
195
+ f = f * (gain ** (f.ndim / 2))
196
+ f = f.to(x.dtype)
197
+ if not flip_filter:
198
+ f = f.flip(list(range(f.ndim)))
199
+
200
+ # Convolve with the filter.
201
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
202
+ if f.ndim == 4:
203
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
204
+ else:
205
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
206
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
207
+
208
+ # Downsample by throwing away pixels.
209
+ x = x[:, :, ::downy, ::downx]
210
+ return x
211
+
212
+ #----------------------------------------------------------------------------
213
+
214
+ _upfirdn2d_cuda_cache = dict()
215
+
216
+ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
217
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
218
+ """
219
+ # Parse arguments.
220
+ upx, upy = _parse_scaling(up)
221
+ downx, downy = _parse_scaling(down)
222
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
223
+
224
+ # Lookup from cache.
225
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
226
+ if key in _upfirdn2d_cuda_cache:
227
+ return _upfirdn2d_cuda_cache[key]
228
+
229
+ # Forward op.
230
+ class Upfirdn2dCuda(torch.autograd.Function):
231
+ @staticmethod
232
+ @custom_fwd(cast_inputs=torch.float32)
233
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
234
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
235
+ if f is None:
236
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
237
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
238
+ y = x
239
+ if f.ndim == 2:
240
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
241
+ else:
242
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
243
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
244
+ ctx.save_for_backward(f)
245
+ ctx.x_shape = x.shape
246
+ return y
247
+
248
+ @staticmethod
249
+ @custom_bwd
250
+ def backward(ctx, dy): # pylint: disable=arguments-differ
251
+ f, = ctx.saved_tensors
252
+ _, _, ih, iw = ctx.x_shape
253
+ _, _, oh, ow = dy.shape
254
+ fw, fh = _get_filter_size(f)
255
+ p = [
256
+ fw - padx0 - 1,
257
+ iw * upx - ow * downx + padx0 - upx + 1,
258
+ fh - pady0 - 1,
259
+ ih * upy - oh * downy + pady0 - upy + 1,
260
+ ]
261
+ dx = None
262
+ df = None
263
+
264
+ if ctx.needs_input_grad[0]:
265
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
266
+
267
+ assert not ctx.needs_input_grad[1]
268
+ return dx, df
269
+
270
+ # Add to cache.
271
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
272
+ return Upfirdn2dCuda
273
+
274
+ #----------------------------------------------------------------------------
275
+
276
+ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
277
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
278
+
279
+ By default, the result is padded so that its shape matches the input.
280
+ User-specified padding is applied on top of that, with negative values
281
+ indicating cropping. Pixels outside the image are assumed to be zero.
282
+
283
+ Args:
284
+ x: Float32/float64/float16 input tensor of the shape
285
+ `[batch_size, num_channels, in_height, in_width]`.
286
+ f: Float32 FIR filter of the shape
287
+ `[filter_height, filter_width]` (non-separable),
288
+ `[filter_taps]` (separable), or
289
+ `None` (identity).
290
+ padding: Padding with respect to the output. Can be a single number or a
291
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
292
+ (default: 0).
293
+ flip_filter: False = convolution, True = correlation (default: False).
294
+ gain: Overall scaling factor for signal magnitude (default: 1).
295
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
296
+
297
+ Returns:
298
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
299
+ """
300
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
301
+ fw, fh = _get_filter_size(f)
302
+ p = [
303
+ padx0 + fw // 2,
304
+ padx1 + (fw - 1) // 2,
305
+ pady0 + fh // 2,
306
+ pady1 + (fh - 1) // 2,
307
+ ]
308
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
309
+
310
+ #----------------------------------------------------------------------------
311
+
312
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
313
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
314
+
315
+ By default, the result is padded so that its shape is a multiple of the input.
316
+ User-specified padding is applied on top of that, with negative values
317
+ indicating cropping. Pixels outside the image are assumed to be zero.
318
+
319
+ Args:
320
+ x: Float32/float64/float16 input tensor of the shape
321
+ `[batch_size, num_channels, in_height, in_width]`.
322
+ f: Float32 FIR filter of the shape
323
+ `[filter_height, filter_width]` (non-separable),
324
+ `[filter_taps]` (separable), or
325
+ `None` (identity).
326
+ up: Integer upsampling factor. Can be a single int or a list/tuple
327
+ `[x, y]` (default: 1).
328
+ padding: Padding with respect to the output. Can be a single number or a
329
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
330
+ (default: 0).
331
+ flip_filter: False = convolution, True = correlation (default: False).
332
+ gain: Overall scaling factor for signal magnitude (default: 1).
333
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
334
+
335
+ Returns:
336
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
337
+ """
338
+ upx, upy = _parse_scaling(up)
339
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
340
+ fw, fh = _get_filter_size(f)
341
+ p = [
342
+ padx0 + (fw + upx - 1) // 2,
343
+ padx1 + (fw - upx) // 2,
344
+ pady0 + (fh + upy - 1) // 2,
345
+ pady1 + (fh - upy) // 2,
346
+ ]
347
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
348
+
349
+ #----------------------------------------------------------------------------
350
+
351
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
352
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
353
+
354
+ By default, the result is padded so that its shape is a fraction of the input.
355
+ User-specified padding is applied on top of that, with negative values
356
+ indicating cropping. Pixels outside the image are assumed to be zero.
357
+
358
+ Args:
359
+ x: Float32/float64/float16 input tensor of the shape
360
+ `[batch_size, num_channels, in_height, in_width]`.
361
+ f: Float32 FIR filter of the shape
362
+ `[filter_height, filter_width]` (non-separable),
363
+ `[filter_taps]` (separable), or
364
+ `None` (identity).
365
+ down: Integer downsampling factor. Can be a single int or a list/tuple
366
+ `[x, y]` (default: 1).
367
+ padding: Padding with respect to the input. Can be a single number or a
368
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
369
+ (default: 0).
370
+ flip_filter: False = convolution, True = correlation (default: False).
371
+ gain: Overall scaling factor for signal magnitude (default: 1).
372
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
373
+
374
+ Returns:
375
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
376
+ """
377
+ downx, downy = _parse_scaling(down)
378
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
379
+ fw, fh = _get_filter_size(f)
380
+ p = [
381
+ padx0 + (fw - downx + 1) // 2,
382
+ padx1 + (fw - downx) // 2,
383
+ pady0 + (fh - downy + 1) // 2,
384
+ pady1 + (fh - downy) // 2,
385
+ ]
386
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
387
+
388
+ #----------------------------------------------------------------------------
stylemc.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Approach: "StyleMC: Multi-Channel Based Fast Text-Guided Image Generation and Manipulation"
3
+ Original source code:
4
+ https://github.com/autonomousvision/stylegan_xl/blob/f9be58e98110bd946fcdadef2aac8345466faaf3/run_stylemc.py#
5
+ Modified by Håkon Hukkelås
6
+ """
7
+ import os
8
+ from pathlib import Path
9
+ import tqdm
10
+ import re
11
+ import click
12
+ from dp2 import utils
13
+ import tops
14
+ from typing import List, Optional
15
+ import PIL.Image
16
+ import imageio
17
+ from timeit import default_timer as timer
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torchvision.transforms.functional import resize, normalize
24
+ from dp2.infer import build_trained_generator
25
+ import clip
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ class AverageMeter(object):
30
+ """Computes and stores the average and current value"""
31
+ def __init__(self, name, fmt=':f'):
32
+ self.name = name
33
+ self.fmt = fmt
34
+ self.reset()
35
+
36
+ def reset(self):
37
+ self.val = 0
38
+ self.avg = 0
39
+ self.sum = 0
40
+ self.count = 0
41
+
42
+ def update(self, val, n=1):
43
+ self.val = val
44
+ self.sum += val * n
45
+ self.count += n
46
+ self.avg = self.sum / self.count
47
+
48
+ def __str__(self):
49
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
50
+ return fmtstr.format(**self.__dict__)
51
+
52
+
53
+ class ProgressMeter(object):
54
+ def __init__(self, num_batches, meters, prefix=""):
55
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
56
+ self.meters = meters
57
+ self.prefix = prefix
58
+
59
+ def display(self, batch):
60
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
61
+ entries += [str(meter) for meter in self.meters]
62
+ print('\t'.join(entries))
63
+
64
+ def _get_batch_fmtstr(self, num_batches):
65
+ num_digits = len(str(num_batches // 1))
66
+ fmt = '{:' + str(num_digits) + 'd}'
67
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
68
+
69
+
70
+ def save_image(img, path):
71
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
72
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(path)
73
+
74
+
75
+ def unravel_index(index, shape):
76
+ out = []
77
+ for dim in reversed(shape):
78
+ out.append(index % dim)
79
+ index = index // dim
80
+ return tuple(reversed(out))
81
+
82
+
83
+ def num_range(s: str) -> List[int]:
84
+ '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
85
+
86
+ range_re = re.compile(r'^(\d+)-(\d+)$')
87
+ m = range_re.match(s)
88
+ if m:
89
+ return list(range(int(m.group(1)), int(m.group(2))+1))
90
+ vals = s.split(',')
91
+ return [int(x) for x in vals]
92
+
93
+
94
+ #----------------------------------------------------------------------------
95
+
96
+
97
+
98
+ def spherical_dist_loss(x, y):
99
+ x = F.normalize(x, dim=-1)
100
+ y = F.normalize(y, dim=-1)
101
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
102
+
103
+
104
+ def prompts_dist_loss(x, targets, loss):
105
+ if len(targets) == 1: # Keeps consistent results vs previous method for single objective guidance
106
+ return loss(x, targets[0])
107
+ distances = [loss(x, target) for target in targets]
108
+ return torch.stack(distances, dim=-1).sum(dim=-1)
109
+
110
+
111
+ def embed_text(model, prompt, device='cuda'):
112
+ return
113
+
114
+
115
+ #----------------------------------------------------------------------------
116
+
117
+ @torch.no_grad()
118
+ @torch.cuda.amp.autocast()
119
+ def generate_edit(
120
+ G,
121
+ dl,
122
+ direction,
123
+ edit_strength,
124
+ path,
125
+ ):
126
+ for it, batch in enumerate(dl):
127
+ batch["embedding"] = None
128
+ styles = get_styles(None, G, batch, truncation_value=0)
129
+ imgs = []
130
+ grad_changes = [_*edit_strength for _ in [0, 0.25, 0.5, 0.75, 1]]
131
+ grad_changes = [*[-x for x in grad_changes][::-1], *grad_changes]
132
+ batch = {k: tops.to_cuda(v) if v is not None else v for k,v in batch.items()}
133
+ for i, grad_change in enumerate(grad_changes):
134
+ s = styles + direction*grad_change
135
+
136
+ img = G(**batch, s=iter(s))["img"]
137
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
138
+ imgs.append(img[0].to(torch.uint8).cpu().numpy())
139
+ PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(path + f'{it}.png')
140
+
141
+
142
+ @torch.no_grad()
143
+ def get_styles(seed, G: torch.nn.Module, batch, truncation_value=1):
144
+ all_styles = []
145
+ if seed is None:
146
+ z = np.random.normal(0, 0, size=(1, G.z_channels))
147
+ else:
148
+ z = np.random.RandomState(seed=seed).normal(0, 1, size=(1, G.z_channels))
149
+ z_idx = np.random.RandomState(seed=seed).randint(0, len(G.style_net.w_centers))
150
+ w_c = G.style_net.w_centers[z_idx].to(tops.get_device()).view(1, -1)
151
+ w = G.style_net(torch.from_numpy(z).to(tops.get_device()))
152
+
153
+ w = w_c.to(w.dtype).lerp(w, truncation_value)
154
+ if hasattr(G, "get_comod_y"):
155
+ w = G.get_comod_y(batch, w)
156
+ for block in G.modules():
157
+ if not hasattr(block, "affine") or not hasattr(block.affine, "weight"):
158
+ continue
159
+ gamma0 = block.affine(w)
160
+ if hasattr(block, "affine_beta"):
161
+ beta0 = block.affine_beta(w)
162
+ gamma0 = torch.cat((gamma0, beta0), dim=1)
163
+ all_styles.append(gamma0)
164
+ max_ch = max([s.shape[-1] for s in all_styles])
165
+ all_styles = [F.pad(s, ((0, max_ch - s.shape[-1])), "constant", 0) for s in all_styles]
166
+ all_styles = torch.cat(all_styles)
167
+ return all_styles
168
+
169
+ def get_and_cache_direction(output_dir: Path, dl_val, G, text_prompt):
170
+ cache_path = output_dir.joinpath(
171
+ "stylemc_cache", text_prompt.replace(" ", "_") + ".torch")
172
+ if cache_path.is_file():
173
+ print("Loaded cache from:", cache_path)
174
+ return torch.load(cache_path)
175
+ direction = find_direction(G, text_prompt, None, dl_val=iter(dl_val))
176
+ cache_path.parent.mkdir(exist_ok=True, parents=True)
177
+ torch.save(direction, cache_path)
178
+ return direction
179
+
180
+ @torch.cuda.amp.autocast()
181
+ def find_direction(
182
+ G,
183
+ text_prompt,
184
+ batches,
185
+ #layers,
186
+ n_iterations=128*8,
187
+ batch_size=8,
188
+ dl_val=None
189
+ ):
190
+ time_start = timer()
191
+
192
+ clip_model = clip.load("ViT-B/16", device=tops.get_device())[0]
193
+
194
+ target = [clip_model.encode_text(clip.tokenize(text_prompt).to(tops.get_device())).float()]
195
+ all_styles = []
196
+ if dl_val is not None:
197
+ first_batch = next(dl_val)
198
+ else:
199
+ first_batch = batches[0]
200
+ first_batch["embedding"] = None if "embedding" not in first_batch else first_batch["embedding"]
201
+ s = get_styles(0, G, first_batch)
202
+ # stats tracker
203
+ cos_sim_track = AverageMeter('cos_sim', ':.4f')
204
+ norm_track = AverageMeter('norm', ':.4f')
205
+ n_iterations = n_iterations // batch_size
206
+ progress = ProgressMeter(n_iterations, [cos_sim_track, norm_track])
207
+
208
+ # initalize styles direction
209
+ direction = torch.zeros(s.shape, device=tops.get_device())
210
+ direction.requires_grad_()
211
+ utils.set_requires_grad(G, False)
212
+ direction_tracker = torch.zeros_like(direction)
213
+ opt = torch.optim.AdamW([direction], lr=0.05, betas=(0., 0.999), weight_decay=0.25)
214
+
215
+ grads = []
216
+ for seed_idx in tqdm.trange(n_iterations):
217
+ # forward pass through synthesis network with new styles
218
+ if seed_idx == 0:
219
+ batch = first_batch
220
+ elif dl_val is not None:
221
+ batch = next(dl_val)
222
+ batch["embedding"] = None if "embedding" not in batch else batch["embedding"]
223
+ else:
224
+ batch = {k: tops.to_cuda(v) if v is not None else v for k, v in batches[seed_idx].items()}
225
+ styles = get_styles(seed_idx, G, batch) + direction
226
+ img = G(**batch, s=iter(styles))["img"]
227
+ batch = {k: v.cpu() if v is not None else v for k, v in batch.items()}
228
+ # clip loss
229
+ img = (img + 1)/2
230
+ img = normalize(img, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
231
+ img = resize(img, (224, 224))
232
+ embeds = clip_model.encode_image(img)
233
+ cos_sim = prompts_dist_loss(embeds, target, spherical_dist_loss)
234
+ cos_sim.backward(retain_graph=True)
235
+
236
+ # track stats
237
+ cos_sim_track.update(cos_sim.item())
238
+ norm_track.update(torch.norm(direction).item())
239
+
240
+ if not (seed_idx % batch_size):
241
+
242
+ # zeroing out gradients for non-optimized layers
243
+ #layers_zeroed = torch.tensor([x for x in range(G.num_ws) if not x in layers])
244
+ #direction.grad[:, layers_zeroed] = 0
245
+
246
+ opt.step()
247
+ grads.append(direction.grad.clone())
248
+ direction.grad.data.zero_()
249
+
250
+ # keep track of gradients over time
251
+ if seed_idx > 3:
252
+ direction_tracker[grads[-2] * grads[-1] < 0] += 1
253
+
254
+ # plot stats
255
+ progress.display(seed_idx)
256
+
257
+ # throw out fluctuating channels
258
+ direction = direction.detach()
259
+ direction[direction_tracker > n_iterations / 4] = 0
260
+ print(direction)
261
+ print(f"Time for direction search: {timer() - time_start:.2f} s")
262
+ return direction
263
+
264
+
265
+
266
+
267
+ @click.command()
268
+ @click.argument("config_path")
269
+ @click.argument("input_path")
270
+ @click.argument("output_path")
271
+ #@click.option('--layers', type=num_range, help='Restrict the style space to a range of layers. We recommend not to optimize the critically sampled layers (last 3).', required=True)
272
+ @click.option('--text-prompt', help='Text', type=str, required=True)
273
+ @click.option('--edit-strength', help='Strength of edit', type=float, required=True)
274
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True)
275
+ def stylemc(
276
+ config_path,
277
+ #layers: List[int],
278
+ text_prompt: str,
279
+ edit_strength: float,
280
+ outdir: str,
281
+ ):
282
+ cfg = utils.load_config(config_path)
283
+ G = build_trained_generator(cfg)
284
+ cfg.train.batch_size = 1
285
+ n_iterations = 256
286
+ dl_val = tops.config.instantiate(cfg.data.val.loader)
287
+
288
+ direction = find_direction(G, text_prompt, None, n_iterations=n_iterations, dl_val=iter(dl_val))
289
+
290
+ text_prompt = text_prompt.replace(" ", "_")
291
+ generate_edit(G, input_path, direction, edit_strength, output_path)
292
+
293
+
294
+ if __name__ == "__main__":
295
+ stylemc()