邱浩楠 commited on
Commit
45beb96
1 Parent(s): 27609eb
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. README.md +3 -3
  3. app.py +8 -0
  4. checkpoints/stylefacev/latest_net_FE.pth +3 -0
  5. checkpoints/stylefacev/latest_net_FE_lm.pth +3 -0
  6. checkpoints/stylefacev/latest_net_FE_pose.pth +3 -0
  7. data/__init__.py +93 -0
  8. data/base_dataset.py +157 -0
  9. data/image_folder.py +90 -0
  10. data/noiseshufflevideo_dataset.py +71 -0
  11. dnnlib/__init__.py +9 -0
  12. dnnlib/util.py +491 -0
  13. legacy.py +323 -0
  14. models/__init__.py +68 -0
  15. models/base_model.py +234 -0
  16. models/diy_networks.py +918 -0
  17. models/lmcode_networks.py +394 -0
  18. models/resnet.py +1452 -0
  19. models/rnn_net.py +99 -0
  20. models/sample_model.py +243 -0
  21. options/__init__.py +1 -0
  22. options/base_options.py +138 -0
  23. options/test_options.py +23 -0
  24. options/train_options.py +43 -0
  25. pretrained_models/.DS_Store +0 -0
  26. pretrained_models/motion_net.pth +3 -0
  27. pretrained_models/network-snapshot-005000.pkl +3 -0
  28. pretrained_models/wing.ckpt +3 -0
  29. torch_utils/__init__.py +9 -0
  30. torch_utils/custom_ops.py +157 -0
  31. torch_utils/misc.py +266 -0
  32. torch_utils/ops/__init__.py +9 -0
  33. torch_utils/ops/bias_act.cpp +99 -0
  34. torch_utils/ops/bias_act.cu +173 -0
  35. torch_utils/ops/bias_act.h +38 -0
  36. torch_utils/ops/bias_act.py +209 -0
  37. torch_utils/ops/conv2d_gradfix.py +198 -0
  38. torch_utils/ops/conv2d_resample.py +143 -0
  39. torch_utils/ops/filtered_lrelu.cpp +300 -0
  40. torch_utils/ops/filtered_lrelu.cu +1284 -0
  41. torch_utils/ops/filtered_lrelu.h +90 -0
  42. torch_utils/ops/filtered_lrelu.py +274 -0
  43. torch_utils/ops/filtered_lrelu_ns.cu +27 -0
  44. torch_utils/ops/filtered_lrelu_rd.cu +27 -0
  45. torch_utils/ops/filtered_lrelu_wr.cu +27 -0
  46. torch_utils/ops/fma.py +60 -0
  47. torch_utils/ops/grid_sample_gradfix.py +77 -0
  48. torch_utils/ops/upfirdn2d.cpp +107 -0
  49. torch_utils/ops/upfirdn2d.cu +384 -0
  50. torch_utils/ops/upfirdn2d.h +59 -0
.gitattributes CHANGED
@@ -29,3 +29,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ checkpoints/stylefacev/latest_net_FE.pth filter=lfs diff=lfs merge=lfs -text
33
+ checkpoints/stylefacev/latest_net_FE_lm.pth filter=lfs diff=lfs merge=lfs -text
34
+ checkpoints/stylefacev/latest_net_FE_pose.pth filter=lfs diff=lfs merge=lfs -text
35
+ pretrained_models/network-snapshot-005000.pkl filter=lfs diff=lfs merge=lfs -text
36
+ pretrained_models/wing.ckpt filter=lfs diff=lfs merge=lfs -text
37
+ pretrained_models/motion_net.pth filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: StyleFaceV
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.1.7
8
  app_file: app.py
 
1
  ---
2
  title: StyleFaceV
3
+ emoji: 🏢
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.1.7
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
5
+
6
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ iface.launch()
8
+
checkpoints/stylefacev/latest_net_FE.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52d5f2eb71cb79fce9faa1448450033d83037a427f1c69cd3924fcc9771fe1ef
3
+ size 559223985
checkpoints/stylefacev/latest_net_FE_lm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34827db294ffd611971460fead1445844144311543e2a35fb2ccd1b52ae8d07c
3
+ size 25497505
checkpoints/stylefacev/latest_net_FE_pose.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27a838efca2be63aa71a1b6da020998ac466aaeee2772c9e2631ba65f993174b
3
+ size 6447709
data/__init__.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import importlib
14
+ import torch.utils.data
15
+ from data.base_dataset import BaseDataset
16
+
17
+
18
+ def find_dataset_using_name(dataset_name):
19
+ """Import the module "data/[dataset_name]_dataset.py".
20
+
21
+ In the file, the class called DatasetNameDataset() will
22
+ be instantiated. It has to be a subclass of BaseDataset,
23
+ and it is case-insensitive.
24
+ """
25
+ dataset_filename = "data." + dataset_name + "_dataset"
26
+ datasetlib = importlib.import_module(dataset_filename)
27
+
28
+ dataset = None
29
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30
+ for name, cls in datasetlib.__dict__.items():
31
+ if name.lower() == target_dataset_name.lower() \
32
+ and issubclass(cls, BaseDataset):
33
+ dataset = cls
34
+
35
+ if dataset is None:
36
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37
+
38
+ return dataset
39
+
40
+
41
+ def get_option_setter(dataset_name):
42
+ """Return the static method <modify_commandline_options> of the dataset class."""
43
+ dataset_class = find_dataset_using_name(dataset_name)
44
+ return dataset_class.modify_commandline_options
45
+
46
+
47
+ def create_dataset(opt):
48
+ """Create a dataset given the option.
49
+
50
+ This function wraps the class CustomDatasetDataLoader.
51
+ This is the main interface between this package and 'train.py'/'test.py'
52
+
53
+ Example:
54
+ >>> from data import create_dataset
55
+ >>> dataset = create_dataset(opt)
56
+ """
57
+ data_loader = CustomDatasetDataLoader(opt)
58
+ dataset = data_loader.load_data()
59
+ return dataset
60
+
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ print("dataset [%s] was created" % type(self.dataset).__name__)
75
+ self.dataloader = torch.utils.data.DataLoader(
76
+ self.dataset,
77
+ batch_size=opt.batch_size,
78
+ shuffle=not opt.serial_batches,
79
+ num_workers=int(opt.num_threads))
80
+
81
+ def load_data(self):
82
+ return self
83
+
84
+ def __len__(self):
85
+ """Return the number of data in the dataset"""
86
+ return min(len(self.dataset), self.opt.max_dataset_size)
87
+
88
+ def __iter__(self):
89
+ """Return a batch of data"""
90
+ for i, data in enumerate(self.dataloader):
91
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
92
+ break
93
+ yield data
data/base_dataset.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ self.root = opt.dataroot
31
+
32
+ @staticmethod
33
+ def modify_commandline_options(parser, is_train):
34
+ """Add new dataset-specific options, and rewrite default values for existing options.
35
+
36
+ Parameters:
37
+ parser -- original option parser
38
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
39
+
40
+ Returns:
41
+ the modified parser.
42
+ """
43
+ return parser
44
+
45
+ @abstractmethod
46
+ def __len__(self):
47
+ """Return the total number of images in the dataset."""
48
+ return 0
49
+
50
+ @abstractmethod
51
+ def __getitem__(self, index):
52
+ """Return a data point and its metadata information.
53
+
54
+ Parameters:
55
+ index - - a random integer for data indexing
56
+
57
+ Returns:
58
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
59
+ """
60
+ pass
61
+
62
+
63
+ def get_params(opt, size):
64
+ w, h = size
65
+ new_h = h
66
+ new_w = w
67
+ if opt.preprocess == 'resize_and_crop':
68
+ new_h = new_w = opt.load_size
69
+ elif opt.preprocess == 'scale_width_and_crop':
70
+ new_w = opt.load_size
71
+ new_h = opt.load_size * h // w
72
+
73
+ x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
74
+ y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
75
+
76
+ flip = random.random() > 0.5
77
+
78
+ return {'crop_pos': (x, y), 'flip': flip}
79
+
80
+
81
+ def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
82
+ transform_list = []
83
+ if grayscale:
84
+ transform_list.append(transforms.Grayscale(1))
85
+ if 'resize' in opt.preprocess:
86
+ osize = [opt.load_size, opt.load_size]
87
+ transform_list.append(transforms.Resize(osize, method))
88
+ elif 'scale_width' in opt.preprocess:
89
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
90
+
91
+ if 'crop' in opt.preprocess:
92
+ if params is None:
93
+ transform_list.append(transforms.RandomCrop(opt.crop_size))
94
+ else:
95
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
96
+
97
+ if opt.preprocess == 'none':
98
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
99
+
100
+ if not opt.no_flip:
101
+ if params is None:
102
+ transform_list.append(transforms.RandomHorizontalFlip())
103
+ elif params['flip']:
104
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
105
+
106
+ if convert:
107
+ transform_list += [transforms.ToTensor()]
108
+ if grayscale:
109
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
110
+ else:
111
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
112
+ return transforms.Compose(transform_list)
113
+
114
+
115
+ def __make_power_2(img, base, method=Image.BICUBIC):
116
+ ow, oh = img.size
117
+ h = int(round(oh / base) * base)
118
+ w = int(round(ow / base) * base)
119
+ if h == oh and w == ow:
120
+ return img
121
+
122
+ __print_size_warning(ow, oh, w, h)
123
+ return img.resize((w, h), method)
124
+
125
+
126
+ def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
127
+ ow, oh = img.size
128
+ if ow == target_size and oh >= crop_size:
129
+ return img
130
+ w = target_size
131
+ h = int(max(target_size * oh / ow, crop_size))
132
+ return img.resize((w, h), method)
133
+
134
+
135
+ def __crop(img, pos, size):
136
+ ow, oh = img.size
137
+ x1, y1 = pos
138
+ tw = th = size
139
+ if (ow > tw or oh > th):
140
+ return img.crop((x1, y1, x1 + tw, y1 + th))
141
+ return img
142
+
143
+
144
+ def __flip(img, flip):
145
+ if flip:
146
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
147
+ return img
148
+
149
+
150
+ def __print_size_warning(ow, oh, w, h):
151
+ """Print warning information about image size(only print once)"""
152
+ if not hasattr(__print_size_warning, 'has_printed'):
153
+ print("The image size needs to be a multiple of 4. "
154
+ "The loaded image size was (%d, %d), so it was adjusted to "
155
+ "(%d, %d). This adjustment will be done to all images "
156
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
157
+ __print_size_warning.has_printed = True
data/image_folder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+
12
+ IMG_EXTENSIONS = [
13
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
14
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
15
+ '.tif', '.TIF', '.tiff', '.TIFF',
16
+ ]
17
+
18
+
19
+ def is_image_file(filename):
20
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21
+
22
+
23
+ def make_id_dataset(dir, max_dataset_size=float("inf")):
24
+ ids = []
25
+ images = []
26
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
27
+
28
+ id_names = sorted(os.listdir(dir))
29
+ for id_name in id_names:
30
+ id_path = os.path.join(dir, id_name)
31
+ fnames = os.listdir(id_path)
32
+ for fname in fnames:
33
+ path = os.path.join(dir, id_name, fname)
34
+ images.append(path)
35
+ ids.append(id_name)
36
+ return images[:min(max_dataset_size, len(images))], ids[:min(max_dataset_size, len(ids))]
37
+
38
+ def make_noid_dataset(dir, max_dataset_size=float("inf")):
39
+ images = []
40
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
41
+
42
+ fnames = sorted(os.listdir(dir))
43
+ for fname in fnames:
44
+ path = os.path.join(dir, fname)
45
+ images.append(path)
46
+ return images[:min(max_dataset_size, len(images))]
47
+
48
+ def make_dataset(dir, max_dataset_size=float("inf")):
49
+ images = []
50
+ assert os.path.isdir(dir), '%s is not a valid directory' % dir
51
+
52
+ for root, _, fnames in sorted(os.walk(dir)):
53
+ for fname in fnames:
54
+ if is_image_file(fname):
55
+ path = os.path.join(root, fname)
56
+ images.append(path)
57
+ return images[:min(max_dataset_size, len(images))]
58
+
59
+
60
+ def default_loader(path):
61
+ return Image.open(path).convert('RGB')
62
+
63
+
64
+ class ImageFolder(data.Dataset):
65
+
66
+ def __init__(self, root, transform=None, return_paths=False,
67
+ loader=default_loader):
68
+ imgs = make_dataset(root)
69
+ if len(imgs) == 0:
70
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
71
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
72
+
73
+ self.root = root
74
+ self.imgs = imgs
75
+ self.transform = transform
76
+ self.return_paths = return_paths
77
+ self.loader = loader
78
+
79
+ def __getitem__(self, index):
80
+ path = self.imgs[index]
81
+ img = self.loader(path)
82
+ if self.transform is not None:
83
+ img = self.transform(img)
84
+ if self.return_paths:
85
+ return img, path
86
+ else:
87
+ return img
88
+
89
+ def __len__(self):
90
+ return len(self.imgs)
data/noiseshufflevideo_dataset.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.base_dataset import BaseDataset, get_transform, get_params
2
+ from data.image_folder import make_id_dataset
3
+ from PIL import Image
4
+ import random
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ class NoiseShuffleVideoDataset(BaseDataset):
11
+ """This dataset class can load a set of images specified by the path --dataroot /path/to/data.
12
+
13
+ It can be used for generating CycleGAN results only for one side with the model option '-model test'.
14
+ """
15
+
16
+ def __init__(self, opt):
17
+ """Initialize this dataset class.
18
+
19
+ Parameters:
20
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
21
+ """
22
+ BaseDataset.__init__(self, opt)
23
+ self.opt = opt
24
+ self.A_paths, self.A_ids = make_id_dataset(opt.dataroot, opt.max_dataset_size)
25
+
26
+ self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
27
+
28
+ def __getitem__(self, index):
29
+ """Return a data point and its metadata information.
30
+
31
+ Parameters:
32
+ index - - a random integer for data indexing
33
+
34
+ Returns a dictionary that contains A and A_paths
35
+ A(tensor) - - an image in one domain
36
+ A_paths(str) - - the path of the image
37
+ """
38
+ # A_id = self.A_ids[index]
39
+ A_list = []
40
+ random.seed(index)
41
+ A_index = int(random.random() * (len(self.A_paths) - 1))
42
+ A_video = self.A_paths[A_index]
43
+ A_frames = sorted(os.listdir(A_video))
44
+ max_frames = len(A_frames)
45
+ while max_frames < 60:
46
+ A_index = (A_index + 1) % len(self.A_paths)
47
+ A_video = self.A_paths[A_index]
48
+ A_frames = sorted(os.listdir(A_video))
49
+ max_frames = len(A_frames)
50
+
51
+ for i in range(max_frames):
52
+ A_frame = A_frames[i]
53
+ # print(A_frame)
54
+ A_path = os.path.join(A_video, A_frame)
55
+ A_img = Image.open(A_path).convert('RGB')
56
+
57
+ if i == 0:
58
+ transform_params = get_params(self.opt, A_img.size)
59
+ self.transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
60
+
61
+ A = self.transform(A_img)
62
+ A_list.append(A.unsqueeze(0))
63
+
64
+ A = torch.cat(A_list, 0)
65
+ B = torch.from_numpy(np.random.RandomState(index).randn(512))
66
+
67
+ return {'A': A, 'A_paths': A_path, 'B': B}
68
+
69
+ def __len__(self):
70
+ """Return the total number of images in the dataset."""
71
+ return len(self.A_paths)
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/util.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def format_time_brief(seconds: Union[int, float]) -> str:
154
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
155
+ s = int(np.rint(seconds))
156
+
157
+ if s < 60:
158
+ return "{0}s".format(s)
159
+ elif s < 60 * 60:
160
+ return "{0}m {1:02}s".format(s // 60, s % 60)
161
+ elif s < 24 * 60 * 60:
162
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
163
+ else:
164
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
165
+
166
+
167
+ def ask_yes_no(question: str) -> bool:
168
+ """Ask the user the question until the user inputs a valid answer."""
169
+ while True:
170
+ try:
171
+ print("{0} [y/n]".format(question))
172
+ return strtobool(input().lower())
173
+ except ValueError:
174
+ pass
175
+
176
+
177
+ def tuple_product(t: Tuple) -> Any:
178
+ """Calculate the product of the tuple elements."""
179
+ result = 1
180
+
181
+ for v in t:
182
+ result *= v
183
+
184
+ return result
185
+
186
+
187
+ _str_to_ctype = {
188
+ "uint8": ctypes.c_ubyte,
189
+ "uint16": ctypes.c_uint16,
190
+ "uint32": ctypes.c_uint32,
191
+ "uint64": ctypes.c_uint64,
192
+ "int8": ctypes.c_byte,
193
+ "int16": ctypes.c_int16,
194
+ "int32": ctypes.c_int32,
195
+ "int64": ctypes.c_int64,
196
+ "float32": ctypes.c_float,
197
+ "float64": ctypes.c_double
198
+ }
199
+
200
+
201
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
202
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
203
+ type_str = None
204
+
205
+ if isinstance(type_obj, str):
206
+ type_str = type_obj
207
+ elif hasattr(type_obj, "__name__"):
208
+ type_str = type_obj.__name__
209
+ elif hasattr(type_obj, "name"):
210
+ type_str = type_obj.name
211
+ else:
212
+ raise RuntimeError("Cannot infer type name from input")
213
+
214
+ assert type_str in _str_to_ctype.keys()
215
+
216
+ my_dtype = np.dtype(type_str)
217
+ my_ctype = _str_to_ctype[type_str]
218
+
219
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
220
+
221
+ return my_dtype, my_ctype
222
+
223
+
224
+ def is_pickleable(obj: Any) -> bool:
225
+ try:
226
+ with io.BytesIO() as stream:
227
+ pickle.dump(obj, stream)
228
+ return True
229
+ except:
230
+ return False
231
+
232
+
233
+ # Functionality to import modules/objects by name, and call functions by name
234
+ # ------------------------------------------------------------------------------------------
235
+
236
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
237
+ """Searches for the underlying module behind the name to some python object.
238
+ Returns the module and the object name (original name with module part removed)."""
239
+
240
+ # allow convenience shorthands, substitute them by full names
241
+ obj_name = re.sub("^np.", "numpy.", obj_name)
242
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
243
+
244
+ # list alternatives for (module_name, local_obj_name)
245
+ parts = obj_name.split(".")
246
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
247
+
248
+ # try each alternative in turn
249
+ for module_name, local_obj_name in name_pairs:
250
+ try:
251
+ module = importlib.import_module(module_name) # may raise ImportError
252
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
253
+ return module, local_obj_name
254
+ except:
255
+ pass
256
+
257
+ # maybe some of the modules themselves contain errors?
258
+ for module_name, _local_obj_name in name_pairs:
259
+ try:
260
+ importlib.import_module(module_name) # may raise ImportError
261
+ except ImportError:
262
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
263
+ raise
264
+
265
+ # maybe the requested attribute is missing?
266
+ for module_name, local_obj_name in name_pairs:
267
+ try:
268
+ module = importlib.import_module(module_name) # may raise ImportError
269
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
270
+ except ImportError:
271
+ pass
272
+
273
+ # we are out of luck, but we have no idea why
274
+ raise ImportError(obj_name)
275
+
276
+
277
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
278
+ """Traverses the object name and returns the last (rightmost) python object."""
279
+ if obj_name == '':
280
+ return module
281
+ obj = module
282
+ for part in obj_name.split("."):
283
+ obj = getattr(obj, part)
284
+ return obj
285
+
286
+
287
+ def get_obj_by_name(name: str) -> Any:
288
+ """Finds the python object with the given name."""
289
+ module, obj_name = get_module_from_obj_name(name)
290
+ return get_obj_from_module(module, obj_name)
291
+
292
+
293
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
294
+ """Finds the python object with the given name and calls it as a function."""
295
+ assert func_name is not None
296
+ func_obj = get_obj_by_name(func_name)
297
+ assert callable(func_obj)
298
+ return func_obj(*args, **kwargs)
299
+
300
+
301
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
302
+ """Finds the python class with the given name and constructs it with the given arguments."""
303
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
304
+
305
+
306
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
307
+ """Get the directory path of the module containing the given object name."""
308
+ module, _ = get_module_from_obj_name(obj_name)
309
+ return os.path.dirname(inspect.getfile(module))
310
+
311
+
312
+ def is_top_level_function(obj: Any) -> bool:
313
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
314
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
315
+
316
+
317
+ def get_top_level_function_name(obj: Any) -> str:
318
+ """Return the fully-qualified name of a top-level function."""
319
+ assert is_top_level_function(obj)
320
+ module = obj.__module__
321
+ if module == '__main__':
322
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
323
+ return module + "." + obj.__name__
324
+
325
+
326
+ # File system helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
330
+ """List all files recursively in a given directory while ignoring given file and directory names.
331
+ Returns list of tuples containing both absolute and relative paths."""
332
+ assert os.path.isdir(dir_path)
333
+ base_name = os.path.basename(os.path.normpath(dir_path))
334
+
335
+ if ignores is None:
336
+ ignores = []
337
+
338
+ result = []
339
+
340
+ for root, dirs, files in os.walk(dir_path, topdown=True):
341
+ for ignore_ in ignores:
342
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
343
+
344
+ # dirs need to be edited in-place
345
+ for d in dirs_to_remove:
346
+ dirs.remove(d)
347
+
348
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
349
+
350
+ absolute_paths = [os.path.join(root, f) for f in files]
351
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
352
+
353
+ if add_base_to_relative:
354
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
355
+
356
+ assert len(absolute_paths) == len(relative_paths)
357
+ result += zip(absolute_paths, relative_paths)
358
+
359
+ return result
360
+
361
+
362
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
363
+ """Takes in a list of tuples of (src, dst) paths and copies files.
364
+ Will create all necessary directories."""
365
+ for file in files:
366
+ target_dir_name = os.path.dirname(file[1])
367
+
368
+ # will create all intermediate-level directories
369
+ if not os.path.exists(target_dir_name):
370
+ os.makedirs(target_dir_name)
371
+
372
+ shutil.copyfile(file[0], file[1])
373
+
374
+
375
+ # URL helpers
376
+ # ------------------------------------------------------------------------------------------
377
+
378
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
379
+ """Determine whether the given object is a valid URL string."""
380
+ if not isinstance(obj, str) or not "://" in obj:
381
+ return False
382
+ if allow_file_urls and obj.startswith('file://'):
383
+ return True
384
+ try:
385
+ res = requests.compat.urlparse(obj)
386
+ if not res.scheme or not res.netloc or not "." in res.netloc:
387
+ return False
388
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
389
+ if not res.scheme or not res.netloc or not "." in res.netloc:
390
+ return False
391
+ except:
392
+ return False
393
+ return True
394
+
395
+
396
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
397
+ """Download the given URL and return a binary-mode file object to access the data."""
398
+ assert num_attempts >= 1
399
+ assert not (return_filename and (not cache))
400
+
401
+ # Doesn't look like an URL scheme so interpret it as a local filename.
402
+ if not re.match('^[a-z]+://', url):
403
+ return url if return_filename else open(url, "rb")
404
+
405
+ # Handle file URLs. This code handles unusual file:// patterns that
406
+ # arise on Windows:
407
+ #
408
+ # file:///c:/foo.txt
409
+ #
410
+ # which would translate to a local '/c:/foo.txt' filename that's
411
+ # invalid. Drop the forward slash for such pathnames.
412
+ #
413
+ # If you touch this code path, you should test it on both Linux and
414
+ # Windows.
415
+ #
416
+ # Some internet resources suggest using urllib.request.url2pathname() but
417
+ # but that converts forward slashes to backslashes and this causes
418
+ # its own set of problems.
419
+ if url.startswith('file://'):
420
+ filename = urllib.parse.urlparse(url).path
421
+ if re.match(r'^/[a-zA-Z]:', filename):
422
+ filename = filename[1:]
423
+ return filename if return_filename else open(filename, "rb")
424
+
425
+ assert is_url(url)
426
+
427
+ # Lookup from cache.
428
+ if cache_dir is None:
429
+ cache_dir = make_cache_dir_path('downloads')
430
+
431
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
432
+ if cache:
433
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
434
+ if len(cache_files) == 1:
435
+ filename = cache_files[0]
436
+ return filename if return_filename else open(filename, "rb")
437
+
438
+ # Download.
439
+ url_name = None
440
+ url_data = None
441
+ with requests.Session() as session:
442
+ if verbose:
443
+ print("Downloading %s ..." % url, end="", flush=True)
444
+ for attempts_left in reversed(range(num_attempts)):
445
+ try:
446
+ with session.get(url) as res:
447
+ res.raise_for_status()
448
+ if len(res.content) == 0:
449
+ raise IOError("No data received")
450
+
451
+ if len(res.content) < 8192:
452
+ content_str = res.content.decode("utf-8")
453
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
454
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
455
+ if len(links) == 1:
456
+ url = requests.compat.urljoin(url, links[0])
457
+ raise IOError("Google Drive virus checker nag")
458
+ if "Google Drive - Quota exceeded" in content_str:
459
+ raise IOError("Google Drive download quota exceeded -- please try again later")
460
+
461
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
462
+ url_name = match[1] if match else url
463
+ url_data = res.content
464
+ if verbose:
465
+ print(" done")
466
+ break
467
+ except KeyboardInterrupt:
468
+ raise
469
+ except:
470
+ if not attempts_left:
471
+ if verbose:
472
+ print(" failed")
473
+ raise
474
+ if verbose:
475
+ print(".", end="", flush=True)
476
+
477
+ # Save to cache.
478
+ if cache:
479
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
480
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482
+ os.makedirs(cache_dir, exist_ok=True)
483
+ with open(temp_file, "wb") as f:
484
+ f.write(url_data)
485
+ os.replace(temp_file, cache_file) # atomic
486
+ if return_filename:
487
+ return cache_file
488
+
489
+ # Return data as file object.
490
+ assert not return_filename
491
+ return io.BytesIO(url_data)
legacy.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Converting legacy network pickle into the new format."""
10
+
11
+ import click
12
+ import pickle
13
+ import re
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+ from torch_utils import misc
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def load_network_pkl(f, force_fp16=False):
23
+ data = _LegacyUnpickler(f).load()
24
+
25
+ # Legacy TensorFlow pickle => convert.
26
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
27
+ tf_G, tf_D, tf_Gs = data
28
+ G = convert_tf_generator(tf_G)
29
+ D = convert_tf_discriminator(tf_D)
30
+ G_ema = convert_tf_generator(tf_Gs)
31
+ data = dict(G=G, D=D, G_ema=G_ema)
32
+
33
+ # Add missing fields.
34
+ if 'training_set_kwargs' not in data:
35
+ data['training_set_kwargs'] = None
36
+ if 'augment_pipe' not in data:
37
+ data['augment_pipe'] = None
38
+
39
+ # Validate contents.
40
+ assert isinstance(data['G'], torch.nn.Module)
41
+ assert isinstance(data['D'], torch.nn.Module)
42
+ assert isinstance(data['G_ema'], torch.nn.Module)
43
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
44
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
45
+
46
+ # Force FP16.
47
+ if force_fp16:
48
+ for key in ['G', 'D', 'G_ema']:
49
+ old = data[key]
50
+ kwargs = copy.deepcopy(old.init_kwargs)
51
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
52
+ fp16_kwargs.num_fp16_res = 4
53
+ fp16_kwargs.conv_clamp = 256
54
+ if kwargs != old.init_kwargs:
55
+ new = type(old)(**kwargs).eval().requires_grad_(False)
56
+ misc.copy_params_and_buffers(old, new, require_all=True)
57
+ data[key] = new
58
+ return data
59
+
60
+ #----------------------------------------------------------------------------
61
+
62
+ class _TFNetworkStub(dnnlib.EasyDict):
63
+ pass
64
+
65
+ class _LegacyUnpickler(pickle.Unpickler):
66
+ def find_class(self, module, name):
67
+ if module == 'dnnlib.tflib.network' and name == 'Network':
68
+ return _TFNetworkStub
69
+ return super().find_class(module, name)
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ def _collect_tf_params(tf_net):
74
+ # pylint: disable=protected-access
75
+ tf_params = dict()
76
+ def recurse(prefix, tf_net):
77
+ for name, value in tf_net.variables:
78
+ tf_params[prefix + name] = value
79
+ for name, comp in tf_net.components.items():
80
+ recurse(prefix + name + '/', comp)
81
+ recurse('', tf_net)
82
+ return tf_params
83
+
84
+ #----------------------------------------------------------------------------
85
+
86
+ def _populate_module_params(module, *patterns):
87
+ for name, tensor in misc.named_params_and_buffers(module):
88
+ found = False
89
+ value = None
90
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
91
+ match = re.fullmatch(pattern, name)
92
+ if match:
93
+ found = True
94
+ if value_fn is not None:
95
+ value = value_fn(*match.groups())
96
+ break
97
+ try:
98
+ assert found
99
+ if value is not None:
100
+ tensor.copy_(torch.from_numpy(np.array(value)))
101
+ except:
102
+ print(name, list(tensor.shape))
103
+ raise
104
+
105
+ #----------------------------------------------------------------------------
106
+
107
+ def convert_tf_generator(tf_G):
108
+ if tf_G.version < 4:
109
+ raise ValueError('TensorFlow pickle version too low')
110
+
111
+ # Collect kwargs.
112
+ tf_kwargs = tf_G.static_kwargs
113
+ known_kwargs = set()
114
+ def kwarg(tf_name, default=None, none=None):
115
+ known_kwargs.add(tf_name)
116
+ val = tf_kwargs.get(tf_name, default)
117
+ return val if val is not None else none
118
+
119
+ # Convert kwargs.
120
+ from training import networks_stylegan2
121
+ network_class = networks_stylegan2.Generator
122
+ kwargs = dnnlib.EasyDict(
123
+ z_dim = kwarg('latent_size', 512),
124
+ c_dim = kwarg('label_size', 0),
125
+ w_dim = kwarg('dlatent_size', 512),
126
+ img_resolution = kwarg('resolution', 1024),
127
+ img_channels = kwarg('num_channels', 3),
128
+ channel_base = kwarg('fmap_base', 16384) * 2,
129
+ channel_max = kwarg('fmap_max', 512),
130
+ num_fp16_res = kwarg('num_fp16_res', 0),
131
+ conv_clamp = kwarg('conv_clamp', None),
132
+ architecture = kwarg('architecture', 'skip'),
133
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
134
+ use_noise = kwarg('use_noise', True),
135
+ activation = kwarg('nonlinearity', 'lrelu'),
136
+ mapping_kwargs = dnnlib.EasyDict(
137
+ num_layers = kwarg('mapping_layers', 8),
138
+ embed_features = kwarg('label_fmaps', None),
139
+ layer_features = kwarg('mapping_fmaps', None),
140
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
141
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
142
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
143
+ ),
144
+ )
145
+
146
+ # Check for unknown kwargs.
147
+ kwarg('truncation_psi')
148
+ kwarg('truncation_cutoff')
149
+ kwarg('style_mixing_prob')
150
+ kwarg('structure')
151
+ kwarg('conditioning')
152
+ kwarg('fused_modconv')
153
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154
+ if len(unknown_kwargs) > 0:
155
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156
+
157
+ # Collect params.
158
+ tf_params = _collect_tf_params(tf_G)
159
+ for name, value in list(tf_params.items()):
160
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161
+ if match:
162
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
163
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164
+ kwargs.synthesis.kwargs.architecture = 'orig'
165
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166
+
167
+ # Convert params.
168
+ G = network_class(**kwargs).eval().requires_grad_(False)
169
+ # pylint: disable=unnecessary-lambda
170
+ # pylint: disable=f-string-without-interpolation
171
+ _populate_module_params(G,
172
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201
+ r'.*\.resample_filter', None,
202
+ r'.*\.act_filter', None,
203
+ )
204
+ return G
205
+
206
+ #----------------------------------------------------------------------------
207
+
208
+ def convert_tf_discriminator(tf_D):
209
+ if tf_D.version < 4:
210
+ raise ValueError('TensorFlow pickle version too low')
211
+
212
+ # Collect kwargs.
213
+ tf_kwargs = tf_D.static_kwargs
214
+ known_kwargs = set()
215
+ def kwarg(tf_name, default=None):
216
+ known_kwargs.add(tf_name)
217
+ return tf_kwargs.get(tf_name, default)
218
+
219
+ # Convert kwargs.
220
+ kwargs = dnnlib.EasyDict(
221
+ c_dim = kwarg('label_size', 0),
222
+ img_resolution = kwarg('resolution', 1024),
223
+ img_channels = kwarg('num_channels', 3),
224
+ architecture = kwarg('architecture', 'resnet'),
225
+ channel_base = kwarg('fmap_base', 16384) * 2,
226
+ channel_max = kwarg('fmap_max', 512),
227
+ num_fp16_res = kwarg('num_fp16_res', 0),
228
+ conv_clamp = kwarg('conv_clamp', None),
229
+ cmap_dim = kwarg('mapping_fmaps', None),
230
+ block_kwargs = dnnlib.EasyDict(
231
+ activation = kwarg('nonlinearity', 'lrelu'),
232
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
233
+ freeze_layers = kwarg('freeze_layers', 0),
234
+ ),
235
+ mapping_kwargs = dnnlib.EasyDict(
236
+ num_layers = kwarg('mapping_layers', 0),
237
+ embed_features = kwarg('mapping_fmaps', None),
238
+ layer_features = kwarg('mapping_fmaps', None),
239
+ activation = kwarg('nonlinearity', 'lrelu'),
240
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
241
+ ),
242
+ epilogue_kwargs = dnnlib.EasyDict(
243
+ mbstd_group_size = kwarg('mbstd_group_size', None),
244
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
245
+ activation = kwarg('nonlinearity', 'lrelu'),
246
+ ),
247
+ )
248
+
249
+ # Check for unknown kwargs.
250
+ kwarg('structure')
251
+ kwarg('conditioning')
252
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
253
+ if len(unknown_kwargs) > 0:
254
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
255
+
256
+ # Collect params.
257
+ tf_params = _collect_tf_params(tf_D)
258
+ for name, value in list(tf_params.items()):
259
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
260
+ if match:
261
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
262
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
263
+ kwargs.architecture = 'orig'
264
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
265
+
266
+ # Convert params.
267
+ from training import networks_stylegan2
268
+ D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
269
+ # pylint: disable=unnecessary-lambda
270
+ # pylint: disable=f-string-without-interpolation
271
+ _populate_module_params(D,
272
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
273
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
274
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
275
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
276
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
277
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
278
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
279
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
280
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
281
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
282
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
283
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
284
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
285
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
286
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
287
+ r'.*\.resample_filter', None,
288
+ )
289
+ return D
290
+
291
+ #----------------------------------------------------------------------------
292
+
293
+ @click.command()
294
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
295
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
296
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
297
+ def convert_network_pickle(source, dest, force_fp16):
298
+ """Convert legacy network pickle into the native PyTorch format.
299
+
300
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
301
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
302
+
303
+ Example:
304
+
305
+ \b
306
+ python legacy.py \\
307
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
308
+ --dest=stylegan2-cat-config-f.pkl
309
+ """
310
+ print(f'Loading "{source}"...')
311
+ with dnnlib.util.open_url(source) as f:
312
+ data = load_network_pkl(f, force_fp16=force_fp16)
313
+ print(f'Saving "{dest}"...')
314
+ with open(dest, 'wb') as f:
315
+ pickle.dump(data, f)
316
+ print('Done.')
317
+
318
+ #----------------------------------------------------------------------------
319
+
320
+ if __name__ == "__main__":
321
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
322
+
323
+ #----------------------------------------------------------------------------
models/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "models." + model_name + "_model"
33
+ print(model_filename)
34
+ modellib = importlib.import_module(model_filename)
35
+ model = None
36
+ target_model_name = model_name.replace('_', '') + 'model'
37
+ for name, cls in modellib.__dict__.items():
38
+ if name.lower() == target_model_name.lower() \
39
+ and issubclass(cls, BaseModel):
40
+ model = cls
41
+
42
+ if model is None:
43
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
44
+ exit(0)
45
+
46
+ return model
47
+
48
+
49
+ def get_option_setter(model_name):
50
+ """Return the static method <modify_commandline_options> of the model class."""
51
+ model_class = find_model_using_name(model_name)
52
+ return model_class.modify_commandline_options
53
+
54
+
55
+ def create_model(opt):
56
+ """Create a model given the option.
57
+
58
+ This function warps the class CustomDatasetDataLoader.
59
+ This is the main interface between this package and 'train.py'/'test.py'
60
+
61
+ Example:
62
+ >>> from models import create_model
63
+ >>> model = create_model(opt)
64
+ """
65
+ model = find_model_using_name(opt.model)
66
+ instance = model(opt)
67
+ print("model [%s] was created" % type(instance).__name__)
68
+ return instance
models/base_model.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from . import networks
6
+
7
+
8
+ class BaseModel(ABC):
9
+ """This class is an abstract base class (ABC) for models.
10
+ To create a subclass, you need to implement the following five functions:
11
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12
+ -- <set_input>: unpack data from dataset and apply preprocessing.
13
+ -- <forward>: produce intermediate results.
14
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
15
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
16
+ """
17
+
18
+ def __init__(self, opt):
19
+ """Initialize the BaseModel class.
20
+
21
+ Parameters:
22
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23
+
24
+ When creating your custom class, you need to implement your own initialization.
25
+ In this function, you should first call <BaseModel.__init__(self, opt)>
26
+ Then, you need to define four lists:
27
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
28
+ -- self.model_names (str list): define networks used in our training.
29
+ -- self.visual_names (str list): specify the images that you want to display and save.
30
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31
+ """
32
+ self.opt = opt
33
+ self.gpu_ids = opt.gpu_ids
34
+ self.isTrain = opt.isTrain
35
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38
+ torch.backends.cudnn.benchmark = True
39
+ self.loss_names = []
40
+ self.model_names = []
41
+ self.visual_names = []
42
+ self.optimizers = []
43
+ self.image_paths = []
44
+ self.metric = 0 # used for learning rate policy 'plateau'
45
+
46
+ @staticmethod
47
+ def modify_commandline_options(parser, is_train):
48
+ """Add new model-specific options, and rewrite default values for existing options.
49
+
50
+ Parameters:
51
+ parser -- original option parser
52
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
53
+
54
+ Returns:
55
+ the modified parser.
56
+ """
57
+ return parser
58
+
59
+ @abstractmethod
60
+ def set_input(self, input):
61
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
62
+
63
+ Parameters:
64
+ input (dict): includes the data itself and its metadata information.
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def forward(self):
70
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
71
+ pass
72
+
73
+ @abstractmethod
74
+ def optimize_parameters(self):
75
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
76
+ pass
77
+
78
+ def setup(self, opt):
79
+ """Load and print networks; create schedulers
80
+
81
+ Parameters:
82
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
83
+ """
84
+ if self.isTrain:
85
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
86
+ if not self.isTrain or opt.continue_train:
87
+ load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
88
+ self.load_networks(load_suffix)
89
+ self.print_networks(opt.verbose)
90
+
91
+ def eval(self):
92
+ """Make models eval mode during test time"""
93
+ for name in self.model_names:
94
+ if isinstance(name, str):
95
+ net = getattr(self, 'net' + name)
96
+ net.eval()
97
+
98
+ def test(self):
99
+ """Forward function used in test time.
100
+
101
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
102
+ It also calls <compute_visuals> to produce additional visualization results
103
+ """
104
+ with torch.no_grad():
105
+ self.forward()
106
+ self.compute_visuals()
107
+
108
+ def compute_visuals(self):
109
+ """Calculate additional output images for visdom and HTML visualization"""
110
+ pass
111
+
112
+ def get_image_paths(self):
113
+ """ Return image paths that are used to load current data"""
114
+ return self.image_paths
115
+
116
+ def update_learning_rate(self):
117
+ """Update learning rates for all the networks; called at the end of every epoch"""
118
+ old_lr = self.optimizers[0].param_groups[0]['lr']
119
+ for scheduler in self.schedulers:
120
+ if self.opt.lr_policy == 'plateau':
121
+ scheduler.step(self.metric)
122
+ else:
123
+ scheduler.step()
124
+
125
+ lr = self.optimizers[0].param_groups[0]['lr']
126
+ print('learning rate %.7f -> %.7f' % (old_lr, lr))
127
+
128
+ def get_current_visuals(self):
129
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
130
+ visual_ret = OrderedDict()
131
+ for name in self.visual_names:
132
+ if isinstance(name, str):
133
+ visual_ret[name] = getattr(self, name)
134
+ return visual_ret
135
+
136
+ def get_current_losses(self):
137
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
138
+ errors_ret = OrderedDict()
139
+ for name in self.loss_names:
140
+ if isinstance(name, str):
141
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
142
+ return errors_ret
143
+
144
+ def save_networks(self, epoch):
145
+ """Save all the networks to the disk.
146
+
147
+ Parameters:
148
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
149
+ """
150
+ for name in self.model_names:
151
+ if isinstance(name, str):
152
+ save_filename = '%s_net_%s.pth' % (epoch, name)
153
+ save_path = os.path.join(self.save_dir, save_filename)
154
+ net = getattr(self, 'net' + name)
155
+
156
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
157
+ if hasattr(net, 'module'):
158
+ torch.save(net.module.cpu().state_dict(), save_path)
159
+ net.cuda(self.gpu_ids[0])
160
+ else:
161
+ torch.save(net.cpu().state_dict(), save_path)
162
+ net.cuda(self.gpu_ids[0])
163
+ else:
164
+ torch.save(net.cpu().state_dict(), save_path)
165
+
166
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
167
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
168
+ key = keys[i]
169
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
170
+ if module.__class__.__name__.startswith('InstanceNorm') and \
171
+ (key == 'running_mean' or key == 'running_var'):
172
+ if getattr(module, key) is None:
173
+ state_dict.pop('.'.join(keys))
174
+ if module.__class__.__name__.startswith('InstanceNorm') and \
175
+ (key == 'num_batches_tracked'):
176
+ state_dict.pop('.'.join(keys))
177
+ else:
178
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
179
+
180
+ def load_networks(self, epoch):
181
+ """Load all the networks from the disk.
182
+
183
+ Parameters:
184
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
185
+ """
186
+ for name in self.model_names:
187
+ if isinstance(name, str):
188
+ load_filename = '%s_net_%s.pth' % (epoch, name)
189
+ load_path = os.path.join(self.save_dir, load_filename)
190
+ net = getattr(self, 'net' + name)
191
+ if isinstance(net, torch.nn.DataParallel):
192
+ net = net.module
193
+ print('loading the model from %s' % load_path)
194
+ # if you are using PyTorch newer than 0.4 (e.g., built from
195
+ # GitHub source), you can remove str() on self.device
196
+ state_dict = torch.load(load_path, map_location=str(self.device))
197
+ if hasattr(state_dict, '_metadata'):
198
+ del state_dict._metadata
199
+
200
+ # patch InstanceNorm checkpoints prior to 0.4
201
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
202
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
203
+ net.load_state_dict(state_dict)
204
+
205
+ def print_networks(self, verbose):
206
+ """Print the total number of parameters in the network and (if verbose) network architecture
207
+
208
+ Parameters:
209
+ verbose (bool) -- if verbose: print the network architecture
210
+ """
211
+ print('---------- Networks initialized -------------')
212
+ for name in self.model_names:
213
+ if isinstance(name, str):
214
+ net = getattr(self, 'net' + name)
215
+ num_params = 0
216
+ for param in net.parameters():
217
+ num_params += param.numel()
218
+ if verbose:
219
+ print(net)
220
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
221
+ print('-----------------------------------------------')
222
+
223
+ def set_requires_grad(self, nets, requires_grad=False):
224
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
225
+ Parameters:
226
+ nets (network list) -- a list of networks
227
+ requires_grad (bool) -- whether the networks require gradients or not
228
+ """
229
+ if not isinstance(nets, list):
230
+ nets = [nets]
231
+ for net in nets:
232
+ if net is not None:
233
+ for param in net.parameters():
234
+ param.requires_grad = requires_grad
models/diy_networks.py ADDED
@@ -0,0 +1,918 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ from types import FunctionType
5
+ from typing import Type, Any, Callable, Union, List, Optional
6
+
7
+ def _log_api_usage_once(obj: Any) -> None:
8
+ if not obj.__module__.startswith("torchvision"):
9
+ return
10
+ name = obj.__class__.__name__
11
+ if isinstance(obj, FunctionType):
12
+ name = obj.__name__
13
+ torch._C._log_api_usage_once(f"{obj.__module__}.{name}")
14
+
15
+ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
16
+ """3x3 convolution with padding"""
17
+ return nn.Conv2d(
18
+ in_planes,
19
+ out_planes,
20
+ kernel_size=3,
21
+ stride=stride,
22
+ padding=dilation,
23
+ groups=groups,
24
+ bias=False,
25
+ dilation=dilation,
26
+ )
27
+
28
+
29
+ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
30
+ """1x1 convolution"""
31
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
32
+
33
+ class Bottleneck(nn.Module):
34
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
35
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
36
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
37
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
38
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
39
+
40
+ expansion: int = 4
41
+
42
+ def __init__(
43
+ self,
44
+ inplanes: int,
45
+ planes: int,
46
+ stride: int = 1,
47
+ downsample: Optional[nn.Module] = None,
48
+ groups: int = 1,
49
+ base_width: int = 64,
50
+ dilation: int = 1,
51
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
52
+ ) -> None:
53
+ super().__init__()
54
+ if norm_layer is None:
55
+ norm_layer = nn.BatchNorm2d
56
+ width = int(planes * (base_width / 64.0)) * groups
57
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
58
+ self.conv1 = conv1x1(inplanes, width)
59
+ self.bn1 = norm_layer(width)
60
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
61
+ self.bn2 = norm_layer(width)
62
+ self.conv3 = conv1x1(width, planes * self.expansion)
63
+ self.bn3 = norm_layer(planes * self.expansion)
64
+ self.relu = nn.ReLU(inplace=True)
65
+ self.downsample = downsample
66
+ self.stride = stride
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ identity = x
70
+
71
+ out = self.conv1(x)
72
+ out = self.bn1(out)
73
+ out = self.relu(out)
74
+
75
+ out = self.conv2(out)
76
+ out = self.bn2(out)
77
+ out = self.relu(out)
78
+
79
+ out = self.conv3(out)
80
+ out = self.bn3(out)
81
+
82
+ if self.downsample is not None:
83
+ identity = self.downsample(x)
84
+
85
+ out += identity
86
+ out = self.relu(out)
87
+
88
+ return out
89
+
90
+ class BasicBlock(nn.Module):
91
+ expansion: int = 1
92
+
93
+ def __init__(
94
+ self,
95
+ inplanes: int,
96
+ planes: int,
97
+ stride: int = 1,
98
+ downsample: Optional[nn.Module] = None,
99
+ groups: int = 1,
100
+ base_width: int = 64,
101
+ dilation: int = 1,
102
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
103
+ ) -> None:
104
+ super().__init__()
105
+ if norm_layer is None:
106
+ norm_layer = nn.BatchNorm2d
107
+ if groups != 1 or base_width != 64:
108
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
109
+ if dilation > 1:
110
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
111
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
112
+ self.conv1 = conv3x3(inplanes, planes, stride)
113
+ self.bn1 = norm_layer(planes)
114
+ self.relu = nn.ReLU(inplace=True)
115
+ self.conv2 = conv3x3(planes, planes)
116
+ self.bn2 = norm_layer(planes)
117
+ self.downsample = downsample
118
+ self.stride = stride
119
+
120
+ def forward(self, x: Tensor) -> Tensor:
121
+ identity = x
122
+
123
+ out = self.conv1(x)
124
+ out = self.bn1(out)
125
+ out = self.relu(out)
126
+
127
+ out = self.conv2(out)
128
+ out = self.bn2(out)
129
+
130
+ if self.downsample is not None:
131
+ identity = self.downsample(x)
132
+
133
+ out += identity
134
+ out = self.relu(out)
135
+
136
+ return out
137
+
138
+
139
+ class ResPoseNet(nn.Module):
140
+ def __init__(
141
+ self,
142
+ block: Type[Union[BasicBlock, Bottleneck]],
143
+ zero_init_residual: bool = False,
144
+ groups: int = 1,
145
+ width_per_group: int = 64,
146
+ num_point: int = 12,
147
+ replace_stride_with_dilation: Optional[List[bool]] = None,
148
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
149
+ ) -> None:
150
+ super().__init__()
151
+ _log_api_usage_once(self)
152
+
153
+ block.expansion = 1
154
+
155
+ if norm_layer is None:
156
+ norm_layer = nn.BatchNorm2d
157
+ self._norm_layer = norm_layer
158
+
159
+ self.inplanes = 98
160
+ self.dilation = 1
161
+ if replace_stride_with_dilation is None:
162
+ # each element in the tuple indicates if we should replace
163
+ # the 2x2 stride with a dilated convolution instead
164
+ replace_stride_with_dilation = [False, False, False]
165
+ if len(replace_stride_with_dilation) != 3:
166
+ raise ValueError(
167
+ "replace_stride_with_dilation should be None "
168
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
169
+ )
170
+ self.groups = groups
171
+ self.base_width = width_per_group
172
+ self.layer1 = self._make_layer(block, 98, 3, stride=2)
173
+ self.layer2 = self._make_layer(block, 49, 3, stride=2)
174
+ self.layer3 = self._make_layer(block, 1, 3, stride=2)
175
+
176
+ self.layer4 = self._make_layer(block, 32, 3, stride=2)
177
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
178
+ self.fc = nn.Linear(32 * block.expansion, num_point * 2)
179
+
180
+ # self.layer4_1 = self._make_layer(block, 16, 3, stride=2)
181
+ # self.layer4_2 = self._make_layer(block, 16, 3, stride=2)
182
+ # self.layer4_3 = self._make_layer(block, 16, 3, stride=2)
183
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
184
+ # self.fc_1 = nn.Linear(16 * block.expansion, 2) # move
185
+ # self.fc_2 = nn.Linear(16 * block.expansion, 3) # pose
186
+ # self.fc_3 = nn.Linear(16 * block.expansion, 11) # attributes
187
+
188
+ for m in self.modules():
189
+ if isinstance(m, nn.Conv2d):
190
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
191
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
192
+ nn.init.constant_(m.weight, 1)
193
+ nn.init.constant_(m.bias, 0)
194
+
195
+ # Zero-initialize the last BN in each residual branch,
196
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
197
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
198
+ if zero_init_residual:
199
+ for m in self.modules():
200
+ if isinstance(m, Bottleneck):
201
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
202
+ elif isinstance(m, BasicBlock):
203
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
204
+
205
+ def _make_layer(
206
+ self,
207
+ block: Type[Union[BasicBlock, Bottleneck]],
208
+ planes: int,
209
+ blocks: int,
210
+ stride: int = 1,
211
+ dilate: bool = False,
212
+ ) -> nn.Sequential:
213
+ norm_layer = self._norm_layer
214
+ downsample = None
215
+ previous_dilation = self.dilation
216
+ if dilate:
217
+ self.dilation *= stride
218
+ stride = 1
219
+ if stride != 1 or self.inplanes != planes * block.expansion:
220
+ downsample = nn.Sequential(
221
+ conv1x1(self.inplanes, planes * block.expansion, stride),
222
+ norm_layer(planes * block.expansion),
223
+ )
224
+
225
+ layers = []
226
+ layers.append(
227
+ block(
228
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
229
+ )
230
+ )
231
+ self.inplanes = planes * block.expansion
232
+ for _ in range(1, blocks):
233
+ layers.append(
234
+ block(
235
+ self.inplanes,
236
+ planes,
237
+ groups=self.groups,
238
+ base_width=self.base_width,
239
+ dilation=self.dilation,
240
+ norm_layer=norm_layer,
241
+ )
242
+ )
243
+
244
+ return nn.Sequential(*layers)
245
+
246
+ def _forward_impl(self, x: Tensor) -> Tensor:
247
+ # See note [TorchScript super()]
248
+ x = self.layer1(x)
249
+ x = self.layer2(x)
250
+ x = self.layer3(x)
251
+
252
+ x = self.layer4(x)
253
+ x = self.avgpool(x)
254
+ x = torch.flatten(x, 1)
255
+ x = torch.sigmoid(self.fc(x))
256
+
257
+ return x
258
+
259
+ def _forward_feature(self, x: Tensor) -> Tensor:
260
+ # See note [TorchScript super()]
261
+ x = self.layer1(x)
262
+ x = self.layer2(x)
263
+ x = self.layer3(x)
264
+ return x
265
+
266
+ def _forward_trans(self, x: Tensor) -> Tensor:
267
+ # See note [TorchScript super()]
268
+ x = self.layer4(x)
269
+ x = self.avgpool(x)
270
+ x = torch.flatten(x, 1)
271
+ x = torch.sigmoid(self.fc(x))
272
+
273
+ return x
274
+
275
+ # def _forward_impl(self, x: Tensor) -> Tensor:
276
+ # # See note [TorchScript super()]
277
+ # x = self.layer1(x)
278
+ # x = self.layer2(x)
279
+ # x = self.layer3(x)
280
+ #
281
+ # x_1 = self.layer4_1(x)
282
+ # x_1 = self.avgpool(x_1)
283
+ # x_1 = torch.flatten(x_1, 1)
284
+ # x_1 = self.fc_1(x_1)
285
+ #
286
+ # x_2 = self.layer4_2(x)
287
+ # x_2 = self.avgpool(x_2)
288
+ # x_2 = torch.flatten(x_2, 1)
289
+ # x_2 = self.fc_2(x_2)
290
+ #
291
+ # x_3 = self.layer4_3(x)
292
+ # x_3 = self.avgpool(x_3)
293
+ # x_3 = torch.flatten(x_3, 1)
294
+ # x_3 = self.fc_3(x_3)
295
+ #
296
+ # return x_1, x_2, x_3
297
+ #
298
+ # def _forward_feature(self, x: Tensor) -> Tensor:
299
+ # # See note [TorchScript super()]
300
+ # x = self.layer1(x)
301
+ # x = self.layer2(x)
302
+ # x = self.layer3(x)
303
+ # return x
304
+ #
305
+ # def _forward_trans(self, x: Tensor) -> Tensor:
306
+ # # See note [TorchScript super()]
307
+ # x_1 = self.layer4_1(x)
308
+ # x_1 = self.avgpool(x_1)
309
+ # x_1 = torch.flatten(x_1, 1)
310
+ # x_1 = self.fc_1(x_1)
311
+ #
312
+ # x_2 = self.layer4_2(x)
313
+ # x_2 = self.avgpool(x_2)
314
+ # x_2 = torch.flatten(x_2, 1)
315
+ # x_2 = self.fc_2(x_2)
316
+ #
317
+ # x_3 = self.layer4_3(x)
318
+ # x_3 = self.avgpool(x_3)
319
+ # x_3 = torch.flatten(x_3, 1)
320
+ # x_3 = self.fc_3(x_3)
321
+ #
322
+ # return x_1, x_2, x_3
323
+ #
324
+ def forward(self, x: Tensor, mode: int = 0) -> Tensor:
325
+ if mode == 0:
326
+ return self._forward_impl(x)
327
+ elif mode == 1:
328
+ return self._forward_feature(x)
329
+ elif mode == 2:
330
+ return self._forward_trans(x)
331
+
332
+ class NormResPoseNet(nn.Module):
333
+ def __init__(
334
+ self,
335
+ block: Type[Union[BasicBlock, Bottleneck]],
336
+ zero_init_residual: bool = False,
337
+ groups: int = 1,
338
+ width_per_group: int = 64,
339
+ num_point: int = 12,
340
+ replace_stride_with_dilation: Optional[List[bool]] = None,
341
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
342
+ ) -> None:
343
+ super().__init__()
344
+ _log_api_usage_once(self)
345
+
346
+ block.expansion = 1
347
+
348
+ if norm_layer is None:
349
+ norm_layer = nn.BatchNorm2d
350
+ self._norm_layer = norm_layer
351
+
352
+ self.inplanes = 98
353
+ self.dilation = 1
354
+ if replace_stride_with_dilation is None:
355
+ # each element in the tuple indicates if we should replace
356
+ # the 2x2 stride with a dilated convolution instead
357
+ replace_stride_with_dilation = [False, False, False]
358
+ if len(replace_stride_with_dilation) != 3:
359
+ raise ValueError(
360
+ "replace_stride_with_dilation should be None "
361
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
362
+ )
363
+ self.groups = groups
364
+ self.base_width = width_per_group
365
+ self.layer1 = self._make_layer(block, 98, 3, stride=2)
366
+ self.layer2 = self._make_layer(block, 49, 3, stride=2)
367
+ self.layer3 = self._make_layer(block, 1, 3, stride=2)
368
+
369
+ self.layer4 = self._make_layer(block, 32, 3, stride=2)
370
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
371
+ self.fc = nn.Linear(32 * block.expansion, num_point * 2)
372
+
373
+ for m in self.modules():
374
+ if isinstance(m, nn.Conv2d):
375
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
376
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
377
+ nn.init.constant_(m.weight, 1)
378
+ nn.init.constant_(m.bias, 0)
379
+
380
+ # Zero-initialize the last BN in each residual branch,
381
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
382
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
383
+ if zero_init_residual:
384
+ for m in self.modules():
385
+ if isinstance(m, Bottleneck):
386
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
387
+ elif isinstance(m, BasicBlock):
388
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
389
+
390
+ def _make_layer(
391
+ self,
392
+ block: Type[Union[BasicBlock, Bottleneck]],
393
+ planes: int,
394
+ blocks: int,
395
+ stride: int = 1,
396
+ dilate: bool = False,
397
+ ) -> nn.Sequential:
398
+ norm_layer = self._norm_layer
399
+ downsample = None
400
+ previous_dilation = self.dilation
401
+ if dilate:
402
+ self.dilation *= stride
403
+ stride = 1
404
+ if stride != 1 or self.inplanes != planes * block.expansion:
405
+ downsample = nn.Sequential(
406
+ conv1x1(self.inplanes, planes * block.expansion, stride),
407
+ norm_layer(planes * block.expansion),
408
+ )
409
+
410
+ layers = []
411
+ layers.append(
412
+ block(
413
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
414
+ )
415
+ )
416
+ self.inplanes = planes * block.expansion
417
+ for _ in range(1, blocks):
418
+ layers.append(
419
+ block(
420
+ self.inplanes,
421
+ planes,
422
+ groups=self.groups,
423
+ base_width=self.base_width,
424
+ dilation=self.dilation,
425
+ norm_layer=norm_layer,
426
+ )
427
+ )
428
+
429
+ return nn.Sequential(*layers)
430
+
431
+ def _forward_impl(self, x: Tensor) -> Tensor:
432
+ # See note [TorchScript super()]
433
+ x = self.layer1(x)
434
+ x = self.layer2(x)
435
+ x = torch.sigmoid(self.layer3(x))
436
+
437
+ x = self.layer4(x)
438
+ x = self.avgpool(x)
439
+ x = torch.flatten(x, 1)
440
+ x = torch.sigmoid(self.fc(x))
441
+
442
+ return x
443
+
444
+ def _forward_feature(self, x: Tensor) -> Tensor:
445
+ # See note [TorchScript super()]
446
+ x = self.layer1(x)
447
+ x = self.layer2(x)
448
+ x = torch.sigmoid(self.layer3(x))
449
+ return x
450
+
451
+ def _forward_trans(self, x: Tensor) -> Tensor:
452
+ # See note [TorchScript super()]
453
+ x = self.layer4(x)
454
+ x = self.avgpool(x)
455
+ x = torch.flatten(x, 1)
456
+ x = torch.sigmoid(self.fc(x))
457
+
458
+ return x
459
+
460
+ def forward(self, x: Tensor, mode: int = 0) -> Tensor:
461
+ if mode == 0:
462
+ return self._forward_impl(x)
463
+ elif mode == 1:
464
+ return self._forward_feature(x)
465
+ elif mode == 2:
466
+ return self._forward_trans(x)
467
+
468
+
469
+ class ResPoseWNet(nn.Module):
470
+ def __init__(
471
+ self,
472
+ block: Type[Union[BasicBlock, Bottleneck]],
473
+ layers: List[int],
474
+ num_classes: int = 1000,
475
+ zero_init_residual: bool = False,
476
+ groups: int = 1,
477
+ width_per_group: int = 64,
478
+ num_point: int = 12,
479
+ replace_stride_with_dilation: Optional[List[bool]] = None,
480
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
481
+ ) -> None:
482
+ super().__init__()
483
+ _log_api_usage_once(self)
484
+
485
+ # from .attention_networks import Self_Attn
486
+ # self.attention_layer = Self_Attn(64, 'relu')
487
+
488
+ if norm_layer is None:
489
+ norm_layer = nn.BatchNorm2d
490
+ self._norm_layer = norm_layer
491
+
492
+ self.inplanes = 3
493
+ self.dilation = 1
494
+ block.expansion = 1
495
+ if replace_stride_with_dilation is None:
496
+ # each element in the tuple indicates if we should replace
497
+ # the 2x2 stride with a dilated convolution instead
498
+ replace_stride_with_dilation = [False, False, False]
499
+ if len(replace_stride_with_dilation) != 3:
500
+ raise ValueError(
501
+ "replace_stride_with_dilation should be None "
502
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
503
+ )
504
+ self.groups = groups
505
+ self.base_width = width_per_group
506
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
507
+ self.bn1 = norm_layer(self.inplanes)
508
+ self.relu = nn.ReLU(inplace=True)
509
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
510
+ self.layer1 = self._make_layer(block, 64, layers[0])
511
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
512
+ self.layer3 = self._make_layer(block, 64, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
513
+ self.layer4 = self._make_layer(block, 1, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
514
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
515
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
516
+
517
+ self.layer5 = self._make_layer(block, 32, 3, stride=2)
518
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
519
+ self.fc = nn.Linear(32 * block.expansion, num_point * 2)
520
+
521
+ for m in self.modules():
522
+ if isinstance(m, nn.Conv2d):
523
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
524
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
525
+ nn.init.constant_(m.weight, 1)
526
+ nn.init.constant_(m.bias, 0)
527
+
528
+ # Zero-initialize the last BN in each residual branch,
529
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
530
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
531
+ if zero_init_residual:
532
+ for m in self.modules():
533
+ if isinstance(m, Bottleneck):
534
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
535
+ elif isinstance(m, BasicBlock):
536
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
537
+
538
+ def _make_layer(
539
+ self,
540
+ block: Type[Union[BasicBlock, Bottleneck]],
541
+ planes: int,
542
+ blocks: int,
543
+ stride: int = 1,
544
+ dilate: bool = False,
545
+ ) -> nn.Sequential:
546
+ norm_layer = self._norm_layer
547
+ downsample = None
548
+ previous_dilation = self.dilation
549
+ if dilate:
550
+ self.dilation *= stride
551
+ stride = 1
552
+ if stride != 1 or self.inplanes != planes * block.expansion:
553
+ downsample = nn.Sequential(
554
+ conv1x1(self.inplanes, planes * block.expansion, stride),
555
+ norm_layer(planes * block.expansion),
556
+ )
557
+
558
+ layers = []
559
+ layers.append(
560
+ block(
561
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
562
+ )
563
+ )
564
+ self.inplanes = planes * block.expansion
565
+ for _ in range(1, blocks):
566
+ layers.append(
567
+ block(
568
+ self.inplanes,
569
+ planes,
570
+ groups=self.groups,
571
+ base_width=self.base_width,
572
+ dilation=self.dilation,
573
+ norm_layer=norm_layer,
574
+ )
575
+ )
576
+
577
+ return nn.Sequential(*layers)
578
+
579
+ def _forward_impl(self, x: Tensor) -> Tensor:
580
+ # See note [TorchScript super()]
581
+ x = self.conv1(x)
582
+ x = self.bn1(x)
583
+ x = self.relu(x)
584
+ x = self.maxpool(x)
585
+
586
+ x = self.layer1(x)
587
+ x = self.layer2(x)
588
+ x = self.layer3(x)
589
+ x = torch.tanh(self.layer4(x))
590
+
591
+ x = self.layer5(x)
592
+ x = self.avgpool(x)
593
+ x = torch.flatten(x, 1)
594
+ x = torch.sigmoid(self.fc(x))
595
+
596
+ return x
597
+
598
+ def _forward_feature(self, x: Tensor) -> Tensor:
599
+ # See note [TorchScript super()]
600
+ x = self.conv1(x)
601
+ x = self.bn1(x)
602
+ x = self.relu(x)
603
+ x = self.maxpool(x)
604
+
605
+ x = self.layer1(x)
606
+ x = self.layer2(x)
607
+ x = self.layer3(x)
608
+ x = torch.tanh(self.layer4(x))
609
+
610
+ return x
611
+
612
+ def _forward_trans(self, x: Tensor) -> Tensor:
613
+ # See note [TorchScript super()]
614
+ x = self.layer5(x)
615
+ x = self.avgpool(x)
616
+ x = torch.flatten(x, 1)
617
+ x = torch.sigmoid(self.fc(x))
618
+
619
+ return x
620
+
621
+ # def _forward_impl(self, x: Tensor) -> Tensor:
622
+ # # See note [TorchScript super()]
623
+ # x = self.layer1(x)
624
+ # x = self.layer2(x)
625
+ # x = self.layer3(x)
626
+ #
627
+ # x_1 = self.layer4_1(x)
628
+ # x_1 = self.avgpool(x_1)
629
+ # x_1 = torch.flatten(x_1, 1)
630
+ # x_1 = self.fc_1(x_1)
631
+ #
632
+ # x_2 = self.layer4_2(x)
633
+ # x_2 = self.avgpool(x_2)
634
+ # x_2 = torch.flatten(x_2, 1)
635
+ # x_2 = self.fc_2(x_2)
636
+ #
637
+ # x_3 = self.layer4_3(x)
638
+ # x_3 = self.avgpool(x_3)
639
+ # x_3 = torch.flatten(x_3, 1)
640
+ # x_3 = self.fc_3(x_3)
641
+ #
642
+ # return x_1, x_2, x_3
643
+ #
644
+ # def _forward_feature(self, x: Tensor) -> Tensor:
645
+ # # See note [TorchScript super()]
646
+ # x = self.layer1(x)
647
+ # x = self.layer2(x)
648
+ # x = self.layer3(x)
649
+ # return x
650
+ #
651
+ # def _forward_trans(self, x: Tensor) -> Tensor:
652
+ # # See note [TorchScript super()]
653
+ # x_1 = self.layer4_1(x)
654
+ # x_1 = self.avgpool(x_1)
655
+ # x_1 = torch.flatten(x_1, 1)
656
+ # x_1 = self.fc_1(x_1)
657
+ #
658
+ # x_2 = self.layer4_2(x)
659
+ # x_2 = self.avgpool(x_2)
660
+ # x_2 = torch.flatten(x_2, 1)
661
+ # x_2 = self.fc_2(x_2)
662
+ #
663
+ # x_3 = self.layer4_3(x)
664
+ # x_3 = self.avgpool(x_3)
665
+ # x_3 = torch.flatten(x_3, 1)
666
+ # x_3 = self.fc_3(x_3)
667
+ #
668
+ # return x_1, x_2, x_3
669
+ #
670
+ def forward(self, x: Tensor, mode: int = 0) -> Tensor:
671
+ if mode == 0:
672
+ return self._forward_impl(x)
673
+ elif mode == 1:
674
+ return self._forward_feature(x)
675
+ elif mode == 2:
676
+ return self._forward_trans(x)
677
+
678
+ class ResPose4Net(nn.Module):
679
+ def __init__(
680
+ self,
681
+ block: Type[Union[BasicBlock, Bottleneck]],
682
+ zero_init_residual: bool = False,
683
+ groups: int = 1,
684
+ width_per_group: int = 64,
685
+ num_point: int = 12,
686
+ replace_stride_with_dilation: Optional[List[bool]] = None,
687
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
688
+ ) -> None:
689
+ super().__init__()
690
+ _log_api_usage_once(self)
691
+
692
+ block.expansion = 1
693
+
694
+ if norm_layer is None:
695
+ norm_layer = nn.BatchNorm2d
696
+ self._norm_layer = norm_layer
697
+
698
+ self.inplanes = 98
699
+ self.dilation = 1
700
+ if replace_stride_with_dilation is None:
701
+ # each element in the tuple indicates if we should replace
702
+ # the 2x2 stride with a dilated convolution instead
703
+ replace_stride_with_dilation = [False, False, False]
704
+ if len(replace_stride_with_dilation) != 3:
705
+ raise ValueError(
706
+ "replace_stride_with_dilation should be None "
707
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
708
+ )
709
+ self.groups = groups
710
+ self.base_width = width_per_group
711
+ self.layer1 = self._make_layer(block, 98, 3, stride=2)
712
+ self.layer2 = self._make_layer(block, 49, 3, stride=2)
713
+ self.layer3 = self._make_layer(block, 7, 3, stride=2)
714
+ self.layer4 = self._make_layer(block, 1, 3, stride=2)
715
+
716
+ self.layer5 = self._make_layer(block, 32, 3, stride=2)
717
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
718
+ self.fc = nn.Linear(32 * block.expansion, num_point * 2)
719
+
720
+ # self.layer4_1 = self._make_layer(block, 16, 3, stride=2)
721
+ # self.layer4_2 = self._make_layer(block, 16, 3, stride=2)
722
+ # self.layer4_3 = self._make_layer(block, 16, 3, stride=2)
723
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
724
+ # self.fc_1 = nn.Linear(16 * block.expansion, 2) # move
725
+ # self.fc_2 = nn.Linear(16 * block.expansion, 3) # pose
726
+ # self.fc_3 = nn.Linear(16 * block.expansion, 11) # attributes
727
+
728
+ for m in self.modules():
729
+ if isinstance(m, nn.Conv2d):
730
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
731
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
732
+ nn.init.constant_(m.weight, 1)
733
+ nn.init.constant_(m.bias, 0)
734
+
735
+ # Zero-initialize the last BN in each residual branch,
736
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
737
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
738
+ if zero_init_residual:
739
+ for m in self.modules():
740
+ if isinstance(m, Bottleneck):
741
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
742
+ elif isinstance(m, BasicBlock):
743
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
744
+
745
+ def _make_layer(
746
+ self,
747
+ block: Type[Union[BasicBlock, Bottleneck]],
748
+ planes: int,
749
+ blocks: int,
750
+ stride: int = 1,
751
+ dilate: bool = False,
752
+ ) -> nn.Sequential:
753
+ norm_layer = self._norm_layer
754
+ downsample = None
755
+ previous_dilation = self.dilation
756
+ if dilate:
757
+ self.dilation *= stride
758
+ stride = 1
759
+ if stride != 1 or self.inplanes != planes * block.expansion:
760
+ downsample = nn.Sequential(
761
+ conv1x1(self.inplanes, planes * block.expansion, stride),
762
+ norm_layer(planes * block.expansion),
763
+ )
764
+
765
+ layers = []
766
+ layers.append(
767
+ block(
768
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
769
+ )
770
+ )
771
+ self.inplanes = planes * block.expansion
772
+ for _ in range(1, blocks):
773
+ layers.append(
774
+ block(
775
+ self.inplanes,
776
+ planes,
777
+ groups=self.groups,
778
+ base_width=self.base_width,
779
+ dilation=self.dilation,
780
+ norm_layer=norm_layer,
781
+ )
782
+ )
783
+
784
+ return nn.Sequential(*layers)
785
+
786
+ def _forward_impl(self, x: Tensor) -> Tensor:
787
+ # See note [TorchScript super()]
788
+ x = self.layer1(x)
789
+ x = self.layer2(x)
790
+ x = self.layer3(x)
791
+ x = self.layer4(x)
792
+
793
+ x = self.layer5(x)
794
+ x = self.avgpool(x)
795
+ x = torch.flatten(x, 1)
796
+ x = torch.sigmoid(self.fc(x))
797
+
798
+ return x
799
+
800
+ def _forward_feature(self, x: Tensor) -> Tensor:
801
+ # See note [TorchScript super()]
802
+ x = self.layer1(x)
803
+ x = self.layer2(x)
804
+ x = self.layer3(x)
805
+ x = self.layer4(x)
806
+
807
+ return x
808
+
809
+ def _forward_trans(self, x: Tensor) -> Tensor:
810
+ # See note [TorchScript super()]
811
+ x = self.layer5(x)
812
+ x = self.avgpool(x)
813
+ x = torch.flatten(x, 1)
814
+ x = torch.sigmoid(self.fc(x))
815
+
816
+ return x
817
+
818
+ def forward(self, x: Tensor, mode: int = 0) -> Tensor:
819
+ if mode == 0:
820
+ return self._forward_impl(x)
821
+ elif mode == 1:
822
+ return self._forward_feature(x)
823
+ elif mode == 2:
824
+ return self._forward_trans(x)
825
+
826
+ def _normresposenet(**kwargs: Any) -> ResPoseNet:
827
+ r"""Wide ResNet-50-2 model from
828
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
829
+ The model is the same as ResNet except for the bottleneck number of channels
830
+ which is twice larger in every block. The number of channels in outer 1x1
831
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
832
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
833
+ Args:
834
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
835
+ progress (bool): If True, displays a progress bar of the download to stderr
836
+ """
837
+ kwargs["width_per_group"] = 64 * 2
838
+ return NormResPoseNet(Bottleneck, **kwargs)
839
+
840
+ def _resposenet(**kwargs: Any) -> ResPoseNet:
841
+ r"""Wide ResNet-50-2 model from
842
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
843
+ The model is the same as ResNet except for the bottleneck number of channels
844
+ which is twice larger in every block. The number of channels in outer 1x1
845
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
846
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
847
+ Args:
848
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
849
+ progress (bool): If True, displays a progress bar of the download to stderr
850
+ """
851
+ kwargs["width_per_group"] = 64 * 2
852
+ return ResPoseNet(Bottleneck, **kwargs)
853
+
854
+ def _respose4net(**kwargs: Any) -> ResPose4Net:
855
+ r"""Wide ResNet-50-2 model from
856
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
857
+ The model is the same as ResNet except for the bottleneck number of channels
858
+ which is twice larger in every block. The number of channels in outer 1x1
859
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
860
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
861
+ Args:
862
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
863
+ progress (bool): If True, displays a progress bar of the download to stderr
864
+ """
865
+ kwargs["width_per_group"] = 64 * 2
866
+ return ResPose4Net(Bottleneck, **kwargs)
867
+
868
+ def _resposewnet(**kwargs: Any) -> ResPoseNet:
869
+ r"""Wide ResNet-50-2 model from
870
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
871
+ The model is the same as ResNet except for the bottleneck number of channels
872
+ which is twice larger in every block. The number of channels in outer 1x1
873
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
874
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
875
+ Args:
876
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
877
+ progress (bool): If True, displays a progress bar of the download to stderr
878
+ """
879
+ kwargs["width_per_group"] = 64 * 2
880
+ return ResPoseWNet(Bottleneck, layers = [3, 4, 6, 3], **kwargs)
881
+
882
+ class PoseMapBN(nn.Module):
883
+ def __init__(self, input_num, output_num):
884
+ super().__init__()
885
+
886
+ self.ln1 = nn.Linear(input_num, 256)
887
+ self.bn1 = nn.BatchNorm1d(256)
888
+ self.ac1 = nn.LeakyReLU()
889
+
890
+ self.ln2 = nn.Linear(256, 256)
891
+ self.bn2 = nn.BatchNorm1d(256)
892
+ self.ac2 = nn.LeakyReLU()
893
+
894
+ self.ln3 = nn.Linear(256, output_num)
895
+
896
+ def forward(self, x):
897
+ x = self.ac1(self.bn1(self.ln1(x)))
898
+ x = self.ac2(self.bn2(self.ln2(x)))
899
+ out = self.ln3(x)
900
+ return out
901
+
902
+ class PoseMap(nn.Module):
903
+ def __init__(self, input_num, output_num):
904
+ super().__init__()
905
+
906
+ self.ln1 = nn.Linear(input_num, 256)
907
+ self.ac1 = nn.LeakyReLU()
908
+
909
+ self.ln2 = nn.Linear(256, 256)
910
+ self.ac2 = nn.LeakyReLU()
911
+
912
+ self.ln3 = nn.Linear(256, output_num)
913
+
914
+ def forward(self, x):
915
+ x = self.ac1(self.ln1(x))
916
+ x = self.ac2(self.ln2(x))
917
+ out = self.ln3(x)
918
+ return out
models/lmcode_networks.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ from collections import namedtuple
7
+ from munch import Munch
8
+ from copy import deepcopy
9
+ from functools import partial
10
+
11
+ IDXPAIR = namedtuple('IDXPAIR', 'start end')
12
+ index_map = Munch(chin=IDXPAIR(0 + 8, 33 - 8),
13
+ eyebrows=IDXPAIR(33, 51),
14
+ eyebrowsedges=IDXPAIR(33, 46),
15
+ nose=IDXPAIR(51, 55),
16
+ nostrils=IDXPAIR(55, 60),
17
+ eyes=IDXPAIR(60, 76),
18
+ lipedges=IDXPAIR(76, 82),
19
+ lipupper=IDXPAIR(77, 82),
20
+ liplower=IDXPAIR(83, 88),
21
+ lipinner=IDXPAIR(88, 96))
22
+ OPPAIR = namedtuple('OPPAIR', 'shift resize')
23
+
24
+ def conv3x3(in_planes, out_planes, strd=1, padding=1,
25
+ bias=False,dilation=1):
26
+ "3x3 convolution with padding"
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
28
+ stride=strd, padding=padding, bias=bias,
29
+ dilation=dilation)
30
+
31
+ class BasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
35
+ super(BasicBlock, self).__init__()
36
+ self.conv1 = conv3x3(inplanes, planes, stride)
37
+ # self.bn1 = nn.BatchNorm2d(planes)
38
+ self.relu = nn.ReLU(inplace=True)
39
+ self.conv2 = conv3x3(planes, planes)
40
+ # self.bn2 = nn.BatchNorm2d(planes)
41
+ self.downsample = downsample
42
+ self.stride = stride
43
+
44
+ def forward(self, x):
45
+ residual = x
46
+
47
+ out = self.conv1(x)
48
+ # out = self.bn1(out)
49
+ out = self.relu(out)
50
+
51
+ out = self.conv2(out)
52
+ # out = self.bn2(out)
53
+
54
+ if self.downsample is not None:
55
+ residual = self.downsample(x)
56
+
57
+ out += residual
58
+ out = self.relu(out)
59
+
60
+ return out
61
+
62
+ class HourGlass(nn.Module):
63
+ def __init__(self, num_modules, depth, num_features, first_one=False):
64
+ super(HourGlass, self).__init__()
65
+ self.num_modules = num_modules
66
+ self.depth = depth
67
+ self.features = num_features
68
+ self.coordconv = CoordConvTh(64, 64, True, True, 256, first_one,
69
+ out_channels=256,
70
+ kernel_size=1, stride=1, padding=0)
71
+ self._generate_network(self.depth)
72
+
73
+ def _generate_network(self, level):
74
+ self.add_module('b1_' + str(level), ConvBlock(256, 256))
75
+ self.add_module('b2_' + str(level), ConvBlock(256, 256))
76
+ if level > 1:
77
+ self._generate_network(level - 1)
78
+ else:
79
+ self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))
80
+ self.add_module('b3_' + str(level), ConvBlock(256, 256))
81
+
82
+ def _forward(self, level, inp):
83
+ up1 = inp
84
+ up1 = self._modules['b1_' + str(level)](up1)
85
+ low1 = F.avg_pool2d(inp, 2, stride=2)
86
+ low1 = self._modules['b2_' + str(level)](low1)
87
+
88
+ if level > 1:
89
+ low2 = self._forward(level - 1, low1)
90
+ else:
91
+ low2 = low1
92
+ low2 = self._modules['b2_plus_' + str(level)](low2)
93
+ low3 = low2
94
+ low3 = self._modules['b3_' + str(level)](low3)
95
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
96
+
97
+ return up1 + up2
98
+
99
+ def forward(self, x, heatmap):
100
+ x, last_channel = self.coordconv(x, heatmap)
101
+ return self._forward(self.depth, x), last_channel
102
+
103
+ class AddCoordsTh(nn.Module):
104
+ def __init__(self, height=64, width=64, with_r=False, with_boundary=False):
105
+ super(AddCoordsTh, self).__init__()
106
+ self.with_r = with_r
107
+ self.with_boundary = with_boundary
108
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
109
+
110
+ with torch.no_grad():
111
+ x_coords = torch.arange(height).unsqueeze(1).expand(height, width).float()
112
+ y_coords = torch.arange(width).unsqueeze(0).expand(height, width).float()
113
+ x_coords = (x_coords / (height - 1)) * 2 - 1
114
+ y_coords = (y_coords / (width - 1)) * 2 - 1
115
+ coords = torch.stack([x_coords, y_coords], dim=0) # (2, height, width)
116
+
117
+ if self.with_r:
118
+ rr = torch.sqrt(torch.pow(x_coords, 2) + torch.pow(y_coords, 2)) # (height, width)
119
+ rr = (rr / torch.max(rr)).unsqueeze(0)
120
+ coords = torch.cat([coords, rr], dim=0)
121
+
122
+ self.coords = coords.unsqueeze(0).to(device) # (1, 2 or 3, height, width)
123
+ self.x_coords = x_coords.to(device)
124
+ self.y_coords = y_coords.to(device)
125
+
126
+ def forward(self, x, heatmap=None):
127
+ """
128
+ x: (batch, c, x_dim, y_dim)
129
+ """
130
+ coords = self.coords.repeat(x.size(0), 1, 1, 1).to(x.device)
131
+
132
+ if self.with_boundary and heatmap is not None:
133
+ boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0)
134
+ zero_tensor = torch.zeros_like(self.x_coords)
135
+ xx_boundary_channel = torch.where(boundary_channel > 0.05, self.x_coords, zero_tensor).to(x.device)
136
+ yy_boundary_channel = torch.where(boundary_channel > 0.05, self.y_coords, zero_tensor).to(x.device)
137
+ coords = torch.cat([coords, xx_boundary_channel, yy_boundary_channel], dim=1)
138
+
139
+ x_and_coords = torch.cat([x, coords], dim=1)
140
+ return x_and_coords
141
+
142
+
143
+ class CoordConvTh(nn.Module):
144
+ """CoordConv layer as in the paper."""
145
+ def __init__(self, height, width, with_r, with_boundary,
146
+ in_channels, first_one=False, *args, **kwargs):
147
+ super(CoordConvTh, self).__init__()
148
+ self.addcoords = AddCoordsTh(height, width, with_r, with_boundary)
149
+ in_channels += 2
150
+ if with_r:
151
+ in_channels += 1
152
+ if with_boundary and not first_one:
153
+ in_channels += 2
154
+ self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)
155
+
156
+ def forward(self, input_tensor, heatmap=None):
157
+ ret = self.addcoords(input_tensor, heatmap)
158
+ last_channel = ret[:, -2:, :, :]
159
+ ret = self.conv(ret)
160
+ return ret, last_channel
161
+
162
+
163
+ class ConvBlock(nn.Module):
164
+ def __init__(self, in_planes, out_planes):
165
+ super(ConvBlock, self).__init__()
166
+ self.bn1 = nn.BatchNorm2d(in_planes)
167
+ conv3x3 = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1, bias=False, dilation=1)
168
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
169
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
170
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
171
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
172
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
173
+
174
+ self.downsample = None
175
+ if in_planes != out_planes:
176
+ self.downsample = nn.Sequential(nn.BatchNorm2d(in_planes),
177
+ nn.ReLU(True),
178
+ nn.Conv2d(in_planes, out_planes, 1, 1, bias=False))
179
+
180
+ def forward(self, x):
181
+ residual = x
182
+
183
+ out1 = self.bn1(x)
184
+ out1 = F.relu(out1, True)
185
+ out1 = self.conv1(out1)
186
+
187
+ out2 = self.bn2(out1)
188
+ out2 = F.relu(out2, True)
189
+ out2 = self.conv2(out2)
190
+
191
+ out3 = self.bn3(out2)
192
+ out3 = F.relu(out3, True)
193
+ out3 = self.conv3(out3)
194
+
195
+ out3 = torch.cat((out1, out2, out3), 1)
196
+ if self.downsample is not None:
197
+ residual = self.downsample(residual)
198
+ out3 += residual
199
+ return out3
200
+
201
+
202
+ class FAN(nn.Module):
203
+ def __init__(self, num_modules=1, end_relu=False, num_landmarks=98, fname_pretrained=None):
204
+ super(FAN, self).__init__()
205
+ self.num_modules = num_modules
206
+ self.end_relu = end_relu
207
+
208
+ # Base part
209
+ self.conv1 = CoordConvTh(256, 256, True, False,
210
+ in_channels=3, out_channels=64,
211
+ kernel_size=7, stride=2, padding=3)
212
+ self.bn1 = nn.BatchNorm2d(64)
213
+ self.conv2 = ConvBlock(64, 128)
214
+ self.conv3 = ConvBlock(128, 128)
215
+ self.conv4 = ConvBlock(128, 256)
216
+
217
+ # Stacking part
218
+ self.add_module('m0', HourGlass(1, 4, 256, first_one=True))
219
+ self.add_module('top_m_0', ConvBlock(256, 256))
220
+ self.add_module('conv_last0', nn.Conv2d(256, 256, 1, 1, 0))
221
+ self.add_module('bn_end0', nn.BatchNorm2d(256))
222
+ self.add_module('l0', nn.Conv2d(256, num_landmarks+1, 1, 1, 0))
223
+
224
+ if fname_pretrained is not None:
225
+ self.load_pretrained_weights(fname_pretrained)
226
+
227
+ def load_pretrained_weights(self, fname):
228
+ if torch.cuda.is_available():
229
+ checkpoint = torch.load(fname)
230
+ else:
231
+ checkpoint = torch.load(fname, map_location=torch.device('cpu'))
232
+ model_weights = self.state_dict()
233
+ model_weights.update({k: v for k, v in checkpoint['state_dict'].items()
234
+ if k in model_weights})
235
+ self.load_state_dict(model_weights)
236
+
237
+ def forward(self, x):
238
+ x, _ = self.conv1(x)
239
+ x = F.relu(self.bn1(x), True)
240
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
241
+ x = self.conv3(x)
242
+ x = self.conv4(x)
243
+
244
+ outputs = []
245
+ boundary_channels = []
246
+ tmp_out = None
247
+ ll, boundary_channel = self._modules['m0'](x, tmp_out)
248
+ ll = self._modules['top_m_0'](ll)
249
+ ll = F.relu(self._modules['bn_end0']
250
+ (self._modules['conv_last0'](ll)), True)
251
+
252
+ # Predict heatmaps
253
+ tmp_out = self._modules['l0'](ll)
254
+ if self.end_relu:
255
+ tmp_out = F.relu(tmp_out) # HACK: Added relu
256
+ outputs.append(tmp_out)
257
+ boundary_channels.append(boundary_channel)
258
+ return outputs, boundary_channels
259
+
260
+ @torch.no_grad()
261
+ def get_heatmap(self, x, b_preprocess=True):
262
+ ''' outputs 0-1 normalized heatmap '''
263
+ x = F.interpolate(x, size=256, mode='bilinear')
264
+ x_01 = x*0.5 + 0.5
265
+ outputs, _ = self(x_01)
266
+ heatmaps = outputs[-1][:, :-1, :, :]
267
+ scale_factor = x.size(2) // heatmaps.size(2)
268
+ if b_preprocess:
269
+ heatmaps = F.interpolate(heatmaps, scale_factor=scale_factor,
270
+ mode='bilinear', align_corners=True)
271
+ heatmaps = preprocess(heatmaps)
272
+ return heatmaps
273
+
274
+ @torch.no_grad()
275
+ def get_landmark(self, x):
276
+ ''' outputs landmarks of x.shape '''
277
+ heatmaps = self.get_heatmap(x, b_preprocess=False)
278
+ landmarks = []
279
+ for i in range(x.size(0)):
280
+ pred_landmarks = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0))
281
+ landmarks.append(pred_landmarks)
282
+ scale_factor = x.size(2) // heatmaps.size(2)
283
+ landmarks = torch.cat(landmarks) * scale_factor
284
+ return landmarks
285
+
286
+
287
+ def get_preds_fromhm(hm):
288
+ max, idx = torch.max(
289
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
290
+ idx += 1
291
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
292
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
293
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
294
+
295
+ for i in range(preds.size(0)):
296
+ for j in range(preds.size(1)):
297
+ hm_ = hm[i, j, :]
298
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
299
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
300
+ diff = torch.FloatTensor(
301
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
302
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
303
+ preds[i, j].add_(diff.sign_().mul_(.25))
304
+
305
+ preds.add_(-0.5)
306
+ return preds
307
+
308
+ def truncate(x, thres=0.1):
309
+ """Remove small values in heatmaps."""
310
+ return torch.where(x < thres, torch.zeros_like(x), x)
311
+
312
+ def normalize(x, eps=1e-6):
313
+ """Apply min-max normalization."""
314
+ x = x.contiguous()
315
+ N, C, H, W = x.size()
316
+ x_ = x.view(N*C, -1)
317
+ max_val = torch.max(x_, dim=1, keepdim=True)[0]
318
+ min_val = torch.min(x_, dim=1, keepdim=True)[0]
319
+ x_ = (x_ - min_val) / (max_val - min_val + eps)
320
+ out = x_.view(N, C, H, W)
321
+ return out
322
+
323
+ def resize(x, p=2):
324
+ """Resize heatmaps."""
325
+ return x**p
326
+
327
+
328
+ def shift(x, N):
329
+ """Shift N pixels up or down."""
330
+ up = N >= 0
331
+ N = abs(N)
332
+ _, _, H, W = x.size()
333
+ head = torch.arange(N)
334
+ tail = torch.arange(H-N)
335
+
336
+ if up:
337
+ head = torch.arange(H-N)+N
338
+ tail = torch.arange(N)
339
+ else:
340
+ head = torch.arange(N) + (H-N)
341
+ tail = torch.arange(H-N)
342
+
343
+ # permutation indices
344
+ perm = torch.cat([head, tail]).to(x.device)
345
+ out = x[:, :, perm, :]
346
+ return out
347
+
348
+ def preprocess(x):
349
+ """Preprocess 98-dimensional heatmaps."""
350
+ N, C, H, W = x.size()
351
+ x = truncate(x)
352
+ x = normalize(x)
353
+
354
+ sw = H // 256
355
+ operations = Munch(chin=OPPAIR(0, 3),
356
+ eyebrows=OPPAIR(-7*sw, 2),
357
+ nostrils=OPPAIR(8*sw, 4),
358
+ lipupper=OPPAIR(-8*sw, 4),
359
+ liplower=OPPAIR(8*sw, 4),
360
+ lipinner=OPPAIR(-2*sw, 3))
361
+
362
+ for part, ops in operations.items():
363
+ start, end = index_map[part]
364
+ x[:, start:end] = resize(shift(x[:, start:end], ops.shift), ops.resize)
365
+
366
+ zero_out = torch.cat([torch.arange(0, index_map.chin.start),
367
+ torch.arange(index_map.chin.end, 33),
368
+ torch.LongTensor([index_map.eyebrowsedges.start,
369
+ index_map.eyebrowsedges.end,
370
+ index_map.lipedges.start,
371
+ index_map.lipedges.end])])
372
+ x[:, zero_out] = 0
373
+
374
+ start, end = index_map.nose
375
+ x[:, start+1:end] = shift(x[:, start+1:end], 4*sw)
376
+ x[:, start:end] = resize(x[:, start:end], 1)
377
+
378
+ start, end = index_map.eyes
379
+ x[:, start:end] = resize(x[:, start:end], 1)
380
+ x[:, start:end] = resize(shift(x[:, start:end], -8), 3) + \
381
+ shift(x[:, start:end], -24)
382
+
383
+ # Second-level mask
384
+ x2 = deepcopy(x)
385
+ x2[:, index_map.chin.start:index_map.chin.end] = 0 # start:end was 0:33
386
+ x2[:, index_map.lipedges.start:index_map.lipinner.end] = 0 # start:end was 76:96
387
+ x2[:, index_map.eyebrows.start:index_map.eyebrows.end] = 0 # start:end was 33:51
388
+
389
+ x = torch.sum(x, dim=1, keepdim=True) # (N, 1, H, W)
390
+ x2 = torch.sum(x2, dim=1, keepdim=True) # mask without faceline and mouth
391
+
392
+ x[x != x] = 0 # set nan to zero
393
+ x2[x != x] = 0 # set nan to zero
394
+ return x.clamp_(0, 1), x2.clamp_(0, 1)
models/resnet.py ADDED
@@ -0,0 +1,1452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type, Any, Callable, Union, List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+
7
+ try:
8
+ from torch.hub import load_state_dict_from_url # noqa: 401
9
+ except ImportError:
10
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
11
+ from types import FunctionType
12
+
13
+ def _log_api_usage_once(obj: Any) -> None:
14
+ if not obj.__module__.startswith("torchvision"):
15
+ return
16
+ name = obj.__class__.__name__
17
+ if isinstance(obj, FunctionType):
18
+ name = obj.__name__
19
+ torch._C._log_api_usage_once(f"{obj.__module__}.{name}")
20
+
21
+ __all__ = [
22
+ "ResNet",
23
+ "resnet18",
24
+ "resnet34",
25
+ "resnet50",
26
+ "resnet101",
27
+ "resnet152",
28
+ "resnext50_32x4d",
29
+ "resnext101_32x8d",
30
+ "wide_resnet50_2",
31
+ "wide_resnet101_2",
32
+ ]
33
+
34
+
35
+ model_urls = {
36
+ "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
37
+ "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
38
+ "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
39
+ "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
40
+ "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
41
+ "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
42
+ "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
43
+ "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
44
+ "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
45
+ }
46
+
47
+
48
+ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
49
+ """3x3 convolution with padding"""
50
+ return nn.Conv2d(
51
+ in_planes,
52
+ out_planes,
53
+ kernel_size=3,
54
+ stride=stride,
55
+ padding=dilation,
56
+ groups=groups,
57
+ bias=False,
58
+ dilation=dilation,
59
+ )
60
+
61
+
62
+ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
63
+ """1x1 convolution"""
64
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
65
+
66
+
67
+ class BasicBlock(nn.Module):
68
+ expansion: int = 1
69
+
70
+ def __init__(
71
+ self,
72
+ inplanes: int,
73
+ planes: int,
74
+ stride: int = 1,
75
+ downsample: Optional[nn.Module] = None,
76
+ groups: int = 1,
77
+ base_width: int = 64,
78
+ dilation: int = 1,
79
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
80
+ ) -> None:
81
+ super().__init__()
82
+ if norm_layer is None:
83
+ norm_layer = nn.BatchNorm2d
84
+ if groups != 1 or base_width != 64:
85
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
86
+ if dilation > 1:
87
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
88
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
89
+ self.conv1 = conv3x3(inplanes, planes, stride)
90
+ self.bn1 = norm_layer(planes)
91
+ self.relu = nn.ReLU(inplace=True)
92
+ self.conv2 = conv3x3(planes, planes)
93
+ self.bn2 = norm_layer(planes)
94
+ self.downsample = downsample
95
+ self.stride = stride
96
+
97
+ def forward(self, x: Tensor) -> Tensor:
98
+ identity = x
99
+
100
+ out = self.conv1(x)
101
+ out = self.bn1(out)
102
+ out = self.relu(out)
103
+
104
+ out = self.conv2(out)
105
+ out = self.bn2(out)
106
+
107
+ if self.downsample is not None:
108
+ identity = self.downsample(x)
109
+
110
+ out += identity
111
+ out = self.relu(out)
112
+
113
+ return out
114
+
115
+
116
+ class Bottleneck(nn.Module):
117
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
118
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
119
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
120
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
121
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
122
+
123
+ expansion: int = 4
124
+
125
+ def __init__(
126
+ self,
127
+ inplanes: int,
128
+ planes: int,
129
+ stride: int = 1,
130
+ downsample: Optional[nn.Module] = None,
131
+ groups: int = 1,
132
+ base_width: int = 64,
133
+ dilation: int = 1,
134
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
135
+ ) -> None:
136
+ super().__init__()
137
+ if norm_layer is None:
138
+ norm_layer = nn.BatchNorm2d
139
+ width = int(planes * (base_width / 64.0)) * groups
140
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
141
+ self.conv1 = conv1x1(inplanes, width)
142
+ self.bn1 = norm_layer(width)
143
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
144
+ self.bn2 = norm_layer(width)
145
+ self.conv3 = conv1x1(width, planes * self.expansion)
146
+ self.bn3 = norm_layer(planes * self.expansion)
147
+ self.relu = nn.ReLU(inplace=True)
148
+ self.downsample = downsample
149
+ self.stride = stride
150
+
151
+ def forward(self, x: Tensor) -> Tensor:
152
+ identity = x
153
+
154
+ out = self.conv1(x)
155
+ out = self.bn1(out)
156
+ out = self.relu(out)
157
+
158
+ out = self.conv2(out)
159
+ out = self.bn2(out)
160
+ out = self.relu(out)
161
+
162
+ out = self.conv3(out)
163
+ out = self.bn3(out)
164
+
165
+ if self.downsample is not None:
166
+ identity = self.downsample(x)
167
+
168
+ out += identity
169
+ out = self.relu(out)
170
+
171
+ return out
172
+
173
+
174
+ class ResNet(nn.Module):
175
+ def __init__(
176
+ self,
177
+ block: Type[Union[BasicBlock, Bottleneck]],
178
+ layers: List[int],
179
+ num_classes: int = 1000,
180
+ zero_init_residual: bool = False,
181
+ groups: int = 1,
182
+ width_per_group: int = 64,
183
+ replace_stride_with_dilation: Optional[List[bool]] = None,
184
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
185
+ ) -> None:
186
+ super().__init__()
187
+ _log_api_usage_once(self)
188
+ if norm_layer is None:
189
+ norm_layer = nn.BatchNorm2d
190
+ self._norm_layer = norm_layer
191
+
192
+ self.inplanes = 64
193
+ self.dilation = 1
194
+ if replace_stride_with_dilation is None:
195
+ # each element in the tuple indicates if we should replace
196
+ # the 2x2 stride with a dilated convolution instead
197
+ replace_stride_with_dilation = [False, False, False]
198
+ if len(replace_stride_with_dilation) != 3:
199
+ raise ValueError(
200
+ "replace_stride_with_dilation should be None "
201
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
202
+ )
203
+ self.groups = groups
204
+ self.base_width = width_per_group
205
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
206
+ self.bn1 = norm_layer(self.inplanes)
207
+ self.relu = nn.ReLU(inplace=True)
208
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
209
+ self.layer1 = self._make_layer(block, 64, layers[0])
210
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
211
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
212
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
213
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
214
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
215
+
216
+ for m in self.modules():
217
+ if isinstance(m, nn.Conv2d):
218
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
219
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
220
+ nn.init.constant_(m.weight, 1)
221
+ nn.init.constant_(m.bias, 0)
222
+
223
+ # Zero-initialize the last BN in each residual branch,
224
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
225
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
226
+ if zero_init_residual:
227
+ for m in self.modules():
228
+ if isinstance(m, Bottleneck):
229
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
230
+ elif isinstance(m, BasicBlock):
231
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
232
+
233
+ def _make_layer(
234
+ self,
235
+ block: Type[Union[BasicBlock, Bottleneck]],
236
+ planes: int,
237
+ blocks: int,
238
+ stride: int = 1,
239
+ dilate: bool = False,
240
+ ) -> nn.Sequential:
241
+ norm_layer = self._norm_layer
242
+ downsample = None
243
+ previous_dilation = self.dilation
244
+ if dilate:
245
+ self.dilation *= stride
246
+ stride = 1
247
+ if stride != 1 or self.inplanes != planes * block.expansion:
248
+ downsample = nn.Sequential(
249
+ conv1x1(self.inplanes, planes * block.expansion, stride),
250
+ norm_layer(planes * block.expansion),
251
+ )
252
+
253
+ layers = []
254
+ layers.append(
255
+ block(
256
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
257
+ )
258
+ )
259
+ self.inplanes = planes * block.expansion
260
+ for _ in range(1, blocks):
261
+ layers.append(
262
+ block(
263
+ self.inplanes,
264
+ planes,
265
+ groups=self.groups,
266
+ base_width=self.base_width,
267
+ dilation=self.dilation,
268
+ norm_layer=norm_layer,
269
+ )
270
+ )
271
+
272
+ return nn.Sequential(*layers)
273
+
274
+ def _forward_impl(self, x: Tensor) -> Tensor:
275
+ # See note [TorchScript super()]
276
+ x = self.conv1(x)
277
+ x = self.bn1(x)
278
+ x = self.relu(x)
279
+ x = self.maxpool(x)
280
+
281
+ x = self.layer1(x)
282
+ x = self.layer2(x)
283
+ x = self.layer3(x)
284
+ x = self.layer4(x)
285
+
286
+ x = self.avgpool(x)
287
+ x = torch.flatten(x, 1)
288
+ x = self.fc(x)
289
+
290
+ return x
291
+
292
+ def forward(self, x: Tensor) -> Tensor:
293
+ return self._forward_impl(x)
294
+
295
+
296
+ class ResAppNet(nn.Module):
297
+ def __init__(
298
+ self,
299
+ block: Type[Union[BasicBlock, Bottleneck]],
300
+ layers: List[int],
301
+ num_classes: int = 1000,
302
+ zero_init_residual: bool = False,
303
+ groups: int = 1,
304
+ width_per_group: int = 64,
305
+ replace_stride_with_dilation: Optional[List[bool]] = None,
306
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
307
+ ) -> None:
308
+ super().__init__()
309
+ _log_api_usage_once(self)
310
+ if norm_layer is None:
311
+ norm_layer = nn.BatchNorm2d
312
+ self._norm_layer = norm_layer
313
+
314
+ self.inplanes = 64
315
+ self.dilation = 1
316
+ if replace_stride_with_dilation is None:
317
+ # each element in the tuple indicates if we should replace
318
+ # the 2x2 stride with a dilated convolution instead
319
+ replace_stride_with_dilation = [False, False, False]
320
+ if len(replace_stride_with_dilation) != 3:
321
+ raise ValueError(
322
+ "replace_stride_with_dilation should be None "
323
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
324
+ )
325
+ self.groups = groups
326
+ self.base_width = width_per_group
327
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
328
+ self.bn1 = norm_layer(self.inplanes)
329
+ self.relu = nn.ReLU(inplace=True)
330
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
331
+ self.layer1 = self._make_layer(block, 64, layers[0])
332
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
333
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilate=replace_stride_with_dilation[1])
334
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
335
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
336
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
337
+
338
+ self.layer5 = self._make_layer(block, 1, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
339
+ self.layer1_a = self._make_layer(block, 64, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
340
+ self.layer2_a = self._make_layer(block, 128, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
341
+ self.avgpool_a = nn.AdaptiveAvgPool2d((1, 1))
342
+ self.fc_a = nn.Linear(128 * block.expansion, num_classes)
343
+
344
+ for m in self.modules():
345
+ if isinstance(m, nn.Conv2d):
346
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
347
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
348
+ nn.init.constant_(m.weight, 1)
349
+ nn.init.constant_(m.bias, 0)
350
+
351
+ # Zero-initialize the last BN in each residual branch,
352
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
353
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
354
+ if zero_init_residual:
355
+ for m in self.modules():
356
+ if isinstance(m, Bottleneck):
357
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
358
+ elif isinstance(m, BasicBlock):
359
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
360
+
361
+ def _make_layer(
362
+ self,
363
+ block: Type[Union[BasicBlock, Bottleneck]],
364
+ planes: int,
365
+ blocks: int,
366
+ stride: int = 1,
367
+ dilate: bool = False,
368
+ ) -> nn.Sequential:
369
+ norm_layer = self._norm_layer
370
+ downsample = None
371
+ previous_dilation = self.dilation
372
+ if dilate:
373
+ self.dilation *= stride
374
+ stride = 1
375
+ if stride != 1 or self.inplanes != planes * block.expansion:
376
+ downsample = nn.Sequential(
377
+ conv1x1(self.inplanes, planes * block.expansion, stride),
378
+ norm_layer(planes * block.expansion),
379
+ )
380
+
381
+ layers = []
382
+ layers.append(
383
+ block(
384
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
385
+ )
386
+ )
387
+ self.inplanes = planes * block.expansion
388
+ for _ in range(1, blocks):
389
+ layers.append(
390
+ block(
391
+ self.inplanes,
392
+ planes,
393
+ groups=self.groups,
394
+ base_width=self.base_width,
395
+ dilation=self.dilation,
396
+ norm_layer=norm_layer,
397
+ )
398
+ )
399
+
400
+ return nn.Sequential(*layers)
401
+
402
+ def _forward_impl(self, x: Tensor) -> Tensor:
403
+ # See note [TorchScript super()]
404
+ x = self.conv1(x)
405
+ x = self.bn1(x)
406
+ x = self.relu(x)
407
+ x = self.maxpool(x)
408
+
409
+ x = self.layer1(x)
410
+ x = self.layer2(x)
411
+ x = self.layer3(x)
412
+ x = self.layer4(x)
413
+ x = self.layer5(x)
414
+ # print(x.shape, flush = True)
415
+
416
+ x = self.layer1_a(x)
417
+ x = self.layer2_a(x)
418
+ x = self.avgpool_a(x)
419
+ x = torch.flatten(x, 1)
420
+ x = self.fc_a(x)
421
+
422
+ return x
423
+
424
+ def _forward_feature(self, x: Tensor) -> Tensor:
425
+ # See note [TorchScript super()]
426
+ x = self.conv1(x)
427
+ x = self.bn1(x)
428
+ x = self.relu(x)
429
+ x = self.maxpool(x)
430
+
431
+ x = self.layer1(x)
432
+ x = self.layer2(x)
433
+ x = self.layer3(x)
434
+ x = self.layer4(x)
435
+ x = self.layer5(x)
436
+
437
+ return x
438
+
439
+ def _forward_trans(self, x: Tensor) -> Tensor:
440
+
441
+ x = self.layer1_a(x)
442
+ x = self.layer2_a(x)
443
+ x = self.avgpool_a(x)
444
+ x = torch.flatten(x, 1)
445
+ x = self.fc_a(x)
446
+
447
+ return x
448
+
449
+ def forward(self, x: Tensor, mode: int = 0) -> Tensor:
450
+ if mode == 0:
451
+ return self._forward_impl(x)
452
+ elif mode == 1:
453
+ return self._forward_feature(x)
454
+ elif mode == 2:
455
+ return self._forward_trans(x)
456
+
457
+
458
+ class ResDisNet(nn.Module):
459
+ def __init__(
460
+ self,
461
+ block: Type[Union[BasicBlock, Bottleneck]],
462
+ layers: List[int],
463
+ num_classes: int = 1000,
464
+ zero_init_residual: bool = False,
465
+ groups: int = 1,
466
+ width_per_group: int = 64,
467
+ replace_stride_with_dilation: Optional[List[bool]] = None,
468
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
469
+ ) -> None:
470
+ super().__init__()
471
+ _log_api_usage_once(self)
472
+
473
+ # from .attention_networks import Self_Attn
474
+ # self.attention_layer = Self_Attn(64, 'relu')
475
+
476
+ if norm_layer is None:
477
+ norm_layer = nn.BatchNorm2d
478
+ self._norm_layer = norm_layer
479
+
480
+ self.inplanes = 64
481
+ self.dilation = 1
482
+ if replace_stride_with_dilation is None:
483
+ # each element in the tuple indicates if we should replace
484
+ # the 2x2 stride with a dilated convolution instead
485
+ replace_stride_with_dilation = [False, False, False]
486
+ if len(replace_stride_with_dilation) != 3:
487
+ raise ValueError(
488
+ "replace_stride_with_dilation should be None "
489
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
490
+ )
491
+ self.groups = groups
492
+ self.base_width = width_per_group
493
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
494
+ self.bn1 = norm_layer(self.inplanes)
495
+ self.relu = nn.ReLU(inplace=True)
496
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
497
+ self.layer1 = self._make_layer(block, 64, layers[0])
498
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
499
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
500
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
501
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
502
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
503
+
504
+ self.layer5 = self._make_layer(block, 16, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
505
+
506
+ self.layer1_a = self._make_layer(block, 64, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
507
+ self.layer2_a = self._make_layer(block, 128, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
508
+ self.layer3_a = self._make_layer(block, 256, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
509
+ self.layer4_a = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
510
+
511
+ self.inplanes = 1
512
+ self.layer5_b = self._make_layer(block, 16, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
513
+
514
+ self.avgpool_a = nn.AdaptiveAvgPool2d((1, 1))
515
+ self.fc_a = nn.Linear(512 * block.expansion, num_classes)
516
+
517
+ for m in self.modules():
518
+ if isinstance(m, nn.Conv2d):
519
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
520
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
521
+ nn.init.constant_(m.weight, 1)
522
+ nn.init.constant_(m.bias, 0)
523
+
524
+ # Zero-initialize the last BN in each residual branch,
525
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
526
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
527
+ if zero_init_residual:
528
+ for m in self.modules():
529
+ if isinstance(m, Bottleneck):
530
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
531
+ elif isinstance(m, BasicBlock):
532
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
533
+
534
+ def _make_layer(
535
+ self,
536
+ block: Type[Union[BasicBlock, Bottleneck]],
537
+ planes: int,
538
+ blocks: int,
539
+ stride: int = 1,
540
+ dilate: bool = False,
541
+ ) -> nn.Sequential:
542
+ norm_layer = self._norm_layer
543
+ downsample = None
544
+ previous_dilation = self.dilation
545
+ if dilate:
546
+ self.dilation *= stride
547
+ stride = 1
548
+ if stride != 1 or self.inplanes != planes * block.expansion:
549
+ downsample = nn.Sequential(
550
+ conv1x1(self.inplanes, planes * block.expansion, stride),
551
+ norm_layer(planes * block.expansion),
552
+ )
553
+
554
+ layers = []
555
+ layers.append(
556
+ block(
557
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
558
+ )
559
+ )
560
+ self.inplanes = planes * block.expansion
561
+ for _ in range(1, blocks):
562
+ layers.append(
563
+ block(
564
+ self.inplanes,
565
+ planes,
566
+ groups=self.groups,
567
+ base_width=self.base_width,
568
+ dilation=self.dilation,
569
+ norm_layer=norm_layer,
570
+ )
571
+ )
572
+
573
+ return nn.Sequential(*layers)
574
+
575
+ def _forward_impl(self, x: Tensor, y: Tensor) -> Tensor:
576
+ # See note [TorchScript super()]
577
+ x = self.conv1(x)
578
+ x = self.bn1(x)
579
+ x = self.relu(x)
580
+ x = self.maxpool(x)
581
+
582
+ x = self.layer1(x)
583
+ x = self.layer2(x)
584
+ x = self.layer3(x)
585
+ x = self.layer4(x)
586
+ x = self.layer5(x)
587
+ # print(x.shape, flush = True)
588
+
589
+ y = self.layer5_b(y)
590
+
591
+ x = self.layer1_a(x*y)
592
+ x = self.layer2_a(x)
593
+ # x = self.layer2_a(self.attention_layer(x))
594
+ x = self.layer3_a(x)
595
+ x = self.layer4_a(x)
596
+ x = self.avgpool_a(x)
597
+ x = torch.flatten(x, 1)
598
+ x = self.fc_a(x)
599
+
600
+ return x
601
+
602
+ def _forward_feature(self, x: Tensor) -> Tensor:
603
+ # See note [TorchScript super()]
604
+ x = self.conv1(x)
605
+ x = self.bn1(x)
606
+ x = self.relu(x)
607
+ x = self.maxpool(x)
608
+
609
+ x = self.layer1(x)
610
+ x = self.layer2(x)
611
+ x = self.layer3(x)
612
+ x = self.layer4(x)
613
+ x = self.layer5(x)
614
+
615
+ return x
616
+
617
+ def _forward_trans(self, x: Tensor, y: Tensor) -> Tensor:
618
+
619
+ y = self.layer5_b(y)
620
+
621
+ x = self.layer1_a(x*y)
622
+ x = self.layer2_a(x)
623
+ # x = self.layer2_a(self.attention_layer(x))
624
+ x = self.layer3_a(x)
625
+ x = self.layer4_a(x)
626
+ x = self.avgpool_a(x)
627
+ x = torch.flatten(x, 1)
628
+ x = self.fc_a(x)
629
+
630
+ return x
631
+
632
+ def forward(self, x: Tensor, y: Tensor=None, mode: int = 0) -> Tensor:
633
+ if mode == 0:
634
+ return self._forward_impl(x, y)
635
+ elif mode == 1:
636
+ return self._forward_feature(x)
637
+ elif mode == 2:
638
+ return self._forward_trans(x, y)
639
+
640
+ class ResMDisNet(nn.Module):
641
+ def __init__(
642
+ self,
643
+ block: Type[Union[BasicBlock, Bottleneck]],
644
+ layers: List[int],
645
+ num_classes: int = 1000,
646
+ zero_init_residual: bool = False,
647
+ groups: int = 1,
648
+ width_per_group: int = 64,
649
+ replace_stride_with_dilation: Optional[List[bool]] = None,
650
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
651
+ ) -> None:
652
+ super().__init__()
653
+ _log_api_usage_once(self)
654
+
655
+ # from .attention_networks import Self_Attn
656
+ # self.attention_layer = Self_Attn(64, 'relu')
657
+
658
+ if norm_layer is None:
659
+ norm_layer = nn.BatchNorm2d
660
+ self._norm_layer = norm_layer
661
+
662
+ self.inplanes = 64
663
+ self.dilation = 1
664
+ if replace_stride_with_dilation is None:
665
+ # each element in the tuple indicates if we should replace
666
+ # the 2x2 stride with a dilated convolution instead
667
+ replace_stride_with_dilation = [False, False, False]
668
+ if len(replace_stride_with_dilation) != 3:
669
+ raise ValueError(
670
+ "replace_stride_with_dilation should be None "
671
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
672
+ )
673
+ self.groups = groups
674
+ self.base_width = width_per_group
675
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
676
+ self.bn1 = norm_layer(self.inplanes)
677
+ self.relu = nn.ReLU(inplace=True)
678
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
679
+ self.layer1 = self._make_layer(block, 64, layers[0])
680
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
681
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
682
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
683
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
684
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
685
+
686
+ self.layer5 = self._make_layer(block, 16, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
687
+
688
+ self.layer1_a = self._make_layer(block, 64, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
689
+ self.layer2_a = self._make_layer(block, 128, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
690
+ self.layer3_a = self._make_layer(block, 256, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
691
+ self.layer4_a = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
692
+
693
+ self.inplanes = 2
694
+ self.layer5_b = self._make_layer(block, 16, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
695
+
696
+ self.avgpool_a = nn.AdaptiveAvgPool2d((1, 1))
697
+ self.fc_a = nn.Linear(512 * block.expansion, num_classes)
698
+
699
+ for m in self.modules():
700
+ if isinstance(m, nn.Conv2d):
701
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
702
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
703
+ nn.init.constant_(m.weight, 1)
704
+ nn.init.constant_(m.bias, 0)
705
+
706
+ # Zero-initialize the last BN in each residual branch,
707
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
708
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
709
+ if zero_init_residual:
710
+ for m in self.modules():
711
+ if isinstance(m, Bottleneck):
712
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
713
+ elif isinstance(m, BasicBlock):
714
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
715
+
716
+ def _make_layer(
717
+ self,
718
+ block: Type[Union[BasicBlock, Bottleneck]],
719
+ planes: int,
720
+ blocks: int,
721
+ stride: int = 1,
722
+ dilate: bool = False,
723
+ ) -> nn.Sequential:
724
+ norm_layer = self._norm_layer
725
+ downsample = None
726
+ previous_dilation = self.dilation
727
+ if dilate:
728
+ self.dilation *= stride
729
+ stride = 1
730
+ if stride != 1 or self.inplanes != planes * block.expansion:
731
+ downsample = nn.Sequential(
732
+ conv1x1(self.inplanes, planes * block.expansion, stride),
733
+ norm_layer(planes * block.expansion),
734
+ )
735
+
736
+ layers = []
737
+ layers.append(
738
+ block(
739
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
740
+ )
741
+ )
742
+ self.inplanes = planes * block.expansion
743
+ for _ in range(1, blocks):
744
+ layers.append(
745
+ block(
746
+ self.inplanes,
747
+ planes,
748
+ groups=self.groups,
749
+ base_width=self.base_width,
750
+ dilation=self.dilation,
751
+ norm_layer=norm_layer,
752
+ )
753
+ )
754
+
755
+ return nn.Sequential(*layers)
756
+
757
+ def _forward_impl(self, x: Tensor, y: Tensor) -> Tensor:
758
+ # See note [TorchScript super()]
759
+ x = self.conv1(x)
760
+ x = self.bn1(x)
761
+ x = self.relu(x)
762
+ x = self.maxpool(x)
763
+
764
+ x = self.layer1(x)
765
+ x = self.layer2(x)
766
+ x = self.layer3(x)
767
+ x = self.layer4(x)
768
+ x = self.layer5(x)
769
+ # print(x.shape, flush = True)
770
+
771
+ y = self.layer5_b(y)
772
+
773
+ x = self.layer1_a(x*y)
774
+ x = self.layer2_a(x)
775
+ # x = self.layer2_a(self.attention_layer(x))
776
+ x = self.layer3_a(x)
777
+ x = self.layer4_a(x)
778
+ x = self.avgpool_a(x)
779
+ x = torch.flatten(x, 1)
780
+ x = self.fc_a(x)
781
+
782
+ return x
783
+
784
+ def _forward_feature(self, x: Tensor) -> Tensor:
785
+ # See note [TorchScript super()]
786
+ x = self.conv1(x)
787
+ x = self.bn1(x)
788
+ x = self.relu(x)
789
+ x = self.maxpool(x)
790
+
791
+ x = self.layer1(x)
792
+ x = self.layer2(x)
793
+ x = self.layer3(x)
794
+ x = self.layer4(x)
795
+ x = self.layer5(x)
796
+
797
+ return x
798
+
799
+ def _forward_trans(self, x: Tensor, y: Tensor) -> Tensor:
800
+
801
+ y = self.layer5_b(y)
802
+
803
+ x = self.layer1_a(x*y)
804
+ x = self.layer2_a(x)
805
+ # x = self.layer2_a(self.attention_layer(x))
806
+ x = self.layer3_a(x)
807
+ x = self.layer4_a(x)
808
+ x = self.avgpool_a(x)
809
+ x = torch.flatten(x, 1)
810
+ x = self.fc_a(x)
811
+
812
+ return x
813
+
814
+ def forward(self, x: Tensor, y: Tensor=None, mode: int = 0) -> Tensor:
815
+ if mode == 0:
816
+ return self._forward_impl(x, y)
817
+ elif mode == 1:
818
+ return self._forward_feature(x)
819
+ elif mode == 2:
820
+ return self._forward_trans(x, y)
821
+
822
+ class ResDis2Net(nn.Module):
823
+ def __init__(
824
+ self,
825
+ block: Type[Union[BasicBlock, Bottleneck]],
826
+ layers: List[int],
827
+ num_classes: int = 1000,
828
+ zero_init_residual: bool = False,
829
+ groups: int = 1,
830
+ width_per_group: int = 64,
831
+ replace_stride_with_dilation: Optional[List[bool]] = None,
832
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
833
+ ) -> None:
834
+ super().__init__()
835
+ _log_api_usage_once(self)
836
+
837
+ # from .attention_networks import Self_Attn
838
+ # self.attention_layer = Self_Attn(64, 'relu')
839
+
840
+ if norm_layer is None:
841
+ norm_layer = nn.BatchNorm2d
842
+ self._norm_layer = norm_layer
843
+
844
+ self.inplanes = 64
845
+ self.dilation = 1
846
+ if replace_stride_with_dilation is None:
847
+ # each element in the tuple indicates if we should replace
848
+ # the 2x2 stride with a dilated convolution instead
849
+ replace_stride_with_dilation = [False, False, False]
850
+ if len(replace_stride_with_dilation) != 3:
851
+ raise ValueError(
852
+ "replace_stride_with_dilation should be None "
853
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
854
+ )
855
+ self.groups = groups
856
+ self.base_width = width_per_group
857
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
858
+ self.bn1 = norm_layer(self.inplanes)
859
+ self.relu = nn.ReLU(inplace=True)
860
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
861
+ self.layer1 = self._make_layer(block, 64, layers[0])
862
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
863
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
864
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
865
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
866
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
867
+
868
+ self.layer5 = self._make_layer(block, 16, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
869
+
870
+ self.inplanes *= 2
871
+ self.layer1_a = self._make_layer(block, 64, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
872
+ self.layer2_a = self._make_layer(block, 128, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
873
+ self.layer3_a = self._make_layer(block, 256, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
874
+ self.layer4_a = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
875
+
876
+ self.inplanes = 1
877
+ self.layer5_b = self._make_layer(block, 16, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
878
+
879
+ self.avgpool_a = nn.AdaptiveAvgPool2d((1, 1))
880
+ self.fc_a = nn.Linear(512 * block.expansion, num_classes)
881
+
882
+ for m in self.modules():
883
+ if isinstance(m, nn.Conv2d):
884
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
885
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
886
+ nn.init.constant_(m.weight, 1)
887
+ nn.init.constant_(m.bias, 0)
888
+
889
+ # Zero-initialize the last BN in each residual branch,
890
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
891
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
892
+ if zero_init_residual:
893
+ for m in self.modules():
894
+ if isinstance(m, Bottleneck):
895
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
896
+ elif isinstance(m, BasicBlock):
897
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
898
+
899
+ def _make_layer(
900
+ self,
901
+ block: Type[Union[BasicBlock, Bottleneck]],
902
+ planes: int,
903
+ blocks: int,
904
+ stride: int = 1,
905
+ dilate: bool = False,
906
+ ) -> nn.Sequential:
907
+ norm_layer = self._norm_layer
908
+ downsample = None
909
+ previous_dilation = self.dilation
910
+ if dilate:
911
+ self.dilation *= stride
912
+ stride = 1
913
+ if stride != 1 or self.inplanes != planes * block.expansion:
914
+ downsample = nn.Sequential(
915
+ conv1x1(self.inplanes, planes * block.expansion, stride),
916
+ norm_layer(planes * block.expansion),
917
+ )
918
+
919
+ layers = []
920
+ layers.append(
921
+ block(
922
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
923
+ )
924
+ )
925
+ self.inplanes = planes * block.expansion
926
+ for _ in range(1, blocks):
927
+ layers.append(
928
+ block(
929
+ self.inplanes,
930
+ planes,
931
+ groups=self.groups,
932
+ base_width=self.base_width,
933
+ dilation=self.dilation,
934
+ norm_layer=norm_layer,
935
+ )
936
+ )
937
+
938
+ return nn.Sequential(*layers)
939
+
940
+ def _forward_impl(self, x: Tensor, y: Tensor) -> Tensor:
941
+ # See note [TorchScript super()]
942
+ x = self.conv1(x)
943
+ x = self.bn1(x)
944
+ x = self.relu(x)
945
+ x = self.maxpool(x)
946
+
947
+ x = self.layer1(x)
948
+ x = self.layer2(x)
949
+ x = self.layer3(x)
950
+ x = self.layer4(x)
951
+ x = self.layer5(x)
952
+ # print(x.shape, flush = True)
953
+
954
+ y = self.layer5_b(y)
955
+
956
+ x = self.layer1_a(torch.cat([x,y], 1))
957
+ x = self.layer2_a(x)
958
+ # x = self.layer2_a(self.attention_layer(x))
959
+ x = self.layer3_a(x)
960
+ x = self.layer4_a(x)
961
+ x = self.avgpool_a(x)
962
+ x = torch.flatten(x, 1)
963
+ x = self.fc_a(x)
964
+
965
+ return x
966
+
967
+ def _forward_feature(self, x: Tensor) -> Tensor:
968
+ # See note [TorchScript super()]
969
+ x = self.conv1(x)
970
+ x = self.bn1(x)
971
+ x = self.relu(x)
972
+ x = self.maxpool(x)
973
+
974
+ x = self.layer1(x)
975
+ x = self.layer2(x)
976
+ x = self.layer3(x)
977
+ x = self.layer4(x)
978
+ x = self.layer5(x)
979
+
980
+ return x
981
+
982
+ def _forward_trans(self, x: Tensor, y: Tensor) -> Tensor:
983
+
984
+ y = self.layer5_b(y)
985
+
986
+ x = self.layer1_a(torch.cat([x,y], 1))
987
+ x = self.layer2_a(x)
988
+ # x = self.layer2_a(self.attention_layer(x))
989
+ x = self.layer3_a(x)
990
+ x = self.layer4_a(x)
991
+ x = self.avgpool_a(x)
992
+ x = torch.flatten(x, 1)
993
+ x = self.fc_a(x)
994
+
995
+ return x
996
+
997
+ def forward(self, x: Tensor, y: Tensor=None, mode: int = 0) -> Tensor:
998
+ if mode == 0:
999
+ return self._forward_impl(x, y)
1000
+ elif mode == 1:
1001
+ return self._forward_feature(x)
1002
+ elif mode == 2:
1003
+ return self._forward_trans(x, y)
1004
+
1005
+ class ResDisAttNet(nn.Module):
1006
+ def __init__(
1007
+ self,
1008
+ block: Type[Union[BasicBlock, Bottleneck]],
1009
+ layers: List[int],
1010
+ num_classes: int = 1000,
1011
+ zero_init_residual: bool = False,
1012
+ groups: int = 1,
1013
+ width_per_group: int = 64,
1014
+ replace_stride_with_dilation: Optional[List[bool]] = None,
1015
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
1016
+ ) -> None:
1017
+ super().__init__()
1018
+ _log_api_usage_once(self)
1019
+
1020
+ from .attention_networks import Pose_Attn
1021
+ self.attention_layer = Pose_Attn(64, 'relu')
1022
+
1023
+ if norm_layer is None:
1024
+ norm_layer = nn.BatchNorm2d
1025
+ self._norm_layer = norm_layer
1026
+
1027
+ self.inplanes = 64
1028
+ self.dilation = 1
1029
+ if replace_stride_with_dilation is None:
1030
+ # each element in the tuple indicates if we should replace
1031
+ # the 2x2 stride with a dilated convolution instead
1032
+ replace_stride_with_dilation = [False, False, False]
1033
+ if len(replace_stride_with_dilation) != 3:
1034
+ raise ValueError(
1035
+ "replace_stride_with_dilation should be None "
1036
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
1037
+ )
1038
+ self.groups = groups
1039
+ self.base_width = width_per_group
1040
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
1041
+ self.bn1 = norm_layer(self.inplanes)
1042
+ self.relu = nn.ReLU(inplace=True)
1043
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
1044
+ self.layer1 = self._make_layer(block, 64, layers[0])
1045
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
1046
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
1047
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
1048
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
1049
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
1050
+
1051
+ self.layer5 = self._make_layer(block, 16, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
1052
+
1053
+ self.layer1_a = self._make_layer(block, 64, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
1054
+ self.layer2_a = self._make_layer(block, 128, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
1055
+ self.layer3_a = self._make_layer(block, 256, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
1056
+ self.layer4_a = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
1057
+
1058
+ self.inplanes = 1
1059
+ self.layer5_b = self._make_layer(block, 16, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
1060
+
1061
+ self.avgpool_a = nn.AdaptiveAvgPool2d((1, 1))
1062
+ self.fc_a = nn.Linear(512 * block.expansion, num_classes)
1063
+
1064
+ for m in self.modules():
1065
+ if isinstance(m, nn.Conv2d):
1066
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
1067
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
1068
+ nn.init.constant_(m.weight, 1)
1069
+ nn.init.constant_(m.bias, 0)
1070
+
1071
+ # Zero-initialize the last BN in each residual branch,
1072
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
1073
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
1074
+ if zero_init_residual:
1075
+ for m in self.modules():
1076
+ if isinstance(m, Bottleneck):
1077
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
1078
+ elif isinstance(m, BasicBlock):
1079
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
1080
+
1081
+ def _make_layer(
1082
+ self,
1083
+ block: Type[Union[BasicBlock, Bottleneck]],
1084
+ planes: int,
1085
+ blocks: int,
1086
+ stride: int = 1,
1087
+ dilate: bool = False,
1088
+ ) -> nn.Sequential:
1089
+ norm_layer = self._norm_layer
1090
+ downsample = None
1091
+ previous_dilation = self.dilation
1092
+ if dilate:
1093
+ self.dilation *= stride
1094
+ stride = 1
1095
+ if stride != 1 or self.inplanes != planes * block.expansion:
1096
+ downsample = nn.Sequential(
1097
+ conv1x1(self.inplanes, planes * block.expansion, stride),
1098
+ norm_layer(planes * block.expansion),
1099
+ )
1100
+
1101
+ layers = []
1102
+ layers.append(
1103
+ block(
1104
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
1105
+ )
1106
+ )
1107
+ self.inplanes = planes * block.expansion
1108
+ for _ in range(1, blocks):
1109
+ layers.append(
1110
+ block(
1111
+ self.inplanes,
1112
+ planes,
1113
+ groups=self.groups,
1114
+ base_width=self.base_width,
1115
+ dilation=self.dilation,
1116
+ norm_layer=norm_layer,
1117
+ )
1118
+ )
1119
+
1120
+ return nn.Sequential(*layers)
1121
+
1122
+ def _forward_impl(self, x: Tensor, y: Tensor) -> Tensor:
1123
+ # See note [TorchScript super()]
1124
+ x = self.conv1(x)
1125
+ x = self.bn1(x)
1126
+ x = self.relu(x)
1127
+ x = self.maxpool(x)
1128
+
1129
+ x = self.layer1(x)
1130
+ x = self.layer2(x)
1131
+ x = self.layer3(x)
1132
+ x = self.layer4(x)
1133
+ x = self.layer5(x)
1134
+ # print(x.shape, flush = True)
1135
+
1136
+ y = self.layer5_b(y)
1137
+
1138
+ x = self.layer1_a(self.attention_layer(x, y)[0])
1139
+ x = self.layer2_a(x)
1140
+ # x = self.layer2_a(self.attention_layer(x))
1141
+ x = self.layer3_a(x)
1142
+ x = self.layer4_a(x)
1143
+ x = self.avgpool_a(x)
1144
+ x = torch.flatten(x, 1)
1145
+ x = self.fc_a(x)
1146
+
1147
+ return x
1148
+
1149
+ def _forward_feature(self, x: Tensor) -> Tensor:
1150
+ # See note [TorchScript super()]
1151
+ x = self.conv1(x)
1152
+ x = self.bn1(x)
1153
+ x = self.relu(x)
1154
+ x = self.maxpool(x)
1155
+
1156
+ x = self.layer1(x)
1157
+ x = self.layer2(x)
1158
+ x = self.layer3(x)
1159
+ x = self.layer4(x)
1160
+ x = self.layer5(x)
1161
+
1162
+ return x
1163
+
1164
+ def _forward_trans(self, x: Tensor, y: Tensor) -> Tensor:
1165
+
1166
+ y = self.layer5_b(y)
1167
+
1168
+ x = self.layer1_a(self.attention_layer(x, y)[0])
1169
+ x = self.layer2_a(x)
1170
+ # x = self.layer2_a(self.attention_layer(x))
1171
+ x = self.layer3_a(x)
1172
+ x = self.layer4_a(x)
1173
+ x = self.avgpool_a(x)
1174
+ x = torch.flatten(x, 1)
1175
+ x = self.fc_a(x)
1176
+
1177
+ return x
1178
+
1179
+ def forward(self, x: Tensor, y: Tensor=None, mode: int = 0) -> Tensor:
1180
+ if mode == 0:
1181
+ return self._forward_impl(x, y)
1182
+ elif mode == 1:
1183
+ return self._forward_feature(x)
1184
+ elif mode == 2:
1185
+ return self._forward_trans(x, y)
1186
+
1187
+ def _resnet(
1188
+ arch: str,
1189
+ block: Type[Union[BasicBlock, Bottleneck]],
1190
+ layers: List[int],
1191
+ pretrained: bool,
1192
+ progress: bool,
1193
+ **kwargs: Any,
1194
+ ) -> ResNet:
1195
+ model = ResNet(block, layers, **kwargs)
1196
+ if pretrained:
1197
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
1198
+ model.load_state_dict(state_dict)
1199
+ return model
1200
+
1201
+ def _resappnet(
1202
+ arch: str,
1203
+ block: Type[Union[BasicBlock, Bottleneck]],
1204
+ layers: List[int],
1205
+ pretrained: bool,
1206
+ progress: bool,
1207
+ **kwargs: Any,
1208
+ ) -> ResAppNet:
1209
+ model = ResAppNet(block, layers, **kwargs)
1210
+ if pretrained:
1211
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
1212
+ model.load_state_dict(state_dict)
1213
+ return model
1214
+
1215
+ def _resdisnet(
1216
+ arch: str,
1217
+ block: Type[Union[BasicBlock, Bottleneck]],
1218
+ layers: List[int],
1219
+ pretrained: bool,
1220
+ progress: bool,
1221
+ **kwargs: Any,
1222
+ ) -> ResDisNet:
1223
+ model = ResDisNet(block, layers, **kwargs)
1224
+ if pretrained:
1225
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
1226
+ model.load_state_dict(state_dict)
1227
+ return model
1228
+
1229
+ def _resmdisnet(
1230
+ arch: str,
1231
+ block: Type[Union[BasicBlock, Bottleneck]],
1232
+ layers: List[int],
1233
+ pretrained: bool,
1234
+ progress: bool,
1235
+ **kwargs: Any,
1236
+ ) -> ResMDisNet:
1237
+ model = ResMDisNet(block, layers, **kwargs)
1238
+ if pretrained:
1239
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
1240
+ model.load_state_dict(state_dict)
1241
+ return model
1242
+
1243
+ def _resdis2net(
1244
+ arch: str,
1245
+ block: Type[Union[BasicBlock, Bottleneck]],
1246
+ layers: List[int],
1247
+ pretrained: bool,
1248
+ progress: bool,
1249
+ **kwargs: Any,
1250
+ ) -> ResDis2Net:
1251
+ model = ResDis2Net(block, layers, **kwargs)
1252
+ if pretrained:
1253
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
1254
+ model.load_state_dict(state_dict)
1255
+ return model
1256
+
1257
+ def _resdisattnet(
1258
+ arch: str,
1259
+ block: Type[Union[BasicBlock, Bottleneck]],
1260
+ layers: List[int],
1261
+ pretrained: bool,
1262
+ progress: bool,
1263
+ **kwargs: Any,
1264
+ ) -> ResDisAttNet:
1265
+ model = ResDisAttNet(block, layers, **kwargs)
1266
+ if pretrained:
1267
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
1268
+ model.load_state_dict(state_dict)
1269
+ return model
1270
+
1271
+ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1272
+ r"""ResNet-18 model from
1273
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
1274
+ Args:
1275
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1276
+ progress (bool): If True, displays a progress bar of the download to stderr
1277
+ """
1278
+ return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
1279
+
1280
+
1281
+ def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1282
+ r"""ResNet-34 model from
1283
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
1284
+ Args:
1285
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1286
+ progress (bool): If True, displays a progress bar of the download to stderr
1287
+ """
1288
+ return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
1289
+
1290
+
1291
+ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1292
+ r"""ResNet-50 model from
1293
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
1294
+ Args:
1295
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1296
+ progress (bool): If True, displays a progress bar of the download to stderr
1297
+ """
1298
+ return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
1299
+
1300
+
1301
+ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1302
+ r"""ResNet-101 model from
1303
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
1304
+ Args:
1305
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1306
+ progress (bool): If True, displays a progress bar of the download to stderr
1307
+ """
1308
+ return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
1309
+
1310
+
1311
+ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1312
+ r"""ResNet-152 model from
1313
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
1314
+ Args:
1315
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1316
+ progress (bool): If True, displays a progress bar of the download to stderr
1317
+ """
1318
+ return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
1319
+
1320
+
1321
+ def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1322
+ r"""ResNeXt-50 32x4d model from
1323
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
1324
+ Args:
1325
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1326
+ progress (bool): If True, displays a progress bar of the download to stderr
1327
+ """
1328
+ kwargs["groups"] = 32
1329
+ kwargs["width_per_group"] = 4
1330
+ return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
1331
+
1332
+
1333
+ def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1334
+ r"""ResNeXt-101 32x8d model from
1335
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
1336
+ Args:
1337
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1338
+ progress (bool): If True, displays a progress bar of the download to stderr
1339
+ """
1340
+ kwargs["groups"] = 32
1341
+ kwargs["width_per_group"] = 8
1342
+ return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
1343
+
1344
+
1345
+ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1346
+ r"""Wide ResNet-50-2 model from
1347
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
1348
+ The model is the same as ResNet except for the bottleneck number of channels
1349
+ which is twice larger in every block. The number of channels in outer 1x1
1350
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
1351
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
1352
+ Args:
1353
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1354
+ progress (bool): If True, displays a progress bar of the download to stderr
1355
+ """
1356
+ kwargs["width_per_group"] = 64 * 2
1357
+ return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
1358
+
1359
+
1360
+ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1361
+ r"""Wide ResNet-101-2 model from
1362
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
1363
+ The model is the same as ResNet except for the bottleneck number of channels
1364
+ which is twice larger in every block. The number of channels in outer 1x1
1365
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
1366
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
1367
+ Args:
1368
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1369
+ progress (bool): If True, displays a progress bar of the download to stderr
1370
+ """
1371
+ kwargs["width_per_group"] = 64 * 2
1372
+ return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
1373
+
1374
+
1375
+ def wide_resappnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1376
+ r"""Wide ResNet-50-2 model from
1377
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
1378
+ The model is the same as ResNet except for the bottleneck number of channels
1379
+ which is twice larger in every block. The number of channels in outer 1x1
1380
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
1381
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
1382
+ Args:
1383
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1384
+ progress (bool): If True, displays a progress bar of the download to stderr
1385
+ """
1386
+ kwargs["width_per_group"] = 64 * 2
1387
+ return _resappnet("wide_resappnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
1388
+
1389
+ def wide_resdisnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1390
+ r"""Wide ResNet-50-2 model from
1391
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
1392
+ The model is the same as ResNet except for the bottleneck number of channels
1393
+ which is twice larger in every block. The number of channels in outer 1x1
1394
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
1395
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
1396
+ Args:
1397
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1398
+ progress (bool): If True, displays a progress bar of the download to stderr
1399
+ """
1400
+ kwargs["width_per_group"] = 64 * 2
1401
+ return _resdisnet("wide_resdisnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
1402
+
1403
+ def wide_resmdisnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1404
+ r"""Wide ResNet-50-2 model from
1405
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
1406
+ The model is the same as ResNet except for the bottleneck number of channels
1407
+ which is twice larger in every block. The number of channels in outer 1x1
1408
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
1409
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
1410
+ Args:
1411
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1412
+ progress (bool): If True, displays a progress bar of the download to stderr
1413
+ """
1414
+ kwargs["width_per_group"] = 64 * 2
1415
+ return _resmdisnet("wide_resmdisnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
1416
+
1417
+ def wide_resdis2net50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1418
+ r"""Wide ResNet-50-2 model from
1419
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
1420
+ The model is the same as ResNet except for the bottleneck number of channels
1421
+ which is twice larger in every block. The number of channels in outer 1x1
1422
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
1423
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
1424
+ Args:
1425
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1426
+ progress (bool): If True, displays a progress bar of the download to stderr
1427
+ """
1428
+ kwargs["width_per_group"] = 64 * 2
1429
+ return _resdis2net("wide_resdis2net50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
1430
+
1431
+ def wide_resdisattnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1432
+ r"""Wide ResNet-50-2 model from
1433
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
1434
+ The model is the same as ResNet except for the bottleneck number of channels
1435
+ which is twice larger in every block. The number of channels in outer 1x1
1436
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
1437
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
1438
+ Args:
1439
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1440
+ progress (bool): If True, displays a progress bar of the download to stderr
1441
+ """
1442
+ kwargs["width_per_group"] = 64 * 2
1443
+ return _resdisattnet("wide_resdisattnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
1444
+
1445
+ def resappnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
1446
+ r"""ResNet-34 model from
1447
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
1448
+ Args:
1449
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
1450
+ progress (bool): If True, displays a progress bar of the download to stderr
1451
+ """
1452
+ return _resappnet("resappnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
models/rnn_net.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only.
3
+ No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify,
4
+ publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights.
5
+ Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability,
6
+ title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses.
7
+ In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof.
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn import init
13
+ import torch.optim as optim
14
+
15
+
16
+ class RNNModule(nn.Module):
17
+ def __init__(self,
18
+ z_dim=64,
19
+ h_dim=64,
20
+ w_residual=0.2):
21
+ super(RNNModule, self).__init__()
22
+
23
+ self.z_dim = z_dim
24
+ self.h_dim = h_dim
25
+ self.w_residual = w_residual
26
+
27
+ self.enc_cell = nn.LSTMCell(z_dim, h_dim)
28
+ self.cell = nn.LSTMCell(z_dim, h_dim)
29
+ self.w = nn.Parameter(torch.FloatTensor(h_dim, h_dim))
30
+ self.b = nn.Parameter(torch.FloatTensor(h_dim))
31
+ self.fc1 = nn.Linear(h_dim * 2, z_dim)
32
+ self.relu = nn.ReLU()
33
+ self.fc2 = nn.Linear(z_dim, z_dim)
34
+
35
+ self.init_weights()
36
+
37
+ def init_optim(self, lr, beta1, beta2):
38
+ self.optim = optim.Adam(params=self.parameters(),
39
+ lr=lr,
40
+ betas=(beta1, beta2),
41
+ weight_decay=0,
42
+ eps=1e-8)
43
+
44
+ def init_weights(self):
45
+ for module in self.modules():
46
+ if (isinstance(module, nn.LSTMCell)):
47
+ for name, param in module.named_parameters():
48
+ if ('weight_ih' in name) or ('weight_hh' in name):
49
+ mul = param.shape[0] // 4
50
+ for idx in range(4):
51
+ init.orthogonal_(param[idx * mul:(idx + 1) * mul])
52
+ elif 'bias' in name:
53
+ param.data.fill_(0)
54
+ if (isinstance(module, nn.Linear)):
55
+ init.orthogonal_(module.weight)
56
+
57
+ nn.init.normal_(self.w, std=0.02)
58
+ self.b.data.fill_(0.0)
59
+
60
+ def forward(self, z, n_frame):
61
+
62
+ out = [z]
63
+ h_, c_ = self.enc_cell(z)
64
+ h = [h_]
65
+ c = [c_]
66
+ e = []
67
+ for i in range(n_frame - 1):
68
+ e_ = self.get_initial_state_z(z.shape[0])
69
+ h_, c_ = self.cell(e_, (h[-1], c[-1]))
70
+ mul = torch.matmul(h_, self.w) + self.b
71
+ mul = torch.tanh(mul)
72
+ e.append(e_)
73
+ h.append(h_)
74
+ c.append(c_)
75
+ out_ = out[-1] + self.w_residual * mul
76
+ out.append(out_)
77
+
78
+ out = [item.unsqueeze(1) for item in out]
79
+
80
+ out = torch.cat(out, dim=1).view(-1, self.z_dim)
81
+
82
+ e = [item.unsqueeze(1) for item in e]
83
+ e = torch.cat(e, dim=1).view(-1, self.z_dim)
84
+
85
+ hh = h[1:]
86
+ hh = [item.unsqueeze(1) for item in hh]
87
+ hh = torch.cat(hh, dim=1).view(-1, self.h_dim)
88
+
89
+ cc = c[1:]
90
+ cc = [item.unsqueeze(1) for item in cc]
91
+ cc = torch.cat(cc, dim=1).view(-1, self.h_dim)
92
+
93
+ hc = torch.cat((hh, cc), dim=1)
94
+ e_rec = self.fc2(self.relu(self.fc1(hc)))
95
+
96
+ return out, e, e_rec
97
+
98
+ def get_initial_state_z(self, batchSize):
99
+ return torch.cuda.FloatTensor(batchSize, self.z_dim).normal_()
models/sample_model.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .base_model import BaseModel
3
+ from . import networks
4
+ from . import lmcode_networks
5
+ from . import diy_networks
6
+ from . import resnet
7
+
8
+ import dnnlib
9
+ import legacy
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ import random
13
+ import os
14
+
15
+ from . import rnn_net
16
+
17
+ def make_transform(translate, angle):
18
+ m = np.eye(3)
19
+ s = np.sin(angle/360.0*np.pi*2)
20
+ c = np.cos(angle/360.0*np.pi*2)
21
+ m[0][0] = c
22
+ m[0][1] = s
23
+ m[0][2] = translate[0]
24
+ m[1][0] = -s
25
+ m[1][1] = c
26
+ m[1][2] = translate[1]
27
+ return m
28
+
29
+ class SampleModel(BaseModel):
30
+ """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
31
+
32
+ The model training requires '--dataset_mode aligned' dataset.
33
+ By default, it uses a '--netG unet256' U-Net generator,
34
+ a '--netD basic' discriminator (PatchGAN),
35
+ and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
36
+
37
+ pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
38
+ """
39
+ @staticmethod
40
+ def modify_commandline_options(parser, is_train=True):
41
+ """Add new dataset-specific options, and rewrite default values for existing options.
42
+
43
+ Parameters:
44
+ parser -- original option parser
45
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
46
+
47
+ Returns:
48
+ the modified parser.
49
+
50
+ For pix2pix, we do not use image buffer
51
+ The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
52
+ By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
53
+ """
54
+ # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
55
+ parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='noiseshufflevideo', num_test = 32)
56
+ parser.add_argument('--pose_path', type=str, default='', help='path for pose net')
57
+ parser.add_argument('--rnn_path', type=str, default='', help='path for rnn net')
58
+ parser.add_argument('--n_frames_G', type=int, default=60)
59
+ parser.add_argument('--w_residual', type=float, default=0.2)
60
+ parser.add_argument('--num_point', type=int, default=14)
61
+ parser.add_argument('--model_names', type=str, default='')
62
+ if is_train:
63
+ parser.set_defaults(pool_size=0, gan_mode='vanilla')
64
+ parser.add_argument('--lambda_L1', type=float, default=1.0, help='weight for L1 loss')
65
+
66
+ return parser
67
+
68
+ def __init__(self, opt):
69
+ """Initialize the pix2pix class.
70
+
71
+ Parameters:
72
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
73
+ """
74
+ BaseModel.__init__(self, opt)
75
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
76
+ self.loss_names = ['G_L1', 'G_VGG', 'G_W']
77
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
78
+ self.visual_names = ['real_vid_B', 'fake_vid_AR', 'fake_vid_BR', 'fake_vid_AR1', 'fake_vid_BR1', 'fake_vid_AR2', 'fake_vid_BR2', 'fake_vid_AB', 'fake_vid_B', 'fake_vid']
79
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
80
+ if self.isTrain:
81
+ self.model_names = ['FE']
82
+ else: # during test time, only load G
83
+ self.model_names = ['FE']
84
+ if opt.model_names != '':
85
+ str_models = opt.model_names.split(',')
86
+ self.model_names = []
87
+ for str_model in str_models:
88
+ self.model_names.append(str_model)
89
+ # define networks (both generator and discriminator)
90
+ with dnnlib.util.open_url(opt.network_pkl) as f:
91
+ self.netG = legacy.load_network_pkl(f)['G_ema'].eval().to(self.gpu_ids[0]) # type: ignore
92
+
93
+ lm_path = 'pretrained_models/wing.ckpt'
94
+ self.netFE_lm = lmcode_networks.FAN(fname_pretrained=lm_path).eval().to(self.gpu_ids[0])
95
+ self.netFE_pose = diy_networks._resposenet(num_point=opt.num_point).eval().to(self.gpu_ids[0])
96
+ if opt.pose_path != '':
97
+ self.netFE_pose.load_state_dict(torch.load(opt.pose_path))
98
+
99
+ self.netFE = resnet.wide_resdisnet50_2(num_classes=512 * 16).to(self.gpu_ids[0])
100
+ self.netFE = networks.init_net(self.netFE, opt.init_type, opt.init_gain, self.gpu_ids)
101
+
102
+ self.netR = rnn_net.RNNModule(w_residual = opt.w_residual).to(self.gpu_ids[0])
103
+ if opt.rnn_path != '':
104
+ self.netR.load_state_dict(torch.load(opt.rnn_path))
105
+ self.n_frames_G = opt.n_frames_G
106
+ self.style_gan_size = 8
107
+
108
+ self.m_zero = make_transform((0.0,0.0),(0.0))
109
+ self.count = 0
110
+
111
+
112
+ def set_input(self, input):
113
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
114
+
115
+ Parameters:
116
+ input (dict): include the data itself and its metadata information.
117
+
118
+ The option 'direction' can be used to swap images in domain A and domain B.
119
+ """
120
+ self.real_Bs = input['A'].to(self.device)
121
+ self.image_paths = input['A_paths']
122
+ self.count += 1
123
+ self.image_paths[0] = os.path.split(self.image_paths[0])[0] + '/' + str(self.count) + '.png'
124
+
125
+ real_v_list = []
126
+ with torch.no_grad():
127
+ for i in range(self.real_Bs.shape[1]):
128
+ real_v_list.append(self.netFE_pose(self.netFE_lm.get_heatmap(self.real_Bs[:,i,...], b_preprocess=False), mode = 1).unsqueeze(1))
129
+
130
+ self.real_v = torch.cat(real_v_list, 1).detach()
131
+
132
+ self.real_z = input['B'].to(self.device)
133
+
134
+ def forward(self):
135
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
136
+
137
+ self.real_A_w = self.netG.mapping(self.real_z, None)
138
+ self.real_A = self.netG.synthesis(self.real_A_w, noise_mode='const').detach().clamp(-1, 1)
139
+ if self.real_A.shape[2] != 256:
140
+ self.real_A = F.interpolate(self.real_A, size=(256, 256), mode='area')
141
+ self.real_A_heat = self.netFE_lm.get_heatmap(self.real_A, b_preprocess=False)
142
+ self.real_A_pose = self.netFE_pose(self.real_A_heat, mode=1).detach()
143
+ self.real_A_app = self.netFE(self.real_A, mode=1).detach()
144
+ self.fake_A_w = self.netFE(self.real_A_app, self.real_A_pose, mode=2).view(-1, 16, 512)
145
+ self.fake_A = self.netG.synthesis(self.fake_A_w, noise_mode='const') # G(A)
146
+
147
+ self.real_B_app = self.netFE(self.real_Bs[:, 0, ...], mode=1)
148
+
149
+ x_fake, self.rand_in, self.rand_rec = self.netR(self.real_v[:, 0].view(self.opt.batch_size, self.style_gan_size * self.style_gan_size), self.n_frames_G)
150
+ x_fake = x_fake.view(self.opt.batch_size, self.n_frames_G, 1, self.style_gan_size,
151
+ self.style_gan_size)
152
+
153
+ self.real_R_pose = x_fake.clone()
154
+
155
+ x_fake, self.rand_in, self.rand_rec = self.netR(self.real_v[:, 29].view(self.opt.batch_size, self.style_gan_size * self.style_gan_size), self.n_frames_G)
156
+ x_fake = x_fake.view(self.opt.batch_size, self.n_frames_G, 1, self.style_gan_size,
157
+ self.style_gan_size)
158
+
159
+ self.real_R1_pose = x_fake.clone()
160
+
161
+ x_fake, self.rand_in, self.rand_rec = self.netR(self.real_v[:, 59].view(self.opt.batch_size, self.style_gan_size * self.style_gan_size), self.n_frames_G)
162
+ x_fake = x_fake.view(self.opt.batch_size, self.n_frames_G, 1, self.style_gan_size,
163
+ self.style_gan_size)
164
+
165
+ self.real_R2_pose = x_fake.clone()
166
+
167
+ x_fake_A, self.rand_in, self.rand_rec = self.netR(self.real_A_pose.view(self.opt.batch_size, self.style_gan_size * self.style_gan_size), self.n_frames_G)
168
+ x_fake_A = x_fake_A.view(self.opt.batch_size, self.n_frames_G, 1, self.style_gan_size,
169
+ self.style_gan_size)
170
+
171
+ self.real_R_pose_A = x_fake_A
172
+
173
+ if hasattr(self.netG.synthesis, 'input'):
174
+ self.netG.synthesis.input.transform.copy_(torch.from_numpy(self.m_zero))
175
+
176
+ self.real_A_list = []
177
+ self.real_B_list = []
178
+ self.fake_AR_list = []
179
+ self.fake_BR_list = []
180
+ self.fake_AR1_list = []
181
+ self.fake_BR1_list = []
182
+ self.fake_AR2_list = []
183
+ self.fake_BR2_list = []
184
+ self.fake_AB_list = []
185
+ self.fake_B_list = []
186
+ # for i in range(self.real_Bs.shape[1]):
187
+ self.real_B_app = self.netFE(self.real_Bs[:,0,...], mode=1)
188
+ for i in range(self.n_frames_G):
189
+ self.real_B = self.real_Bs[:,i,...]
190
+ if self.real_B.shape[2] != 256:
191
+ self.real_B = F.interpolate(self.real_B, size=(256, 256), mode='area')
192
+
193
+ self.fake_AR_w = self.netFE(self.real_A_app, self.real_R_pose[:,i,...], mode=2).view(-1, 16, 512)
194
+ self.fake_BR_w = self.netFE(self.real_B_app, self.real_R_pose[:,i,...], mode=2).view(-1, 16, 512)
195
+ self.fake_AR1_w = self.netFE(self.real_A_app, self.real_R1_pose[:,i,...], mode=2).view(-1, 16, 512)
196
+ self.fake_BR1_w = self.netFE(self.real_B_app, self.real_R1_pose[:,i,...], mode=2).view(-1, 16, 512)
197
+ self.fake_AR2_w = self.netFE(self.real_A_app, self.real_R2_pose[:,i,...], mode=2).view(-1, 16, 512)
198
+ self.fake_BR2_w = self.netFE(self.real_B_app, self.real_R2_pose[:,i,...], mode=2).view(-1, 16, 512)
199
+ self.fake_AB_w = self.netFE(self.real_A_app, self.real_R_pose_A[:,i,...], mode=2).view(-1, 16, 512)
200
+ self.fake_B_w = self.netFE(self.real_B_app, self.real_R_pose_A[:,i,...], mode=2).view(-1, 16, 512)
201
+
202
+ self.fake_AR = self.netG.synthesis(self.fake_AR_w, noise_mode='const') # G(A)
203
+ self.fake_BR = self.netG.synthesis(self.fake_BR_w, noise_mode='const') # G(A)
204
+ self.fake_AR1 = self.netG.synthesis(self.fake_AR1_w, noise_mode='const') # G(A)
205
+ self.fake_BR1 = self.netG.synthesis(self.fake_BR1_w, noise_mode='const') # G(A)
206
+ self.fake_AR2 = self.netG.synthesis(self.fake_AR2_w, noise_mode='const') # G(A)
207
+ self.fake_BR2 = self.netG.synthesis(self.fake_BR2_w, noise_mode='const') # G(A)
208
+ self.fake_AB = self.netG.synthesis(self.fake_AB_w, noise_mode='const') # G(A)
209
+ self.fake_B = self.netG.synthesis(self.fake_B_w, noise_mode='const') # G(A)
210
+
211
+ self.real_A_list.append(self.real_A.clamp(-1, 1))
212
+ self.real_B_list.append(self.real_B.clamp(-1, 1))
213
+ self.fake_AR_list.append(self.fake_AR.clamp(-1, 1))
214
+ self.fake_BR_list.append(self.fake_BR.clamp(-1, 1))
215
+ self.fake_AR1_list.append(self.fake_AR1.clamp(-1, 1))
216
+ self.fake_BR1_list.append(self.fake_BR1.clamp(-1, 1))
217
+ self.fake_AR2_list.append(self.fake_AR2.clamp(-1, 1))
218
+ self.fake_BR2_list.append(self.fake_BR2.clamp(-1, 1))
219
+ self.fake_AB_list.append(self.fake_AB.clamp(-1, 1))
220
+ self.fake_B_list.append(self.fake_B.clamp(-1, 1))
221
+
222
+ def optimize_parameters(self):
223
+ self.forward() # compute fake images: G(A)
224
+ # update G
225
+ self.optimizer_FE.zero_grad() # set G's gradients to zero
226
+ self.backward_G() # calculate graidents for G
227
+ self.optimizer_FE.step() # udpate G's weights
228
+
229
+ def compute_visuals(self):
230
+
231
+ self.real_vid_A = torch.cat(self.real_A_list, 0)
232
+ self.real_vid_B = torch.cat(self.real_B_list, 0)
233
+ self.fake_vid_AR = torch.cat(self.fake_AR_list, 0)
234
+ self.fake_vid_BR = torch.cat(self.fake_BR_list, 0)
235
+ self.fake_vid_AR1 = torch.cat(self.fake_AR1_list, 0)
236
+ self.fake_vid_BR1 = torch.cat(self.fake_BR1_list, 0)
237
+ self.fake_vid_AR2 = torch.cat(self.fake_AR2_list, 0)
238
+ self.fake_vid_BR2 = torch.cat(self.fake_BR2_list, 0)
239
+ self.fake_vid_AB = torch.cat(self.fake_AB_list, 0)
240
+ self.fake_vid_B = torch.cat(self.fake_B_list, 0)
241
+
242
+ self.fake_vid = torch.cat([torch.cat([self.fake_vid_BR, self.fake_vid_BR1, self.fake_vid_BR2, self.fake_vid_B], dim = 3), torch.cat([self.fake_vid_AR, self.fake_vid_AR1, self.fake_vid_AR2, self.fake_vid_AB], dim = 3)], dim = 2)
243
+
options/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
options/base_options.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from util import util
4
+ import torch
5
+ import models
6
+ import data
7
+
8
+
9
+ class BaseOptions():
10
+ """This class defines options used during both training and test time.
11
+
12
+ It also implements several helper functions such as parsing, printing, and saving the options.
13
+ It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
14
+ """
15
+
16
+ def __init__(self):
17
+ """Reset the class; indicates the class hasn't been initailized"""
18
+ self.initialized = False
19
+
20
+ def initialize(self, parser):
21
+ """Define the common options that are used in both training and test."""
22
+ # basic parameters
23
+ parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
24
+ parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
25
+ parser.add_argument('--network_pkl', type=str, help='Network pickle filename')
26
+ parser.add_argument('--use_wandb', action='store_true', help='use wandb')
27
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
28
+ parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
29
+ # model parameters
30
+ parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
31
+ parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
32
+ parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
33
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
34
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
35
+ parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
36
+ parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
37
+ parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
38
+ parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
39
+ parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
40
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
41
+ parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
42
+ # dataset parameters
43
+ parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
44
+ parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
45
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
46
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
47
+ parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
48
+ parser.add_argument('--load_size', type=int, default=256, help='scale images to this size')
49
+ parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
50
+ parser.add_argument('--max_dataset_size', type=int, default=20000, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
51
+ parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
52
+ parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
53
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
54
+ # additional parameters
55
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
56
+ parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
57
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
58
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
59
+ self.initialized = True
60
+ return parser
61
+
62
+ def gather_options(self):
63
+ """Initialize our parser with basic options(only once).
64
+ Add additional model-specific and dataset-specific options.
65
+ These options are defined in the <modify_commandline_options> function
66
+ in model and dataset classes.
67
+ """
68
+ if not self.initialized: # check if it has been initialized
69
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
70
+ parser = self.initialize(parser)
71
+
72
+ # get the basic options
73
+ opt, _ = parser.parse_known_args()
74
+
75
+ # modify model-related parser options
76
+ model_name = opt.model
77
+ model_option_setter = models.get_option_setter(model_name)
78
+ parser = model_option_setter(parser, self.isTrain)
79
+ opt, _ = parser.parse_known_args() # parse again with new defaults
80
+
81
+ # modify dataset-related parser options
82
+ dataset_name = opt.dataset_mode
83
+ dataset_option_setter = data.get_option_setter(dataset_name)
84
+ parser = dataset_option_setter(parser, self.isTrain)
85
+
86
+ # save and return the parser
87
+ self.parser = parser
88
+ return parser.parse_args()
89
+
90
+ def print_options(self, opt):
91
+ """Print and save options
92
+
93
+ It will print both current options and default values(if different).
94
+ It will save options into a text file / [checkpoints_dir] / opt.txt
95
+ """
96
+ message = ''
97
+ message += '----------------- Options ---------------\n'
98
+ for k, v in sorted(vars(opt).items()):
99
+ comment = ''
100
+ default = self.parser.get_default(k)
101
+ if v != default:
102
+ comment = '\t[default: %s]' % str(default)
103
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
104
+ message += '----------------- End -------------------'
105
+ print(message)
106
+
107
+ # save to the disk
108
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
109
+ util.mkdirs(expr_dir)
110
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
111
+ with open(file_name, 'wt') as opt_file:
112
+ opt_file.write(message)
113
+ opt_file.write('\n')
114
+
115
+ def parse(self):
116
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
117
+ opt = self.gather_options()
118
+ opt.isTrain = self.isTrain # train or test
119
+
120
+ # process opt.suffix
121
+ if opt.suffix:
122
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
123
+ opt.name = opt.name + suffix
124
+
125
+ self.print_options(opt)
126
+
127
+ # set gpu ids
128
+ str_ids = opt.gpu_ids.split(',')
129
+ opt.gpu_ids = []
130
+ for str_id in str_ids:
131
+ id = int(str_id)
132
+ if id >= 0:
133
+ opt.gpu_ids.append(id)
134
+ if len(opt.gpu_ids) > 0:
135
+ torch.cuda.set_device(opt.gpu_ids[0])
136
+
137
+ self.opt = opt
138
+ return self.opt
options/test_options.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+
4
+ class TestOptions(BaseOptions):
5
+ """This class includes test options.
6
+
7
+ It also includes shared options defined in BaseOptions.
8
+ """
9
+
10
+ def initialize(self, parser):
11
+ parser = BaseOptions.initialize(self, parser) # define shared options
12
+ parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
13
+ parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
14
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
15
+ # Dropout and Batchnorm has different behavioir during training and test.
16
+ parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
17
+ parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
18
+ # rewrite devalue values
19
+ parser.set_defaults(model='test')
20
+ # To avoid cropping, the load_size should be the same as crop_size
21
+ parser.set_defaults(load_size=parser.get_default('crop_size'))
22
+ self.isTrain = False
23
+ return parser
options/train_options.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+
4
+ class TrainOptions(BaseOptions):
5
+ """This class includes training options.
6
+
7
+ It also includes shared options defined in BaseOptions.
8
+ """
9
+
10
+ def initialize(self, parser):
11
+ parser = BaseOptions.initialize(self, parser)
12
+ # visdom and HTML visualization parameters
13
+ parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
14
+ parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
15
+ parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
16
+ parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
17
+ parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
18
+ parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
19
+ parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
20
+ parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
21
+ parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
22
+ # network saving and loading parameters
23
+ parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
24
+ parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
25
+ parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
26
+ parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
27
+ parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
28
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
29
+ # training parameters
30
+ parser.add_argument('--n_epochs', type=int, default=50, help='number of epochs with the initial learning rate')
31
+ parser.add_argument('--n_epochs_decay', type=int, default=50, help='number of epochs to linearly decay learning rate to zero')
32
+ parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
33
+ parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
34
+ parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
35
+ parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
36
+ parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
37
+ parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
38
+
39
+ parser.add_argument('--epoch_gan', type=int, default=0,
40
+ help='finetune the whole model with GAN loss finally')
41
+
42
+ self.isTrain = True
43
+ return parser
pretrained_models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
pretrained_models/motion_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59d0dc583c1979aba667b74761f704f073427400448b2347f2766bd0e317095b
3
+ size 336251
pretrained_models/network-snapshot-005000.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f4b61e718d80495ad0864e4aa1107e4b731a423baedc5ebb0d5eb813fa990f0
3
+ size 222508337
pretrained_models/wing.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbfd137307a4c7debd5c283b9b0ce539466cee417ac0a155e184d857f9f2899c
3
+ size 193670248
torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
torch_utils/custom_ops.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import glob
10
+ import hashlib
11
+ import importlib
12
+ import os
13
+ import re
14
+ import shutil
15
+ import uuid
16
+
17
+ import torch
18
+ import torch.utils.cpp_extension
19
+ from torch.utils.file_baton import FileBaton
20
+
21
+ #----------------------------------------------------------------------------
22
+ # Global options.
23
+
24
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
25
+
26
+ #----------------------------------------------------------------------------
27
+ # Internal helper funcs.
28
+
29
+ def _find_compiler_bindir():
30
+ patterns = [
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
34
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
35
+ ]
36
+ for pattern in patterns:
37
+ matches = sorted(glob.glob(pattern))
38
+ if len(matches):
39
+ return matches[-1]
40
+ return None
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ def _get_mangled_gpu_name():
45
+ name = torch.cuda.get_device_name().lower()
46
+ out = []
47
+ for c in name:
48
+ if re.match('[a-z0-9_-]+', c):
49
+ out.append(c)
50
+ else:
51
+ out.append('-')
52
+ return ''.join(out)
53
+
54
+ #----------------------------------------------------------------------------
55
+ # Main entry point for compiling and loading C++/CUDA plugins.
56
+
57
+ _cached_plugins = dict()
58
+
59
+ def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
60
+ assert verbosity in ['none', 'brief', 'full']
61
+ if headers is None:
62
+ headers = []
63
+ if source_dir is not None:
64
+ sources = [os.path.join(source_dir, fname) for fname in sources]
65
+ headers = [os.path.join(source_dir, fname) for fname in headers]
66
+
67
+ # Already cached?
68
+ if module_name in _cached_plugins:
69
+ return _cached_plugins[module_name]
70
+
71
+ # Print status.
72
+ if verbosity == 'full':
73
+ print(f'Setting up PyTorch plugin "{module_name}"...')
74
+ elif verbosity == 'brief':
75
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
76
+ verbose_build = (verbosity == 'full')
77
+
78
+ # Compile and load.
79
+ try: # pylint: disable=too-many-nested-blocks
80
+ # Make sure we can find the necessary compiler binaries.
81
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
82
+ compiler_bindir = _find_compiler_bindir()
83
+ if compiler_bindir is None:
84
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
85
+ os.environ['PATH'] += ';' + compiler_bindir
86
+
87
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
88
+ # break the build or unnecessarily restrict what's available to nvcc.
89
+ # Unset it to let nvcc decide based on what's available on the
90
+ # machine.
91
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
92
+
93
+ # Incremental build md5sum trickery. Copies all the input source files
94
+ # into a cached build directory under a combined md5 digest of the input
95
+ # source files. Copying is done only if the combined digest has changed.
96
+ # This keeps input file timestamps and filenames the same as in previous
97
+ # extension builds, allowing for fast incremental rebuilds.
98
+ #
99
+ # This optimization is done only in case all the source files reside in
100
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
101
+ # environment variable is set (we take this as a signal that the user
102
+ # actually cares about this.)
103
+ #
104
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
105
+ # around the *.cu dependency bug in ninja config.
106
+ #
107
+ all_source_files = sorted(sources + headers)
108
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
109
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
110
+
111
+ # Compute combined hash digest for all source files.
112
+ hash_md5 = hashlib.md5()
113
+ for src in all_source_files:
114
+ with open(src, 'rb') as f:
115
+ hash_md5.update(f.read())
116
+
117
+ # Select cached build directory name.
118
+ source_digest = hash_md5.hexdigest()
119
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
120
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
121
+
122
+ if not os.path.isdir(cached_build_dir):
123
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
124
+ os.makedirs(tmpdir)
125
+ for src in all_source_files:
126
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
127
+ try:
128
+ os.replace(tmpdir, cached_build_dir) # atomic
129
+ except OSError:
130
+ # source directory already exists, delete tmpdir and its contents.
131
+ shutil.rmtree(tmpdir)
132
+ if not os.path.isdir(cached_build_dir): raise
133
+
134
+ # Compile.
135
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
137
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
138
+ else:
139
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
140
+
141
+ # Load.
142
+ module = importlib.import_module(module_name)
143
+
144
+ except:
145
+ if verbosity == 'brief':
146
+ print('Failed!')
147
+ raise
148
+
149
+ # Print status and add to cache dict.
150
+ if verbosity == 'full':
151
+ print(f'Done setting up PyTorch plugin "{module_name}".')
152
+ elif verbosity == 'brief':
153
+ print('Done.')
154
+ _cached_plugins[module_name] = module
155
+ return module
156
+
157
+ #----------------------------------------------------------------------------
torch_utils/misc.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import re
10
+ import contextlib
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import dnnlib
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18
+ # same constant is used multiple times.
19
+
20
+ _constant_cache = dict()
21
+
22
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23
+ value = np.asarray(value)
24
+ if shape is not None:
25
+ shape = tuple(shape)
26
+ if dtype is None:
27
+ dtype = torch.get_default_dtype()
28
+ if device is None:
29
+ device = torch.device('cpu')
30
+ if memory_format is None:
31
+ memory_format = torch.contiguous_format
32
+
33
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34
+ tensor = _constant_cache.get(key, None)
35
+ if tensor is None:
36
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37
+ if shape is not None:
38
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39
+ tensor = tensor.contiguous(memory_format=memory_format)
40
+ _constant_cache[key] = tensor
41
+ return tensor
42
+
43
+ #----------------------------------------------------------------------------
44
+ # Replace NaN/Inf with specified numerical values.
45
+
46
+ try:
47
+ nan_to_num = torch.nan_to_num # 1.8.0a0
48
+ except AttributeError:
49
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50
+ assert isinstance(input, torch.Tensor)
51
+ if posinf is None:
52
+ posinf = torch.finfo(input.dtype).max
53
+ if neginf is None:
54
+ neginf = torch.finfo(input.dtype).min
55
+ assert nan == 0
56
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57
+
58
+ #----------------------------------------------------------------------------
59
+ # Symbolic assert.
60
+
61
+ try:
62
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63
+ except AttributeError:
64
+ symbolic_assert = torch.Assert # 1.7.0
65
+
66
+ #----------------------------------------------------------------------------
67
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
68
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
69
+
70
+ @contextlib.contextmanager
71
+ def suppress_tracer_warnings():
72
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
73
+ warnings.filters.insert(0, flt)
74
+ yield
75
+ warnings.filters.remove(flt)
76
+
77
+ #----------------------------------------------------------------------------
78
+ # Assert that the shape of a tensor matches the given list of integers.
79
+ # None indicates that the size of a dimension is allowed to vary.
80
+ # Performs symbolic assertion when used in torch.jit.trace().
81
+
82
+ def assert_shape(tensor, ref_shape):
83
+ if tensor.ndim != len(ref_shape):
84
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
85
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86
+ if ref_size is None:
87
+ pass
88
+ elif isinstance(ref_size, torch.Tensor):
89
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
90
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
91
+ elif isinstance(size, torch.Tensor):
92
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
93
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
94
+ elif size != ref_size:
95
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
96
+
97
+ #----------------------------------------------------------------------------
98
+ # Function decorator that calls torch.autograd.profiler.record_function().
99
+
100
+ def profiled_function(fn):
101
+ def decorator(*args, **kwargs):
102
+ with torch.autograd.profiler.record_function(fn.__name__):
103
+ return fn(*args, **kwargs)
104
+ decorator.__name__ = fn.__name__
105
+ return decorator
106
+
107
+ #----------------------------------------------------------------------------
108
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
109
+ # indefinitely, shuffling items as it goes.
110
+
111
+ class InfiniteSampler(torch.utils.data.Sampler):
112
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
113
+ assert len(dataset) > 0
114
+ assert num_replicas > 0
115
+ assert 0 <= rank < num_replicas
116
+ assert 0 <= window_size <= 1
117
+ super().__init__(dataset)
118
+ self.dataset = dataset
119
+ self.rank = rank
120
+ self.num_replicas = num_replicas
121
+ self.shuffle = shuffle
122
+ self.seed = seed
123
+ self.window_size = window_size
124
+
125
+ def __iter__(self):
126
+ order = np.arange(len(self.dataset))
127
+ rnd = None
128
+ window = 0
129
+ if self.shuffle:
130
+ rnd = np.random.RandomState(self.seed)
131
+ rnd.shuffle(order)
132
+ window = int(np.rint(order.size * self.window_size))
133
+
134
+ idx = 0
135
+ while True:
136
+ i = idx % order.size
137
+ if idx % self.num_replicas == self.rank:
138
+ yield order[i]
139
+ if window >= 2:
140
+ j = (i - rnd.randint(window)) % order.size
141
+ order[i], order[j] = order[j], order[i]
142
+ idx += 1
143
+
144
+ #----------------------------------------------------------------------------
145
+ # Utilities for operating with torch.nn.Module parameters and buffers.
146
+
147
+ def params_and_buffers(module):
148
+ assert isinstance(module, torch.nn.Module)
149
+ return list(module.parameters()) + list(module.buffers())
150
+
151
+ def named_params_and_buffers(module):
152
+ assert isinstance(module, torch.nn.Module)
153
+ return list(module.named_parameters()) + list(module.named_buffers())
154
+
155
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
156
+ assert isinstance(src_module, torch.nn.Module)
157
+ assert isinstance(dst_module, torch.nn.Module)
158
+ src_tensors = dict(named_params_and_buffers(src_module))
159
+ for name, tensor in named_params_and_buffers(dst_module):
160
+ assert (name in src_tensors) or (not require_all)
161
+ if name in src_tensors:
162
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
163
+
164
+ #----------------------------------------------------------------------------
165
+ # Context manager for easily enabling/disabling DistributedDataParallel
166
+ # synchronization.
167
+
168
+ @contextlib.contextmanager
169
+ def ddp_sync(module, sync):
170
+ assert isinstance(module, torch.nn.Module)
171
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172
+ yield
173
+ else:
174
+ with module.no_sync():
175
+ yield
176
+
177
+ #----------------------------------------------------------------------------
178
+ # Check DistributedDataParallel consistency across processes.
179
+
180
+ def check_ddp_consistency(module, ignore_regex=None):
181
+ assert isinstance(module, torch.nn.Module)
182
+ for name, tensor in named_params_and_buffers(module):
183
+ fullname = type(module).__name__ + '.' + name
184
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185
+ continue
186
+ tensor = tensor.detach()
187
+ if tensor.is_floating_point():
188
+ tensor = nan_to_num(tensor)
189
+ other = tensor.clone()
190
+ torch.distributed.broadcast(tensor=other, src=0)
191
+ assert (tensor == other).all(), fullname
192
+
193
+ #----------------------------------------------------------------------------
194
+ # Print summary table of module hierarchy.
195
+
196
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
197
+ assert isinstance(module, torch.nn.Module)
198
+ assert not isinstance(module, torch.jit.ScriptModule)
199
+ assert isinstance(inputs, (tuple, list))
200
+
201
+ # Register hooks.
202
+ entries = []
203
+ nesting = [0]
204
+ def pre_hook(_mod, _inputs):
205
+ nesting[0] += 1
206
+ def post_hook(mod, _inputs, outputs):
207
+ nesting[0] -= 1
208
+ if nesting[0] <= max_nesting:
209
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
210
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
211
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
212
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
213
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
214
+
215
+ # Run module.
216
+ outputs = module(*inputs)
217
+ for hook in hooks:
218
+ hook.remove()
219
+
220
+ # Identify unique outputs, parameters, and buffers.
221
+ tensors_seen = set()
222
+ for e in entries:
223
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
224
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
225
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
226
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
227
+
228
+ # Filter out redundant entries.
229
+ if skip_redundant:
230
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
231
+
232
+ # Construct table.
233
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
234
+ rows += [['---'] * len(rows[0])]
235
+ param_total = 0
236
+ buffer_total = 0
237
+ submodule_names = {mod: name for name, mod in module.named_modules()}
238
+ for e in entries:
239
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
240
+ param_size = sum(t.numel() for t in e.unique_params)
241
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
242
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
243
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
244
+ rows += [[
245
+ name + (':0' if len(e.outputs) >= 2 else ''),
246
+ str(param_size) if param_size else '-',
247
+ str(buffer_size) if buffer_size else '-',
248
+ (output_shapes + ['-'])[0],
249
+ (output_dtypes + ['-'])[0],
250
+ ]]
251
+ for idx in range(1, len(e.outputs)):
252
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
253
+ param_total += param_size
254
+ buffer_total += buffer_size
255
+ rows += [['---'] * len(rows[0])]
256
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
257
+
258
+ # Print table.
259
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
260
+ print()
261
+ for row in rows:
262
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
263
+ print()
264
+ return outputs
265
+
266
+ #----------------------------------------------------------------------------
torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
torch_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "bias_act.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ //------------------------------------------------------------------------
21
+ // CUDA kernel.
22
+
23
+ template <class T, int A>
24
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
25
+ {
26
+ typedef typename InternalType<T>::scalar_t scalar_t;
27
+ int G = p.grad;
28
+ scalar_t alpha = (scalar_t)p.alpha;
29
+ scalar_t gain = (scalar_t)p.gain;
30
+ scalar_t clamp = (scalar_t)p.clamp;
31
+ scalar_t one = (scalar_t)1;
32
+ scalar_t two = (scalar_t)2;
33
+ scalar_t expRange = (scalar_t)80;
34
+ scalar_t halfExpRange = (scalar_t)40;
35
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
+
38
+ // Loop over elements.
39
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
+ {
42
+ // Load.
43
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
49
+ scalar_t y = 0;
50
+
51
+ // Apply bias.
52
+ ((G == 0) ? x : xref) += b;
53
+
54
+ // linear
55
+ if (A == 1)
56
+ {
57
+ if (G == 0) y = x;
58
+ if (G == 1) y = x;
59
+ }
60
+
61
+ // relu
62
+ if (A == 2)
63
+ {
64
+ if (G == 0) y = (x > 0) ? x : 0;
65
+ if (G == 1) y = (yy > 0) ? x : 0;
66
+ }
67
+
68
+ // lrelu
69
+ if (A == 3)
70
+ {
71
+ if (G == 0) y = (x > 0) ? x : x * alpha;
72
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
73
+ }
74
+
75
+ // tanh
76
+ if (A == 4)
77
+ {
78
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
+ if (G == 1) y = x * (one - yy * yy);
80
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
+ }
82
+
83
+ // sigmoid
84
+ if (A == 5)
85
+ {
86
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
+ if (G == 1) y = x * yy * (one - yy);
88
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
+ }
90
+
91
+ // elu
92
+ if (A == 6)
93
+ {
94
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
+ }
98
+
99
+ // selu
100
+ if (A == 7)
101
+ {
102
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
+ }
106
+
107
+ // softplus
108
+ if (A == 8)
109
+ {
110
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
+ if (G == 1) y = x * (one - exp(-yy));
112
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
+ }
114
+
115
+ // swish
116
+ if (A == 9)
117
+ {
118
+ if (G == 0)
119
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
+ else
121
+ {
122
+ scalar_t c = exp(xref);
123
+ scalar_t d = c + one;
124
+ if (G == 1)
125
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
+ else
127
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
+ }
130
+ }
131
+
132
+ // Apply gain.
133
+ y *= gain * dy;
134
+
135
+ // Clamp.
136
+ if (clamp >= 0)
137
+ {
138
+ if (G == 0)
139
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
+ else
141
+ y = (yref > -clamp & yref < clamp) ? y : 0;
142
+ }
143
+
144
+ // Store.
145
+ ((T*)p.y)[xi] = (T)y;
146
+ }
147
+ }
148
+
149
+ //------------------------------------------------------------------------
150
+ // CUDA kernel selection.
151
+
152
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
+ {
154
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
+ return NULL;
164
+ }
165
+
166
+ //------------------------------------------------------------------------
167
+ // Template specializations.
168
+
169
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
+
173
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+ import dnnlib
15
+
16
+ from .. import custom_ops
17
+ from .. import misc
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ activation_funcs = {
22
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
23
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
24
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
25
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
26
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
27
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
28
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
29
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
30
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
31
+ }
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ _plugin = None
36
+ _null_tensor = torch.empty([0])
37
+
38
+ def _init():
39
+ global _plugin
40
+ if _plugin is None:
41
+ _plugin = custom_ops.get_plugin(
42
+ module_name='bias_act_plugin',
43
+ sources=['bias_act.cpp', 'bias_act.cu'],
44
+ headers=['bias_act.h'],
45
+ source_dir=os.path.dirname(__file__),
46
+ extra_cuda_cflags=['--use_fast_math'],
47
+ )
48
+ return True
49
+
50
+ #----------------------------------------------------------------------------
51
+
52
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
53
+ r"""Fused bias and activation function.
54
+
55
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
56
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
57
+ the fused op is considerably more efficient than performing the same calculation
58
+ using standard PyTorch ops. It supports first and second order gradients,
59
+ but not third order gradients.
60
+
61
+ Args:
62
+ x: Input activation tensor. Can be of any shape.
63
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
64
+ as `x`. The shape must be known, and it must match the dimension of `x`
65
+ corresponding to `dim`.
66
+ dim: The dimension in `x` corresponding to the elements of `b`.
67
+ The value of `dim` is ignored if `b` is not specified.
68
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
69
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
70
+ See `activation_funcs` for a full list. `None` is not allowed.
71
+ alpha: Shape parameter for the activation function, or `None` to use the default.
72
+ gain: Scaling factor for the output tensor, or `None` to use default.
73
+ See `activation_funcs` for the default scaling of each activation function.
74
+ If unsure, consider specifying 1.
75
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
76
+ the clamping (default).
77
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
78
+
79
+ Returns:
80
+ Tensor of the same shape and datatype as `x`.
81
+ """
82
+ assert isinstance(x, torch.Tensor)
83
+ assert impl in ['ref', 'cuda']
84
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
85
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
86
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
87
+
88
+ #----------------------------------------------------------------------------
89
+
90
+ @misc.profiled_function
91
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
92
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
93
+ """
94
+ assert isinstance(x, torch.Tensor)
95
+ assert clamp is None or clamp >= 0
96
+ spec = activation_funcs[act]
97
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
98
+ gain = float(gain if gain is not None else spec.def_gain)
99
+ clamp = float(clamp if clamp is not None else -1)
100
+
101
+ # Add bias.
102
+ if b is not None:
103
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
104
+ assert 0 <= dim < x.ndim
105
+ assert b.shape[0] == x.shape[dim]
106
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
107
+
108
+ # Evaluate activation function.
109
+ alpha = float(alpha)
110
+ x = spec.func(x, alpha=alpha)
111
+
112
+ # Scale by gain.
113
+ gain = float(gain)
114
+ if gain != 1:
115
+ x = x * gain
116
+
117
+ # Clamp.
118
+ if clamp >= 0:
119
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
120
+ return x
121
+
122
+ #----------------------------------------------------------------------------
123
+
124
+ _bias_act_cuda_cache = dict()
125
+
126
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
127
+ """Fast CUDA implementation of `bias_act()` using custom ops.
128
+ """
129
+ # Parse arguments.
130
+ assert clamp is None or clamp >= 0
131
+ spec = activation_funcs[act]
132
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
133
+ gain = float(gain if gain is not None else spec.def_gain)
134
+ clamp = float(clamp if clamp is not None else -1)
135
+
136
+ # Lookup from cache.
137
+ key = (dim, act, alpha, gain, clamp)
138
+ if key in _bias_act_cuda_cache:
139
+ return _bias_act_cuda_cache[key]
140
+
141
+ # Forward op.
142
+ class BiasActCuda(torch.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
145
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
146
+ x = x.contiguous(memory_format=ctx.memory_format)
147
+ b = b.contiguous() if b is not None else _null_tensor
148
+ y = x
149
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
150
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
151
+ ctx.save_for_backward(
152
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
153
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
154
+ y if 'y' in spec.ref else _null_tensor)
155
+ return y
156
+
157
+ @staticmethod
158
+ def backward(ctx, dy): # pylint: disable=arguments-differ
159
+ dy = dy.contiguous(memory_format=ctx.memory_format)
160
+ x, b, y = ctx.saved_tensors
161
+ dx = None
162
+ db = None
163
+
164
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
165
+ dx = dy
166
+ if act != 'linear' or gain != 1 or clamp >= 0:
167
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
168
+
169
+ if ctx.needs_input_grad[1]:
170
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
171
+
172
+ return dx, db
173
+
174
+ # Backward op.
175
+ class BiasActCudaGrad(torch.autograd.Function):
176
+ @staticmethod
177
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
178
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
179
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
180
+ ctx.save_for_backward(
181
+ dy if spec.has_2nd_grad else _null_tensor,
182
+ x, b, y)
183
+ return dx
184
+
185
+ @staticmethod
186
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
187
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
188
+ dy, x, b, y = ctx.saved_tensors
189
+ d_dy = None
190
+ d_x = None
191
+ d_b = None
192
+ d_y = None
193
+
194
+ if ctx.needs_input_grad[0]:
195
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
196
+
197
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
198
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
199
+
200
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
201
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
202
+
203
+ return d_dy, d_x, d_b, d_y
204
+
205
+ # Add to cache.
206
+ _bias_act_cuda_cache[key] = BiasActCuda
207
+ return BiasActCuda
208
+
209
+ #----------------------------------------------------------------------------
torch_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
10
+ arbitrarily high order gradients with zero performance penalty."""
11
+
12
+ import contextlib
13
+ import torch
14
+
15
+ # pylint: disable=redefined-builtin
16
+ # pylint: disable=arguments-differ
17
+ # pylint: disable=protected-access
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ enabled = False # Enable the custom op by setting this to true.
22
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
23
+
24
+ @contextlib.contextmanager
25
+ def no_weight_gradients(disable=True):
26
+ global weight_gradients_disabled
27
+ old = weight_gradients_disabled
28
+ if disable:
29
+ weight_gradients_disabled = True
30
+ yield
31
+ weight_gradients_disabled = old
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36
+ if _should_use_custom_op(input):
37
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39
+
40
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41
+ if _should_use_custom_op(input):
42
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ def _should_use_custom_op(input):
48
+ assert isinstance(input, torch.Tensor)
49
+ if (not enabled) or (not torch.backends.cudnn.enabled):
50
+ return False
51
+ if input.device.type != 'cuda':
52
+ return False
53
+ return True
54
+
55
+ def _tuple_of_ints(xs, ndim):
56
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
57
+ assert len(xs) == ndim
58
+ assert all(isinstance(x, int) for x in xs)
59
+ return xs
60
+
61
+ #----------------------------------------------------------------------------
62
+
63
+ _conv2d_gradfix_cache = dict()
64
+ _null_tensor = torch.empty([0])
65
+
66
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
67
+ # Parse arguments.
68
+ ndim = 2
69
+ weight_shape = tuple(weight_shape)
70
+ stride = _tuple_of_ints(stride, ndim)
71
+ padding = _tuple_of_ints(padding, ndim)
72
+ output_padding = _tuple_of_ints(output_padding, ndim)
73
+ dilation = _tuple_of_ints(dilation, ndim)
74
+
75
+ # Lookup from cache.
76
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
77
+ if key in _conv2d_gradfix_cache:
78
+ return _conv2d_gradfix_cache[key]
79
+
80
+ # Validate arguments.
81
+ assert groups >= 1
82
+ assert len(weight_shape) == ndim + 2
83
+ assert all(stride[i] >= 1 for i in range(ndim))
84
+ assert all(padding[i] >= 0 for i in range(ndim))
85
+ assert all(dilation[i] >= 0 for i in range(ndim))
86
+ if not transpose:
87
+ assert all(output_padding[i] == 0 for i in range(ndim))
88
+ else: # transpose
89
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
90
+
91
+ # Helpers.
92
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
93
+ def calc_output_padding(input_shape, output_shape):
94
+ if transpose:
95
+ return [0, 0]
96
+ return [
97
+ input_shape[i + 2]
98
+ - (output_shape[i + 2] - 1) * stride[i]
99
+ - (1 - 2 * padding[i])
100
+ - dilation[i] * (weight_shape[i + 2] - 1)
101
+ for i in range(ndim)
102
+ ]
103
+
104
+ # Forward & backward.
105
+ class Conv2d(torch.autograd.Function):
106
+ @staticmethod
107
+ def forward(ctx, input, weight, bias):
108
+ assert weight.shape == weight_shape
109
+ ctx.save_for_backward(
110
+ input if weight.requires_grad else _null_tensor,
111
+ weight if input.requires_grad else _null_tensor,
112
+ )
113
+ ctx.input_shape = input.shape
114
+
115
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
116
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
117
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
118
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
119
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
120
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
121
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
122
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
123
+
124
+ # General case => cuDNN.
125
+ if transpose:
126
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
127
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
128
+
129
+ @staticmethod
130
+ def backward(ctx, grad_output):
131
+ input, weight = ctx.saved_tensors
132
+ input_shape = ctx.input_shape
133
+ grad_input = None
134
+ grad_weight = None
135
+ grad_bias = None
136
+
137
+ if ctx.needs_input_grad[0]:
138
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
139
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
140
+ grad_input = op.apply(grad_output, weight, None)
141
+ assert grad_input.shape == input_shape
142
+
143
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
144
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
145
+ assert grad_weight.shape == weight_shape
146
+
147
+ if ctx.needs_input_grad[2]:
148
+ grad_bias = grad_output.sum([0, 2, 3])
149
+
150
+ return grad_input, grad_weight, grad_bias
151
+
152
+ # Gradient with respect to the weights.
153
+ class Conv2dGradWeight(torch.autograd.Function):
154
+ @staticmethod
155
+ def forward(ctx, grad_output, input):
156
+ ctx.save_for_backward(
157
+ grad_output if input.requires_grad else _null_tensor,
158
+ input if grad_output.requires_grad else _null_tensor,
159
+ )
160
+ ctx.grad_output_shape = grad_output.shape
161
+ ctx.input_shape = input.shape
162
+
163
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
164
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
165
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
166
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
167
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
168
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
169
+
170
+ # General case => cuDNN.
171
+ name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
172
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
173
+ return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
174
+
175
+ @staticmethod
176
+ def backward(ctx, grad2_grad_weight):
177
+ grad_output, input = ctx.saved_tensors
178
+ grad_output_shape = ctx.grad_output_shape
179
+ input_shape = ctx.input_shape
180
+ grad2_grad_output = None
181
+ grad2_input = None
182
+
183
+ if ctx.needs_input_grad[0]:
184
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
185
+ assert grad2_grad_output.shape == grad_output_shape
186
+
187
+ if ctx.needs_input_grad[1]:
188
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
189
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
190
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
191
+ assert grad2_input.shape == input_shape
192
+
193
+ return grad2_grad_output, grad2_input
194
+
195
+ _conv2d_gradfix_cache[key] = Conv2d
196
+ return Conv2d
197
+
198
+ #----------------------------------------------------------------------------
torch_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """2D convolution with optional up/downsampling."""
10
+
11
+ import torch
12
+
13
+ from .. import misc
14
+ from . import conv2d_gradfix
15
+ from . import upfirdn2d
16
+ from .upfirdn2d import _parse_padding
17
+ from .upfirdn2d import _get_filter_size
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def _get_weight_shape(w):
22
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
+ shape = [int(sz) for sz in w.shape]
24
+ misc.assert_shape(w, shape)
25
+ return shape
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
+ """
32
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
33
+
34
+ # Flip weight if requested.
35
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
+ if not flip_weight and (kw > 1 or kh > 1):
37
+ w = w.flip([2, 3])
38
+
39
+ # Execute using conv2d_gradfix.
40
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
41
+ return op(x, w, stride=stride, padding=padding, groups=groups)
42
+
43
+ #----------------------------------------------------------------------------
44
+
45
+ @misc.profiled_function
46
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
47
+ r"""2D convolution with optional up/downsampling.
48
+
49
+ Padding is performed only once at the beginning, not between the operations.
50
+
51
+ Args:
52
+ x: Input tensor of shape
53
+ `[batch_size, in_channels, in_height, in_width]`.
54
+ w: Weight tensor of shape
55
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
56
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
57
+ calling upfirdn2d.setup_filter(). None = identity (default).
58
+ up: Integer upsampling factor (default: 1).
59
+ down: Integer downsampling factor (default: 1).
60
+ padding: Padding with respect to the upsampled image. Can be a single number
61
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
62
+ (default: 0).
63
+ groups: Split input channels into N groups (default: 1).
64
+ flip_weight: False = convolution, True = correlation (default: True).
65
+ flip_filter: False = convolution, True = correlation (default: False).
66
+
67
+ Returns:
68
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
69
+ """
70
+ # Validate arguments.
71
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
72
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
73
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
74
+ assert isinstance(up, int) and (up >= 1)
75
+ assert isinstance(down, int) and (down >= 1)
76
+ assert isinstance(groups, int) and (groups >= 1)
77
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
78
+ fw, fh = _get_filter_size(f)
79
+ px0, px1, py0, py1 = _parse_padding(padding)
80
+
81
+ # Adjust padding to account for up/downsampling.
82
+ if up > 1:
83
+ px0 += (fw + up - 1) // 2
84
+ px1 += (fw - up) // 2
85
+ py0 += (fh + up - 1) // 2
86
+ py1 += (fh - up) // 2
87
+ if down > 1:
88
+ px0 += (fw - down + 1) // 2
89
+ px1 += (fw - down) // 2
90
+ py0 += (fh - down + 1) // 2
91
+ py1 += (fh - down) // 2
92
+
93
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
94
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
95
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
96
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
97
+ return x
98
+
99
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
100
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
101
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
102
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
103
+ return x
104
+
105
+ # Fast path: downsampling only => use strided convolution.
106
+ if down > 1 and up == 1:
107
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
108
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
109
+ return x
110
+
111
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
112
+ if up > 1:
113
+ if groups == 1:
114
+ w = w.transpose(0, 1)
115
+ else:
116
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
117
+ w = w.transpose(1, 2)
118
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
119
+ px0 -= kw - 1
120
+ px1 -= kw - up
121
+ py0 -= kh - 1
122
+ py1 -= kh - up
123
+ pxt = max(min(-px0, -px1), 0)
124
+ pyt = max(min(-py0, -py1), 0)
125
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
126
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
127
+ if down > 1:
128
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
129
+ return x
130
+
131
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
132
+ if up == 1 and down == 1:
133
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
134
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
135
+
136
+ # Fallback: Generic reference implementation.
137
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
138
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
139
+ if down > 1:
140
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
141
+ return x
142
+
143
+ #----------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.cpp ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "filtered_lrelu.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
17
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
18
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
19
+ {
20
+ // Set CUDA device.
21
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
22
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
23
+
24
+ // Validate arguments.
25
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
26
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
27
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
28
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
29
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
30
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
31
+ TORCH_CHECK(x.numel() > 0, "x is empty");
32
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
33
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
34
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
35
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
36
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
37
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
38
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
39
+
40
+ // Figure out how much shared memory is available on the device.
41
+ int maxSharedBytes = 0;
42
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
43
+ int sharedKB = maxSharedBytes >> 10;
44
+
45
+ // Populate enough launch parameters to check if a CUDA kernel exists.
46
+ filtered_lrelu_kernel_params p;
47
+ p.up = up;
48
+ p.down = down;
49
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
50
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
51
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
52
+ if (!test_spec.exec)
53
+ {
54
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
55
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
56
+ }
57
+
58
+ // Input/output element size.
59
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
60
+
61
+ // Input sizes.
62
+ int64_t xw = (int)x.size(3);
63
+ int64_t xh = (int)x.size(2);
64
+ int64_t fut_w = (int)fu.size(-1) - 1;
65
+ int64_t fut_h = (int)fu.size(0) - 1;
66
+ int64_t fdt_w = (int)fd.size(-1) - 1;
67
+ int64_t fdt_h = (int)fd.size(0) - 1;
68
+
69
+ // Logical size of upsampled buffer.
70
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
71
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
72
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
73
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
74
+
75
+ // Compute output size and allocate.
76
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
77
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
78
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
79
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
80
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
81
+
82
+ // Allocate sign tensor.
83
+ torch::Tensor so;
84
+ torch::Tensor s = si;
85
+ bool readSigns = !!s.numel();
86
+ int64_t sw_active = 0; // Active width of sign tensor.
87
+ if (writeSigns)
88
+ {
89
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
90
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
91
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
92
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
93
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
94
+ }
95
+ else if (readSigns)
96
+ sw_active = s.size(3) << 2;
97
+
98
+ // Validate sign tensor if in use.
99
+ if (readSigns || writeSigns)
100
+ {
101
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
102
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
103
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
104
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
105
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
106
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
107
+ }
108
+
109
+ // Populate rest of CUDA kernel parameters.
110
+ p.x = x.data_ptr();
111
+ p.y = y.data_ptr();
112
+ p.b = b.data_ptr();
113
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
114
+ p.fu = fu.data_ptr<float>();
115
+ p.fd = fd.data_ptr<float>();
116
+ p.pad0 = make_int2(px0, py0);
117
+ p.gain = gain;
118
+ p.slope = slope;
119
+ p.clamp = clamp;
120
+ p.flip = (flip_filters) ? 1 : 0;
121
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
122
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
123
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
124
+ p.sOfs = make_int2(sx, sy);
125
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
126
+
127
+ // x, y, b strides are in bytes.
128
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
129
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
130
+ p.bStride = sz * b.stride(0);
131
+
132
+ // fu, fd strides are in elements.
133
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
134
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
135
+
136
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
137
+ bool index64b = false;
138
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
139
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
140
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
141
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
142
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
143
+ if (s.numel() > INT_MAX) index64b = true;
144
+
145
+ // Choose CUDA kernel.
146
+ filtered_lrelu_kernel_spec spec = { 0 };
147
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
148
+ {
149
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
150
+ {
151
+ // Choose kernel based on index type, datatype and sign read/write modes.
152
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
153
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
154
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
155
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
156
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
157
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
158
+ }
159
+ });
160
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
161
+
162
+ // Launch CUDA kernel.
163
+ void* args[] = {&p};
164
+ int bx = spec.numWarps * 32;
165
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
166
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
167
+ int gz = p.yShape.z * p.yShape.w;
168
+
169
+ // Repeat multiple horizontal tiles in a CTA?
170
+ if (spec.xrep)
171
+ {
172
+ p.tilesXrep = spec.xrep;
173
+ p.tilesXdim = gx;
174
+
175
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
176
+ std::swap(gx, gy);
177
+ }
178
+ else
179
+ {
180
+ p.tilesXrep = 0;
181
+ p.tilesXdim = 0;
182
+ }
183
+
184
+ // Launch filter setup kernel.
185
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
186
+
187
+ // Copy kernels to constant memory.
188
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
189
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
190
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
191
+
192
+ // Set cache and shared memory configurations for main kernel.
193
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
194
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
195
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
196
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
197
+
198
+ // Launch main kernel.
199
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
200
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
201
+ {
202
+ p.blockZofs = zofs;
203
+ int subGz = std::min(maxSubGz, gz - zofs);
204
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
205
+ }
206
+
207
+ // Done.
208
+ return std::make_tuple(y, so, 0);
209
+ }
210
+
211
+ //------------------------------------------------------------------------
212
+
213
+ static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
214
+ {
215
+ // Set CUDA device.
216
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
217
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
218
+
219
+ // Validate arguments.
220
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
221
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
222
+ TORCH_CHECK(x.numel() > 0, "x is empty");
223
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
224
+
225
+ // Output signs if we don't have sign input.
226
+ torch::Tensor so;
227
+ torch::Tensor s = si;
228
+ bool readSigns = !!s.numel();
229
+ if (writeSigns)
230
+ {
231
+ int64_t sw = x.size(3);
232
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
233
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
234
+ }
235
+
236
+ // Validate sign tensor if in use.
237
+ if (readSigns || writeSigns)
238
+ {
239
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
240
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
241
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
242
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
243
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
244
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
245
+ }
246
+
247
+ // Initialize CUDA kernel parameters.
248
+ filtered_lrelu_act_kernel_params p;
249
+ p.x = x.data_ptr();
250
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
251
+ p.gain = gain;
252
+ p.slope = slope;
253
+ p.clamp = clamp;
254
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
255
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
256
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
257
+ p.sOfs = make_int2(sx, sy);
258
+
259
+ // Choose CUDA kernel.
260
+ void* func = 0;
261
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
262
+ {
263
+ if (writeSigns)
264
+ func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
265
+ else if (readSigns)
266
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
267
+ else
268
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
269
+ });
270
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
271
+
272
+ // Launch CUDA kernel.
273
+ void* args[] = {&p};
274
+ int bx = 128; // 4 warps per block.
275
+
276
+ // Logical size of launch = writeSigns ? p.s : p.x
277
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
278
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
279
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
280
+ gx = (gx - 1) / bx + 1;
281
+
282
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
283
+ const uint32_t gmax = 65535;
284
+ gy = std::min(gy, gmax);
285
+ gz = std::min(gz, gmax);
286
+
287
+ // Launch.
288
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
289
+ return so;
290
+ }
291
+
292
+ //------------------------------------------------------------------------
293
+
294
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
295
+ {
296
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
297
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
298
+ }
299
+
300
+ //------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.cu ADDED
@@ -0,0 +1,1284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "filtered_lrelu.h"
11
+ #include <cstdint>
12
+
13
+ //------------------------------------------------------------------------
14
+ // Helpers.
15
+
16
+ enum // Filter modes.
17
+ {
18
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
19
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
20
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
21
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
22
+ };
23
+
24
+ template <class T> struct InternalType;
25
+ template <> struct InternalType<double>
26
+ {
27
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
28
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
29
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
30
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
31
+ };
32
+ template <> struct InternalType<float>
33
+ {
34
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
35
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
36
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
37
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
38
+ };
39
+ template <> struct InternalType<c10::Half>
40
+ {
41
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
42
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
43
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
44
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
45
+ };
46
+
47
+ #define MIN(A, B) ((A) < (B) ? (A) : (B))
48
+ #define MAX(A, B) ((A) > (B) ? (A) : (B))
49
+ #define CEIL_DIV(A, B) (((B)==1) ? (A) : \
50
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
51
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
52
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
53
+
54
+ // This works only up to blocks of size 256 x 256 and for all N that are powers of two.
55
+ template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
56
+ {
57
+ if ((N & (N-1)) && N <= 256)
58
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
59
+ else
60
+ y = i/N;
61
+
62
+ x = i - y*N;
63
+ }
64
+
65
+ // Type cast stride before reading it.
66
+ template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
67
+ {
68
+ return *reinterpret_cast<const T*>(&x);
69
+ }
70
+
71
+ //------------------------------------------------------------------------
72
+ // Filters, setup kernel, copying function.
73
+
74
+ #define MAX_FILTER_SIZE 32
75
+
76
+ // Combined up/down filter buffers so that transfer can be done with one copy.
77
+ __device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
78
+ __device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
79
+
80
+ // Accessors to combined buffers to index up/down filters individually.
81
+ #define c_fu (c_fbuf)
82
+ #define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
83
+ #define g_fu (g_fbuf)
84
+ #define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
85
+
86
+ // Set up filters into global memory buffer.
87
+ static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
88
+ {
89
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
90
+ {
91
+ int x, y;
92
+ fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
93
+
94
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
95
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
96
+ if (p.fuShape.y > 0)
97
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
98
+ else
99
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
100
+
101
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
102
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
103
+ if (p.fdShape.y > 0)
104
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
105
+ else
106
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
107
+ }
108
+ }
109
+
110
+ // Host function to copy filters written by setup kernel into constant buffer for main kernel.
111
+ template <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)
112
+ {
113
+ void* src = 0;
114
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
115
+ if (err) return err;
116
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
117
+ }
118
+
119
+ //------------------------------------------------------------------------
120
+ // Coordinate spaces:
121
+ // - Relative to input tensor: inX, inY, tileInX, tileInY
122
+ // - Relative to input tile: relInX, relInY, tileInW, tileInH
123
+ // - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
124
+ // - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
125
+ // - Relative to output tensor: outX, outY, tileOutX, tileOutY
126
+ //
127
+ // Relationships between coordinate spaces:
128
+ // - inX = tileInX + relInX
129
+ // - inY = tileInY + relInY
130
+ // - relUpX = relInX * up + phaseInX
131
+ // - relUpY = relInY * up + phaseInY
132
+ // - relUpX = relOutX * down
133
+ // - relUpY = relOutY * down
134
+ // - outX = tileOutX + relOutX
135
+ // - outY = tileOutY + relOutY
136
+
137
+ extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
138
+
139
+ template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
140
+ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
141
+ {
142
+ // Check that we don't try to support non-existing filter modes.
143
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
144
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
145
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
146
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
147
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
148
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
149
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
150
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
151
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
152
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
153
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
154
+
155
+ // Static definitions.
156
+ typedef typename InternalType<T>::scalar_t scalar_t;
157
+ typedef typename InternalType<T>::vec2_t vec2_t;
158
+ typedef typename InternalType<T>::vec4_t vec4_t;
159
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
160
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
161
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
162
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
163
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
164
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
165
+
166
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
167
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
168
+
169
+ // Sizes of logical buffers.
170
+ const int szIn = tileInH_up * tileInW;
171
+ const int szUpX = tileInH_up * tileUpW;
172
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
173
+ const int szDownX = tileUpH * tileOutW;
174
+
175
+ // Sizes for shared memory arrays.
176
+ const int s_buf0_size_base =
177
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
178
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
179
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
180
+ (filterMode == MODE_FUFD) ? szIn :
181
+ -1;
182
+ const int s_buf1_size_base =
183
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
184
+ (filterMode == MODE_FUSD) ? szUpXY :
185
+ (filterMode == MODE_SUFD) ? szUpX :
186
+ (filterMode == MODE_FUFD) ? szUpXY :
187
+ -1;
188
+
189
+ // Ensure U128 alignment.
190
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
191
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
192
+
193
+ // Check at compile time that we don't use too much shared memory.
194
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
195
+
196
+ // Declare shared memory arrays.
197
+ scalar_t* s_buf0;
198
+ scalar_t* s_buf1;
199
+ if (sharedKB <= 48)
200
+ {
201
+ // Allocate shared memory arrays here.
202
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
203
+ s_buf0 = s_buf0_st;
204
+ s_buf1 = s_buf0 + s_buf0_size;
205
+ }
206
+ else
207
+ {
208
+ // Use the dynamically allocated shared memory array.
209
+ s_buf0 = (scalar_t*)s_buf_raw;
210
+ s_buf1 = s_buf0 + s_buf0_size;
211
+ }
212
+
213
+ // Pointers to the buffers.
214
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
215
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
216
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
217
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
218
+ if (filterMode == MODE_SUSD)
219
+ {
220
+ s_tileIn = s_buf0;
221
+ s_tileUpX = s_buf1;
222
+ s_tileUpXY = s_buf0;
223
+ s_tileDownX = s_buf1;
224
+ }
225
+ else if (filterMode == MODE_FUSD)
226
+ {
227
+ s_tileIn = s_buf0;
228
+ s_tileUpXY = s_buf1;
229
+ s_tileDownX = s_buf0;
230
+ }
231
+ else if (filterMode == MODE_SUFD)
232
+ {
233
+ s_tileIn = s_buf0;
234
+ s_tileUpX = s_buf1;
235
+ s_tileUpXY = s_buf0;
236
+ }
237
+ else if (filterMode == MODE_FUFD)
238
+ {
239
+ s_tileIn = s_buf0;
240
+ s_tileUpXY = s_buf1;
241
+ }
242
+
243
+ // Allow large grids in z direction via per-launch offset.
244
+ int channelIdx = blockIdx.z + p.blockZofs;
245
+ int batchIdx = channelIdx / p.yShape.z;
246
+ channelIdx -= batchIdx * p.yShape.z;
247
+
248
+ // Offset to output feature map. In bytes.
249
+ index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
250
+
251
+ // Sign shift amount.
252
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
253
+
254
+ // Inner tile loop.
255
+ #pragma unroll 1
256
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
257
+ {
258
+ // Locate output tile.
259
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
260
+ int tileOutX = tileX * tileOutW;
261
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
262
+
263
+ // Locate input tile.
264
+ int tmpX = tileOutX * down - p.pad0.x;
265
+ int tmpY = tileOutY * down - p.pad0.y;
266
+ int tileInX = CEIL_DIV(tmpX, up);
267
+ int tileInY = CEIL_DIV(tmpY, up);
268
+ const int phaseInX = tileInX * up - tmpX;
269
+ const int phaseInY = tileInY * up - tmpY;
270
+
271
+ // Extra sync if input and output buffers are the same and we are not on first tile.
272
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
273
+ __syncthreads();
274
+
275
+ // Load input tile & apply bias. Unrolled.
276
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
277
+ index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
278
+ int idx = threadIdx.x;
279
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
280
+ #pragma unroll
281
+ for (int loop = 0; loop < loopCountIN; loop++)
282
+ {
283
+ int relInX, relInY;
284
+ fast_div_mod<tileInW>(relInX, relInY, idx);
285
+ int inX = tileInX + relInX;
286
+ int inY = tileInY + relInY;
287
+ scalar_t v = 0;
288
+
289
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
290
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
291
+
292
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
293
+ if (!skip)
294
+ s_tileIn[idx] = v;
295
+
296
+ idx += threadsPerBlock;
297
+ }
298
+
299
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
300
+ {
301
+ // Horizontal upsampling.
302
+ __syncthreads();
303
+ if (up == 4)
304
+ {
305
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
306
+ {
307
+ int relUpX0, relInY;
308
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
309
+ int relInX0 = relUpX0 / up;
310
+ int src0 = relInX0 + tileInW * relInY;
311
+ int dst = relInY * tileUpW + relUpX0;
312
+ vec4_t v = InternalType<T>::zero_vec4();
313
+ scalar_t a = s_tileIn[src0];
314
+ if (phaseInX == 0)
315
+ {
316
+ #pragma unroll
317
+ for (int step = 0; step < fuSize / up; step++)
318
+ {
319
+ v.x += a * (scalar_t)c_fu[step * up + 0];
320
+ a = s_tileIn[src0 + step + 1];
321
+ v.y += a * (scalar_t)c_fu[step * up + 3];
322
+ v.z += a * (scalar_t)c_fu[step * up + 2];
323
+ v.w += a * (scalar_t)c_fu[step * up + 1];
324
+ }
325
+ }
326
+ else if (phaseInX == 1)
327
+ {
328
+ #pragma unroll
329
+ for (int step = 0; step < fuSize / up; step++)
330
+ {
331
+ v.x += a * (scalar_t)c_fu[step * up + 1];
332
+ v.y += a * (scalar_t)c_fu[step * up + 0];
333
+ a = s_tileIn[src0 + step + 1];
334
+ v.z += a * (scalar_t)c_fu[step * up + 3];
335
+ v.w += a * (scalar_t)c_fu[step * up + 2];
336
+ }
337
+ }
338
+ else if (phaseInX == 2)
339
+ {
340
+ #pragma unroll
341
+ for (int step = 0; step < fuSize / up; step++)
342
+ {
343
+ v.x += a * (scalar_t)c_fu[step * up + 2];
344
+ v.y += a * (scalar_t)c_fu[step * up + 1];
345
+ v.z += a * (scalar_t)c_fu[step * up + 0];
346
+ a = s_tileIn[src0 + step + 1];
347
+ v.w += a * (scalar_t)c_fu[step * up + 3];
348
+ }
349
+ }
350
+ else // (phaseInX == 3)
351
+ {
352
+ #pragma unroll
353
+ for (int step = 0; step < fuSize / up; step++)
354
+ {
355
+ v.x += a * (scalar_t)c_fu[step * up + 3];
356
+ v.y += a * (scalar_t)c_fu[step * up + 2];
357
+ v.z += a * (scalar_t)c_fu[step * up + 1];
358
+ v.w += a * (scalar_t)c_fu[step * up + 0];
359
+ a = s_tileIn[src0 + step + 1];
360
+ }
361
+ }
362
+ s_tileUpX[dst+0] = v.x;
363
+ s_tileUpX[dst+1] = v.y;
364
+ s_tileUpX[dst+2] = v.z;
365
+ s_tileUpX[dst+3] = v.w;
366
+ }
367
+ }
368
+ else if (up == 2)
369
+ {
370
+ bool p0 = (phaseInX == 0);
371
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
372
+ {
373
+ int relUpX0, relInY;
374
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
375
+ int relInX0 = relUpX0 / up;
376
+ int src0 = relInX0 + tileInW * relInY;
377
+ int dst = relInY * tileUpW + relUpX0;
378
+ vec2_t v = InternalType<T>::zero_vec2();
379
+ scalar_t a = s_tileIn[src0];
380
+ if (p0) // (phaseInX == 0)
381
+ {
382
+ #pragma unroll
383
+ for (int step = 0; step < fuSize / up; step++)
384
+ {
385
+ v.x += a * (scalar_t)c_fu[step * up + 0];
386
+ a = s_tileIn[src0 + step + 1];
387
+ v.y += a * (scalar_t)c_fu[step * up + 1];
388
+ }
389
+ }
390
+ else // (phaseInX == 1)
391
+ {
392
+ #pragma unroll
393
+ for (int step = 0; step < fuSize / up; step++)
394
+ {
395
+ v.x += a * (scalar_t)c_fu[step * up + 1];
396
+ v.y += a * (scalar_t)c_fu[step * up + 0];
397
+ a = s_tileIn[src0 + step + 1];
398
+ }
399
+ }
400
+ s_tileUpX[dst+0] = v.x;
401
+ s_tileUpX[dst+1] = v.y;
402
+ }
403
+ }
404
+
405
+ // Vertical upsampling & nonlinearity.
406
+
407
+ __syncthreads();
408
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
409
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
410
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
411
+ if (up == 4)
412
+ {
413
+ minY -= 3; // Adjust according to block height.
414
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
415
+ {
416
+ int relUpX, relInY0;
417
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
418
+ int relUpY0 = relInY0 * up;
419
+ int src0 = relInY0 * tileUpW + relUpX;
420
+ int dst = relUpY0 * tileUpW + relUpX;
421
+ vec4_t v = InternalType<T>::zero_vec4();
422
+
423
+ scalar_t a = s_tileUpX[src0];
424
+ if (phaseInY == 0)
425
+ {
426
+ #pragma unroll
427
+ for (int step = 0; step < fuSize / up; step++)
428
+ {
429
+ v.x += a * (scalar_t)c_fu[step * up + 0];
430
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
431
+ v.y += a * (scalar_t)c_fu[step * up + 3];
432
+ v.z += a * (scalar_t)c_fu[step * up + 2];
433
+ v.w += a * (scalar_t)c_fu[step * up + 1];
434
+ }
435
+ }
436
+ else if (phaseInY == 1)
437
+ {
438
+ #pragma unroll
439
+ for (int step = 0; step < fuSize / up; step++)
440
+ {
441
+ v.x += a * (scalar_t)c_fu[step * up + 1];
442
+ v.y += a * (scalar_t)c_fu[step * up + 0];
443
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
444
+ v.z += a * (scalar_t)c_fu[step * up + 3];
445
+ v.w += a * (scalar_t)c_fu[step * up + 2];
446
+ }
447
+ }
448
+ else if (phaseInY == 2)
449
+ {
450
+ #pragma unroll
451
+ for (int step = 0; step < fuSize / up; step++)
452
+ {
453
+ v.x += a * (scalar_t)c_fu[step * up + 2];
454
+ v.y += a * (scalar_t)c_fu[step * up + 1];
455
+ v.z += a * (scalar_t)c_fu[step * up + 0];
456
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
457
+ v.w += a * (scalar_t)c_fu[step * up + 3];
458
+ }
459
+ }
460
+ else // (phaseInY == 3)
461
+ {
462
+ #pragma unroll
463
+ for (int step = 0; step < fuSize / up; step++)
464
+ {
465
+ v.x += a * (scalar_t)c_fu[step * up + 3];
466
+ v.y += a * (scalar_t)c_fu[step * up + 2];
467
+ v.z += a * (scalar_t)c_fu[step * up + 1];
468
+ v.w += a * (scalar_t)c_fu[step * up + 0];
469
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
470
+ }
471
+ }
472
+
473
+ int x = tileOutX * down + relUpX;
474
+ int y = tileOutY * down + relUpY0;
475
+ int signX = x + p.sOfs.x;
476
+ int signY = y + p.sOfs.y;
477
+ int signZ = blockIdx.z + p.blockZofs;
478
+ int signXb = signX >> 2;
479
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
480
+ index_t si1 = si0 + p.sShape.x;
481
+ index_t si2 = si0 + p.sShape.x * 2;
482
+ index_t si3 = si0 + p.sShape.x * 3;
483
+
484
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
485
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
486
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
487
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
488
+
489
+ if (signWrite)
490
+ {
491
+ if (!enableWriteSkip)
492
+ {
493
+ // Determine and write signs.
494
+ int sx = __float_as_uint(v.x) >> 31 << 0;
495
+ int sy = __float_as_uint(v.y) >> 31 << 8;
496
+ int sz = __float_as_uint(v.z) >> 31 << 16;
497
+ int sw = __float_as_uint(v.w) >> 31 << 24;
498
+ if (sx) v.x *= p.slope;
499
+ if (sy) v.y *= p.slope;
500
+ if (sz) v.z *= p.slope;
501
+ if (sw) v.w *= p.slope;
502
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
503
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
504
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
505
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
506
+
507
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
508
+ {
509
+ // Combine signs.
510
+ uint32_t s = sx + sy + sw + sz;
511
+ s <<= (signX & 3) << 1;
512
+ s |= __shfl_xor_sync(groupMask, s, 1);
513
+ s |= __shfl_xor_sync(groupMask, s, 2);
514
+
515
+ // Write signs.
516
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
517
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
518
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
519
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
520
+ }
521
+ }
522
+ else
523
+ {
524
+ // Determine and write signs.
525
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
526
+ {
527
+ int sx = __float_as_uint(v.x) >> 31 << 0;
528
+ int sy = __float_as_uint(v.y) >> 31 << 8;
529
+ int sz = __float_as_uint(v.z) >> 31 << 16;
530
+ int sw = __float_as_uint(v.w) >> 31 << 24;
531
+ if (sx) v.x *= p.slope;
532
+ if (sy) v.y *= p.slope;
533
+ if (sz) v.z *= p.slope;
534
+ if (sw) v.w *= p.slope;
535
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
536
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
537
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
538
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
539
+
540
+ // Combine signs.
541
+ uint32_t s = sx + sy + sw + sz;
542
+ s <<= (signX & 3) << 1;
543
+ s |= __shfl_xor_sync(groupMask, s, 1);
544
+ s |= __shfl_xor_sync(groupMask, s, 2);
545
+
546
+ // Write signs.
547
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
548
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
549
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
550
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
551
+ }
552
+ else
553
+ {
554
+ // Just compute the values.
555
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
556
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
557
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
558
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
559
+ }
560
+ }
561
+ }
562
+ else if (signRead) // Read signs and apply.
563
+ {
564
+ if ((uint32_t)signXb < p.swLimit)
565
+ {
566
+ int ss = (signX & 3) << 1;
567
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
568
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
569
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
570
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
571
+ }
572
+ }
573
+ else // Forward pass with no sign write.
574
+ {
575
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
576
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
577
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
578
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
579
+ }
580
+
581
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
582
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
583
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
584
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
585
+ }
586
+ }
587
+ else if (up == 2)
588
+ {
589
+ minY -= 1; // Adjust according to block height.
590
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
591
+ {
592
+ int relUpX, relInY0;
593
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
594
+ int relUpY0 = relInY0 * up;
595
+ int src0 = relInY0 * tileUpW + relUpX;
596
+ int dst = relUpY0 * tileUpW + relUpX;
597
+ vec2_t v = InternalType<T>::zero_vec2();
598
+
599
+ scalar_t a = s_tileUpX[src0];
600
+ if (phaseInY == 0)
601
+ {
602
+ #pragma unroll
603
+ for (int step = 0; step < fuSize / up; step++)
604
+ {
605
+ v.x += a * (scalar_t)c_fu[step * up + 0];
606
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
607
+ v.y += a * (scalar_t)c_fu[step * up + 1];
608
+ }
609
+ }
610
+ else // (phaseInY == 1)
611
+ {
612
+ #pragma unroll
613
+ for (int step = 0; step < fuSize / up; step++)
614
+ {
615
+ v.x += a * (scalar_t)c_fu[step * up + 1];
616
+ v.y += a * (scalar_t)c_fu[step * up + 0];
617
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
618
+ }
619
+ }
620
+
621
+ int x = tileOutX * down + relUpX;
622
+ int y = tileOutY * down + relUpY0;
623
+ int signX = x + p.sOfs.x;
624
+ int signY = y + p.sOfs.y;
625
+ int signZ = blockIdx.z + p.blockZofs;
626
+ int signXb = signX >> 2;
627
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
628
+ index_t si1 = si0 + p.sShape.x;
629
+
630
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
631
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
632
+
633
+ if (signWrite)
634
+ {
635
+ if (!enableWriteSkip)
636
+ {
637
+ // Determine and write signs.
638
+ int sx = __float_as_uint(v.x) >> 31 << 0;
639
+ int sy = __float_as_uint(v.y) >> 31 << 8;
640
+ if (sx) v.x *= p.slope;
641
+ if (sy) v.y *= p.slope;
642
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
643
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
644
+
645
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
646
+ {
647
+ // Combine signs.
648
+ int s = sx + sy;
649
+ s <<= signXo;
650
+ s |= __shfl_xor_sync(groupMask, s, 1);
651
+ s |= __shfl_xor_sync(groupMask, s, 2);
652
+
653
+ // Write signs.
654
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
655
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
656
+ }
657
+ }
658
+ else
659
+ {
660
+ // Determine and write signs.
661
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
662
+ {
663
+ int sx = __float_as_uint(v.x) >> 31 << 0;
664
+ int sy = __float_as_uint(v.y) >> 31 << 8;
665
+ if (sx) v.x *= p.slope;
666
+ if (sy) v.y *= p.slope;
667
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
668
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
669
+
670
+ // Combine signs.
671
+ int s = sx + sy;
672
+ s <<= signXo;
673
+ s |= __shfl_xor_sync(groupMask, s, 1);
674
+ s |= __shfl_xor_sync(groupMask, s, 2);
675
+
676
+ // Write signs.
677
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
678
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
679
+ }
680
+ else
681
+ {
682
+ // Just compute the values.
683
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
684
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
685
+ }
686
+ }
687
+ }
688
+ else if (signRead) // Read signs and apply.
689
+ {
690
+ if ((uint32_t)signXb < p.swLimit)
691
+ {
692
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
693
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
694
+ }
695
+ }
696
+ else // Forward pass with no sign write.
697
+ {
698
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
699
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
700
+ }
701
+
702
+ if (!downInline)
703
+ {
704
+ // Write into temporary buffer.
705
+ s_tileUpXY[dst] = v.x;
706
+ if (relUpY0 < tileUpH - 1)
707
+ s_tileUpXY[dst + tileUpW] = v.y;
708
+ }
709
+ else
710
+ {
711
+ // Write directly into output buffer.
712
+ if ((uint32_t)x < p.yShape.x)
713
+ {
714
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
715
+ index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
716
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
717
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
718
+ }
719
+ }
720
+ }
721
+ }
722
+ }
723
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
724
+ {
725
+ // Full upsampling filter.
726
+
727
+ if (up == 2)
728
+ {
729
+ // 2 x 2-wide.
730
+ __syncthreads();
731
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
732
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
733
+ {
734
+ int relUpX0, relUpY0;
735
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
736
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
737
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
738
+ int src0 = relInX0 + tileInW * relInY0;
739
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
740
+
741
+ #define X_LOOP(TAPY, PX) \
742
+ for (int sx = 0; sx < fuSize / up; sx++) \
743
+ { \
744
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
745
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
746
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
747
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
748
+ }
749
+
750
+ vec4_t v = InternalType<T>::zero_vec4();
751
+ if (tap0y == 0 && phaseInX == 0)
752
+ #pragma unroll
753
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
754
+ #pragma unroll
755
+ X_LOOP(0, 0) }
756
+ if (tap0y == 0 && phaseInX == 1)
757
+ #pragma unroll
758
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
759
+ #pragma unroll
760
+ X_LOOP(0, 1) }
761
+ if (tap0y == 1 && phaseInX == 0)
762
+ #pragma unroll
763
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
764
+ #pragma unroll
765
+ X_LOOP(1, 0) }
766
+ if (tap0y == 1 && phaseInX == 1)
767
+ #pragma unroll
768
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
769
+ #pragma unroll
770
+ X_LOOP(1, 1) }
771
+
772
+ #undef X_LOOP
773
+
774
+ int x = tileOutX * down + relUpX0;
775
+ int y = tileOutY * down + relUpY0;
776
+ int signX = x + p.sOfs.x;
777
+ int signY = y + p.sOfs.y;
778
+ int signZ = blockIdx.z + p.blockZofs;
779
+ int signXb = signX >> 2;
780
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
781
+
782
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
783
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
784
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
785
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
786
+
787
+ if (signWrite)
788
+ {
789
+ if (!enableWriteSkip)
790
+ {
791
+ // Determine and write signs.
792
+ int sx = __float_as_uint(v.x) >> 31;
793
+ int sy = __float_as_uint(v.y) >> 31;
794
+ int sz = __float_as_uint(v.z) >> 31;
795
+ int sw = __float_as_uint(v.w) >> 31;
796
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
797
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
798
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
799
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
800
+
801
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
802
+ {
803
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
804
+ }
805
+ }
806
+ else
807
+ {
808
+ // Determine and write signs.
809
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
810
+ {
811
+ int sx = __float_as_uint(v.x) >> 31;
812
+ int sy = __float_as_uint(v.y) >> 31;
813
+ int sz = __float_as_uint(v.z) >> 31;
814
+ int sw = __float_as_uint(v.w) >> 31;
815
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
816
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
817
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
818
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
819
+
820
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
821
+ }
822
+ else
823
+ {
824
+ // Just compute the values.
825
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
826
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
827
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
828
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
829
+ }
830
+ }
831
+ }
832
+ else if (signRead) // Read sign and apply.
833
+ {
834
+ if ((uint32_t)signY < p.sShape.y)
835
+ {
836
+ int s = 0;
837
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
838
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
839
+ s >>= (signX & 3) << 1;
840
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
841
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
842
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
843
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
844
+ }
845
+ }
846
+ else // Forward pass with no sign write.
847
+ {
848
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
849
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
850
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
851
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
852
+ }
853
+
854
+ s_tileUpXY[idx + 0] = v.x;
855
+ s_tileUpXY[idx + 1] = v.y;
856
+ s_tileUpXY[idx + 2] = v.z;
857
+ s_tileUpXY[idx + 3] = v.w;
858
+ }
859
+ }
860
+ else if (up == 1)
861
+ {
862
+ __syncthreads();
863
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
864
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
865
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
866
+ {
867
+ int relUpX0, relUpY0;
868
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
869
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
870
+
871
+ int x = tileOutX * down + relUpX0;
872
+ int y = tileOutY * down + relUpY0;
873
+ int signX = x + p.sOfs.x;
874
+ int signY = y + p.sOfs.y;
875
+ int signZ = blockIdx.z + p.blockZofs;
876
+ int signXb = signX >> 2;
877
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
878
+ v *= (scalar_t)((float)up * (float)up * p.gain);
879
+
880
+ if (signWrite)
881
+ {
882
+ if (!enableWriteSkip)
883
+ {
884
+ // Determine and write sign.
885
+ uint32_t s = 0;
886
+ uint32_t signXbit = (1u << signXo);
887
+ if (v < 0.f)
888
+ {
889
+ s = signXbit;
890
+ v *= p.slope;
891
+ }
892
+ if (fabsf(v) > p.clamp)
893
+ {
894
+ s = signXbit * 2;
895
+ v = InternalType<T>::clamp(v, p.clamp);
896
+ }
897
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
898
+ {
899
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
900
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
901
+ p.s[si] = s; // Write.
902
+ }
903
+ }
904
+ else
905
+ {
906
+ // Determine and write sign.
907
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
908
+ {
909
+ uint32_t s = 0;
910
+ uint32_t signXbit = (1u << signXo);
911
+ if (v < 0.f)
912
+ {
913
+ s = signXbit;
914
+ v *= p.slope;
915
+ }
916
+ if (fabsf(v) > p.clamp)
917
+ {
918
+ s = signXbit * 2;
919
+ v = InternalType<T>::clamp(v, p.clamp);
920
+ }
921
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
922
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
923
+ p.s[si] = s; // Write.
924
+ }
925
+ else
926
+ {
927
+ // Just compute the value.
928
+ if (v < 0.f) v *= p.slope;
929
+ v = InternalType<T>::clamp(v, p.clamp);
930
+ }
931
+ }
932
+ }
933
+ else if (signRead)
934
+ {
935
+ // Read sign and apply if within sign tensor bounds.
936
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
937
+ {
938
+ int s = p.s[si];
939
+ s >>= signXo;
940
+ if (s & 1) v *= p.slope;
941
+ if (s & 2) v = 0.f;
942
+ }
943
+ }
944
+ else // Forward pass with no sign write.
945
+ {
946
+ if (v < 0.f) v *= p.slope;
947
+ v = InternalType<T>::clamp(v, p.clamp);
948
+ }
949
+
950
+ if (!downInline) // Write into temporary buffer.
951
+ s_tileUpXY[idx] = v;
952
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
953
+ *((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
954
+ }
955
+ }
956
+ }
957
+
958
+ // Downsampling.
959
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
960
+ {
961
+ // Horizontal downsampling.
962
+ __syncthreads();
963
+ if (down == 4 && tileOutW % 4 == 0)
964
+ {
965
+ // Calculate 4 pixels at a time.
966
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
967
+ {
968
+ int relOutX0, relUpY;
969
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
970
+ int relUpX0 = relOutX0 * down;
971
+ int src0 = relUpY * tileUpW + relUpX0;
972
+ vec4_t v = InternalType<T>::zero_vec4();
973
+ #pragma unroll
974
+ for (int step = 0; step < fdSize; step++)
975
+ {
976
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
977
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
978
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
979
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
980
+ }
981
+ s_tileDownX[idx+0] = v.x;
982
+ s_tileDownX[idx+1] = v.y;
983
+ s_tileDownX[idx+2] = v.z;
984
+ s_tileDownX[idx+3] = v.w;
985
+ }
986
+ }
987
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
988
+ {
989
+ // Calculate 2 pixels at a time.
990
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
991
+ {
992
+ int relOutX0, relUpY;
993
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
994
+ int relUpX0 = relOutX0 * down;
995
+ int src0 = relUpY * tileUpW + relUpX0;
996
+ vec2_t v = InternalType<T>::zero_vec2();
997
+ #pragma unroll
998
+ for (int step = 0; step < fdSize; step++)
999
+ {
1000
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
1001
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
1002
+ }
1003
+ s_tileDownX[idx+0] = v.x;
1004
+ s_tileDownX[idx+1] = v.y;
1005
+ }
1006
+ }
1007
+ else
1008
+ {
1009
+ // Calculate 1 pixel at a time.
1010
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
1011
+ {
1012
+ int relOutX0, relUpY;
1013
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
1014
+ int relUpX0 = relOutX0 * down;
1015
+ int src = relUpY * tileUpW + relUpX0;
1016
+ scalar_t v = 0.f;
1017
+ #pragma unroll
1018
+ for (int step = 0; step < fdSize; step++)
1019
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
1020
+ s_tileDownX[idx] = v;
1021
+ }
1022
+ }
1023
+
1024
+ // Vertical downsampling & store output tile.
1025
+ __syncthreads();
1026
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1027
+ {
1028
+ int relOutX, relOutY0;
1029
+ fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
1030
+ int relUpY0 = relOutY0 * down;
1031
+ int src0 = relUpY0 * tileOutW + relOutX;
1032
+ scalar_t v = 0;
1033
+ #pragma unroll
1034
+ for (int step = 0; step < fdSize; step++)
1035
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
1036
+
1037
+ int outX = tileOutX + relOutX;
1038
+ int outY = tileOutY + relOutY0;
1039
+
1040
+ if (outX < p.yShape.x & outY < p.yShape.y)
1041
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1042
+ }
1043
+ }
1044
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
1045
+ {
1046
+ // Full downsampling filter.
1047
+ if (down == 2)
1048
+ {
1049
+ // 2-wide.
1050
+ __syncthreads();
1051
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
1052
+ {
1053
+ int relOutX0, relOutY0;
1054
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1055
+ int relUpX0 = relOutX0 * down;
1056
+ int relUpY0 = relOutY0 * down;
1057
+ int src0 = relUpY0 * tileUpW + relUpX0;
1058
+ vec2_t v = InternalType<T>::zero_vec2();
1059
+ #pragma unroll
1060
+ for (int sy = 0; sy < fdSize; sy++)
1061
+ #pragma unroll
1062
+ for (int sx = 0; sx < fdSize; sx++)
1063
+ {
1064
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1065
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1066
+ }
1067
+
1068
+ int outX = tileOutX + relOutX0;
1069
+ int outY = tileOutY + relOutY0;
1070
+ if ((uint32_t)outY < p.yShape.y)
1071
+ {
1072
+ index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
1073
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
1074
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
1075
+ }
1076
+ }
1077
+ }
1078
+ else if (down == 1 && !downInline)
1079
+ {
1080
+ // Thread per pixel.
1081
+ __syncthreads();
1082
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1083
+ {
1084
+ int relOutX0, relOutY0;
1085
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1086
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
1087
+
1088
+ int outX = tileOutX + relOutX0;
1089
+ int outY = tileOutY + relOutY0;
1090
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
1091
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1092
+ }
1093
+ }
1094
+ }
1095
+
1096
+ if (!enableXrep)
1097
+ break;
1098
+ }
1099
+ }
1100
+
1101
+ //------------------------------------------------------------------------
1102
+ // Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
1103
+ // Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
1104
+
1105
+ template <class T, bool signWrite, bool signRead>
1106
+ static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
1107
+ {
1108
+ typedef typename InternalType<T>::scalar_t scalar_t;
1109
+
1110
+ // Indexing.
1111
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
1112
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
1113
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
1114
+
1115
+ // Loop to accommodate oversized tensors.
1116
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
1117
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
1118
+ {
1119
+ // Extract z and w (channel, minibatch index).
1120
+ int32_t w = q / p.xShape.z;
1121
+ int32_t z = q - w * p.xShape.z;
1122
+
1123
+ // Choose behavior based on sign read/write mode.
1124
+ if (signWrite)
1125
+ {
1126
+ // Process value if in p.x.
1127
+ uint32_t s = 0;
1128
+ if (x < p.xShape.x && y < p.xShape.y)
1129
+ {
1130
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1131
+ T* pv = ((T*)p.x) + ix;
1132
+ scalar_t v = (scalar_t)(*pv);
1133
+
1134
+ // Gain, LReLU, clamp.
1135
+ v *= p.gain;
1136
+ if (v < 0.f)
1137
+ {
1138
+ v *= p.slope;
1139
+ s = 1; // Sign.
1140
+ }
1141
+ if (fabsf(v) > p.clamp)
1142
+ {
1143
+ v = InternalType<T>::clamp(v, p.clamp);
1144
+ s = 2; // Clamp.
1145
+ }
1146
+
1147
+ *pv = (T)v; // Write value.
1148
+ }
1149
+
1150
+ // Coalesce into threads 0 and 16 of warp.
1151
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
1152
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
1153
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
1154
+ s |= __shfl_xor_sync(m, s, 2);
1155
+ s |= __shfl_xor_sync(m, s, 4);
1156
+ s |= __shfl_xor_sync(m, s, 8);
1157
+
1158
+ // Write signs if leader and in p.s.
1159
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
1160
+ {
1161
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
1162
+ ((uint32_t*)p.s)[is >> 4] = s;
1163
+ }
1164
+ }
1165
+ else if (signRead)
1166
+ {
1167
+ // Process value if in p.x.
1168
+ if (x < p.xShape.x) // y is always in.
1169
+ {
1170
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1171
+ T* pv = ((T*)p.x) + ix;
1172
+ scalar_t v = (scalar_t)(*pv);
1173
+ v *= p.gain;
1174
+
1175
+ // Apply sign buffer offset.
1176
+ uint32_t sx = x + p.sOfs.x;
1177
+ uint32_t sy = y + p.sOfs.y;
1178
+
1179
+ // Read and apply signs if we land inside valid region of sign buffer.
1180
+ if (sx < p.sShape.x && sy < p.sShape.y)
1181
+ {
1182
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
1183
+ unsigned char s = p.s[is];
1184
+ s >>= (sx & 3) << 1; // Shift into place.
1185
+ if (s & 1) // Sign?
1186
+ v *= p.slope;
1187
+ if (s & 2) // Clamp?
1188
+ v = 0.f;
1189
+ }
1190
+
1191
+ *pv = (T)v; // Write value.
1192
+ }
1193
+ }
1194
+ else
1195
+ {
1196
+ // Forward pass with no sign write. Process value if in p.x.
1197
+ if (x < p.xShape.x) // y is always in.
1198
+ {
1199
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1200
+ T* pv = ((T*)p.x) + ix;
1201
+ scalar_t v = (scalar_t)(*pv);
1202
+ v *= p.gain;
1203
+ if (v < 0.f)
1204
+ v *= p.slope;
1205
+ if (fabsf(v) > p.clamp)
1206
+ v = InternalType<T>::clamp(v, p.clamp);
1207
+ *pv = (T)v; // Write value.
1208
+ }
1209
+ }
1210
+ }
1211
+ }
1212
+
1213
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
1214
+ {
1215
+ return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
1216
+ }
1217
+
1218
+ //------------------------------------------------------------------------
1219
+ // CUDA kernel selection.
1220
+
1221
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
1222
+ {
1223
+ filtered_lrelu_kernel_spec s = { 0 };
1224
+
1225
+ // Return the first matching kernel.
1226
+ #define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
1227
+ if (sharedKB >= SH) \
1228
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
1229
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
1230
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
1231
+ { \
1232
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
1233
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
1234
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
1235
+ s.setup = (void*)setup_filters_kernel; \
1236
+ s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
1237
+ s.tileOut = make_int2(TW, TH); \
1238
+ s.numWarps = W; \
1239
+ s.xrep = XR; \
1240
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
1241
+ return s; \
1242
+ }
1243
+
1244
+ // Launch parameters for various kernel specializations.
1245
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
1246
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
1247
+
1248
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
1249
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
1250
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
1251
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
1252
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
1253
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
1254
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
1255
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
1256
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
1257
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
1258
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
1259
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
1260
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
1261
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
1262
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
1263
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
1264
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
1265
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
1266
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
1267
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
1268
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
1269
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
1270
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
1271
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
1272
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
1273
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
1274
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
1275
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
1276
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
1277
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
1278
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
1279
+
1280
+ #undef CASE
1281
+ return s; // No kernel found.
1282
+ }
1283
+
1284
+ //------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct filtered_lrelu_kernel_params
15
+ {
16
+ // These parameters decide which kernel to use.
17
+ int up; // upsampling ratio (1, 2, 4)
18
+ int down; // downsampling ratio (1, 2, 4)
19
+ int2 fuShape; // [size, 1] | [size, size]
20
+ int2 fdShape; // [size, 1] | [size, size]
21
+
22
+ int _dummy; // Alignment.
23
+
24
+ // Rest of the parameters.
25
+ const void* x; // Input tensor.
26
+ void* y; // Output tensor.
27
+ const void* b; // Bias tensor.
28
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
29
+ const float* fu; // Upsampling filter.
30
+ const float* fd; // Downsampling filter.
31
+
32
+ int2 pad0; // Left/top padding.
33
+ float gain; // Additional gain factor.
34
+ float slope; // Leaky ReLU slope on negative side.
35
+ float clamp; // Clamp after nonlinearity.
36
+ int flip; // Filter kernel flip for gradient computation.
37
+
38
+ int tilesXdim; // Original number of horizontal output tiles.
39
+ int tilesXrep; // Number of horizontal tiles per CTA.
40
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41
+
42
+ int4 xShape; // [width, height, channel, batch]
43
+ int4 yShape; // [width, height, channel, batch]
44
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46
+ int swLimit; // Active width of sign tensor in bytes.
47
+
48
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49
+ longlong4 yStride; //
50
+ int64_t bStride; //
51
+ longlong3 fuStride; //
52
+ longlong3 fdStride; //
53
+ };
54
+
55
+ struct filtered_lrelu_act_kernel_params
56
+ {
57
+ void* x; // Input/output, modified in-place.
58
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
59
+
60
+ float gain; // Additional gain factor.
61
+ float slope; // Leaky ReLU slope on negative side.
62
+ float clamp; // Clamp after nonlinearity.
63
+
64
+ int4 xShape; // [width, height, channel, batch]
65
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
66
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68
+ };
69
+
70
+ //------------------------------------------------------------------------
71
+ // CUDA kernel specialization.
72
+
73
+ struct filtered_lrelu_kernel_spec
74
+ {
75
+ void* setup; // Function for filter kernel setup.
76
+ void* exec; // Function for main operation.
77
+ int2 tileOut; // Width/height of launch tile.
78
+ int numWarps; // Number of warps per thread block, determines launch block size.
79
+ int xrep; // For processing multiple horizontal tiles per thread block.
80
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81
+ };
82
+
83
+ //------------------------------------------------------------------------
84
+ // CUDA kernel selection.
85
+
86
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
88
+ template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
89
+
90
+ //------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import numpy as np
11
+ import torch
12
+ import warnings
13
+
14
+ from .. import custom_ops
15
+ from .. import misc
16
+ from . import upfirdn2d
17
+ from . import bias_act
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ _plugin = None
22
+
23
+ def _init():
24
+ global _plugin
25
+ if _plugin is None:
26
+ _plugin = custom_ops.get_plugin(
27
+ module_name='filtered_lrelu_plugin',
28
+ sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
29
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
30
+ source_dir=os.path.dirname(__file__),
31
+ extra_cuda_cflags=['--use_fast_math'],
32
+ )
33
+ return True
34
+
35
+ def _get_filter_size(f):
36
+ if f is None:
37
+ return 1, 1
38
+ assert isinstance(f, torch.Tensor)
39
+ assert 1 <= f.ndim <= 2
40
+ return f.shape[-1], f.shape[0] # width, height
41
+
42
+ def _parse_padding(padding):
43
+ if isinstance(padding, int):
44
+ padding = [padding, padding]
45
+ assert isinstance(padding, (list, tuple))
46
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
47
+ padding = [int(x) for x in padding]
48
+ if len(padding) == 2:
49
+ px, py = padding
50
+ padding = [px, px, py, py]
51
+ px0, px1, py0, py1 = padding
52
+ return px0, px1, py0, py1
53
+
54
+ #----------------------------------------------------------------------------
55
+
56
+ def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
57
+ r"""Filtered leaky ReLU for a batch of 2D images.
58
+
59
+ Performs the following sequence of operations for each channel:
60
+
61
+ 1. Add channel-specific bias if provided (`b`).
62
+
63
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
64
+
65
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
66
+ Negative padding corresponds to cropping the image.
67
+
68
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
69
+ so that the footprint of all output pixels lies within the input image.
70
+
71
+ 5. Multiply each value by the provided gain factor (`gain`).
72
+
73
+ 6. Apply leaky ReLU activation function to each value.
74
+
75
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
76
+
77
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
78
+ it so that the footprint of all output pixels lies within the input image.
79
+
80
+ 9. Downsample the image by keeping every Nth pixel (`down`).
81
+
82
+ The fused op is considerably more efficient than performing the same calculation
83
+ using standard PyTorch ops. It supports gradients of arbitrary order.
84
+
85
+ Args:
86
+ x: Float32/float16/float64 input tensor of the shape
87
+ `[batch_size, num_channels, in_height, in_width]`.
88
+ fu: Float32 upsampling FIR filter of the shape
89
+ `[filter_height, filter_width]` (non-separable),
90
+ `[filter_taps]` (separable), or
91
+ `None` (identity).
92
+ fd: Float32 downsampling FIR filter of the shape
93
+ `[filter_height, filter_width]` (non-separable),
94
+ `[filter_taps]` (separable), or
95
+ `None` (identity).
96
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
97
+ as `x`. The length of vector must must match the channel dimension of `x`.
98
+ up: Integer upsampling factor (default: 1).
99
+ down: Integer downsampling factor. (default: 1).
100
+ padding: Padding with respect to the upsampled image. Can be a single number
101
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
102
+ (default: 0).
103
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
104
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
105
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
106
+ flip_filter: False = convolution, True = correlation (default: False).
107
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
108
+
109
+ Returns:
110
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
111
+ """
112
+ assert isinstance(x, torch.Tensor)
113
+ assert impl in ['ref', 'cuda']
114
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
115
+ return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
116
+ return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ @misc.profiled_function
121
+ def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
122
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
123
+ existing `upfirdn2n()` and `bias_act()` ops.
124
+ """
125
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
126
+ fu_w, fu_h = _get_filter_size(fu)
127
+ fd_w, fd_h = _get_filter_size(fd)
128
+ if b is not None:
129
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
130
+ misc.assert_shape(b, [x.shape[1]])
131
+ assert isinstance(up, int) and up >= 1
132
+ assert isinstance(down, int) and down >= 1
133
+ px0, px1, py0, py1 = _parse_padding(padding)
134
+ assert gain == float(gain) and gain > 0
135
+ assert slope == float(slope) and slope >= 0
136
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
137
+
138
+ # Calculate output size.
139
+ batch_size, channels, in_h, in_w = x.shape
140
+ in_dtype = x.dtype
141
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
142
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
143
+
144
+ # Compute using existing ops.
145
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
146
+ x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
147
+ x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
148
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
149
+
150
+ # Check output shape & dtype.
151
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
152
+ assert x.dtype == in_dtype
153
+ return x
154
+
155
+ #----------------------------------------------------------------------------
156
+
157
+ _filtered_lrelu_cuda_cache = dict()
158
+
159
+ def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
160
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
161
+ """
162
+ assert isinstance(up, int) and up >= 1
163
+ assert isinstance(down, int) and down >= 1
164
+ px0, px1, py0, py1 = _parse_padding(padding)
165
+ assert gain == float(gain) and gain > 0
166
+ gain = float(gain)
167
+ assert slope == float(slope) and slope >= 0
168
+ slope = float(slope)
169
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
170
+ clamp = float(clamp if clamp is not None else 'inf')
171
+
172
+ # Lookup from cache.
173
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
174
+ if key in _filtered_lrelu_cuda_cache:
175
+ return _filtered_lrelu_cuda_cache[key]
176
+
177
+ # Forward op.
178
+ class FilteredLReluCuda(torch.autograd.Function):
179
+ @staticmethod
180
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
181
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
182
+
183
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
184
+ if fu is None:
185
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
186
+ if fd is None:
187
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
188
+ assert 1 <= fu.ndim <= 2
189
+ assert 1 <= fd.ndim <= 2
190
+
191
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
192
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
193
+ fu = fu.square()[None]
194
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
195
+ fd = fd.square()[None]
196
+
197
+ # Missing sign input tensor.
198
+ if si is None:
199
+ si = torch.empty([0])
200
+
201
+ # Missing bias tensor.
202
+ if b is None:
203
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
204
+
205
+ # Construct internal sign tensor only if gradients are needed.
206
+ write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
207
+
208
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
209
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
210
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
211
+ warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
212
+
213
+ # Call C++/Cuda plugin if datatype is supported.
214
+ if x.dtype in [torch.float16, torch.float32]:
215
+ if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
216
+ warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
217
+ y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
218
+ else:
219
+ return_code = -1
220
+
221
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
222
+ # only the bit-packed sign tensor is retained for gradient computation.
223
+ if return_code < 0:
224
+ warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
225
+
226
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
227
+ y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
228
+ so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
229
+ y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
230
+
231
+ # Prepare for gradient computation.
232
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
233
+ ctx.x_shape = x.shape
234
+ ctx.y_shape = y.shape
235
+ ctx.s_ofs = sx, sy
236
+ return y
237
+
238
+ @staticmethod
239
+ def backward(ctx, dy): # pylint: disable=arguments-differ
240
+ fu, fd, si = ctx.saved_tensors
241
+ _, _, xh, xw = ctx.x_shape
242
+ _, _, yh, yw = ctx.y_shape
243
+ sx, sy = ctx.s_ofs
244
+ dx = None # 0
245
+ dfu = None; assert not ctx.needs_input_grad[1]
246
+ dfd = None; assert not ctx.needs_input_grad[2]
247
+ db = None # 3
248
+ dsi = None; assert not ctx.needs_input_grad[4]
249
+ dsx = None; assert not ctx.needs_input_grad[5]
250
+ dsy = None; assert not ctx.needs_input_grad[6]
251
+
252
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
253
+ pp = [
254
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
255
+ xw * up - yw * down + px0 - (up - 1),
256
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
257
+ xh * up - yh * down + py0 - (up - 1),
258
+ ]
259
+ gg = gain * (up ** 2) / (down ** 2)
260
+ ff = (not flip_filter)
261
+ sx = sx - (fu.shape[-1] - 1) + px0
262
+ sy = sy - (fu.shape[0] - 1) + py0
263
+ dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
264
+
265
+ if ctx.needs_input_grad[3]:
266
+ db = dx.sum([0, 2, 3])
267
+
268
+ return dx, dfu, dfd, db, dsi, dsx, dsy
269
+
270
+ # Add to cache.
271
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
272
+ return FilteredLReluCuda
273
+
274
+ #----------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu_ns.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for no signs mode (no gradients required).
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, false>(cudaStream_t stream);
torch_utils/ops/filtered_lrelu_rd.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign read mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, true>(cudaStream_t stream);
torch_utils/ops/filtered_lrelu_wr.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign write mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<true, false>(cudaStream_t stream);
torch_utils/ops/fma.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10
+
11
+ import torch
12
+
13
+ #----------------------------------------------------------------------------
14
+
15
+ def fma(a, b, c): # => a * b + c
16
+ return _FusedMultiplyAdd.apply(a, b, c)
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21
+ @staticmethod
22
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23
+ out = torch.addcmul(c, a, b)
24
+ ctx.save_for_backward(a, b)
25
+ ctx.c_shape = c.shape
26
+ return out
27
+
28
+ @staticmethod
29
+ def backward(ctx, dout): # pylint: disable=arguments-differ
30
+ a, b = ctx.saved_tensors
31
+ c_shape = ctx.c_shape
32
+ da = None
33
+ db = None
34
+ dc = None
35
+
36
+ if ctx.needs_input_grad[0]:
37
+ da = _unbroadcast(dout * b, a.shape)
38
+
39
+ if ctx.needs_input_grad[1]:
40
+ db = _unbroadcast(dout * a, b.shape)
41
+
42
+ if ctx.needs_input_grad[2]:
43
+ dc = _unbroadcast(dout, c_shape)
44
+
45
+ return da, db, dc
46
+
47
+ #----------------------------------------------------------------------------
48
+
49
+ def _unbroadcast(x, shape):
50
+ extra_dims = x.ndim - len(shape)
51
+ assert extra_dims >= 0
52
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53
+ if len(dim):
54
+ x = x.sum(dim=dim, keepdim=True)
55
+ if extra_dims:
56
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
57
+ assert x.shape == shape
58
+ return x
59
+
60
+ #----------------------------------------------------------------------------
torch_utils/ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.grid_sample` that
10
+ supports arbitrarily high order gradients between the input and output.
11
+ Only works on 2D images and assumes
12
+ `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
+
14
+ import torch
15
+
16
+ # pylint: disable=redefined-builtin
17
+ # pylint: disable=arguments-differ
18
+ # pylint: disable=protected-access
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ enabled = False # Enable the custom op by setting this to true.
23
+
24
+ #----------------------------------------------------------------------------
25
+
26
+ def grid_sample(input, grid):
27
+ if _should_use_custom_op():
28
+ return _GridSample2dForward.apply(input, grid)
29
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
30
+
31
+ #----------------------------------------------------------------------------
32
+
33
+ def _should_use_custom_op():
34
+ return enabled
35
+
36
+ #----------------------------------------------------------------------------
37
+
38
+ class _GridSample2dForward(torch.autograd.Function):
39
+ @staticmethod
40
+ def forward(ctx, input, grid):
41
+ assert input.ndim == 4
42
+ assert grid.ndim == 4
43
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
44
+ ctx.save_for_backward(input, grid)
45
+ return output
46
+
47
+ @staticmethod
48
+ def backward(ctx, grad_output):
49
+ input, grid = ctx.saved_tensors
50
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
51
+ return grad_input, grad_grid
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ class _GridSample2dBackward(torch.autograd.Function):
56
+ @staticmethod
57
+ def forward(ctx, grad_output, input, grid):
58
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
59
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
60
+ ctx.save_for_backward(grid)
61
+ return grad_input, grad_grid
62
+
63
+ @staticmethod
64
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
65
+ _ = grad2_grad_grid # unused
66
+ grid, = ctx.saved_tensors
67
+ grad2_grad_output = None
68
+ grad2_input = None
69
+ grad2_grid = None
70
+
71
+ if ctx.needs_input_grad[0]:
72
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
73
+
74
+ assert not ctx.needs_input_grad[2]
75
+ return grad2_grad_output, grad2_input, grad2_grid
76
+
77
+ #----------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
+ {
18
+ // Validate arguments.
19
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
+ TORCH_CHECK(x.numel() > 0, "x has zero size");
25
+ TORCH_CHECK(f.numel() > 0, "f has zero size");
26
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28
+ TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
29
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
30
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
31
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
32
+
33
+ // Create output tensor.
34
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
35
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
36
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
37
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
38
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
39
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
40
+ TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
41
+
42
+ // Initialize CUDA kernel parameters.
43
+ upfirdn2d_kernel_params p;
44
+ p.x = x.data_ptr();
45
+ p.f = f.data_ptr<float>();
46
+ p.y = y.data_ptr();
47
+ p.up = make_int2(upx, upy);
48
+ p.down = make_int2(downx, downy);
49
+ p.pad0 = make_int2(padx0, pady0);
50
+ p.flip = (flip) ? 1 : 0;
51
+ p.gain = gain;
52
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
53
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
54
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
55
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
56
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
57
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
58
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
59
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
60
+
61
+ // Choose CUDA kernel.
62
+ upfirdn2d_kernel_spec spec;
63
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
64
+ {
65
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
66
+ });
67
+
68
+ // Set looping options.
69
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
70
+ p.loopMinor = spec.loopMinor;
71
+ p.loopX = spec.loopX;
72
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
73
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
74
+
75
+ // Compute grid size.
76
+ dim3 blockSize, gridSize;
77
+ if (spec.tileOutW < 0) // large
78
+ {
79
+ blockSize = dim3(4, 32, 1);
80
+ gridSize = dim3(
81
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
82
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
83
+ p.launchMajor);
84
+ }
85
+ else // small
86
+ {
87
+ blockSize = dim3(256, 1, 1);
88
+ gridSize = dim3(
89
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
90
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
91
+ p.launchMajor);
92
+ }
93
+
94
+ // Launch CUDA kernel.
95
+ void* args[] = {&p};
96
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
97
+ return y;
98
+ }
99
+
100
+ //------------------------------------------------------------------------
101
+
102
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
103
+ {
104
+ m.def("upfirdn2d", &upfirdn2d);
105
+ }
106
+
107
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.cu ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "upfirdn2d.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ static __device__ __forceinline__ int floor_div(int a, int b)
21
+ {
22
+ int t = 1 - a / b;
23
+ return (a + t * b) / b - t;
24
+ }
25
+
26
+ //------------------------------------------------------------------------
27
+ // Generic CUDA implementation for large filters.
28
+
29
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
+ {
31
+ typedef typename InternalType<T>::scalar_t scalar_t;
32
+
33
+ // Calculate thread index.
34
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
+ int outY = minorBase / p.launchMinor;
36
+ minorBase -= outY * p.launchMinor;
37
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
+ int majorBase = blockIdx.z * p.loopMajor;
39
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
+ return;
41
+
42
+ // Setup Y receptive field.
43
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
+ if (p.flip)
48
+ filterY = p.filterSize.y - 1 - filterY;
49
+
50
+ // Loop over major, minor, and X.
51
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
+ {
54
+ int nc = major * p.sizeMinor + minor;
55
+ int n = nc / p.inSize.z;
56
+ int c = nc - n * p.inSize.z;
57
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
+ {
59
+ // Setup X receptive field.
60
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
+ if (p.flip)
65
+ filterX = p.filterSize.x - 1 - filterX;
66
+
67
+ // Initialize pointers.
68
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
+
73
+ // Inner loop.
74
+ scalar_t v = 0;
75
+ for (int y = 0; y < h; y++)
76
+ {
77
+ for (int x = 0; x < w; x++)
78
+ {
79
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
+ xp += p.inStride.x;
81
+ fp += filterStepX;
82
+ }
83
+ xp += p.inStride.y - w * p.inStride.x;
84
+ fp += filterStepY - w * filterStepX;
85
+ }
86
+
87
+ // Store result.
88
+ v *= p.gain;
89
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
+ }
91
+ }
92
+ }
93
+
94
+ //------------------------------------------------------------------------
95
+ // Specialized CUDA implementation for small filters.
96
+
97
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
+ {
100
+ typedef typename InternalType<T>::scalar_t scalar_t;
101
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
+ __shared__ volatile scalar_t sf[filterH][filterW];
104
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
+
106
+ // Calculate tile index.
107
+ int minorBase = blockIdx.x;
108
+ int tileOutY = minorBase / p.launchMinor;
109
+ minorBase -= tileOutY * p.launchMinor;
110
+ minorBase *= loopMinor;
111
+ tileOutY *= tileOutH;
112
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
+ int majorBase = blockIdx.z * p.loopMajor;
114
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
+ return;
116
+
117
+ // Load filter (flipped).
118
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
+ {
120
+ int fy = tapIdx / filterW;
121
+ int fx = tapIdx - fy * filterW;
122
+ scalar_t v = 0;
123
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
124
+ {
125
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
+ }
129
+ sf[fy][fx] = v;
130
+ }
131
+
132
+ // Loop over major and X.
133
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
+ {
135
+ int baseNC = major * p.sizeMinor + minorBase;
136
+ int n = baseNC / p.inSize.z;
137
+ int baseC = baseNC - n * p.inSize.z;
138
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
+ {
140
+ // Load input pixels.
141
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
+ int tileInX = floor_div(tileMidX, upx);
144
+ int tileInY = floor_div(tileMidY, upy);
145
+ __syncthreads();
146
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
+ {
148
+ int relC = inIdx;
149
+ int relInX = relC / loopMinor;
150
+ int relInY = relInX / tileInW;
151
+ relC -= relInX * loopMinor;
152
+ relInX -= relInY * tileInW;
153
+ int c = baseC + relC;
154
+ int inX = tileInX + relInX;
155
+ int inY = tileInY + relInY;
156
+ scalar_t v = 0;
157
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
+ sx[relInY][relInX][relC] = v;
160
+ }
161
+
162
+ // Loop over output pixels.
163
+ __syncthreads();
164
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
+ {
166
+ int relC = outIdx;
167
+ int relOutX = relC / loopMinor;
168
+ int relOutY = relOutX / tileOutW;
169
+ relC -= relOutX * loopMinor;
170
+ relOutX -= relOutY * tileOutW;
171
+ int c = baseC + relC;
172
+ int outX = tileOutX + relOutX;
173
+ int outY = tileOutY + relOutY;
174
+
175
+ // Setup receptive field.
176
+ int midX = tileMidX + relOutX * downx;
177
+ int midY = tileMidY + relOutY * downy;
178
+ int inX = floor_div(midX, upx);
179
+ int inY = floor_div(midY, upy);
180
+ int relInX = inX - tileInX;
181
+ int relInY = inY - tileInY;
182
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
183
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
184
+
185
+ // Inner loop.
186
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
+ {
188
+ scalar_t v = 0;
189
+ #pragma unroll
190
+ for (int y = 0; y < filterH / upy; y++)
191
+ #pragma unroll
192
+ for (int x = 0; x < filterW / upx; x++)
193
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
+ v *= p.gain;
195
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ //------------------------------------------------------------------------
203
+ // CUDA kernel selection.
204
+
205
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
+ {
207
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
209
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
210
+
211
+ // No up/downsampling.
212
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
213
+ {
214
+ // contiguous
215
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
216
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
217
+ if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
218
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
219
+ if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
220
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
221
+ if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
222
+ if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
223
+ if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
224
+ if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
225
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
226
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
228
+ // channels_last
229
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
230
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
231
+ if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
232
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
233
+ if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
234
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
+ if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
236
+ if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
237
+ if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
238
+ if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
239
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
240
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
241
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
242
+ }
243
+
244
+ // 2x upsampling.
245
+ if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
246
+ {
247
+ // contiguous
248
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
249
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
250
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
+ // channels_last
255
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
256
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
257
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
+ }
262
+ if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
263
+ {
264
+ // contiguous
265
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
266
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
268
+ // channels_last
269
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
270
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
271
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
272
+ }
273
+ if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
274
+ {
275
+ // contiguous
276
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
277
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
278
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
279
+ // channels_last
280
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
281
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
282
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
283
+ }
284
+
285
+ // 2x downsampling.
286
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
287
+ {
288
+ // contiguous
289
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
290
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
291
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
292
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
293
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
294
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
295
+ // channels_last
296
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
297
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
298
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
299
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
300
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
301
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
302
+ }
303
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
304
+ {
305
+ // contiguous
306
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
307
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
308
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
309
+ // channels_last
310
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
311
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
312
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
313
+ }
314
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
315
+ {
316
+ // contiguous
317
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
318
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
319
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
320
+ // channels_last
321
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
322
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
323
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
324
+ }
325
+
326
+ // 4x upsampling.
327
+ if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
328
+ {
329
+ // contiguous
330
+ if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
331
+ if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
332
+ // channels_last
333
+ if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
334
+ if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
335
+ }
336
+ if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
337
+ {
338
+ // contiguous
339
+ if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
340
+ if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
341
+ // channels_last
342
+ if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
343
+ if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
344
+ }
345
+ if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
346
+ {
347
+ // contiguous
348
+ if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
349
+ if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
350
+ // channels_last
351
+ if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
352
+ if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
353
+ }
354
+
355
+ // 4x downsampling (inefficient).
356
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
357
+ {
358
+ // contiguous
359
+ if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
360
+ if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
361
+ // channels_last
362
+ if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
363
+ if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
364
+ }
365
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
366
+ {
367
+ // contiguous
368
+ if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
369
+ if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
370
+ // channels_last
371
+ if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
372
+ if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
373
+ }
374
+ return spec;
375
+ }
376
+
377
+ //------------------------------------------------------------------------
378
+ // Template specializations.
379
+
380
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
381
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
382
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
383
+
384
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct upfirdn2d_kernel_params
15
+ {
16
+ const void* x;
17
+ const float* f;
18
+ void* y;
19
+
20
+ int2 up;
21
+ int2 down;
22
+ int2 pad0;
23
+ int flip;
24
+ float gain;
25
+
26
+ int4 inSize; // [width, height, channel, batch]
27
+ int4 inStride;
28
+ int2 filterSize; // [width, height]
29
+ int2 filterStride;
30
+ int4 outSize; // [width, height, channel, batch]
31
+ int4 outStride;
32
+ int sizeMinor;
33
+ int sizeMajor;
34
+
35
+ int loopMinor;
36
+ int loopMajor;
37
+ int loopX;
38
+ int launchMinor;
39
+ int launchMajor;
40
+ };
41
+
42
+ //------------------------------------------------------------------------
43
+ // CUDA kernel specialization.
44
+
45
+ struct upfirdn2d_kernel_spec
46
+ {
47
+ void* kernel;
48
+ int tileOutW;
49
+ int tileOutH;
50
+ int loopMinor;
51
+ int loopX;
52
+ };
53
+
54
+ //------------------------------------------------------------------------
55
+ // CUDA kernel selection.
56
+
57
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
+
59
+ //------------------------------------------------------------------------