diff --git a/compute_direction.py b/compute_direction.py new file mode 100644 index 0000000000000000000000000000000000000000..ab441f331739acae9606f34f42d4c33e69b952d3 --- /dev/null +++ b/compute_direction.py @@ -0,0 +1,96 @@ +# python3.7 +"""Computes the semantic directions regarding a specific image region.""" + +import os +import argparse +import numpy as np +from tqdm import tqdm + +from coordinate import COORDINATES +from coordinate import get_mask +from utils.image_utils import save_image + + +def parse_args(): + """Parses arguments.""" + + parser = argparse.ArgumentParser() + parser.add_argument('jaco_path', type=str, + help='Path to jacobian matrix.') + parser.add_argument('--region', type=str, default='eyes', + help='The region to be used to compute jacobian.') + parser.add_argument('--save_dir', type=str, default='', + help='Directory to save the results. If not specified,' + 'the results will be saved to ' + '`work_dirs/{TASK_SPECIFIC}/` by default') + parser.add_argument('--job', type=str, default='directions', + help='Name for the job (default: directions)') + parser.add_argument('--name', type=str, default='resefa', + help='Name of help save the results.') + parser.add_argument('--data_name', type=str, default='ffhq', + help='Name of the dataset.') + parser.add_argument('--full_rank', action='store_true', + help='Whether or not to full rank background' + ' (default: False).') + parser.add_argument('--tao', type=float, default=1e-3, + help='Coefficient to the identity matrix ' + '(default: 1e-3).') + return parser.parse_args() + + +def main(): + """Main function.""" + args = parse_args() + assert os.path.exists(args.jaco_path) + Jacobians = np.load(args.jaco_path) + image_size = Jacobians.shape[2] + w_dim = Jacobians.shape[-1] + coord_dict = COORDINATES[args.data_name] + assert args.region in coord_dict, \ + f'{args.region} coordinate is not defined in ' \ + f'COORDINATE_{args.data_name}. Please define this region first!' + coords = coord_dict[args.region] + mask = get_mask(image_size, coordinate=coords) + foreground_ind = np.where(mask == 1) + background_ind = np.where((1 - mask) == 1) + temp_dir = f'./work_dirs/{args.job}/{args.data_name}/{args.region}' + save_dir = args.save_dir or temp_dir + os.makedirs(save_dir, exist_ok=True) + for ind in tqdm(range(Jacobians.shape[0])): + Jacobian = Jacobians[ind] + if len(Jacobian.shape) == 4: # [H, W, 1, latent_dim] + Jaco_fore = Jacobian[foreground_ind[0], foreground_ind[1], 0] + Jaco_back = Jacobian[background_ind[0], background_ind[1], 0] + elif len(Jacobian.shape) == 5: # [channel, H, W, 1, latent_dim] + Jaco_fore = Jacobian[:, foreground_ind[0], foreground_ind[1], 0] + Jaco_back = Jacobian[:, background_ind[0], background_ind[1], 0] + else: + raise ValueError('Shape of the Jacobian is not correct!') + Jaco_fore = np.reshape(Jaco_fore, [-1, w_dim]) + Jaco_back = np.reshape(Jaco_back, [-1, w_dim]) + coef_f = 1 / Jaco_fore.shape[0] + coef_b = 1 / Jaco_back.shape[0] + M_fore = coef_f * Jaco_fore.T.dot(Jaco_fore) + M_back = coef_b * Jaco_back.T.dot(Jaco_back) + if args.full_rank: + # J = J_b^TJ_b + # J = (J + tao * trace(J) * I) + print('Using full rank') + coef = args.tao * np.trace(M_back) + M_back = M_back + coef * np.identity(M_back.shape[0]) + # inv(B) * A = lambda x + temp = np.linalg.inv(M_back).dot(M_fore) + eig_val, eig_vec = np.linalg.eig(temp) + eig_val = np.real(eig_val) + eig_vec = np.real(eig_vec) + directions = eig_vec.T + directions = directions[np.argsort(-eig_val)] + save_name = f'{save_dir}/image_{ind:02d}_region_{args.region}' \ + f'_name_{args.name}' + np.save(f'{save_name}.npy', directions) + mask_i = np.tile(mask[:, :, np.newaxis], [1, 1, 3]) * 255 + save_image(f'{save_name}_mask.png', mask_i.astype(np.uint8)) + + +if __name__ == '__main__': + main() diff --git a/compute_jacobian.py b/compute_jacobian.py new file mode 100644 index 0000000000000000000000000000000000000000..36749ef6a62b10bdfd3cb23a0f5881b81d66cc49 --- /dev/null +++ b/compute_jacobian.py @@ -0,0 +1,200 @@ +# python3.7 +"""Functions to compute Jacobian based on pre-trained GAN generator. + +Support StyleGAN2 or StyleGAN3 +""" + +import os +import argparse +import warnings +from tqdm import tqdm +import numpy as np + +import torch +import torch.nn.functional as F +from torch.autograd.functional import jacobian +from models import build_model +from utils.image_utils import save_image +from utils.image_utils import postprocess_image +from utils.custom_utils import to_numpy + + +warnings.filterwarnings(action='ignore', category=UserWarning) + + +def parse_args(): + """Parses arguments.""" + parser = argparse.ArgumentParser() + group = parser.add_argument_group('General options.') + group.add_argument('weight_path', type=str, + help='Weight path to the pre-trained model.') + group.add_argument('--save_dir', type=str, default=None, + help='Directory to save the results. If not specified, ' + 'the results will be saved to ' + '`work_dirs/{TASK_SPECIFIC}/` by default.') + group.add_argument('--job', type=str, default='jacobians', + help='Name for the job (default: jacobians)') + group.add_argument('--seed', type=int, default=4, + help='Seed for sampling. (default: 4)') + group.add_argument('--nums', type=int, default=5, + help='Number of samples to synthesized. (default: 5)') + group.add_argument('--img_size', type=int, default=1024, + help='Size of the synthesized images. (default: 1024)') + group.add_argument('--w_dim', type=int, default=512, + help='Dimension of the latent w. (default: 512)') + group.add_argument('--save_jpg', action='store_false', + help='Whether to save the images used to compute ' + 'jacobians. (default: True)') + group.add_argument('-d', '--data_name', type=str, default='ffhq', + help='Name of the datasets. (default: ffhq)') + group.add_argument('--latent_path', type=str, default='', + help='Path to the given latent codes. (default: None)') + + group = parser.add_argument_group('StyleGAN2') + group.add_argument('--stylegan2', action='store_true', + help='Whether or not using StyleGAN2. (default: False)') + group.add_argument('--scale_stylegan2', type=float, default=1.0, + help='Scale for the number of channel fro stylegan2.') + group.add_argument('--randomize_noise', type=str, default='const', + help='Noise type when computing. (const or random)') + + group = parser.add_argument_group('StyleGAN3') + group.add_argument('--stylegan3', action='store_true', + help='Whether or not using StyleGAN3. (default: False)') + group.add_argument('--cfg', type=str, default='T', + help='Config of the stylegan3 (T/R).') + group.add_argument('--scale_stylegan3r', type=float, default=2.0, + help='Scale for the number of channel for stylegan3 R.') + group.add_argument('--scale_stylegan3t', type=float, default=1.0, + help='Scale for the number of channel for stylegan3 T.') + group.add_argument('--tx', type=float, default=0, + help='Translate X-coordinate. (default: 0.0)') + group.add_argument('--ty', type=float, default=0, + help='Translate Y-coordinate. (default: 0.0)') + group.add_argument('--rotate', type=float, default=0, + help='Rotation angle in degrees. (default: 0)') + + group = parser.add_argument_group('Jacobians') + group.add_argument('--b', type=float, default=1e-3, + help='Constant when computing jacobians fast.') + group.add_argument('--batch_size', type=int, default=4, + help='Batch size. (default: 4)') + return parser.parse_args() + + +def main(): + """Main function.""" + args = parse_args() + # Parse model configuration. + assert (args.stylegan2 and not args.stylegan3) or \ + (not args.stylegan2 and args.stylegan3) + job_disc = '' + if args.stylegan2: + config = dict(model_type='StyleGAN2Generator', + resolution=args.img_size, + w_dim=args.w_dim, + fmaps_base=int(args.scale_stylegan2 * (32 << 10)), + fmaps_max=512,) + job_disc += 'stylegan2' + else: + if args.stylegan3 and args.cfg == 'R': + config = dict(model_type='StyleGAN3Generator', + resolution=args.img_size, + w_dim=args.w_dim, + fmaps_base=int(args.scale_stylegan3r * (32 << 10)), + fmaps_max=1024, + use_radial_filter=True,) + job_disc += 'stylegan3r' + elif args.stylegan3 and args.cfg == 'T': + config = dict(model_type='StyleGAN3Generator', + resolution=args.img_size, + w_dim=args.w_dim, + fmaps_base=int(args.scale_stylegan3t * (32 << 10)), + fmaps_max=512, + use_radial_filter=False, + kernel_size=3,) + job_disc += 'stylegan3t' + else: + raise TypeError(f'StyleGAN3 config type error, need `R/T`,' + f' but got {args.cfg}') + job_name = f'seed_{args.seed}_num_{args.nums}_{job_disc}' + temp_dir = f'work_dirs/{args.job}/{args.data_name}/{job_name}' + save_dir = args.save_dir or temp_dir + os.makedirs(save_dir, exist_ok=True) + if args.save_jpg: + os.makedirs(f'{save_dir}/images', exist_ok=True) + + print('Building generator...') + generator = build_model(**config) + checkpoint_path = args.weight_path + print(f'Loading checkpoint from `{checkpoint_path}` ...') + checkpoint = torch.load(checkpoint_path, map_location='cpu')['models'] + if 'generator_smooth' in checkpoint: + generator.load_state_dict(checkpoint['generator_smooth']) + else: + generator.load_state_dict(checkpoint['generator']) + generator = generator.eval().cuda() + print('Finish loading checkpoint.') + + # Set random seed. + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if os.path.exists(args.latent_path): + latent_zs = np.load(args.latent_path) + latent_zs = latent_zs[:args.nums] + else: + latent_zs = np.random.randn(args.nums, generator.z_dim) + latent_zs = torch.from_numpy(latent_zs.astype(np.float32)) + latent_zs = latent_zs.cuda() + with torch.no_grad(): + latent_ws = generator.mapping(latent_zs)['w'] + print(f'Shape of the latent w: {latent_ws.shape}') + + def syn2jaco(w): + """Wrap the synthesized function to compute the Jacobian easily. + + Basically, this function defines a generator that takes the input + from the W space and then synthesizes an image. If the image is + larger than 256, it will be resized to 256 to save the time and + storage. + + Args: + w: latent code from the W space + + Returns: + An image with the size of [1, 256, 256] + """ + wp = w.unsqueeze(1).repeat((1, generator.num_layers, 1)) + image = generator.synthesis(wp)['image'] + if image.shape[-1] > 256: + scale = 256 / image.shape[-1] + image = F.interpolate(image, scale_factor=scale) + image = torch.sum(image, dim=1) + return image + + jacobians = [] + for idx in tqdm(range(latent_zs.shape[0])): + latent_w = latent_ws[idx:idx+1] + jac_i = jacobian(func=syn2jaco, + inputs=latent_w, + create_graph=False, + strict=False) + jacobians.append(jac_i) + if args.save_jpg: + wp = latent_w.unsqueeze(1).repeat((1, generator.num_layers, 1)) + syn_outputs = generator.synthesis(wp)['image'] + syn_outputs = to_numpy(syn_outputs) + images = postprocess_image(syn_outputs) + save_path = f'{save_dir}/images/{idx:06d}.jpg' + save_image(save_path, images[0]) + jacobians = torch.cat(jacobians, dim=0) + jacobians = to_numpy(jacobians) + print(f'shape of the jacobian: {jacobians.shape}') + latent_ws = to_numpy(latent_ws) + np.save(f'{save_dir}/latent_codes.npy', latent_ws) + np.save(f'{save_dir}/jacobians_w.npy', jacobians) + print(f'Finish computing {args.nums} jacobians.') + + +if __name__ == '__main__': + main() diff --git a/coordinate.py b/coordinate.py new file mode 100644 index 0000000000000000000000000000000000000000..905c08763b7cd7168eed7663102f0c092cb40e59 --- /dev/null +++ b/coordinate.py @@ -0,0 +1,142 @@ +# python3.7 +"""Utility functions to help define the region coordinates within an image.""" + +import os +from glob import glob +import argparse +import numpy as np +import cv2 +from tqdm import tqdm +from utils.parsing_utils import parse_index + + +def get_mask_by_coordinates(image_size, coordinate): + """Get mask using the provided coordinates.""" + mask = np.zeros([image_size, image_size], dtype=np.float32) + center_x, center_y = coordinate[0], coordinate[1] + crop_x, crop_y = coordinate[2], coordinate[3] + xx = center_x - crop_x // 2 + yy = center_y - crop_y // 2 + mask[xx:xx + crop_x, yy:yy + crop_y] = 1. + return mask + + +def get_mask_by_segmentation(seg_mask, label): + """Get the mask using the segmentation array and labels.""" + zeros = np.zeros_like(seg_mask) + ones = np.ones_like(seg_mask) + mask = np.where(seg_mask == label, ones, zeros) + return mask + + +def get_mask(image_size, coordinate=None, seg_mask=None, labels='1'): + """Get mask using either the coordinate or the segmentation array.""" + if coordinate is not None: + print('Using coordinate to get mask!') + mask = get_mask_by_coordinates(image_size, coordinate) + else: + print('Using segmentation to get the mask!') + print(f'Using label {labels}') + mask = np.zeros_like(seg_mask) + for label_ in labels: + mask += get_mask_by_segmentation(seg_mask, int(label_)) + mask = np.clip(mask, a_min=0, a_max=1) + + return mask + + +# For FFHQ [center_x, center_y, height, width] +# Those coordinates are suitable for both ffhq and metface. +COORDINATE_ffhq = {'left_eye': [120, 95, 20, 38], + 'right_eye': [120, 159, 20, 38], + 'eyes': [120, 128, 20, 115], + 'nose': [142, 131, 40, 46], + 'mouth': [184, 127, 30, 70], + 'chin': [217, 130, 42, 110], + 'eyebrow': [126, 105, 15, 118], + } + + +# For FFHQ unaligned +COORDINATE_ffhqu = {'eyesr2': [134, 116, 30, 115], + 'eyesr3': [64, 128, 26, 115], + 'eyest0': [70, 88, 30, 115], + 'eyest3': [108, 142, 26, 115], + } + +# [center_x, center_y, height, width] +COORDINATE_biggan = {'center0': [120, 120, 80, 80], + 'center1': [120, 120, 130, 130], + 'center2': [120, 120, 200, 200], + 'left_side': [128, 64, 256, 128], + 'top_side': [64, 128, 128, 256], + 'head0': [89, 115, 49, 70], + 'head1': [93, 110, 48, 70]} + + +COORDINATES = {'ffhq': COORDINATE_ffhq, + 'ffhqu': COORDINATE_ffhqu, + 'biggan': COORDINATE_biggan + } + + +def parse_args(): + """Parses arguments.""" + + parser = argparse.ArgumentParser() + parser.add_argument('--image_path', type=str, default='', + help='The path to the image.') + parser.add_argument('--mask_path', type=str, default='', + help='The path to the mask.') + parser.add_argument('--save_dir', type=str, default='', + help='The path to the image.') + parser.add_argument('--label', type=str, default=None, + help='The label number in the mask.') + parser.add_argument('--data', type=str, default='ffhq', + help='The name of the dataset to test.') + parser.add_argument('--num', type=int, default=0, + help='number of image to display.') + parser.add_argument('--img_type', type=str, default='jpeg', + help='Format of the image.') + + return parser.parse_args() + + +def main(): + """Main function to show an image with masks""" + args = parse_args() + save_dir = args.save_dir or './temp_mask' + os.makedirs(save_dir, exist_ok=True) + images = sorted(glob(f'{args.image_path}/*.{args.img_type}'))[args.num:] + label_files = sorted(glob(f'{args.mask_path}/*.npy'))[args.num:] + COORDINATE = COORDINATES[args.data] + for i, image in tqdm(enumerate(images)): + img = cv2.imread(image) + im_name = image.split('/')[-1].split('.')[0] + if args.label is None: + for name, coord in COORDINATE.items(): + if len(coord) == 0: + continue + mask = np.zeros(img.shape, dtype=np.float32) + center_x, center_y = coord[0], coord[1] + crop_x, crop_y = coord[2], coord[3] + xx = center_x - crop_x // 2 + yy = center_y - crop_y // 2 + mask[xx:xx + crop_x, yy:yy + crop_y, :] = 1. + img_ = img * mask + cv2.imwrite(f'{save_dir}/{im_name}_{name}.png', img_) + else: + print('Using segmentation to get the mask!') + seg_mask = np.load(label_files[i]) + labels = parse_index(args.label) + print(f'Using label {labels}') + mask = np.zeros_like(seg_mask) + for label_ in labels: + mask += get_mask_by_segmentation(seg_mask, int(label_)) + mask = np.clip(mask, a_min=0, a_max=1) + img_ = img * mask[:, :, np.newaxis] + cv2.imwrite(f'{save_dir}/{im_name}_{args.label}.png', img_) + + +if __name__ == '__main__': + main() diff --git a/directions/.DS_Store b/directions/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..f96601a8f2c1ed91a80d70554595eb7268b27ab0 Binary files /dev/null and b/directions/.DS_Store differ diff --git a/directions/afhq/.DS_Store b/directions/afhq/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c471f2ae94c589e5ba0ad9a0614d98cdf268045a Binary files /dev/null and b/directions/afhq/.DS_Store differ diff --git a/directions/afhq/stylegan3/eyes-r.npy b/directions/afhq/stylegan3/eyes-r.npy new file mode 100644 index 0000000000000000000000000000000000000000..cd68e0511d19120816482f400284e9028ad2d1d8 --- /dev/null +++ b/directions/afhq/stylegan3/eyes-r.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:959982185cd0401b8ab984ad2c1c22c01a494d8928a765e353c4fc34e9b079a2 +size 2176 diff --git a/directions/ffhq/stylegan2/eyebrows.npy b/directions/ffhq/stylegan2/eyebrows.npy new file mode 100644 index 0000000000000000000000000000000000000000..37e074e6fee152e9d42d1b499f86a9798b948ed2 --- /dev/null +++ b/directions/ffhq/stylegan2/eyebrows.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8a7385ee951fa93b34044a768254a937327c87b0d86a569b62ae4593ae0b765 +size 2176 diff --git a/directions/ffhq/stylegan2/eyesize.npy b/directions/ffhq/stylegan2/eyesize.npy new file mode 100644 index 0000000000000000000000000000000000000000..295ced310c83a5c4f9f30d34875fbae4b052ddcc --- /dev/null +++ b/directions/ffhq/stylegan2/eyesize.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74d417489944e2f2c6ee9622c4a3745edf20edad5837f5d4b2d89847003d0ec5 +size 2176 diff --git a/directions/ffhq/stylegan2/gaze_direction.npy b/directions/ffhq/stylegan2/gaze_direction.npy new file mode 100644 index 0000000000000000000000000000000000000000..9460c5549d56b84257030b78378ad01cb663d3a0 --- /dev/null +++ b/directions/ffhq/stylegan2/gaze_direction.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98c7787363716096706db4f6cc039cf5ffc773142a0f790ed185300ddaf8fc0e +size 2176 diff --git a/directions/ffhq/stylegan2/lipstick.npy b/directions/ffhq/stylegan2/lipstick.npy new file mode 100644 index 0000000000000000000000000000000000000000..47467ad63851177b7f4d0d5a6eab02d39fd0e43a --- /dev/null +++ b/directions/ffhq/stylegan2/lipstick.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ce301bb70f35e10c06b9d6fbb485644a65ebf1ac57a080e70fc3448b3343dc6 +size 2176 diff --git a/directions/ffhq/stylegan2/mouth.npy b/directions/ffhq/stylegan2/mouth.npy new file mode 100644 index 0000000000000000000000000000000000000000..d68667a5dee3a7199b754a4de3d5a037739f7ce1 --- /dev/null +++ b/directions/ffhq/stylegan2/mouth.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cf4fe4565fd2d3f363d0b2e7de1073f1379c8488113b99f8d3bad9ba53a0fa4 +size 2176 diff --git a/directions/ffhq/stylegan2/nose_length.npy b/directions/ffhq/stylegan2/nose_length.npy new file mode 100644 index 0000000000000000000000000000000000000000..02a0f8174ac3dd69c849c56b6f5b33cd5234f144 --- /dev/null +++ b/directions/ffhq/stylegan2/nose_length.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a7d3685b38f18e87a49cc3a79106f18dbdf30b1e8d3db5df21c84d592566ea5 +size 2176 diff --git a/directions/ffhq/stylegan3/eyes-r.npy b/directions/ffhq/stylegan3/eyes-r.npy new file mode 100644 index 0000000000000000000000000000000000000000..c18a51d878a4501b824ff37eafdf917d58138e4d --- /dev/null +++ b/directions/ffhq/stylegan3/eyes-r.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91f7c4432a030568b1aac414c30e7f00cf4d441228a625eb8d971f5f28b036db +size 2176 diff --git a/manipulate.py b/manipulate.py new file mode 100644 index 0000000000000000000000000000000000000000..e15c2fa0c32102731ce0d085a00761fb4d25b8ac --- /dev/null +++ b/manipulate.py @@ -0,0 +1,253 @@ +# python3.7 +"""Manipulates synthesized or real images with existing boundary. + +Support StyleGAN2 and StyleGAN3. +""" + +import os.path +import argparse +import numpy as np +from tqdm import tqdm +import torch + +from models import build_model +from utils.visualizers.html_visualizer import HtmlVisualizer +from utils.image_utils import save_image +from utils.parsing_utils import parse_index +from utils.image_utils import postprocess_image +from utils.custom_utils import to_numpy, linear_interpolate +from utils.custom_utils import make_transform + + +def parse_args(): + """Parses arguments.""" + parser = argparse.ArgumentParser() + group = parser.add_argument_group('General options.') + group.add_argument('weight_path', type=str, + help='Weight path to the pre-trained model.') + group.add_argument('boundary_path', type=str, + help='Path to the attribute vectors.') + group.add_argument('--save_dir', type=str, default=None, + help='Directory to save the results. If not specified, ' + 'the results will be saved to ' + '`work_dirs/{TASK_SPECIFIC}/` by default.') + group.add_argument('--job', type=str, default='manipulations', + help='Name for the job. (default: manipulations)') + group.add_argument('--seed', type=int, default=4, + help='Seed for sampling. (default: 4)') + group.add_argument('--nums', type=int, default=10, + help='Number of samples to synthesized. (default: 10)') + group.add_argument('--img_size', type=int, default=1024, + help='Size of the synthesized images. (default: 1024)') + group.add_argument('--vis_size', type=int, default=256, + help='Size of the visualize images. (default: 256)') + group.add_argument('--w_dim', type=int, default=512, + help='Dimension of the latent w. (default: 512)') + group.add_argument('--batch_size', type=int, default=4, + help='Batch size. (default: 4)') + group.add_argument('--save_jpg', action='store_true', default=False, + help='Whether to save raw image. (default: False)') + group.add_argument('-d', '--data_name', type=str, default='ffhq', + help='Name of the datasets. (default: ffhq)') + group.add_argument('--latent_path', type=str, default='', + help='Path to the given latent codes. (default: None)') + group.add_argument('--trunc_psi', type=float, default=0.7, + help='Psi factor used for truncation. (default: 0.7)') + group.add_argument('--trunc_layers', type=int, default=8, + help='Number of layers to perform truncation.' + ' (default: 8)') + group.add_argument('--name', type=str, default='resefa', + help='Name of help save the results.') + + group = parser.add_argument_group('StyleGAN2') + group.add_argument('--stylegan2', action='store_true', + help='Whether or not using StyleGAN2. (default: False)') + group.add_argument('--scale_stylegan2', type=float, default=1.0, + help='Scale for the number of channel fro stylegan2.') + group.add_argument('--randomize_noise', type=str, default='const', + help='Noise type when editing. (const or random)') + + group = parser.add_argument_group('StyleGAN3') + group.add_argument('--stylegan3', action='store_true', + help='Whether or not using StyleGAN3. (default: False)') + group.add_argument('--cfg', type=str, default='T', + help='Config of the stylegan3 (T/R)') + group.add_argument('--scale_stylegan3r', type=float, default=2.0, + help='Scale for the number of channel for stylegan3 R.') + group.add_argument('--scale_stylegan3t', type=float, default=1.0, + help='Scale for the number of channel for stylegan3 T.') + group.add_argument('--tx', type=float, default=0, + help='Translate X-coordinate. (default: 0.0)') + group.add_argument('--ty', type=float, default=0, + help='Translate Y-coordinate. (default: 0.0)') + group.add_argument('--rotate', type=float, default=0, + help='Rotation angle in degrees. (default: 0)') + + group = parser.add_argument_group('Manipulation') + group.add_argument('--mani_layers', type=str, default='4,5,6,7', + help='The layers will be manipulated.' + '(default: 4,5,6,7). For the eyebrow and lipstick,' + 'using [8-11] layers instead.') + group.add_argument('--step', type=int, default=7, + help='Number of manipulation steps. (default: 7)') + group.add_argument('--start', type=int, default=0, + help='The start index of the manipulation directions.') + group.add_argument('--end', type=int, default=1, + help='The end index of the manipulation directions.') + group.add_argument('--start_distance', type=float, default=-10.0, + help='Start distance for manipulation. (default: -10.0)') + group.add_argument('--end_distance', type=float, default=10.0, + help='End distance for manipulation. (default: 10.0)') + + return parser.parse_args() + + +def main(): + """Main function.""" + args = parse_args() + # Parse model configuration. + assert (args.stylegan2 and not args.stylegan3) or \ + (not args.stylegan2 and args.stylegan3) + checkpoint_path = args.weight_path + boundary_path = args.boundary_path + assert os.path.exists(checkpoint_path) + assert os.path.exists(boundary_path) + boundary_name = os.path.splitext(os.path.basename(boundary_path))[0] + job_disc = '' + if args.stylegan2: + config = dict(model_type='StyleGAN2Generator', + resolution=args.img_size, + w_dim=args.w_dim, + fmaps_base=int(args.scale_stylegan2 * (32 << 10)), + fmaps_max=512,) + job_disc += 'stylegan2' + else: + if args.stylegan3 and args.cfg == 'R': + config = dict(model_type='StyleGAN3Generator', + resolution=args.img_size, + w_dim=args.w_dim, + fmaps_base=int(args.scale_stylegan3r * (32 << 10)), + fmaps_max=1024, + use_radial_filter=True,) + job_disc += 'stylegan3r' + elif args.stylegan3 and args.cfg == 'T': + config = dict(model_type='StyleGAN3Generator', + resolution=args.img_size, + w_dim=args.w_dim, + fmaps_base=int(args.scale_stylegan3t * (32 << 10)), + fmaps_max=512, + use_radial_filter=False, + kernel_size=3,) + job_disc += 'stylegan3t' + else: + raise TypeError(f'StyleGAN3 config type error, need `R/T`,' + f' but got {args.cfg} instead.') + + # Get work directory and job name. + save_dir = args.save_dir or f'work_dirs/{args.job}/{args.data_name}' + os.makedirs(save_dir, exist_ok=True) + job_name = f'seed_{args.seed}_num_{args.nums}_{job_disc}_{boundary_name}' + os.makedirs(f'{save_dir}/{job_name}', exist_ok=True) + + print('Building generator...') + generator = build_model(**config) + print(f'Loading checkpoint from `{checkpoint_path}` ...') + checkpoint = torch.load(checkpoint_path, map_location='cpu')['models'] + if 'generator_smooth' in checkpoint: + generator.load_state_dict(checkpoint['generator_smooth']) + else: + generator.load_state_dict(checkpoint['generator']) + generator = generator.eval().cuda() + print('Finish loading checkpoint.') + if args.stylegan3 and hasattr(generator.synthesis, 'early_layer'): + m = make_transform(args.tx, args.ty, args.rotate) + m = np.linalg.inv(m) + generator.synthesis.early_layer.transform.copy_(torch.from_numpy(m)) + + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if os.path.exists(args.latent_path): + print(f'Load latent codes from {args.latent_path}') + latent_zs = np.load(args.latent_path) + latent_zs = latent_zs[:args.nums] + else: + print('Sampling latent code randomly') + latent_zs = np.random.randn(args.nums, generator.z_dim) + latent_zs = torch.from_numpy(latent_zs.astype(np.float32)) + latent_zs = latent_zs.cuda() + num_images = latent_zs.shape[0] + wp = [] + for idx in range(0, num_images, args.batch_size): + latent_z = latent_zs[idx:idx+args.batch_size] + latent_w_ = generator.mapping(latent_z, None)['wp'] + wp.append(latent_w_) + wp = torch.cat(wp, dim=0) + trunc_psi = args.trunc_psi + trunc_layers = args.trunc_layers + if trunc_psi < 1.0 and trunc_layers > 0: + w_avg = generator.w_avg + w_avg = w_avg.reshape(1, -1, generator.w_dim)[:, :trunc_layers] + wp[:, :trunc_layers] = w_avg.lerp(wp[:, :trunc_layers], trunc_psi) + print(f'Shape of the latent ws: {wp.shape}') + image_list = [] + for i in range(num_images): + image_list.append(f'{i:06d}') + + print('Loading boundary.') + directions = np.load(boundary_path) + layer_index = parse_index(args.mani_layers) + if not layer_index: + layer_index = list(range(generator.num_layers - 1)) + print(f'Manipulating on layers `{layer_index}`.') + + vis_size = None if args.vis_size == 0 else args.vis_size + delta_num = args.end - args.start + visualizer = HtmlVisualizer(num_rows=num_images * delta_num, + num_cols=args.step + 2, + image_size=vis_size) + visualizer.set_headers( + ['Name', 'Origin'] + + [f'Step {i:02d}' for i in range(1, args.step + 1)] + ) + # Manipulate images. + print('Start manipulation.') + for row in tqdm(range(num_images)): + latent_w = wp[row:row+1] + images_ori = generator.synthesis(latent_w)['image'] + images_ori = postprocess_image(to_numpy(images_ori)) + if args.save_jpg: + save_image(f'{save_dir}/{job_name}/{row:06d}_orin.jpg', + images_ori[0]) + for num_direc in range(args.start, args.end): + html_row = num_direc - args.start + direction = directions[num_direc:num_direc+1] + direction = np.tile(direction, [1, generator.num_layers, 1]) + visualizer.set_cell(row * delta_num + html_row, 0, + text=f'{image_list[row]}_{num_direc:03d}') + visualizer.set_cell(row * delta_num + html_row, 1, + image=images_ori[0]) + mani_codes = linear_interpolate(latent_code=to_numpy(latent_w), + boundary=direction, + layer_index=layer_index, + start_distance=args.start_distance, + end_distance=args.end_distance, + steps=args.step) + mani_codes = torch.from_numpy(mani_codes.astype(np.float32)).cuda() + for idx in range(0, mani_codes.shape[0], args.batch_size): + codes_ = mani_codes[idx:idx+args.batch_size] + images_ = generator.synthesis(codes_)['image'] + images_ = postprocess_image(to_numpy(images_)) + for i in range(images_.shape[0]): + visualizer.set_cell(row * delta_num + html_row, idx+i+2, + image=images_[i]) + if args.save_jpg: + save_image(f'{save_dir}/{job_name}/{row:06d}_ind_' + f'{num_direc:06d}_mani_{idx+i:06d}.jpg', + images_[i]) + # Save results. + np.save(f'{save_dir}/{job_name}/latent_codes.npy', to_numpy(wp)) + visualizer.save(f'{save_dir}/{job_name}_{args.name}.html') + + +if __name__ == '__main__': + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51eeba1c1a5f08d9af654c806421ec788c8d5377 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,45 @@ +# python3.7 +"""Collects all models.""" + +from .pggan_generator import PGGANGenerator +from .pggan_discriminator import PGGANDiscriminator +from .stylegan_generator import StyleGANGenerator +from .stylegan_discriminator import StyleGANDiscriminator +from .stylegan2_generator import StyleGAN2Generator +from .stylegan2_discriminator import StyleGAN2Discriminator +from .stylegan3_generator import StyleGAN3Generator +from .ghfeat_encoder import GHFeatEncoder +from .perceptual_model import PerceptualModel +from .inception_model import InceptionModel + +__all__ = ['build_model'] + +_MODELS = { + 'PGGANGenerator': PGGANGenerator, + 'PGGANDiscriminator': PGGANDiscriminator, + 'StyleGANGenerator': StyleGANGenerator, + 'StyleGANDiscriminator': StyleGANDiscriminator, + 'StyleGAN2Generator': StyleGAN2Generator, + 'StyleGAN2Discriminator': StyleGAN2Discriminator, + 'StyleGAN3Generator': StyleGAN3Generator, + 'GHFeatEncoder': GHFeatEncoder, + 'PerceptualModel': PerceptualModel.build_model, + 'InceptionModel': InceptionModel.build_model +} + + +def build_model(model_type, **kwargs): + """Builds a model based on its class type. + + Args: + model_type: Class type to which the model belongs, which is case + sensitive. + **kwargs: Additional arguments to build the model. + + Raises: + ValueError: If the `model_type` is not supported. + """ + if model_type not in _MODELS: + raise ValueError(f'Invalid model type: `{model_type}`!\n' + f'Types allowed: {list(_MODELS)}.') + return _MODELS[model_type](**kwargs) diff --git a/models/ghfeat_encoder.py b/models/ghfeat_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..dba2ec7fe8e03df2ae9fda81daf087deb218cdf0 --- /dev/null +++ b/models/ghfeat_encoder.py @@ -0,0 +1,563 @@ +# python3.7 +"""Contains the implementation of encoder used in GH-Feat (including IDInvert). + +ResNet is used as the backbone. + +GH-Feat paper: https://arxiv.org/pdf/2007.10379.pdf +IDInvert paper: https://arxiv.org/pdf/2004.00049.pdf + +NOTE: Please use `latent_num` and `num_latents_per_head` to control the +inversion space, such as Y-space used in GH-Feat and W-space used in IDInvert. +In addition, IDInvert sets `use_fpn` and `use_sam` as `False` by default. +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +__all__ = ['GHFeatEncoder'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# pylint: disable=missing-function-docstring + +class BasicBlock(nn.Module): + """Implementation of ResNet BasicBlock.""" + + expansion = 1 + + def __init__(self, + inplanes, + planes, + base_width=64, + stride=1, + groups=1, + dilation=1, + norm_layer=None, + downsample=None): + super().__init__() + if base_width != 64: + raise ValueError(f'BasicBlock of ResNet only supports ' + f'`base_width=64`, but {base_width} received!') + if stride not in [1, 2]: + raise ValueError(f'BasicBlock of ResNet only supports `stride=1` ' + f'and `stride=2`, but {stride} received!') + if groups != 1: + raise ValueError(f'BasicBlock of ResNet only supports `groups=1`, ' + f'but {groups} received!') + if dilation != 1: + raise ValueError(f'BasicBlock of ResNet only supports ' + f'`dilation=1`, but {dilation} received!') + assert self.expansion == 1 + + self.stride = stride + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = nn.Conv2d(in_channels=inplanes, + out_channels=planes, + kernel_size=3, + stride=stride, + padding=1, + groups=1, + dilation=1, + bias=False) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(in_channels=planes, + out_channels=planes, + kernel_size=3, + stride=1, + padding=1, + groups=1, + dilation=1, + bias=False) + self.bn2 = norm_layer(planes) + self.downsample = downsample + + def forward(self, x): + identity = self.downsample(x) if self.downsample is not None else x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out + identity) + + return out + + +class Bottleneck(nn.Module): + """Implementation of ResNet Bottleneck.""" + + expansion = 4 + + def __init__(self, + inplanes, + planes, + base_width=64, + stride=1, + groups=1, + dilation=1, + norm_layer=None, + downsample=None): + super().__init__() + if stride not in [1, 2]: + raise ValueError(f'Bottleneck of ResNet only supports `stride=1` ' + f'and `stride=2`, but {stride} received!') + + width = int(planes * (base_width / 64)) * groups + self.stride = stride + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.conv1 = nn.Conv2d(in_channels=inplanes, + out_channels=width, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False) + self.bn1 = norm_layer(width) + self.conv2 = nn.Conv2d(in_channels=width, + out_channels=width, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + dilation=dilation, + bias=False) + self.bn2 = norm_layer(width) + self.conv3 = nn.Conv2d(in_channels=width, + out_channels=planes * self.expansion, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + identity = self.downsample(x) if self.downsample is not None else x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out + identity) + + return out + + +class GHFeatEncoder(nn.Module): + """Define the ResNet-based encoder network for GAN inversion. + + On top of the backbone, there are several task-heads to produce inverted + codes. Please use `latent_dim` and `num_latents_per_head` to define the + structure. For example, `latent_dim = [512] * 14` and + `num_latents_per_head = [4, 4, 6]` can be used for StyleGAN inversion with + 14-layer latent codes, where 3 task heads (corresponding to 4, 4, 6 layers, + respectively) are used. + + Settings for the encoder network: + + (1) resolution: The resolution of the output image. + (2) latent_dim: Dimension of the latent space. A number (one code will be + produced), or a list of numbers regarding layer-wise latent codes. + (3) num_latents_per_head: Number of latents that is produced by each head. + (4) image_channels: Number of channels of the output image. (default: 3) + (5) final_res: Final resolution of the convolutional layers. (default: 4) + + ResNet-related settings: + + (1) network_depth: Depth of the network, like 18 for ResNet18. (default: 18) + (2) inplanes: Number of channels of the first convolutional layer. + (default: 64) + (3) groups: Groups of the convolution, used in ResNet. (default: 1) + (4) width_per_group: Number of channels per group, used in ResNet. + (default: 64) + (5) replace_stride_with_dilation: Whether to replace stride with dilation, + used in ResNet. (default: None) + (6) norm_layer: Normalization layer used in the encoder. If set as `None`, + `nn.BatchNorm2d` will be used. Also, please NOTE that when using batch + normalization, the batch size is required to be larger than one for + training. (default: nn.BatchNorm2d) + (7) max_channels: Maximum number of channels in each layer. (default: 512) + + Task-head related settings: + + (1) use_fpn: Whether to use Feature Pyramid Network (FPN) before outputting + the latent code. (default: True) + (2) fpn_channels: Number of channels used in FPN. (default: 512) + (3) use_sam: Whether to use Spatial Alignment Module (SAM) before outputting + the latent code. (default: True) + (4) sam_channels: Number of channels used in SAM. (default: 512) + """ + + arch_settings = { + 18: (BasicBlock, [2, 2, 2, 2]), + 34: (BasicBlock, [3, 4, 6, 3]), + 50: (Bottleneck, [3, 4, 6, 3]), + 101: (Bottleneck, [3, 4, 23, 3]), + 152: (Bottleneck, [3, 8, 36, 3]) + } + + def __init__(self, + resolution, + latent_dim, + num_latents_per_head, + image_channels=3, + final_res=4, + network_depth=18, + inplanes=64, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm_layer=nn.BatchNorm2d, + max_channels=512, + use_fpn=True, + fpn_channels=512, + use_sam=True, + sam_channels=512): + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + if network_depth not in self.arch_settings: + raise ValueError(f'Invalid network depth: `{network_depth}`!\n' + f'Options allowed: ' + f'{list(self.arch_settings.keys())}.') + if isinstance(latent_dim, int): + latent_dim = [latent_dim] + assert isinstance(latent_dim, (list, tuple)) + assert isinstance(num_latents_per_head, (list, tuple)) + assert sum(num_latents_per_head) == len(latent_dim) + + self.resolution = resolution + self.latent_dim = latent_dim + self.num_latents_per_head = num_latents_per_head + self.num_heads = len(self.num_latents_per_head) + self.image_channels = image_channels + self.final_res = final_res + self.inplanes = inplanes + self.network_depth = network_depth + self.groups = groups + self.dilation = 1 + self.base_width = width_per_group + self.replace_stride_with_dilation = replace_stride_with_dilation + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if norm_layer == nn.BatchNorm2d and dist.is_initialized(): + norm_layer = nn.SyncBatchNorm + self.norm_layer = norm_layer + self.max_channels = max_channels + self.use_fpn = use_fpn + self.fpn_channels = fpn_channels + self.use_sam = use_sam + self.sam_channels = sam_channels + + block_fn, num_blocks_per_stage = self.arch_settings[network_depth] + + self.num_stages = int(np.log2(resolution // final_res)) - 1 + # Add one block for additional stages. + for i in range(len(num_blocks_per_stage), self.num_stages): + num_blocks_per_stage.append(1) + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False] * self.num_stages + + # Backbone. + self.conv1 = nn.Conv2d(in_channels=self.image_channels, + out_channels=self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.stage_channels = [self.inplanes] + self.stages = nn.ModuleList() + for i in range(self.num_stages): + inplanes = self.inplanes if i == 0 else planes * block_fn.expansion + planes = min(self.max_channels, self.inplanes * (2 ** i)) + num_blocks = num_blocks_per_stage[i] + stride = 1 if i == 0 else 2 + dilate = replace_stride_with_dilation[i] + self.stages.append(self._make_stage(block_fn=block_fn, + inplanes=inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilate=dilate)) + self.stage_channels.append(planes * block_fn.expansion) + + if self.num_heads > len(self.stage_channels): + raise ValueError('Number of task heads is larger than number of ' + 'stages! Please reduce the number of heads.') + + # Task-head. + if self.num_heads == 1: + self.use_fpn = False + self.use_sam = False + + if self.use_fpn: + fpn_pyramid_channels = self.stage_channels[-self.num_heads:] + self.fpn = FPN(pyramid_channels=fpn_pyramid_channels, + out_channels=self.fpn_channels) + if self.use_sam: + if self.use_fpn: + sam_pyramid_channels = [self.fpn_channels] * self.num_heads + else: + sam_pyramid_channels = self.stage_channels[-self.num_heads:] + self.sam = SAM(pyramid_channels=sam_pyramid_channels, + out_channels=self.sam_channels) + + self.heads = nn.ModuleList() + for head_idx in range(self.num_heads): + # Parse in_channels. + if self.use_sam: + in_channels = self.sam_channels + elif self.use_fpn: + in_channels = self.fpn_channels + else: + in_channels = self.stage_channels[head_idx - self.num_heads] + in_channels = in_channels * final_res * final_res + + # Parse out_channels. + start_latent_idx = sum(self.num_latents_per_head[:head_idx]) + end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1]) + out_channels = sum(self.latent_dim[start_latent_idx:end_latent_idx]) + + self.heads.append(CodeHead(in_channels=in_channels, + out_channels=out_channels, + norm_layer=self.norm_layer)) + + def _make_stage(self, + block_fn, + inplanes, + planes, + num_blocks, + stride, + dilate): + norm_layer = self.norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or inplanes != planes * block_fn.expansion: + downsample = nn.Sequential( + nn.Conv2d(in_channels=inplanes, + out_channels=planes * block_fn.expansion, + kernel_size=1, + stride=stride, + padding=0, + dilation=1, + groups=1, + bias=False), + norm_layer(planes * block_fn.expansion), + ) + + blocks = [] + blocks.append(block_fn(inplanes=inplanes, + planes=planes, + base_width=self.base_width, + stride=stride, + groups=self.groups, + dilation=previous_dilation, + norm_layer=norm_layer, + downsample=downsample)) + for _ in range(1, num_blocks): + blocks.append(block_fn(inplanes=planes * block_fn.expansion, + planes=planes, + base_width=self.base_width, + stride=1, + groups=self.groups, + dilation=self.dilation, + norm_layer=norm_layer, + downsample=None)) + + return nn.Sequential(*blocks) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + features = [x] + for i in range(self.num_stages): + x = self.stages[i](x) + features.append(x) + features = features[-self.num_heads:] + + if self.use_fpn: + features = self.fpn(features) + if self.use_sam: + features = self.sam(features) + else: + final_size = features[-1].shape[2:] + for i in range(self.num_heads - 1): + features[i] = F.adaptive_avg_pool2d(features[i], final_size) + + outputs = [] + for head_idx in range(self.num_heads): + codes = self.heads[head_idx](features[head_idx]) + start_latent_idx = sum(self.num_latents_per_head[:head_idx]) + end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1]) + split_size = self.latent_dim[start_latent_idx:end_latent_idx] + outputs.extend(torch.split(codes, split_size, dim=1)) + max_dim = max(self.latent_dim) + for i, dim in enumerate(self.latent_dim): + if dim < max_dim: + outputs[i] = F.pad(outputs[i], (0, max_dim - dim)) + outputs[i] = outputs[i].unsqueeze(1) + + return torch.cat(outputs, dim=1) + + +class FPN(nn.Module): + """Implementation of Feature Pyramid Network (FPN). + + The input of this module is a pyramid of features with reducing resolutions. + Then, this module fuses these multi-level features from `top_level` to + `bottom_level`. In particular, starting from the `top_level`, each feature + is convoluted, upsampled, and fused into its previous feature (which is also + convoluted). + + Args: + pyramid_channels: A list of integers, each of which indicates the number + of channels of the feature from a particular level. + out_channels: Number of channels for each output. + + Returns: + A list of feature maps, each of which has `out_channels` channels. + """ + + def __init__(self, pyramid_channels, out_channels): + super().__init__() + assert isinstance(pyramid_channels, (list, tuple)) + self.num_levels = len(pyramid_channels) + + self.lateral_layers = nn.ModuleList() + self.feature_layers = nn.ModuleList() + for i in range(self.num_levels): + in_channels = pyramid_channels[i] + self.lateral_layers.append(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias=True)) + self.feature_layers.append(nn.Conv2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias=True)) + + def forward(self, inputs): + if len(inputs) != self.num_levels: + raise ValueError('Number of inputs and `num_levels` mismatch!') + + # Project all related features to `out_channels`. + laterals = [] + for i in range(self.num_levels): + laterals.append(self.lateral_layers[i](inputs[i])) + + # Fusion, starting from `top_level`. + for i in range(self.num_levels - 1, 0, -1): + scale_factor = laterals[i - 1].shape[2] // laterals[i].shape[2] + laterals[i - 1] = (laterals[i - 1] + + F.interpolate(laterals[i], + mode='nearest', + scale_factor=scale_factor)) + + # Get outputs. + outputs = [] + for i, lateral in enumerate(laterals): + outputs.append(self.feature_layers[i](lateral)) + + return outputs + + +class SAM(nn.Module): + """Implementation of Spatial Alignment Module (SAM). + + The input of this module is a pyramid of features with reducing resolutions. + Then this module downsamples all levels of feature to the minimum resolution + and fuses it with the smallest feature map. + + Args: + pyramid_channels: A list of integers, each of which indicates the number + of channels of the feature from a particular level. + out_channels: Number of channels for each output. + + Returns: + A list of feature maps, each of which has `out_channels` channels. + """ + + def __init__(self, pyramid_channels, out_channels): + super().__init__() + assert isinstance(pyramid_channels, (list, tuple)) + self.num_levels = len(pyramid_channels) + + self.fusion_layers = nn.ModuleList() + for i in range(self.num_levels): + in_channels = pyramid_channels[i] + self.fusion_layers.append(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias=True)) + + def forward(self, inputs): + if len(inputs) != self.num_levels: + raise ValueError('Number of inputs and `num_levels` mismatch!') + + output_res = inputs[-1].shape[2:] + for i in range(self.num_levels - 1, -1, -1): + if i != self.num_levels - 1: + inputs[i] = F.adaptive_avg_pool2d(inputs[i], output_res) + inputs[i] = self.fusion_layers[i](inputs[i]) + if i != self.num_levels - 1: + inputs[i] = inputs[i] + inputs[-1] + + return inputs + + +class CodeHead(nn.Module): + """Implementation of the task-head to produce inverted codes.""" + + def __init__(self, in_channels, out_channels, norm_layer): + super().__init__() + self.fc = nn.Linear(in_channels, out_channels, bias=True) + if norm_layer is None: + self.norm = nn.Identity() + else: + self.norm = norm_layer(out_channels) + + def forward(self, x): + if x.ndim > 2: + x = x.flatten(start_dim=1) + latent = self.fc(x) + latent = latent.unsqueeze(2).unsqueeze(3) + latent = self.norm(latent) + + return latent.flatten(start_dim=1) + +# pylint: enable=missing-function-docstring diff --git a/models/inception_model.py b/models/inception_model.py new file mode 100644 index 0000000000000000000000000000000000000000..68fe4ece6b6cdc864b7de49719d7714cabfacedf --- /dev/null +++ b/models/inception_model.py @@ -0,0 +1,562 @@ +# python3.7 +"""Contains the Inception V3 model, which is used for inference ONLY. + +This file is mostly borrowed from `torchvision/models/inception.py`. + +Inception model is widely used to compute FID or IS metric for evaluating +generative models. However, the pre-trained models from torchvision is slightly +different from the TensorFlow version + +http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz + +which is used by the official FID implementation + +https://github.com/bioinf-jku/TTUR + +In particular: + +(1) The number of classes in TensorFlow model is 1008 instead of 1000. +(2) The avg_pool() layers in TensorFlow model does not include the padded zero. +(3) The last Inception E Block in TensorFlow model use max_pool() instead of + avg_pool(). + +Hence, to align the evaluation results with those from TensorFlow +implementation, we modified the inception model to support both versions. Please +use `align_tf` argument to control the version. +""" + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from utils.misc import download_url + +__all__ = ['InceptionModel'] + +# pylint: disable=line-too-long + +_MODEL_URL_SHA256 = { + # This model is provided by `torchvision`, which is ported from TensorFlow. + 'torchvision_official': ( + 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', + '1a9a5a14f40645a370184bd54f4e8e631351e71399112b43ad0294a79da290c8' # hash sha256 + ), + + # This model is provided by https://github.com/mseitzer/pytorch-fid + 'tf_inception_v3': ( + 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth', + '6726825d0af5f729cebd5821db510b11b1cfad8faad88a03f1befd49fb9129b2' # hash sha256 + ) +} + + +class InceptionModel(object): + """Defines the Inception (V3) model. + + This is a static class, which is used to avoid this model to be built + repeatedly. Consequently, this model is particularly used for inference, + like computing FID. If training is required, please use the model from + `torchvision.models` or implement by yourself. + + NOTE: The pre-trained model assumes the inputs to be with `RGB` channel + order and pixel range [-1, 1], and will also resize the images to shape + [299, 299] automatically. If your input is normalized by subtracting + (0.485, 0.456, 0.406) and dividing (0.229, 0.224, 0.225), please use + `transform_input` in the `forward()` function to un-normalize it. + """ + models = dict() + + @staticmethod + def build_model(align_tf=True): + """Builds the model and load pre-trained weights. + + If `align_tf` is set as True, the model will predict 1008 classes, and + the pre-trained weight from `https://github.com/mseitzer/pytorch-fid` + will be loaded. Otherwise, the model will predict 1000 classes, and will + load the model from `torchvision`. + + The built model supports following arguments when forwarding: + + - transform_input: Whether to transform the input back to pixel range + (-1, 1). Please disable this argument if your input is already with + pixel range (-1, 1). (default: False) + - output_logits: Whether to output the categorical logits instead of + features. (default: False) + - remove_logits_bias: Whether to remove the bias when computing the + logits. The official implementation removes the bias by default. + Please refer to + `https://github.com/openai/improved-gan/blob/master/inception_score/model.py`. + (default: False) + - output_predictions: Whether to output the final predictions, i.e., + `softmax(logits)`. (default: False) + """ + if align_tf: + num_classes = 1008 + model_source = 'tf_inception_v3' + else: + num_classes = 1000 + model_source = 'torchvision_official' + + fingerprint = model_source + + if fingerprint not in InceptionModel.models: + # Build model. + model = Inception3(num_classes=num_classes, + aux_logits=False, + init_weights=False, + align_tf=align_tf) + + # Download pre-trained weights. + if dist.is_initialized() and dist.get_rank() != 0: + dist.barrier() # Download by chief. + + url, sha256 = _MODEL_URL_SHA256[model_source] + filename = f'inception_model_{model_source}_{sha256}.pth' + model_path, hash_check = download_url(url, + filename=filename, + sha256=sha256) + state_dict = torch.load(model_path, map_location='cpu') + if hash_check is False: + warnings.warn(f'Hash check failed! The remote file from URL ' + f'`{url}` may be changed, or the downloading is ' + f'interrupted. The loaded inception model may ' + f'have unexpected behavior.') + + if dist.is_initialized() and dist.get_rank() == 0: + dist.barrier() # Wait for other replicas. + + # Load weights. + model.load_state_dict(state_dict, strict=False) + del state_dict + + # For inference only. + model.eval().requires_grad_(False).cuda() + InceptionModel.models[fingerprint] = model + + return InceptionModel.models[fingerprint] + +# pylint: disable=missing-function-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=super-with-arguments +# pylint: disable=consider-merging-isinstance +# pylint: disable=import-outside-toplevel +# pylint: disable=no-else-return + +class Inception3(nn.Module): + + def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None, + init_weights=True, align_tf=True): + super(Inception3, self).__init__() + if inception_blocks is None: + inception_blocks = [ + BasicConv2d, InceptionA, InceptionB, InceptionC, + InceptionD, InceptionE, InceptionAux + ] + assert len(inception_blocks) == 7 + conv_block = inception_blocks[0] + inception_a = inception_blocks[1] + inception_b = inception_blocks[2] + inception_c = inception_blocks[3] + inception_d = inception_blocks[4] + inception_e = inception_blocks[5] + inception_aux = inception_blocks[6] + + self.aux_logits = aux_logits + self.align_tf = align_tf + self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) + self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) + self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf) + self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf) + self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf) + self.Mixed_6a = inception_b(288) + self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf) + self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf) + self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf) + self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf) + if aux_logits: + self.AuxLogits = inception_aux(768, num_classes) + self.Mixed_7a = inception_d(768) + self.Mixed_7b = inception_e(1280, align_tf=self.align_tf) + self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf) + self.fc = nn.Linear(2048, num_classes) + if init_weights: + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + import scipy.stats as stats + stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + X = stats.truncnorm(-2, 2, scale=stddev) + values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) + values = values.view(m.weight.size()) + with torch.no_grad(): + m.weight.copy_(values) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + @staticmethod + def _transform_input(x, transform_input=False): + if transform_input: + x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) + return x + + def _forward(self, + x, + output_logits=False, + remove_logits_bias=False, + output_predictions=False): + # Upsample if necessary. + if x.shape[2] != 299 or x.shape[3] != 299: + if self.align_tf: + theta = torch.eye(2, 3).to(x) + theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 299 + theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 299 + theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1) + grid = F.affine_grid(theta, + size=(x.shape[0], x.shape[1], 299, 299), + align_corners=False) + x = F.grid_sample(x, grid, + mode='bilinear', + padding_mode='border', + align_corners=False) + else: + x = F.interpolate( + x, size=(299, 299), mode='bilinear', align_corners=False) + if x.shape[1] == 1: + x = x.repeat((1, 3, 1, 1)) + + if self.align_tf: + x = (x * 127.5 + 127.5 - 128) / 128 + + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + if self.training and self.aux_logits: + aux = self.AuxLogits(x) + else: + aux = None + # N x 768 x 17 x 17 + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 2048 x 1 x 1 + x = F.dropout(x, training=self.training) + # N x 2048 x 1 x 1 + x = torch.flatten(x, 1) + # N x 2048 + if output_logits or output_predictions: + x = self.fc(x) + # N x 1000 (num_classes) + if remove_logits_bias: + x = x - self.fc.bias.view(1, -1) + if output_predictions: + x = F.softmax(x, dim=1) + return x, aux + + def forward(self, + x, + transform_input=False, + output_logits=False, + remove_logits_bias=False, + output_predictions=False): + x = self._transform_input(x, transform_input) + x, aux = self._forward( + x, output_logits, remove_logits_bias, output_predictions) + if self.training and self.aux_logits: + return x, aux + else: + return x + + +class InceptionA(nn.Module): + + def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False): + super(InceptionA, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) + + self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) + self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) + + self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) + self.pool_include_padding = not align_tf + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=self.pool_include_padding) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionB(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionB, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionC(nn.Module): + + def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False): + super(InceptionC, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) + + c7 = channels_7x7 + self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) + + self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + self.pool_include_padding = not align_tf + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=self.pool_include_padding) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionD(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionD, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) + + self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + outputs = [branch3x3, branch7x7x3, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionE(nn.Module): + + def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False): + super(InceptionE, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) + + self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) + self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) + self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) + self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + self.pool_include_padding = not align_tf + self.use_max_pool = use_max_pool + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + if self.use_max_pool: + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + else: + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=self.pool_include_padding) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes, conv_block=None): + super(InceptionAux, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.conv0 = conv_block(in_channels, 128, kernel_size=1) + self.conv1 = conv_block(128, 768, kernel_size=5) + self.conv1.stddev = 0.01 + self.fc = nn.Linear(768, num_classes) + self.fc.stddev = 0.001 + + def forward(self, x): + # N x 768 x 17 x 17 + x = F.avg_pool2d(x, kernel_size=5, stride=3) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring +# pylint: enable=missing-class-docstring +# pylint: enable=super-with-arguments +# pylint: enable=consider-merging-isinstance +# pylint: enable=import-outside-toplevel +# pylint: enable=no-else-return diff --git a/models/perceptual_model.py b/models/perceptual_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7f0aaa82789f19e9f4760d3b42e00b44e3728ffa --- /dev/null +++ b/models/perceptual_model.py @@ -0,0 +1,519 @@ +# python3.7 +"""Contains the VGG16 model, which is used for inference ONLY. + +VGG16 is commonly used for perceptual feature extraction. The model implemented +in this file can be used for evaluation (like computing LPIPS, perceptual path +length, etc.), OR be used in training for loss computation (like perceptual +loss, etc.). + +The pre-trained model is officially shared by + +https://www.robots.ox.ac.uk/~vgg/research/very_deep/ + +and ported by + +https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt + +Compared to the official VGG16 model, this ported model also support evaluating +LPIPS, which is introduced in + +https://github.com/richzhang/PerceptualSimilarity +""" + +import warnings +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from utils.misc import download_url + +__all__ = ['PerceptualModel'] + +# pylint: disable=line-too-long +_MODEL_URL_SHA256 = { + # This model is provided by `torchvision`, which is ported from TensorFlow. + 'torchvision_official': ( + 'https://download.pytorch.org/models/vgg16-397923af.pth', + '397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0' # hash sha256 + ), + + # This model is provided by https://github.com/NVlabs/stylegan2-ada-pytorch + 'vgg_perceptual_lpips': ( + 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt', + 'b437eb095feaeb0b83eb3fa11200ebca4548ee39a07fb944a417ddc516cc07c3' # hash sha256 + ) +} +# pylint: enable=line-too-long + + +class PerceptualModel(object): + """Defines the perceptual model, which is based on VGG16 structure. + + This is a static class, which is used to avoid this model to be built + repeatedly. Consequently, this model is particularly used for inference, + like computing LPIPS, or for loss computation, like perceptual loss. If + training is required, please use the model from `torchvision.models` or + implement by yourself. + + NOTE: The pre-trained model assumes the inputs to be with `RGB` channel + order and pixel range [-1, 1], and will NOT resize the input automatically + if only perceptual feature is needed. + """ + models = dict() + + @staticmethod + def build_model(use_torchvision=False, no_top=True, enable_lpips=True): + """Builds the model and load pre-trained weights. + + 1. If `use_torchvision` is set as True, the model released by + `torchvision` will be loaded, otherwise, the model released by + https://www.robots.ox.ac.uk/~vgg/research/very_deep/ will be used. + (default: False) + + 2. To save computing resources, these is an option to only load the + backbone (i.e., without the last three fully-connected layers). This + is commonly used for perceptual loss or LPIPS loss computation. + Please use argument `no_top` to control this. (default: True) + + 3. For LPIPS loss computation, some additional weights (which is used + for balancing the features from different resolutions) are employed + on top of the original VGG16 backbone. Details can be found at + https://github.com/richzhang/PerceptualSimilarity. Please use + `enable_lpips` to enable this feature. (default: True) + + The built model supports following arguments when forwarding: + + - resize_input: Whether to resize the input image to size [224, 224] + before forwarding. For feature-based computation (i.e., only + convolutional layers are used), image resizing is not essential. + (default: False) + - return_tensor: This field resolves the model behavior. Following + options are supported: + `feature1`: Before the first max pooling layer. + `pool1`: After the first max pooling layer. + `feature2`: Before the second max pooling layer. + `pool2`: After the second max pooling layer. + `feature3`: Before the third max pooling layer. + `pool3`: After the third max pooling layer. + `feature4`: Before the fourth max pooling layer. + `pool4`: After the fourth max pooling layer. + `feature5`: Before the fifth max pooling layer. + `pool5`: After the fifth max pooling layer. + `flatten`: The flattened feature, after `adaptive_avgpool`. + `feature`: The 4096d feature for logits computation. (default) + `logits`: The 1000d categorical logits. + `prediction`: The 1000d predicted probability. + `lpips`: The LPIPS score between two input images. + """ + if use_torchvision: + model_source = 'torchvision_official' + align_tf_resize = False + is_torch_script = False + else: + model_source = 'vgg_perceptual_lpips' + align_tf_resize = True + is_torch_script = True + + if enable_lpips and model_source != 'vgg_perceptual_lpips': + warnings.warn('The pre-trained model officially released by ' + '`torchvision` does not support LPIPS computation! ' + 'Equal weights will be used for each resolution.') + + fingerprint = (model_source, no_top, enable_lpips) + + if fingerprint not in PerceptualModel.models: + # Build model. + model = VGG16(align_tf_resize=align_tf_resize, + no_top=no_top, + enable_lpips=enable_lpips) + + # Download pre-trained weights. + if dist.is_initialized() and dist.get_rank() != 0: + dist.barrier() # Download by chief. + + url, sha256 = _MODEL_URL_SHA256[model_source] + filename = f'perceptual_model_{model_source}_{sha256}.pth' + model_path, hash_check = download_url(url, + filename=filename, + sha256=sha256) + if is_torch_script: + src_state_dict = torch.jit.load(model_path, map_location='cpu') + else: + src_state_dict = torch.load(model_path, map_location='cpu') + if hash_check is False: + warnings.warn(f'Hash check failed! The remote file from URL ' + f'`{url}` may be changed, or the downloading is ' + f'interrupted. The loaded perceptual model may ' + f'have unexpected behavior.') + + if dist.is_initialized() and dist.get_rank() == 0: + dist.barrier() # Wait for other replicas. + + # Load weights. + dst_state_dict = _convert_weights(src_state_dict, model_source) + model.load_state_dict(dst_state_dict, strict=False) + del src_state_dict, dst_state_dict + + # For inference only. + model.eval().requires_grad_(False).cuda() + PerceptualModel.models[fingerprint] = model + + return PerceptualModel.models[fingerprint] + + +def _convert_weights(src_state_dict, model_source): + if model_source not in _MODEL_URL_SHA256: + raise ValueError(f'Invalid model source `{model_source}`!\n' + f'Sources allowed: {list(_MODEL_URL_SHA256.keys())}.') + if model_source == 'torchvision_official': + dst_to_src_var_mapping = { + 'conv11.weight': 'features.0.weight', + 'conv11.bias': 'features.0.bias', + 'conv12.weight': 'features.2.weight', + 'conv12.bias': 'features.2.bias', + 'conv21.weight': 'features.5.weight', + 'conv21.bias': 'features.5.bias', + 'conv22.weight': 'features.7.weight', + 'conv22.bias': 'features.7.bias', + 'conv31.weight': 'features.10.weight', + 'conv31.bias': 'features.10.bias', + 'conv32.weight': 'features.12.weight', + 'conv32.bias': 'features.12.bias', + 'conv33.weight': 'features.14.weight', + 'conv33.bias': 'features.14.bias', + 'conv41.weight': 'features.17.weight', + 'conv41.bias': 'features.17.bias', + 'conv42.weight': 'features.19.weight', + 'conv42.bias': 'features.19.bias', + 'conv43.weight': 'features.21.weight', + 'conv43.bias': 'features.21.bias', + 'conv51.weight': 'features.24.weight', + 'conv51.bias': 'features.24.bias', + 'conv52.weight': 'features.26.weight', + 'conv52.bias': 'features.26.bias', + 'conv53.weight': 'features.28.weight', + 'conv53.bias': 'features.28.bias', + 'fc1.weight': 'classifier.0.weight', + 'fc1.bias': 'classifier.0.bias', + 'fc2.weight': 'classifier.3.weight', + 'fc2.bias': 'classifier.3.bias', + 'fc3.weight': 'classifier.6.weight', + 'fc3.bias': 'classifier.6.bias', + } + elif model_source == 'vgg_perceptual_lpips': + src_state_dict = src_state_dict.state_dict() + dst_to_src_var_mapping = { + 'conv11.weight': 'layers.conv1.weight', + 'conv11.bias': 'layers.conv1.bias', + 'conv12.weight': 'layers.conv2.weight', + 'conv12.bias': 'layers.conv2.bias', + 'conv21.weight': 'layers.conv3.weight', + 'conv21.bias': 'layers.conv3.bias', + 'conv22.weight': 'layers.conv4.weight', + 'conv22.bias': 'layers.conv4.bias', + 'conv31.weight': 'layers.conv5.weight', + 'conv31.bias': 'layers.conv5.bias', + 'conv32.weight': 'layers.conv6.weight', + 'conv32.bias': 'layers.conv6.bias', + 'conv33.weight': 'layers.conv7.weight', + 'conv33.bias': 'layers.conv7.bias', + 'conv41.weight': 'layers.conv8.weight', + 'conv41.bias': 'layers.conv8.bias', + 'conv42.weight': 'layers.conv9.weight', + 'conv42.bias': 'layers.conv9.bias', + 'conv43.weight': 'layers.conv10.weight', + 'conv43.bias': 'layers.conv10.bias', + 'conv51.weight': 'layers.conv11.weight', + 'conv51.bias': 'layers.conv11.bias', + 'conv52.weight': 'layers.conv12.weight', + 'conv52.bias': 'layers.conv12.bias', + 'conv53.weight': 'layers.conv13.weight', + 'conv53.bias': 'layers.conv13.bias', + 'fc1.weight': 'layers.fc1.weight', + 'fc1.bias': 'layers.fc1.bias', + 'fc2.weight': 'layers.fc2.weight', + 'fc2.bias': 'layers.fc2.bias', + 'fc3.weight': 'layers.fc3.weight', + 'fc3.bias': 'layers.fc3.bias', + 'lpips.0.weight': 'lpips0', + 'lpips.1.weight': 'lpips1', + 'lpips.2.weight': 'lpips2', + 'lpips.3.weight': 'lpips3', + 'lpips.4.weight': 'lpips4', + } + else: + raise NotImplementedError(f'Not implemented model source ' + f'`{model_source}`!') + + dst_state_dict = {} + for dst_name, src_name in dst_to_src_var_mapping.items(): + if dst_name.startswith('lpips'): + dst_state_dict[dst_name] = src_state_dict[src_name].unsqueeze(0) + else: + dst_state_dict[dst_name] = src_state_dict[src_name].clone() + return dst_state_dict + + +_IMG_MEAN = (0.485, 0.456, 0.406) +_IMG_STD = (0.229, 0.224, 0.225) +_ALLOWED_RETURN = [ + 'feature1', 'pool1', 'feature2', 'pool2', 'feature3', 'pool3', 'feature4', + 'pool4', 'feature5', 'pool5', 'flatten', 'feature', 'logits', 'prediction', + 'lpips' +] + +# pylint: disable=missing-function-docstring + +class VGG16(nn.Module): + """Defines the VGG16 structure. + + This model takes `RGB` images with data format `NCHW` as the raw inputs. The + pixel range are assumed to be [-1, 1]. + """ + + def __init__(self, align_tf_resize=False, no_top=True, enable_lpips=True): + """Defines the network structure.""" + super().__init__() + + self.align_tf_resize = align_tf_resize + self.no_top = no_top + self.enable_lpips = enable_lpips + + self.conv11 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.relu11 = nn.ReLU(inplace=True) + self.conv12 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.relu12 = nn.ReLU(inplace=True) + # output `feature1`, with shape [N, 64, 224, 224] + + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + # output `pool1`, with shape [N, 64, 112, 112] + + self.conv21 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.relu21 = nn.ReLU(inplace=True) + self.conv22 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + self.relu22 = nn.ReLU(inplace=True) + # output `feature2`, with shape [N, 128, 112, 112] + + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + # output `pool2`, with shape [N, 128, 56, 56] + + self.conv31 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.relu31 = nn.ReLU(inplace=True) + self.conv32 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.relu32 = nn.ReLU(inplace=True) + self.conv33 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.relu33 = nn.ReLU(inplace=True) + # output `feature3`, with shape [N, 256, 56, 56] + + self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) + # output `pool3`, with shape [N,256, 28, 28] + + self.conv41 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.relu41 = nn.ReLU(inplace=True) + self.conv42 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.relu42 = nn.ReLU(inplace=True) + self.conv43 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.relu43 = nn.ReLU(inplace=True) + # output `feature4`, with shape [N, 512, 28, 28] + + self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) + # output `pool4`, with shape [N, 512, 14, 14] + + self.conv51 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.relu51 = nn.ReLU(inplace=True) + self.conv52 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.relu52 = nn.ReLU(inplace=True) + self.conv53 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.relu53 = nn.ReLU(inplace=True) + # output `feature5`, with shape [N, 512, 14, 14] + + self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) + # output `pool5`, with shape [N, 512, 7, 7] + + if self.enable_lpips: + self.lpips = nn.ModuleList() + for idx, ch in enumerate([64, 128, 256, 512, 512]): + self.lpips.append(nn.Conv2d(ch, 1, kernel_size=1, bias=False)) + self.lpips[idx].weight.data.copy_(torch.ones(1, ch, 1, 1)) + + if not self.no_top: + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.flatten = nn.Flatten(start_dim=1, end_dim=-1) + # output `flatten`, with shape [N, 25088] + + self.fc1 = nn.Linear(512 * 7 * 7, 4096) + self.fc1_relu = nn.ReLU(inplace=True) + self.fc1_dropout = nn.Dropout(0.5, inplace=False) + self.fc2 = nn.Linear(4096, 4096) + self.fc2_relu = nn.ReLU(inplace=True) + self.fc2_dropout = nn.Dropout(0.5, inplace=False) + # output `feature`, with shape [N, 4096] + + self.fc3 = nn.Linear(4096, 1000) + # output `logits`, with shape [N, 1000] + + self.out = nn.Softmax(dim=1) + # output `softmax`, with shape [N, 1000] + + img_mean = np.array(_IMG_MEAN).reshape((1, 3, 1, 1)).astype(np.float32) + img_std = np.array(_IMG_STD).reshape((1, 3, 1, 1)).astype(np.float32) + self.register_buffer('img_mean', torch.from_numpy(img_mean)) + self.register_buffer('img_std', torch.from_numpy(img_std)) + + def forward(self, + x, + y=None, + *, + resize_input=False, + return_tensor='feature'): + return_tensor = return_tensor.lower() + if return_tensor not in _ALLOWED_RETURN: + raise ValueError(f'Invalid output tensor name `{return_tensor}` ' + f'for perceptual model (VGG16)!\n' + f'Names allowed: {_ALLOWED_RETURN}.') + + if return_tensor == 'lpips' and y is None: + raise ValueError('Two images are required for LPIPS computation, ' + 'but only one is received!') + + if return_tensor == 'lpips': + assert x.shape == y.shape + x = torch.cat([x, y], dim=0) + features = [] + + if resize_input: + if self.align_tf_resize: + theta = torch.eye(2, 3).to(x) + theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 224 + theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 224 + theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1) + grid = F.affine_grid(theta, + size=(x.shape[0], x.shape[1], 224, 224), + align_corners=False) + x = F.grid_sample(x, grid, + mode='bilinear', + padding_mode='border', + align_corners=False) + else: + x = F.interpolate(x, + size=(224, 224), + mode='bilinear', + align_corners=False) + if x.shape[1] == 1: + x = x.repeat((1, 3, 1, 1)) + + x = (x + 1) / 2 + x = (x - self.img_mean) / self.img_std + + x = self.conv11(x) + x = self.relu11(x) + x = self.conv12(x) + x = self.relu12(x) + if return_tensor == 'feature1': + return x + if return_tensor == 'lpips': + features.append(x) + + x = self.pool1(x) + if return_tensor == 'pool1': + return x + + x = self.conv21(x) + x = self.relu21(x) + x = self.conv22(x) + x = self.relu22(x) + if return_tensor == 'feature2': + return x + if return_tensor == 'lpips': + features.append(x) + + x = self.pool2(x) + if return_tensor == 'pool2': + return x + + x = self.conv31(x) + x = self.relu31(x) + x = self.conv32(x) + x = self.relu32(x) + x = self.conv33(x) + x = self.relu33(x) + if return_tensor == 'feature3': + return x + if return_tensor == 'lpips': + features.append(x) + + x = self.pool3(x) + if return_tensor == 'pool3': + return x + + x = self.conv41(x) + x = self.relu41(x) + x = self.conv42(x) + x = self.relu42(x) + x = self.conv43(x) + x = self.relu43(x) + if return_tensor == 'feature4': + return x + if return_tensor == 'lpips': + features.append(x) + + x = self.pool4(x) + if return_tensor == 'pool4': + return x + + x = self.conv51(x) + x = self.relu51(x) + x = self.conv52(x) + x = self.relu52(x) + x = self.conv53(x) + x = self.relu53(x) + if return_tensor == 'feature5': + return x + if return_tensor == 'lpips': + features.append(x) + + x = self.pool5(x) + if return_tensor == 'pool5': + return x + + if return_tensor == 'lpips': + score = 0 + assert len(features) == 5 + for idx in range(5): + feature = features[idx] + norm = feature.norm(dim=1, keepdim=True) + feature = feature / (norm + 1e-10) + feature_x, feature_y = feature.chunk(2, dim=0) + diff = (feature_x - feature_y).square() + score += self.lpips[idx](diff).mean(dim=(2, 3), keepdim=False) + return score.sum(dim=1, keepdim=False) + + x = self.avgpool(x) + x = self.flatten(x) + if return_tensor == 'flatten': + return x + + x = self.fc1(x) + x = self.fc1_relu(x) + x = self.fc1_dropout(x) + x = self.fc2(x) + x = self.fc2_relu(x) + x = self.fc2_dropout(x) + if return_tensor == 'feature': + return x + + x = self.fc3(x) + if return_tensor == 'logits': + return x + + x = self.out(x) + if return_tensor == 'prediction': + return x + + raise NotImplementedError(f'Output tensor name `{return_tensor}` is ' + f'not implemented!') + +# pylint: enable=missing-function-docstring diff --git a/models/pggan_discriminator.py b/models/pggan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..30b0868dd6a753ba7f2712c10b4f19708b67eee3 --- /dev/null +++ b/models/pggan_discriminator.py @@ -0,0 +1,465 @@ +# python3.7 +"""Contains the implementation of discriminator described in PGGAN. + +Paper: https://arxiv.org/pdf/1710.10196.pdf + +Official TensorFlow implementation: +https://github.com/tkarras/progressive_growing_of_gans +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['PGGANDiscriminator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Default gain factor for weight scaling. +_WSCALE_GAIN = np.sqrt(2.0) + +# pylint: disable=missing-function-docstring + +class PGGANDiscriminator(nn.Module): + """Defines the discriminator network in PGGAN. + + NOTE: The discriminator takes images with `RGB` channel order and pixel + range [-1, 1] as inputs. + + Settings for the network: + + (1) resolution: The resolution of the input image. + (2) init_res: Smallest resolution of the convolutional backbone. + (default: 4) + (3) image_channels: Number of channels of the input image. (default: 3) + (4) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (5) fused_scale: Whether to fused `conv2d` and `downsample` together, + resulting in `conv2d` with strides. (default: False) + (6) use_wscale: Whether to use weight scaling. (default: True) + (7) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0)) + (8) mbstd_groups: Group size for the minibatch standard deviation layer. + `0` means disable. (default: 16) + (9) fmaps_base: Factor to control number of feature maps for each layer. + (default: 16 << 10) + (10) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (11) eps: A small value to avoid divide overflow. (default: 1e-8) + """ + + def __init__(self, + resolution, + init_res=4, + image_channels=3, + label_dim=0, + fused_scale=False, + use_wscale=True, + wscale_gain=np.sqrt(2.0), + mbstd_groups=16, + fmaps_base=16 << 10, + fmaps_max=512, + eps=1e-8): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + + self.init_res = init_res + self.init_res_log2 = int(np.log2(self.init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(self.resolution)) + self.image_channels = image_channels + self.label_dim = label_dim + self.fused_scale = fused_scale + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.mbstd_groups = mbstd_groups + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.eps = eps + + # Level-of-details (used for progressive training). + self.register_buffer('lod', torch.zeros(())) + self.pth_to_tf_var_mapping = {'lod': 'lod'} + + for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): + res = 2 ** res_log2 + in_channels = self.get_nf(res) + out_channels = self.get_nf(res // 2) + block_idx = self.final_res_log2 - res_log2 + + # Input convolution layer for each resolution. + self.add_module( + f'input{block_idx}', + ConvLayer(in_channels=self.image_channels, + out_channels=in_channels, + kernel_size=1, + add_bias=True, + downsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = ( + f'FromRGB_lod{block_idx}/weight') + self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = ( + f'FromRGB_lod{block_idx}/bias') + + # Convolution block for each resolution (except the last one). + if res != self.init_res: + self.add_module( + f'layer{2 * block_idx}', + ConvLayer(in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + downsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu')) + tf_layer0_name = 'Conv0' + self.add_module( + f'layer{2 * block_idx + 1}', + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + add_bias=True, + downsample=True, + fused_scale=fused_scale, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu')) + tf_layer1_name = 'Conv1_down' if fused_scale else 'Conv1' + + # Convolution block for last resolution. + else: + self.mbstd = MiniBatchSTDLayer(groups=mbstd_groups, eps=eps) + self.add_module( + f'layer{2 * block_idx}', + ConvLayer( + in_channels=in_channels + 1, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + downsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu')) + tf_layer0_name = 'Conv' + self.add_module( + f'layer{2 * block_idx + 1}', + DenseLayer(in_channels=in_channels * res * res, + out_channels=out_channels, + add_bias=True, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu')) + tf_layer1_name = 'Dense0' + + self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = ( + f'{res}x{res}/{tf_layer0_name}/weight') + self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = ( + f'{res}x{res}/{tf_layer0_name}/bias') + self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = ( + f'{res}x{res}/{tf_layer1_name}/weight') + self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = ( + f'{res}x{res}/{tf_layer1_name}/bias') + + # Final dense layer. + self.output = DenseLayer(in_channels=out_channels, + out_channels=1 + self.label_dim, + add_bias=True, + use_wscale=self.use_wscale, + wscale_gain=1.0, + activation_type='linear') + self.pth_to_tf_var_mapping['output.weight'] = ( + f'{res}x{res}/Dense1/weight') + self.pth_to_tf_var_mapping['output.bias'] = ( + f'{res}x{res}/Dense1/bias') + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def forward(self, image, lod=None): + expected_shape = (self.image_channels, self.resolution, self.resolution) + if image.ndim != 4 or image.shape[1:] != expected_shape: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, channel, height, width], where ' + f'`channel` equals to {self.image_channels}, ' + f'`height`, `width` equal to {self.resolution}!\n' + f'But `{image.shape}` is received!') + + lod = self.lod.item() if lod is None else lod + if lod + self.init_res_log2 > self.final_res_log2: + raise ValueError(f'Maximum level-of-details (lod) is ' + f'{self.final_res_log2 - self.init_res_log2}, ' + f'but `{lod}` is received!') + + lod = self.lod.item() + for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): + block_idx = current_lod = self.final_res_log2 - res_log2 + if current_lod <= lod < current_lod + 1: + x = getattr(self, f'input{block_idx}')(image) + elif current_lod - 1 < lod < current_lod: + alpha = lod - np.floor(lod) + y = getattr(self, f'input{block_idx}')(image) + x = y * alpha + x * (1 - alpha) + if lod < current_lod + 1: + if res_log2 == self.init_res_log2: + x = self.mbstd(x) + x = getattr(self, f'layer{2 * block_idx}')(x) + x = getattr(self, f'layer{2 * block_idx + 1}')(x) + if lod > current_lod: + image = F.avg_pool2d( + image, kernel_size=2, stride=2, padding=0) + x = self.output(x) + + return {'score': x} + + +class MiniBatchSTDLayer(nn.Module): + """Implements the minibatch standard deviation layer.""" + + def __init__(self, groups, eps): + super().__init__() + self.groups = groups + self.eps = eps + + def extra_repr(self): + return f'groups={self.groups}, epsilon={self.eps}' + + def forward(self, x): + if self.groups <= 1: + return x + + N, C, H, W = x.shape + G = min(self.groups, N) # Number of groups. + + y = x.reshape(G, -1, C, H, W) # [GnCHW] + y = y - y.mean(dim=0) # [GnCHW] + y = y.square().mean(dim=0) # [nCHW] + y = (y + self.eps).sqrt() # [nCHW] + y = y.mean(dim=(1, 2, 3), keepdim=True) # [n111] + y = y.repeat(G, 1, H, W) # [N1HW] + x = torch.cat([x, y], dim=1) # [N(C+1)HW] + + return x + + +class DownsamplingLayer(nn.Module): + """Implements the downsampling layer. + + Basically, this layer can be used to downsample feature maps with average + pooling. + """ + + def __init__(self, scale_factor): + super().__init__() + self.scale_factor = scale_factor + + def extra_repr(self): + return f'factor={self.scale_factor}' + + def forward(self, x): + if self.scale_factor <= 1: + return x + return F.avg_pool2d(x, + kernel_size=self.scale_factor, + stride=self.scale_factor, + padding=0) + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + Basically, this layer executes convolution, activation, and downsampling (if + needed) in sequence. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + add_bias, + downsample, + fused_scale, + use_wscale, + wscale_gain, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + downsample: Whether to downsample the result after convolution. + fused_scale: Whether to fused `conv2d` and `downsample` together, + resulting in `conv2d` with strides. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.add_bias = add_bias + self.downsample = downsample + self.fused_scale = fused_scale + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.activation_type = activation_type + + if downsample and not fused_scale: + self.down = DownsamplingLayer(scale_factor=2) + else: + self.down = nn.Identity() + + if downsample and fused_scale: + self.use_stride = True + self.stride = 2 + self.padding = 1 + else: + self.use_stride = False + self.stride = 1 + self.padding = kernel_size // 2 + + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape)) + self.wscale = wscale + else: + self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale) + self.wscale = 1.0 + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'downsample={self.scale_factor}, ' + f'fused_scale={self.fused_scale}, ' + f'act={self.activation_type}') + + def forward(self, x): + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + + if self.use_stride: + weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) + weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + + weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25 + x = F.conv2d(x, + weight=weight, + bias=self.bias, + stride=self.stride, + padding=self.padding) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + x = self.down(x) + + return x + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + use_wscale, + wscale_gain, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + activation_type: Type of activation. + + Raises: + NotImplementedError: If the `activation_type` is not supported. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape)) + self.wscale = wscale + else: + self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale) + self.wscale = 1.0 + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def forward(self, x): + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + + x = F.linear(x, weight=weight, bias=self.bias) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + return x + +# pylint: enable=missing-function-docstring diff --git a/models/pggan_generator.py b/models/pggan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..771c5ee0e66304fcd21432a8d873e27b92ca1db5 --- /dev/null +++ b/models/pggan_generator.py @@ -0,0 +1,401 @@ +# python3.7 +"""Contains the implementation of generator described in PGGAN. + +Paper: https://arxiv.org/pdf/1710.10196.pdf + +Official TensorFlow implementation: +https://github.com/tkarras/progressive_growing_of_gans +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['PGGANGenerator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# pylint: disable=missing-function-docstring + +class PGGANGenerator(nn.Module): + """Defines the generator network in PGGAN. + + NOTE: The synthesized images are with `RGB` channel order and pixel range + [-1, 1]. + + Settings for the network: + + (1) resolution: The resolution of the output image. + (2) init_res: The initial resolution to start with convolution. (default: 4) + (3) z_dim: Dimension of the input latent space, Z. (default: 512) + (4) image_channels: Number of channels of the output image. (default: 3) + (5) final_tanh: Whether to use `tanh` to control the final pixel range. + (default: False) + (6) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (7) fused_scale: Whether to fused `upsample` and `conv2d` together, + resulting in `conv2d_transpose`. (default: False) + (8) use_wscale: Whether to use weight scaling. (default: True) + (9) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0)) + (10) fmaps_base: Factor to control number of feature maps for each layer. + (default: 16 << 10) + (11) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (12) eps: A small value to avoid divide overflow. (default: 1e-8) + """ + + def __init__(self, + resolution, + init_res=4, + z_dim=512, + image_channels=3, + final_tanh=False, + label_dim=0, + fused_scale=False, + use_wscale=True, + wscale_gain=np.sqrt(2.0), + fmaps_base=16 << 10, + fmaps_max=512, + eps=1e-8): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + + self.init_res = init_res + self.init_res_log2 = int(np.log2(self.init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(self.resolution)) + self.z_dim = z_dim + self.image_channels = image_channels + self.final_tanh = final_tanh + self.label_dim = label_dim + self.fused_scale = fused_scale + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.eps = eps + + # Dimension of latent space, which is convenient for sampling. + self.latent_dim = (self.z_dim,) + + # Number of convolutional layers. + self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 + + # Level-of-details (used for progressive training). + self.register_buffer('lod', torch.zeros(())) + self.pth_to_tf_var_mapping = {'lod': 'lod'} + + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + res = 2 ** res_log2 + in_channels = self.get_nf(res // 2) + out_channels = self.get_nf(res) + block_idx = res_log2 - self.init_res_log2 + + # First convolution layer for each resolution. + if res == self.init_res: + self.add_module( + f'layer{2 * block_idx}', + ConvLayer(in_channels=z_dim + label_dim, + out_channels=out_channels, + kernel_size=init_res, + padding=init_res - 1, + add_bias=True, + upsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu', + eps=eps)) + tf_layer_name = 'Dense' + else: + self.add_module( + f'layer{2 * block_idx}', + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + add_bias=True, + upsample=True, + fused_scale=fused_scale, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu', + eps=eps)) + tf_layer_name = 'Conv0_up' if fused_scale else 'Conv0' + self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = ( + f'{res}x{res}/{tf_layer_name}/weight') + self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = ( + f'{res}x{res}/{tf_layer_name}/bias') + + # Second convolution layer for each resolution. + self.add_module( + f'layer{2 * block_idx + 1}', + ConvLayer(in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + add_bias=True, + upsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + activation_type='lrelu', + eps=eps)) + tf_layer_name = 'Conv' if res == self.init_res else 'Conv1' + self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = ( + f'{res}x{res}/{tf_layer_name}/weight') + self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = ( + f'{res}x{res}/{tf_layer_name}/bias') + + # Output convolution layer for each resolution. + self.add_module( + f'output{block_idx}', + ConvLayer(in_channels=out_channels, + out_channels=image_channels, + kernel_size=1, + padding=0, + add_bias=True, + upsample=False, + fused_scale=False, + use_wscale=use_wscale, + wscale_gain=1.0, + activation_type='linear', + eps=eps)) + self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = ( + f'ToRGB_lod{self.final_res_log2 - res_log2}/weight') + self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = ( + f'ToRGB_lod{self.final_res_log2 - res_log2}/bias') + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def forward(self, z, label=None, lod=None): + if z.ndim != 2 or z.shape[1] != self.z_dim: + raise ValueError(f'Input latent code should be with shape ' + f'[batch_size, latent_dim], where ' + f'`latent_dim` equals to {self.z_dim}!\n' + f'But `{z.shape}` is received!') + z = self.layer0.pixel_norm(z) + if self.label_dim: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with size {self.label_dim}) as input, ' + f'but no label is received!') + if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'latent codes ({z.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + z = torch.cat((z, label), dim=1) + + lod = self.lod.item() if lod is None else lod + if lod + self.init_res_log2 > self.final_res_log2: + raise ValueError(f'Maximum level-of-details (lod) is ' + f'{self.final_res_log2 - self.init_res_log2}, ' + f'but `{lod}` is received!') + + x = z.view(z.shape[0], self.z_dim + self.label_dim, 1, 1) + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + current_lod = self.final_res_log2 - res_log2 + block_idx = res_log2 - self.init_res_log2 + if lod < current_lod + 1: + x = getattr(self, f'layer{2 * block_idx}')(x) + x = getattr(self, f'layer{2 * block_idx + 1}')(x) + if current_lod - 1 < lod <= current_lod: + image = getattr(self, f'output{block_idx}')(x) + elif current_lod < lod < current_lod + 1: + alpha = np.ceil(lod) - lod + temp = getattr(self, f'output{block_idx}')(x) + image = F.interpolate(image, scale_factor=2, mode='nearest') + image = temp * alpha + image * (1 - alpha) + elif lod >= current_lod + 1: + image = F.interpolate(image, scale_factor=2, mode='nearest') + if self.final_tanh: + image = torch.tanh(image) + + results = { + 'z': z, + 'label': label, + 'image': image, + } + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class UpsamplingLayer(nn.Module): + """Implements the upsampling layer. + + Basically, this layer can be used to upsample feature maps with nearest + neighbor interpolation. + """ + + def __init__(self, scale_factor): + super().__init__() + self.scale_factor = scale_factor + + def extra_repr(self): + return f'factor={self.scale_factor}' + + def forward(self, x): + if self.scale_factor <= 1: + return x + return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + Basically, this layer executes pixel-wise normalization, upsampling (if + needed), convolution, and activation in sequence. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + padding, + add_bias, + upsample, + fused_scale, + use_wscale, + wscale_gain, + activation_type, + eps): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + padding: Padding used in convolution. + add_bias: Whether to add bias onto the convolutional result. + upsample: Whether to upsample the input tensor before convolution. + fused_scale: Whether to fused `upsample` and `conv2d` together, + resulting in `conv2d_transpose`. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + activation_type: Type of activation. + eps: A small value to avoid divide overflow. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + self.add_bias = add_bias + self.upsample = upsample + self.fused_scale = fused_scale + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.activation_type = activation_type + self.eps = eps + + self.pixel_norm = PixelNormLayer(dim=1, eps=eps) + + if upsample and not fused_scale: + self.up = UpsamplingLayer(scale_factor=2) + else: + self.up = nn.Identity() + + if upsample and fused_scale: + self.use_conv2d_transpose = True + weight_shape = (in_channels, out_channels, kernel_size, kernel_size) + self.stride = 2 + self.padding = 1 + else: + self.use_conv2d_transpose = False + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + self.stride = 1 + + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape)) + self.wscale = wscale + else: + self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale) + self.wscale = 1.0 + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'padding={self.padding}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'upsample={self.scale_factor}, ' + f'fused_scale={self.fused_scale}, ' + f'act={self.activation_type}') + + def forward(self, x): + x = self.pixel_norm(x) + x = self.up(x) + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + if self.use_conv2d_transpose: + weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) + weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + + weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) + x = F.conv_transpose2d(x, + weight=weight, + bias=self.bias, + stride=self.stride, + padding=self.padding) + else: + x = F.conv2d(x, + weight=weight, + bias=self.bias, + stride=self.stride, + padding=self.padding) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + return x + +# pylint: enable=missing-function-docstring diff --git a/models/stylegan2_discriminator.py b/models/stylegan2_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..1802d44b68d0801290dcd691f09f605c3cfc9cbd --- /dev/null +++ b/models/stylegan2_discriminator.py @@ -0,0 +1,729 @@ +# python3.7 +"""Contains the implementation of discriminator described in StyleGAN2. + +Compared to that of StyleGAN, the discriminator in StyleGAN2 mainly adds skip +connections, increases model size and disables progressive growth. This script +ONLY supports config F in the original paper. + +Paper: https://arxiv.org/pdf/1912.04958.pdf + +Official TensorFlow implementation: https://github.com/NVlabs/stylegan2 +""" + +import numpy as np + +import torch +import torch.nn as nn + +from third_party.stylegan2_official_ops import bias_act +from third_party.stylegan2_official_ops import upfirdn2d +from third_party.stylegan2_official_ops import conv2d_gradfix + +__all__ = ['StyleGAN2Discriminator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Architectures allowed. +_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin'] + +# pylint: disable=missing-function-docstring + +class StyleGAN2Discriminator(nn.Module): + """Defines the discriminator network in StyleGAN2. + + NOTE: The discriminator takes images with `RGB` channel order and pixel + range [-1, 1] as inputs. + + Settings for the backbone: + + (1) resolution: The resolution of the input image. (default: -1) + (2) init_res: Smallest resolution of the convolutional backbone. + (default: 4) + (3) image_channels: Number of channels of the input image. (default: 3) + (4) architecture: Type of architecture. Support `origin`, `skip`, and + `resnet`. (default: `resnet`) + (5) use_wscale: Whether to use weight scaling. (default: True) + (6) wscale_gain: The factor to control weight scaling. (default: 1.0) + (7) lr_mul: Learning rate multiplier for backbone. (default: 1.0) + (8) mbstd_groups: Group size for the minibatch standard deviation layer. + `0` means disable. (default: 4) + (9) mbstd_channels: Number of new channels (appended to the original feature + map) after the minibatch standard deviation layer. (default: 1) + (10) fmaps_base: Factor to control number of feature maps for each layer. + (default: 32 << 10) + (11) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (12) filter_kernel: Kernel used for filtering (e.g., downsampling). + (default: (1, 3, 3, 1)) + (13) conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. (default: None) + (14) eps: A small value to avoid divide overflow. (default: 1e-8) + + Settings for conditional model: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (2) embedding_dim: Dimension of the embedding space, if needed. + (default: 512) + (3) embedding_bias: Whether to add bias to embedding learning. + (default: True) + (4) embedding_use_wscale: Whether to use weight scaling for embedding + learning. (default: True) + (5) embedding_lr_mul: Learning rate multiplier for the embedding learning. + (default: 1.0) + (6) normalize_embedding: Whether to normalize the embedding. (default: True) + (7) mapping_layers: Number of layers of the additional mapping network after + embedding. (default: 0) + (8) mapping_fmaps: Number of hidden channels of the additional mapping + network after embedding. (default: 512) + (9) mapping_use_wscale: Whether to use weight scaling for the additional + mapping network. (default: True) + (10) mapping_lr_mul: Learning rate multiplier for the additional mapping + network after embedding. (default: 0.1) + + Runtime settings: + + (1) fp16_res: Layers at resolution higher than (or equal to) this field will + use `float16` precision for computation. This is merely used for + acceleration. If set as `None`, all layers will use `float32` by + default. (default: None) + (2) impl: Implementation mode of some particular ops, e.g., `filtering`, + `bias_act`, etc. `cuda` means using the official CUDA implementation + from StyleGAN2, while `ref` means using the native PyTorch ops. + (default: `cuda`) + """ + + def __init__(self, + # Settings for backbone. + resolution=-1, + init_res=4, + image_channels=3, + architecture='resnet', + use_wscale=True, + wscale_gain=1.0, + lr_mul=1.0, + mbstd_groups=4, + mbstd_channels=1, + fmaps_base=32 << 10, + fmaps_max=512, + filter_kernel=(1, 3, 3, 1), + conv_clamp=None, + eps=1e-8, + # Settings for conditional model. + label_dim=0, + embedding_dim=512, + embedding_bias=True, + embedding_use_wscale=True, + embedding_lr_mul=1.0, + normalize_embedding=True, + mapping_layers=0, + mapping_fmaps=512, + mapping_use_wscale=True, + mapping_lr_mul=0.1): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported, or `architecture` + is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + architecture = architecture.lower() + if architecture not in _ARCHITECTURES_ALLOWED: + raise ValueError(f'Invalid architecture: `{architecture}`!\n' + f'Architectures allowed: ' + f'{_ARCHITECTURES_ALLOWED}.') + + self.init_res = init_res + self.init_res_log2 = int(np.log2(init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(resolution)) + self.image_channels = image_channels + self.architecture = architecture + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.mbstd_groups = mbstd_groups + self.mbstd_channels = mbstd_channels + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.conv_clamp = conv_clamp + self.eps = eps + + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_use_wscale = embedding_use_wscale + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.mapping_layers = mapping_layers + self.mapping_fmaps = mapping_fmaps + self.mapping_use_wscale = mapping_use_wscale + self.mapping_lr_mul = mapping_lr_mul + + self.pth_to_tf_var_mapping = {} + + # Embedding for conditional discrimination. + self.use_embedding = label_dim > 0 and embedding_dim > 0 + if self.use_embedding: + self.embedding = DenseLayer(in_channels=label_dim, + out_channels=embedding_dim, + add_bias=embedding_bias, + init_bias=0.0, + use_wscale=embedding_use_wscale, + wscale_gain=wscale_gain, + lr_mul=embedding_lr_mul, + activation_type='linear') + self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight' + if self.embedding_bias: + self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias' + + if self.normalize_embedding: + self.norm = PixelNormLayer(dim=1, eps=eps) + + for i in range(mapping_layers): + in_channels = (embedding_dim if i == 0 else mapping_fmaps) + out_channels = (embedding_dim if i == (mapping_layers - 1) else + mapping_fmaps) + layer_name = f'mapping{i}' + self.add_module(layer_name, + DenseLayer(in_channels=in_channels, + out_channels=out_channels, + add_bias=True, + init_bias=0.0, + use_wscale=mapping_use_wscale, + wscale_gain=wscale_gain, + lr_mul=mapping_lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'Mapping{i}/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'Mapping{i}/bias') + + # Convolutional backbone. + for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): + res = 2 ** res_log2 + in_channels = self.get_nf(res) + out_channels = self.get_nf(res // 2) + block_idx = self.final_res_log2 - res_log2 + + # Input convolution layer for each resolution (if needed). + if res_log2 == self.final_res_log2 or self.architecture == 'skip': + layer_name = f'input{block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=image_channels, + out_channels=in_channels, + kernel_size=1, + add_bias=True, + scale_factor=1, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/FromRGB/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/FromRGB/bias') + + # Convolution block for each resolution (except the last one). + if res != self.init_res: + # First layer (kernel 3x3) without downsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + scale_factor=1, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv0/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv0/bias') + + # Second layer (kernel 3x3) with downsampling + layer_name = f'layer{2 * block_idx + 1}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + add_bias=True, + scale_factor=2, + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv1_down/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv1_down/bias') + + # Residual branch (kernel 1x1) with downsampling, without bias, + # with linear activation. + if self.architecture == 'resnet': + layer_name = f'residual{block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + add_bias=False, + scale_factor=2, + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear', + conv_clamp=None)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Skip/weight') + + # Convolution block for last resolution. + else: + self.mbstd = MiniBatchSTDLayer( + groups=mbstd_groups, new_channels=mbstd_channels, eps=eps) + + # First layer (kernel 3x3) without downsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module( + layer_name, + ConvLayer(in_channels=in_channels + mbstd_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + scale_factor=1, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu', + conv_clamp=conv_clamp)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv/bias') + + # Second layer, as a fully-connected layer. + layer_name = f'layer{2 * block_idx + 1}' + self.add_module(layer_name, + DenseLayer(in_channels=in_channels * res * res, + out_channels=in_channels, + add_bias=True, + init_bias=0.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Dense0/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Dense0/bias') + + # Final dense layer to output score. + self.output = DenseLayer(in_channels=in_channels, + out_channels=(embedding_dim + if self.use_embedding + else max(label_dim, 1)), + add_bias=True, + init_bias=0.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear') + self.pth_to_tf_var_mapping['output.weight'] = 'Output/weight' + self.pth_to_tf_var_mapping['output.bias'] = 'Output/bias' + + # Used for downsampling input image for `skip` architecture. + if self.architecture == 'skip': + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def forward(self, image, label=None, fp16_res=None, impl='cuda'): + # Check shape. + expected_shape = (self.image_channels, self.resolution, self.resolution) + if image.ndim != 4 or image.shape[1:] != expected_shape: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, channel, height, width], where ' + f'`channel` equals to {self.image_channels}, ' + f'`height`, `width` equal to {self.resolution}!\n' + f'But `{image.shape}` is received!') + if self.label_dim > 0: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + batch_size = image.shape[0] + if label.ndim != 2 or label.shape != (batch_size, self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'images ({image.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + if self.use_embedding: + embed = self.embedding(label, impl=impl) + if self.normalize_embedding: + embed = self.norm(embed) + for i in range(self.mapping_layers): + embed = getattr(self, f'mapping{i}')(embed, impl=impl) + + # Cast to `torch.float16` if needed. + if fp16_res is not None and self.resolution >= fp16_res: + image = image.to(torch.float16) + + x = self.input0(image, impl=impl) + + for res_log2 in range(self.final_res_log2, self.init_res_log2, -1): + res = 2 ** res_log2 + # Cast to `torch.float16` if needed. + if fp16_res is not None and res >= fp16_res: + x = x.to(torch.float16) + else: + x = x.to(torch.float32) + + idx = self.final_res_log2 - res_log2 # Block index + + if self.architecture == 'skip' and idx > 0: + image = upfirdn2d.downsample2d(image, self.filter, impl=impl) + # Cast to `torch.float16` if needed. + if fp16_res is not None and res >= fp16_res: + image = image.to(torch.float16) + else: + image = image.to(torch.float32) + y = getattr(self, f'input{idx}')(image, impl=impl) + x = x + y + + if self.architecture == 'resnet': + residual = getattr(self, f'residual{idx}')( + x, runtime_gain=np.sqrt(0.5), impl=impl) + x = getattr(self, f'layer{2 * idx}')(x, impl=impl) + x = getattr(self, f'layer{2 * idx + 1}')( + x, runtime_gain=np.sqrt(0.5), impl=impl) + x = x + residual + else: + x = getattr(self, f'layer{2 * idx}')(x, impl=impl) + x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl) + + # Final output. + idx += 1 + if fp16_res is not None: # Always use FP32 for the last block. + x = x.to(torch.float32) + if self.architecture == 'skip': + image = upfirdn2d.downsample2d(image, self.filter, impl=impl) + if fp16_res is not None: # Always use FP32 for the last block. + image = image.to(torch.float32) + y = getattr(self, f'input{idx}')(image, impl=impl) + x = x + y + x = self.mbstd(x) + x = getattr(self, f'layer{2 * idx}')(x, impl=impl) + x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl) + x = self.output(x, impl=impl) + + if self.use_embedding: + x = (x * embed).sum(dim=1, keepdim=True) + x = x / np.sqrt(self.embedding_dim) + elif self.label_dim > 0: + x = (x * label).sum(dim=1, keepdim=True) + + results = { + 'score': x, + 'label': label + } + if self.use_embedding: + results['embedding'] = embed + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class MiniBatchSTDLayer(nn.Module): + """Implements the minibatch standard deviation layer.""" + + def __init__(self, groups, new_channels, eps): + super().__init__() + self.groups = groups + self.new_channels = new_channels + self.eps = eps + + def extra_repr(self): + return (f'groups={self.groups}, ' + f'new_channels={self.new_channels}, ' + f'epsilon={self.eps}') + + def forward(self, x): + if self.groups <= 1 or self.new_channels < 1: + return x + + dtype = x.dtype + + N, C, H, W = x.shape + G = min(self.groups, N) # Number of groups. + nC = self.new_channels # Number of channel groups. + c = C // nC # Channels per channel group. + + y = x.reshape(G, -1, nC, c, H, W) # [GnFcHW] + y = y - y.mean(dim=0) # [GnFcHW] + y = y.square().mean(dim=0) # [nFcHW] + y = (y + self.eps).sqrt() # [nFcHW] + y = y.mean(dim=(2, 3, 4)) # [nF] + y = y.reshape(-1, nC, 1, 1) # [nF11] + y = y.repeat(G, 1, H, W) # [NFHW] + x = torch.cat((x, y), dim=1) # [N(C+F)HW] + + assert x.dtype == dtype + return x + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + If downsampling is needed (i.e., `scale_factor = 2`), the feature map will + be filtered with `filter_kernel` first. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + add_bias, + scale_factor, + filter_kernel, + use_wscale, + wscale_gain, + lr_mul, + activation_type, + conv_clamp): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for downsampling. `1` means skip + downsampling. + filter_kernel: Kernel used for filtering. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.filter_kernel = filter_kernel + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + self.conv_clamp = conv_clamp + + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + self.act_gain = bias_act.activation_funcs[activation_type].def_gain + + if scale_factor > 1: + assert filter_kernel is not None + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + fh, fw = self.filter.shape + self.filter_padding = ( + kernel_size // 2 + (fw - scale_factor + 1) // 2, + kernel_size // 2 + (fw - scale_factor) // 2, + kernel_size // 2 + (fh - scale_factor + 1) // 2, + kernel_size // 2 + (fh - scale_factor) // 2) + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'downsample={self.scale_factor}, ' + f'downsample_filter={self.filter_kernel}, ' + f'act={self.activation_type}, ' + f'clamp={self.conv_clamp}') + + def forward(self, x, runtime_gain=1.0, impl='cuda'): + dtype = x.dtype + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + if self.scale_factor == 1: # Native convolution without downsampling. + padding = self.kernel_size // 2 + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=1, padding=padding, impl=impl) + else: # Convolution with downsampling. + down = self.scale_factor + f = self.filter + padding = self.filter_padding + # When kernel size = 1, use filtering function for downsampling. + if self.kernel_size == 1: + x = upfirdn2d.upfirdn2d( + x, f, down=down, padding=padding, impl=impl) + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=1, padding=0, impl=impl) + # When kernel size != 1, use stride convolution for downsampling. + else: + x = upfirdn2d.upfirdn2d( + x, f, down=1, padding=padding, impl=impl) + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=down, padding=0, impl=impl) + + act_gain = self.act_gain * runtime_gain + act_clamp = None + if self.conv_clamp is not None: + act_clamp = self.conv_clamp * runtime_gain + x = bias_act.bias_act(x, bias, + act=self.activation_type, + gain=act_gain, + clamp=act_clamp, + impl=impl) + + assert x.dtype == dtype + return x + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + init_bias, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + init_bias: The initial bias value before training. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.init_bias = init_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + init_bias = np.float32(init_bias) / lr_mul + self.bias = nn.Parameter(torch.full([out_channels], init_bias)) + self.bscale = lr_mul + else: + self.bias = None + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'init_bias={self.init_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x, impl='cuda'): + dtype = x.dtype + + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight.to(dtype) * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + # Fast pass for linear activation. + if self.activation_type == 'linear' and bias is not None: + x = torch.addmm(bias.unsqueeze(0), x, weight.t()) + else: + x = x.matmul(weight.t()) + x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl) + + assert x.dtype == dtype + return x + +# pylint: enable=missing-function-docstring diff --git a/models/stylegan2_generator.py b/models/stylegan2_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..331ea320464912601e74bebb6432e0d7c09b6642 --- /dev/null +++ b/models/stylegan2_generator.py @@ -0,0 +1,1394 @@ +# python3.7 +"""Contains the implementation of generator described in StyleGAN2. + +Compared to that of StyleGAN, the generator in StyleGAN2 mainly introduces style +demodulation, adds skip connections, increases model size, and disables +progressive growth. This script ONLY supports config F in the original paper. + +Paper: https://arxiv.org/pdf/1912.04958.pdf + +Official TensorFlow implementation: https://github.com/NVlabs/stylegan2 +""" + +import numpy as np + +import torch +import torch.nn as nn + +from third_party.stylegan2_official_ops import fma +from third_party.stylegan2_official_ops import bias_act +from third_party.stylegan2_official_ops import upfirdn2d +from third_party.stylegan2_official_ops import conv2d_gradfix +from .utils.ops import all_gather + +__all__ = ['StyleGAN2Generator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Architectures allowed. +_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin'] + +# pylint: disable=missing-function-docstring + +class StyleGAN2Generator(nn.Module): + """Defines the generator network in StyleGAN2. + + NOTE: The synthesized images are with `RGB` channel order and pixel range + [-1, 1]. + + Settings for the mapping network: + + (1) z_dim: Dimension of the input latent space, Z. (default: 512) + (2) w_dim: Dimension of the output latent space, W. (default: 512) + (3) repeat_w: Repeat w-code for different layers. (default: True) + (4) normalize_z: Whether to normalize the z-code. (default: True) + (5) mapping_layers: Number of layers of the mapping network. (default: 8) + (6) mapping_fmaps: Number of hidden channels of the mapping network. + (default: 512) + (7) mapping_use_wscale: Whether to use weight scaling for the mapping + network. (default: True) + (8) mapping_wscale_gain: The factor to control weight scaling for the + mapping network (default: 1.0) + (9) mapping_lr_mul: Learning rate multiplier for the mapping network. + (default: 0.01) + + Settings for conditional generation: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (2) embedding_dim: Dimension of the embedding space, if needed. + (default: 512) + (3) embedding_bias: Whether to add bias to embedding learning. + (default: True) + (4) embedding_use_wscale: Whether to use weight scaling for embedding + learning. (default: True) + (5) embedding_wscale_gain: The factor to control weight scaling for + embedding. (default: 1.0) + (6) embedding_lr_mul: Learning rate multiplier for the embedding learning. + (default: 1.0) + (7) normalize_embedding: Whether to normalize the embedding. (default: True) + (8) normalize_embedding_latent: Whether to normalize the embedding together + with the latent. (default: False) + + Settings for the synthesis network: + + (1) resolution: The resolution of the output image. (default: -1) + (2) init_res: The initial resolution to start with convolution. (default: 4) + (3) image_channels: Number of channels of the output image. (default: 3) + (4) final_tanh: Whether to use `tanh` to control the final pixel range. + (default: False) + (5) const_input: Whether to use a constant in the first convolutional layer. + (default: True) + (6) architecture: Type of architecture. Support `origin`, `skip`, and + `resnet`. (default: `skip`) + (7) demodulate: Whether to perform style demodulation. (default: True) + (8) use_wscale: Whether to use weight scaling. (default: True) + (9) wscale_gain: The factor to control weight scaling. (default: 1.0) + (10) lr_mul: Learning rate multiplier for the synthesis network. + (default: 1.0) + (11) noise_type: Type of noise added to the convolutional results at each + layer. (default: `spatial`) + (12) fmaps_base: Factor to control number of feature maps for each layer. + (default: 32 << 10) + (13) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (14) filter_kernel: Kernel used for filtering (e.g., downsampling). + (default: (1, 3, 3, 1)) + (15) conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. (default: None) + (16) eps: A small value to avoid divide overflow. (default: 1e-8) + + Runtime settings: + + (1) w_moving_decay: Decay factor for updating `w_avg`, which is used for + training only. Set `None` to disable. (default: None) + (2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set + as `True`, the stats will be more accurate, yet the speed maybe a little + bit slower. (default: False) + (3) style_mixing_prob: Probability to perform style mixing as a training + regularization. Set `None` to disable. (default: None) + (4) trunc_psi: Truncation psi, set `None` to disable. (default: None) + (5) trunc_layers: Number of layers to perform truncation. (default: None) + (6) noise_mode: Mode of the layer-wise noise. Support `none`, `random`, + `const`. (default: `const`) + (7) fused_modulate: Whether to fuse `style_modulate` and `conv2d` together. + (default: False) + (8) fp16_res: Layers at resolution higher than (or equal to) this field will + use `float16` precision for computation. This is merely used for + acceleration. If set as `None`, all layers will use `float32` by + default. (default: None) + (9) impl: Implementation mode of some particular ops, e.g., `filtering`, + `bias_act`, etc. `cuda` means using the official CUDA implementation + from StyleGAN2, while `ref` means using the native PyTorch ops. + (default: `cuda`) + """ + + def __init__(self, + # Settings for mapping network. + z_dim=512, + w_dim=512, + repeat_w=True, + normalize_z=True, + mapping_layers=8, + mapping_fmaps=512, + mapping_use_wscale=True, + mapping_wscale_gain=1.0, + mapping_lr_mul=0.01, + # Settings for conditional generation. + label_dim=0, + embedding_dim=512, + embedding_bias=True, + embedding_use_wscale=True, + embedding_wscale_gian=1.0, + embedding_lr_mul=1.0, + normalize_embedding=True, + normalize_embedding_latent=False, + # Settings for synthesis network. + resolution=-1, + init_res=4, + image_channels=3, + final_tanh=False, + const_input=True, + architecture='skip', + demodulate=True, + use_wscale=True, + wscale_gain=1.0, + lr_mul=1.0, + noise_type='spatial', + fmaps_base=32 << 10, + fmaps_max=512, + filter_kernel=(1, 3, 3, 1), + conv_clamp=None, + eps=1e-8): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported, or `architecture` + is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + architecture = architecture.lower() + if architecture not in _ARCHITECTURES_ALLOWED: + raise ValueError(f'Invalid architecture: `{architecture}`!\n' + f'Architectures allowed: ' + f'{_ARCHITECTURES_ALLOWED}.') + + self.z_dim = z_dim + self.w_dim = w_dim + self.repeat_w = repeat_w + self.normalize_z = normalize_z + self.mapping_layers = mapping_layers + self.mapping_fmaps = mapping_fmaps + self.mapping_use_wscale = mapping_use_wscale + self.mapping_wscale_gain = mapping_wscale_gain + self.mapping_lr_mul = mapping_lr_mul + + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_use_wscale = embedding_use_wscale + self.embedding_wscale_gain = embedding_wscale_gian + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + + self.resolution = resolution + self.init_res = init_res + self.image_channels = image_channels + self.final_tanh = final_tanh + self.const_input = const_input + self.architecture = architecture + self.demodulate = demodulate + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.conv_clamp = conv_clamp + self.eps = eps + + # Dimension of latent space, which is convenient for sampling. + self.latent_dim = (z_dim,) + + # Number of synthesis (convolutional) layers. + self.num_layers = int(np.log2(resolution // init_res * 2)) * 2 + + self.mapping = MappingNetwork( + input_dim=z_dim, + output_dim=w_dim, + num_outputs=self.num_layers, + repeat_output=repeat_w, + normalize_input=normalize_z, + num_layers=mapping_layers, + hidden_dim=mapping_fmaps, + use_wscale=mapping_use_wscale, + wscale_gain=mapping_wscale_gain, + lr_mul=mapping_lr_mul, + label_dim=label_dim, + embedding_dim=embedding_dim, + embedding_bias=embedding_bias, + embedding_use_wscale=embedding_use_wscale, + embedding_wscale_gian=embedding_wscale_gian, + embedding_lr_mul=embedding_lr_mul, + normalize_embedding=normalize_embedding, + normalize_embedding_latent=normalize_embedding_latent, + eps=eps) + + # This is used for truncation trick. + if self.repeat_w: + self.register_buffer('w_avg', torch.zeros(w_dim)) + else: + self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) + + self.synthesis = SynthesisNetwork(resolution=resolution, + init_res=init_res, + w_dim=w_dim, + image_channels=image_channels, + final_tanh=final_tanh, + const_input=const_input, + architecture=architecture, + demodulate=demodulate, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + fmaps_base=fmaps_base, + filter_kernel=filter_kernel, + fmaps_max=fmaps_max, + conv_clamp=conv_clamp, + eps=eps) + + self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'} + for key, val in self.mapping.pth_to_tf_var_mapping.items(): + self.pth_to_tf_var_mapping[f'mapping.{key}'] = val + for key, val in self.synthesis.pth_to_tf_var_mapping.items(): + self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + See `SynthesisNetwork` for more details. + """ + self.synthesis.set_space_of_latent(space_of_latent) + + def forward(self, + z, + label=None, + w_moving_decay=None, + sync_w_avg=False, + style_mixing_prob=None, + trunc_psi=None, + trunc_layers=None, + noise_mode='const', + fused_modulate=False, + fp16_res=None, + impl='cuda'): + """Connects mapping network and synthesis network. + + This forward function will also update the average `w_code`, perform + style mixing as a training regularizer, and do truncation trick, which + is specially designed for inference. + + Concretely, the truncation trick acts as follows: + + For layers in range [0, truncation_layers), the truncated w-code is + computed as + + w_new = w_avg + (w - w_avg) * truncation_psi + + To disable truncation, please set + + (1) truncation_psi = 1.0 (None) OR + (2) truncation_layers = 0 (None) + """ + + mapping_results = self.mapping(z, label, impl=impl) + + w = mapping_results['w'] + if self.training and w_moving_decay is not None: + if sync_w_avg: + batch_w_avg = all_gather(w.detach()).mean(dim=0) + else: + batch_w_avg = w.detach().mean(dim=0) + self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) + + wp = mapping_results.pop('wp') + if self.training and style_mixing_prob is not None: + if np.random.uniform() < style_mixing_prob: + new_z = torch.randn_like(z) + new_wp = self.mapping(new_z, label, impl=impl)['wp'] + mixing_cutoff = np.random.randint(1, self.num_layers) + wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] + + if not self.training: + trunc_psi = 1.0 if trunc_psi is None else trunc_psi + trunc_layers = 0 if trunc_layers is None else trunc_layers + if trunc_psi < 1.0 and trunc_layers > 0: + w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] + wp[:, :trunc_layers] = w_avg.lerp( + wp[:, :trunc_layers], trunc_psi) + + synthesis_results = self.synthesis(wp, + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl, + fp16_res=fp16_res) + + return {**mapping_results, **synthesis_results} + + +class MappingNetwork(nn.Module): + """Implements the latent space mapping network. + + Basically, this network executes several dense layers in sequence, and the + label embedding if needed. + """ + + def __init__(self, + input_dim, + output_dim, + num_outputs, + repeat_output, + normalize_input, + num_layers, + hidden_dim, + use_wscale, + wscale_gain, + lr_mul, + label_dim, + embedding_dim, + embedding_bias, + embedding_use_wscale, + embedding_wscale_gian, + embedding_lr_mul, + normalize_embedding, + normalize_embedding_latent, + eps): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_outputs = num_outputs + self.repeat_output = repeat_output + self.normalize_input = normalize_input + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_use_wscale = embedding_use_wscale + self.embedding_wscale_gian = embedding_wscale_gian + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + self.eps = eps + + self.pth_to_tf_var_mapping = {} + + self.norm = PixelNormLayer(dim=1, eps=eps) + + if self.label_dim > 0: + input_dim = input_dim + embedding_dim + self.embedding = DenseLayer(in_channels=label_dim, + out_channels=embedding_dim, + add_bias=embedding_bias, + init_bias=0.0, + use_wscale=embedding_use_wscale, + wscale_gain=embedding_wscale_gian, + lr_mul=embedding_lr_mul, + activation_type='linear') + self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight' + if self.embedding_bias: + self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias' + + if num_outputs is not None and not repeat_output: + output_dim = output_dim * num_outputs + for i in range(num_layers): + in_channels = (input_dim if i == 0 else hidden_dim) + out_channels = (output_dim if i == (num_layers - 1) else hidden_dim) + self.add_module(f'dense{i}', + DenseLayer(in_channels=in_channels, + out_channels=out_channels, + add_bias=True, + init_bias=0.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight' + self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias' + + def forward(self, z, label=None, impl='cuda'): + if z.ndim != 2 or z.shape[1] != self.input_dim: + raise ValueError(f'Input latent code should be with shape ' + f'[batch_size, input_dim], where ' + f'`input_dim` equals to {self.input_dim}!\n' + f'But `{z.shape}` is received!') + if self.normalize_input: + z = self.norm(z) + + if self.label_dim > 0: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'latent codes ({z.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + embedding = self.embedding(label, impl=impl) + if self.normalize_embedding: + embedding = self.norm(embedding) + w = torch.cat((z, embedding), dim=1) + else: + w = z + + if self.label_dim > 0 and self.normalize_embedding_latent: + w = self.norm(w) + + for i in range(self.num_layers): + w = getattr(self, f'dense{i}')(w, impl=impl) + + wp = None + if self.num_outputs is not None: + if self.repeat_output: + wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1)) + else: + wp = w.reshape(-1, self.num_outputs, self.output_dim) + + results = { + 'z': z, + 'label': label, + 'w': w, + 'wp': wp, + } + if self.label_dim > 0: + results['embedding'] = embedding + return results + + +class SynthesisNetwork(nn.Module): + """Implements the image synthesis network. + + Basically, this network executes several convolutional layers in sequence. + """ + + def __init__(self, + resolution, + init_res, + w_dim, + image_channels, + final_tanh, + const_input, + architecture, + demodulate, + use_wscale, + wscale_gain, + lr_mul, + noise_type, + fmaps_base, + fmaps_max, + filter_kernel, + conv_clamp, + eps): + super().__init__() + + self.init_res = init_res + self.init_res_log2 = int(np.log2(init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(resolution)) + self.w_dim = w_dim + self.image_channels = image_channels + self.final_tanh = final_tanh + self.const_input = const_input + self.architecture = architecture.lower() + self.demodulate = demodulate + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.conv_clamp = conv_clamp + self.eps = eps + + self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 + + self.pth_to_tf_var_mapping = {} + + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + res = 2 ** res_log2 + in_channels = self.get_nf(res // 2) + out_channels = self.get_nf(res) + block_idx = res_log2 - self.init_res_log2 + + # Early layer. + if res == init_res: + if self.const_input: + self.add_module('early_layer', + InputLayer(init_res=res, + channels=out_channels)) + self.pth_to_tf_var_mapping['early_layer.const'] = ( + f'{res}x{res}/Const/const') + else: + channels = out_channels * res * res + self.add_module('early_layer', + DenseLayer(in_channels=w_dim, + out_channels=channels, + add_bias=True, + init_bias=0.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping['early_layer.weight'] = ( + f'{res}x{res}/Dense/weight') + self.pth_to_tf_var_mapping['early_layer.bias'] = ( + f'{res}x{res}/Dense/bias') + else: + # Residual branch (kernel 1x1) with upsampling, without bias, + # with linear activation. + if self.architecture == 'resnet': + layer_name = f'residual{block_idx}' + self.add_module(layer_name, + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + add_bias=False, + scale_factor=2, + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear', + conv_clamp=None)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Skip/weight') + + # First layer (kernel 3x3) with upsampling. + layer_name = f'layer{2 * block_idx - 1}' + self.add_module(layer_name, + ModulateConvLayer(in_channels=in_channels, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=3, + add_bias=True, + scale_factor=2, + filter_kernel=filter_kernel, + demodulate=demodulate, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + conv_clamp=conv_clamp, + eps=eps)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv0_up/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv0_up/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( + f'{res}x{res}/Conv0_up/mod_weight') + self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( + f'{res}x{res}/Conv0_up/mod_bias') + self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = ( + f'{res}x{res}/Conv0_up/noise_strength') + self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = ( + f'noise{2 * block_idx - 1}') + + # Second layer (kernel 3x3) without upsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module(layer_name, + ModulateConvLayer(in_channels=out_channels, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=3, + add_bias=True, + scale_factor=1, + filter_kernel=None, + demodulate=demodulate, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + conv_clamp=conv_clamp, + eps=eps)) + tf_layer_name = 'Conv' if res == self.init_res else 'Conv1' + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/{tf_layer_name}/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/{tf_layer_name}/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( + f'{res}x{res}/{tf_layer_name}/mod_weight') + self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( + f'{res}x{res}/{tf_layer_name}/mod_bias') + self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = ( + f'{res}x{res}/{tf_layer_name}/noise_strength') + self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = ( + f'noise{2 * block_idx}') + + # Output convolution layer for each resolution (if needed). + if res_log2 == self.final_res_log2 or self.architecture == 'skip': + layer_name = f'output{block_idx}' + self.add_module(layer_name, + ModulateConvLayer(in_channels=out_channels, + out_channels=image_channels, + resolution=res, + w_dim=w_dim, + kernel_size=1, + add_bias=True, + scale_factor=1, + filter_kernel=None, + demodulate=False, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type='none', + activation_type='linear', + conv_clamp=conv_clamp, + eps=eps)) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/ToRGB/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/ToRGB/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( + f'{res}x{res}/ToRGB/mod_weight') + self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( + f'{res}x{res}/ToRGB/mod_bias') + + # Used for upsampling output images for each resolution block for sum. + if self.architecture == 'skip': + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + This function is particularly used for choosing how to inject the latent + code into the convolutional layers. The original generator will take a + W-Space code and apply it for style modulation after an affine + transformation. But, sometimes, it may need to directly feed an already + affine-transformed code into the convolutional layer, e.g., when + training an encoder for GAN inversion. We term the transformed space as + Style Space (or Y-Space). This function is designed to tell the + convolutional layers how to use the input code. + + Args: + space_of_latent: The space to which the latent code belong. Case + insensitive. Support `W` and `Y`. + """ + space_of_latent = space_of_latent.upper() + for module in self.modules(): + if isinstance(module, ModulateConvLayer): + setattr(module, 'space_of_latent', space_of_latent) + + def forward(self, + wp, + noise_mode='const', + fused_modulate=False, + fp16_res=None, + impl='cuda'): + results = {'wp': wp} + + if self.const_input: + x = self.early_layer(wp[:, 0]) + else: + x = self.early_layer(wp[:, 0], impl=impl) + + # Cast to `torch.float16` if needed. + if fp16_res is not None and self.init_res >= fp16_res: + x = x.to(torch.float16) + + if self.architecture == 'origin': + for layer_idx in range(self.num_layers - 1): + layer = getattr(self, f'layer{layer_idx}') + x, style = layer(x, + wp[:, layer_idx], + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl) + results[f'style{layer_idx}'] = style + + # Cast to `torch.float16` if needed. + if layer_idx % 2 == 0 and layer_idx != self.num_layers - 2: + res = self.init_res * (2 ** (layer_idx // 2)) + if fp16_res is not None and res * 2 >= fp16_res: + x = x.to(torch.float16) + else: + x = x.to(torch.float32) + output_layer = getattr(self, f'output{layer_idx // 2}') + image, style = output_layer(x, + wp[:, layer_idx + 1], + fused_modulate=fused_modulate, + impl=impl) + image = image.to(torch.float32) + results[f'output_style{layer_idx // 2}'] = style + + elif self.architecture == 'skip': + for layer_idx in range(self.num_layers - 1): + layer = getattr(self, f'layer{layer_idx}') + x, style = layer(x, + wp[:, layer_idx], + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl) + results[f'style{layer_idx}'] = style + if layer_idx % 2 == 0: + output_layer = getattr(self, f'output{layer_idx // 2}') + y, style = output_layer(x, + wp[:, layer_idx + 1], + fused_modulate=fused_modulate, + impl=impl) + results[f'output_style{layer_idx // 2}'] = style + if layer_idx == 0: + image = y.to(torch.float32) + else: + image = y.to(torch.float32) + upfirdn2d.upsample2d( + image, self.filter, impl=impl) + + # Cast to `torch.float16` if needed. + if layer_idx != self.num_layers - 2: + res = self.init_res * (2 ** (layer_idx // 2)) + if fp16_res is not None and res * 2 >= fp16_res: + x = x.to(torch.float16) + else: + x = x.to(torch.float32) + + elif self.architecture == 'resnet': + x, style = self.layer0(x, + wp[:, 0], + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl) + results['style0'] = style + for layer_idx in range(1, self.num_layers - 1, 2): + # Cast to `torch.float16` if needed. + if layer_idx % 2 == 1: + res = self.init_res * (2 ** (layer_idx // 2)) + if fp16_res is not None and res * 2 >= fp16_res: + x = x.to(torch.float16) + else: + x = x.to(torch.float32) + + skip_layer = getattr(self, f'residual{layer_idx // 2 + 1}') + residual = skip_layer(x, runtime_gain=np.sqrt(0.5), impl=impl) + layer = getattr(self, f'layer{layer_idx}') + x, style = layer(x, + wp[:, layer_idx], + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl) + results[f'style{layer_idx}'] = style + layer = getattr(self, f'layer{layer_idx + 1}') + x, style = layer(x, + wp[:, layer_idx + 1], + runtime_gain=np.sqrt(0.5), + noise_mode=noise_mode, + fused_modulate=fused_modulate, + impl=impl) + results[f'style{layer_idx + 1}'] = style + x = x + residual + output_layer = getattr(self, f'output{layer_idx // 2 + 1}') + image, style = output_layer(x, + wp[:, layer_idx + 2], + fused_modulate=fused_modulate, + impl=impl) + image = image.to(torch.float32) + results[f'output_style{layer_idx // 2}'] = style + + if self.final_tanh: + image = torch.tanh(image) + results['image'] = image + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class InputLayer(nn.Module): + """Implements the input layer to start convolution with. + + Basically, this block starts from a const input, which is with shape + `(channels, init_res, init_res)`. + """ + + def __init__(self, init_res, channels): + super().__init__() + self.const = nn.Parameter(torch.randn(1, channels, init_res, init_res)) + + def forward(self, w): + x = self.const.repeat(w.shape[0], 1, 1, 1) + return x + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + If upsampling is needed (i.e., `scale_factor = 2`), the feature map will + be filtered with `filter_kernel` after convolution. This layer will only be + used for skip connection in `resnet` architecture. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + add_bias, + scale_factor, + filter_kernel, + use_wscale, + wscale_gain, + lr_mul, + activation_type, + conv_clamp): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for upsampling. + filter_kernel: Kernel used for filtering. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.filter_kernel = filter_kernel + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + self.conv_clamp = conv_clamp + + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + self.act_gain = bias_act.activation_funcs[activation_type].def_gain + + if scale_factor > 1: + assert filter_kernel is not None + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + fh, fw = self.filter.shape + self.filter_padding = ( + kernel_size // 2 + (fw + scale_factor - 1) // 2, + kernel_size // 2 + (fw - scale_factor) // 2, + kernel_size // 2 + (fh + scale_factor - 1) // 2, + kernel_size // 2 + (fh - scale_factor) // 2) + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'upsample={self.scale_factor}, ' + f'upsample_filter={self.filter_kernel}, ' + f'act={self.activation_type}, ' + f'clamp={self.conv_clamp}') + + def forward(self, x, runtime_gain=1.0, impl='cuda'): + dtype = x.dtype + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + if self.scale_factor == 1: # Native convolution without upsampling. + padding = self.kernel_size // 2 + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=1, padding=padding, impl=impl) + else: # Convolution with upsampling. + up = self.scale_factor + f = self.filter + # When kernel size = 1, use filtering function for upsampling. + if self.kernel_size == 1: + padding = self.filter_padding + x = conv2d_gradfix.conv2d( + x, weight.to(dtype), stride=1, padding=0, impl=impl) + x = upfirdn2d.upfirdn2d( + x, f, up=up, padding=padding, gain=up ** 2, impl=impl) + # When kernel size != 1, use transpose convolution for upsampling. + else: + # Following codes are borrowed from + # https://github.com/NVlabs/stylegan2-ada-pytorch + px0, px1, py0, py1 = self.filter_padding + kh, kw = weight.shape[2:] + px0 = px0 - (kw - 1) + px1 = px1 - (kw - up) + py0 = py0 - (kh - 1) + py1 = py1 - (kh - up) + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + weight = weight.transpose(0, 1) + padding = (pyt, pxt) + x = conv2d_gradfix.conv_transpose2d( + x, weight.to(dtype), stride=up, padding=padding, impl=impl) + padding = (px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt) + x = upfirdn2d.upfirdn2d( + x, f, up=1, padding=padding, gain=up ** 2, impl=impl) + + act_gain = self.act_gain * runtime_gain + act_clamp = None + if self.conv_clamp is not None: + act_clamp = self.conv_clamp * runtime_gain + x = bias_act.bias_act(x, bias, + act=self.activation_type, + gain=act_gain, + clamp=act_clamp, + impl=impl) + + assert x.dtype == dtype + return x + + +class ModulateConvLayer(nn.Module): + """Implements the convolutional layer with style modulation.""" + + def __init__(self, + in_channels, + out_channels, + resolution, + w_dim, + kernel_size, + add_bias, + scale_factor, + filter_kernel, + demodulate, + use_wscale, + wscale_gain, + lr_mul, + noise_type, + activation_type, + conv_clamp, + eps): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + resolution: Resolution of the output tensor. + w_dim: Dimension of W space for style modulation. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for upsampling. + filter_kernel: Kernel used for filtering. + demodulate: Whether to perform style demodulation. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + noise_type: Type of noise added to the feature map after the + convolution (if needed). Support `none`, `spatial` and + `channel`. + activation_type: Type of activation. + conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. + eps: A small value to avoid divide overflow. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.resolution = resolution + self.w_dim = w_dim + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.filter_kernel = filter_kernel + self.demodulate = demodulate + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.activation_type = activation_type + self.conv_clamp = conv_clamp + self.eps = eps + + self.space_of_latent = 'W' + + # Set up weight. + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + # Set up bias. + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + self.act_gain = bias_act.activation_funcs[activation_type].def_gain + + # Set up style. + self.style = DenseLayer(in_channels=w_dim, + out_channels=in_channels, + add_bias=True, + init_bias=1.0, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='linear') + + # Set up noise. + if self.noise_type != 'none': + self.noise_strength = nn.Parameter(torch.zeros(())) + if self.noise_type == 'spatial': + self.register_buffer( + 'noise', torch.randn(1, 1, resolution, resolution)) + elif self.noise_type == 'channel': + self.register_buffer( + 'noise', torch.randn(1, out_channels, 1, 1)) + else: + raise NotImplementedError(f'Not implemented noise type: ' + f'`{self.noise_type}`!') + + if scale_factor > 1: + assert filter_kernel is not None + self.register_buffer( + 'filter', upfirdn2d.setup_filter(filter_kernel)) + fh, fw = self.filter.shape + self.filter_padding = ( + kernel_size // 2 + (fw + scale_factor - 1) // 2, + kernel_size // 2 + (fw - scale_factor) // 2, + kernel_size // 2 + (fh + scale_factor - 1) // 2, + kernel_size // 2 + (fh - scale_factor) // 2) + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'upsample={self.scale_factor}, ' + f'upsample_filter={self.filter_kernel}, ' + f'demodulate={self.demodulate}, ' + f'noise_type={self.noise_type}, ' + f'act={self.activation_type}, ' + f'clamp={self.conv_clamp}') + + def forward_style(self, w, impl='cuda'): + """Gets style code from the given input. + + More specifically, if the input is from W-Space, it will be projected by + an affine transformation. If it is from the Style Space (Y-Space), no + operation is required. + + NOTE: For codes from Y-Space, we use slicing to make sure the dimension + is correct, in case that the code is padded before fed into this layer. + """ + space_of_latent = self.space_of_latent.upper() + if space_of_latent == 'W': + if w.ndim != 2 or w.shape[1] != self.w_dim: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, w_dim], where ' + f'`w_dim` equals to {self.w_dim}!\n' + f'But `{w.shape}` is received!') + style = self.style(w, impl=impl) + elif space_of_latent == 'Y': + if w.ndim != 2 or w.shape[1] < self.in_channels: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, y_dim], where ' + f'`y_dim` equals to {self.in_channels}!\n' + f'But `{w.shape}` is received!') + style = w[:, :self.in_channels] + else: + raise NotImplementedError(f'Not implemented `space_of_latent`: ' + f'`{space_of_latent}`!') + return style + + def forward(self, + x, + w, + runtime_gain=1.0, + noise_mode='const', + fused_modulate=False, + impl='cuda'): + dtype = x.dtype + N, C, H, W = x.shape + + fused_modulate = (fused_modulate and + not self.training and + (dtype == torch.float32 or N == 1)) + + weight = self.weight + out_ch, in_ch, kh, kw = weight.shape + assert in_ch == C + + # Affine on `w`. + style = self.forward_style(w, impl=impl) + if not self.demodulate: + _style = style * self.wscale # Equivalent to scaling weight. + else: + _style = style + + # Prepare noise. + noise = None + noise_mode = noise_mode.lower() + if self.noise_type != 'none' and noise_mode != 'none': + if noise_mode == 'random': + noise = torch.randn((N, *self.noise.shape[1:]), device=x.device) + elif noise_mode == 'const': + noise = self.noise + else: + raise ValueError(f'Unknown noise mode `{noise_mode}`!') + noise = (noise * self.noise_strength).to(dtype) + + # Pre-normalize inputs to avoid FP16 overflow. + if dtype == torch.float16 and self.demodulate: + weight_max = weight.norm(float('inf'), dim=(1, 2, 3), keepdim=True) + weight = weight * (self.wscale / weight_max) + style_max = _style.norm(float('inf'), dim=1, keepdim=True) + _style = _style / style_max + + if self.demodulate or fused_modulate: + _weight = weight.unsqueeze(0) + _weight = _weight * _style.reshape(N, 1, in_ch, 1, 1) + if self.demodulate: + decoef = (_weight.square().sum(dim=(2, 3, 4)) + self.eps).rsqrt() + if self.demodulate and fused_modulate: + _weight = _weight * decoef.reshape(N, out_ch, 1, 1, 1) + + if not fused_modulate: + x = x * _style.to(dtype).reshape(N, in_ch, 1, 1) + w = weight.to(dtype) + groups = 1 + else: # Use group convolution to fuse style modulation and convolution. + x = x.reshape(1, N * in_ch, H, W) + w = _weight.reshape(N * out_ch, in_ch, kh, kw).to(dtype) + groups = N + + if self.scale_factor == 1: # Native convolution without upsampling. + up = 1 + padding = self.kernel_size // 2 + x = conv2d_gradfix.conv2d( + x, w, stride=1, padding=padding, groups=groups, impl=impl) + else: # Convolution with upsampling. + up = self.scale_factor + f = self.filter + # When kernel size = 1, use filtering function for upsampling. + if self.kernel_size == 1: + padding = self.filter_padding + x = conv2d_gradfix.conv2d( + x, w, stride=1, padding=0, groups=groups, impl=impl) + x = upfirdn2d.upfirdn2d( + x, f, up=up, padding=padding, gain=up ** 2, impl=impl) + # When kernel size != 1, use stride convolution for upsampling. + else: + # Following codes are borrowed from + # https://github.com/NVlabs/stylegan2-ada-pytorch + px0, px1, py0, py1 = self.filter_padding + px0 = px0 - (kw - 1) + px1 = px1 - (kw - up) + py0 = py0 - (kh - 1) + py1 = py1 - (kh - up) + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(N, out_ch, in_ch, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(N * in_ch, out_ch, kh, kw) + padding = (pyt, pxt) + x = conv2d_gradfix.conv_transpose2d( + x, w, stride=up, padding=padding, groups=groups, impl=impl) + padding = (px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt) + x = upfirdn2d.upfirdn2d( + x, f, up=1, padding=padding, gain=up ** 2, impl=impl) + + if not fused_modulate: + if self.demodulate: + decoef = decoef.to(dtype).reshape(N, out_ch, 1, 1) + if self.demodulate and noise is not None: + x = fma.fma(x, decoef, noise, impl=impl) + else: + if self.demodulate: + x = x * decoef + if noise is not None: + x = x + noise + else: + x = x.reshape(N, out_ch, H * up, W * up) + if noise is not None: + x = x + noise + + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + if self.activation_type == 'linear': # Shortcut for output layer. + x = bias_act.bias_act( + x, bias, act='linear', clamp=self.conv_clamp, impl=impl) + else: + act_gain = self.act_gain * runtime_gain + act_clamp = None + if self.conv_clamp is not None: + act_clamp = self.conv_clamp * runtime_gain + x = bias_act.bias_act(x, bias, + act=self.activation_type, + gain=act_gain, + clamp=act_clamp, + impl=impl) + + assert x.dtype == dtype + assert style.dtype == torch.float32 + return x, style + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + init_bias, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + init_bias: The initial bias value before training. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.init_bias = init_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + init_bias = np.float32(init_bias) / lr_mul + self.bias = nn.Parameter(torch.full([out_channels], init_bias)) + self.bscale = lr_mul + else: + self.bias = None + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'init_bias={self.init_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x, impl='cuda'): + dtype = x.dtype + + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight.to(dtype) * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + # Fast pass for linear activation. + if self.activation_type == 'linear' and bias is not None: + x = torch.addmm(bias.unsqueeze(0), x, weight.t()) + else: + x = x.matmul(weight.t()) + x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl) + + assert x.dtype == dtype + return x + +# pylint: enable=missing-function-docstring diff --git a/models/stylegan3_generator.py b/models/stylegan3_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fc390bd070364d7423c51ea0f851c423827d761d --- /dev/null +++ b/models/stylegan3_generator.py @@ -0,0 +1,1332 @@ +# python3.7 +"""Contains the implementation of generator described in StyleGAN3. + +Compared to that of StyleGAN2, the generator in StyleGAN3 controls the frequency +flow along with the convolutional layers growing. + +Paper: https://arxiv.org/pdf/2106.12423.pdf + +Official implementation: https://github.com/NVlabs/stylegan3 +""" + +import numpy as np +import scipy.signal + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from third_party.stylegan3_official_ops import bias_act +from third_party.stylegan3_official_ops import filtered_lrelu +from third_party.stylegan3_official_ops import conv2d_gradfix +from .utils.ops import all_gather + +__all__ = ['StyleGAN3Generator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# pylint: disable=missing-function-docstring + +class StyleGAN3Generator(nn.Module): + """Defines the generator network in StyleGAN3. + + NOTE: The synthesized images are with `RGB` channel order and pixel range + [-1, 1]. + + Settings for the mapping network: + + (1) z_dim: Dimension of the input latent space, Z. (default: 512) + (2) w_dim: Dimension of the output latent space, W. (default: 512) + (3) repeat_w: Repeat w-code for different layers. (default: True) + (4) normalize_z: Whether to normalize the z-code. (default: True) + (5) mapping_layers: Number of layers of the mapping network. (default: 2) + (6) mapping_fmaps: Number of hidden channels of the mapping network. + (default: 512) + (7) mapping_lr_mul: Learning rate multiplier for the mapping network. + (default: 0.01) + + Settings for conditional generation: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (2) embedding_dim: Dimension of the embedding space, if needed. + (default: 512) + (3) embedding_bias: Whether to add bias to embedding learning. + (default: True) + (4) embedding_lr_mul: Learning rate multiplier for the embedding learning. + (default: 1.0) + (5) normalize_embedding: Whether to normalize the embedding. (default: True) + (6) normalize_embedding_latent: Whether to normalize the embedding together + with the latent. (default: False) + + Settings for the synthesis network: + + (1) resolution: The resolution of the output image. (default: -1) + (2) image_channels: Number of channels of the output image. (default: 3) + (3) final_tanh: Whether to use `tanh` to control the final pixel range. + (default: False) + (4) output_scale: Factor to scaling the output image. (default: 0.25) + (5) num_layers: Number of synthesis layers, excluding the first positional + encoding layer and the last ToRGB layer. (default: 14) + (6) num_critical: Number of synthesis layers with critical sampling. These + layers are always set as top (with highest resolution) ones. + (7) fmaps_base: Factor to control number of feature maps for each layer. + (default: 32 << 10) + (8) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (9) kernel_size: Size of convolutional kernels. (default: 1) + (10) conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. (default: None) + (11) first_cutoff: Cutoff frequency of the first layer. (default: 2) + (12) first_stopband: Stopband of the first layer. (default: 2 ** 2.1) + (13) last_stopband_rel: Stopband of the last layer, relative to the last + cutoff, which is `resolution / 2`. Concretely, `last_stopband` will be + equal to `resolution / 2 * last_stopband_rel`. (default: 2 ** 0.3) + (14) margin_size: Size of margin for each feature map. (default: 10) + (15) filter_size: Size of filter for upsampling and downsampling around the + activation. (default: 6) + (16) act_upsampling: Factor used to upsample the feature map before + activation for anti-aliasing. (default: 2) + (17) use_radial_filter: Whether to use radial filter for downsampling after + the activation. (default: False) + (18) eps: A small value to avoid divide overflow. (default: 1e-8) + + Runtime settings: + + (1) w_moving_decay: Decay factor for updating `w_avg`, which is used for + training only. Set `None` to disable. (default: 0.998) + (2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set + as `True`, the stats will be more accurate, yet the speed maybe a little + bit slower. (default: False) + (3) style_mixing_prob: Probability to perform style mixing as a training + regularization. Set `None` to disable. (default: None) + (4) trunc_psi: Truncation psi, set `None` to disable. (default: None) + (5) trunc_layers: Number of layers to perform truncation. (default: None) + (6) magnitude_moving_decay: Decay factor for updating `magnitude_ema` in + each `SynthesisLayer`, which is used for training only. Set `None` to + disable. (default: 0.999) + (7) update_ema: Whether to update `w_avg` in the `MappingNetwork` and + `magnitude_ema` in each `SynthesisLayer`. This field only takes effect + in `training` model. (default: False) + (8) fp16_res: Layers at resolution higher than (or equal to) this field will + use `float16` precision for computation. This is merely used for + acceleration. If set as `None`, all layers will use `float32` by + default. (default: None) + (9) impl: Implementation mode of some particular ops, e.g., `filtering`, + `bias_act`, etc. `cuda` means using the official CUDA implementation + from StyleGAN3, while `ref` means using the native PyTorch ops. + (default: `cuda`) + """ + + def __init__(self, + # Settings for mapping network. + z_dim=512, + w_dim=512, + repeat_w=True, + normalize_z=True, + mapping_layers=2, + mapping_fmaps=512, + mapping_lr_mul=0.01, + # Settings for conditional generation. + label_dim=0, + embedding_dim=512, + embedding_bias=True, + embedding_lr_mul=1.0, + normalize_embedding=True, + normalize_embedding_latent=False, + # Settings for synthesis network. + resolution=-1, + image_channels=3, + final_tanh=False, + output_scale=0.25, + num_layers=14, + num_critical=2, + fmaps_base=32 << 10, + fmaps_max=512, + kernel_size=1, + conv_clamp=256, + first_cutoff=2, + first_stopband=2 ** 2.1, + last_stopband_rel=2 ** 0.3, + margin_size=10, + filter_size=6, + act_upsampling=2, + use_radial_filter=False, + eps=1e-8): + """Initializes with basic settings.""" + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + + self.z_dim = z_dim + self.w_dim = w_dim + self.repeat_w = repeat_w + self.normalize_z = normalize_z + self.mapping_layers = mapping_layers + self.mapping_fmaps = mapping_fmaps + self.mapping_lr_mul = mapping_lr_mul + + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + + self.resolution = resolution + self.image_channels = image_channels + self.final_tanh = final_tanh + self.output_scale = output_scale + self.num_layers = num_layers + 2 # Including InputLayer and ToRGBLayer. + self.num_critical = num_critical + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.kernel_size = kernel_size + self.conv_clamp = conv_clamp + self.first_cutoff = first_cutoff + self.first_stopband = first_stopband + self.last_stopband_rel = last_stopband_rel + self.margin_size = margin_size + self.filter_size = filter_size + self.act_upsampling = act_upsampling + self.use_radial_filter = use_radial_filter + self.eps = eps + + # Dimension of latent space, which is convenient for sampling. + self.latent_dim = (z_dim,) + + self.mapping = MappingNetwork( + input_dim=z_dim, + output_dim=w_dim, + num_outputs=self.num_layers, + repeat_output=repeat_w, + normalize_input=normalize_z, + num_layers=mapping_layers, + hidden_dim=mapping_fmaps, + lr_mul=mapping_lr_mul, + label_dim=label_dim, + embedding_dim=embedding_dim, + embedding_bias=embedding_bias, + embedding_lr_mul=embedding_lr_mul, + normalize_embedding=normalize_embedding, + normalize_embedding_latent=normalize_embedding_latent, + eps=eps) + + # This is used for truncation trick. + if self.repeat_w: + self.register_buffer('w_avg', torch.zeros(w_dim)) + else: + self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) + + self.synthesis = SynthesisNetwork(resolution=resolution, + w_dim=w_dim, + image_channels=image_channels, + final_tanh=final_tanh, + output_scale=output_scale, + num_layers=num_layers, + num_critical=num_critical, + fmaps_base=fmaps_base, + fmaps_max=fmaps_max, + kernel_size=kernel_size, + conv_clamp=conv_clamp, + first_cutoff=first_cutoff, + first_stopband=first_stopband, + last_stopband_rel=last_stopband_rel, + margin_size=margin_size, + filter_size=filter_size, + act_upsampling=act_upsampling, + use_radial_filter=use_radial_filter, + eps=eps) + + self.var_mapping = {'w_avg': 'mapping.w_avg'} + for key, val in self.mapping.var_mapping.items(): + self.var_mapping[f'mapping.{key}'] = f'mapping.{val}' + for key, val in self.synthesis.var_mapping.items(): + self.var_mapping[f'synthesis.{key}'] = f'synthesis.{val}' + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + See `SynthesisNetwork` for more details. + """ + self.synthesis.set_space_of_latent(space_of_latent) + + def forward(self, + z, + label=None, + w_moving_decay=0.998, + sync_w_avg=False, + style_mixing_prob=None, + trunc_psi=None, + trunc_layers=None, + magnitude_moving_decay=0.999, + update_ema=False, + fp16_res=None, + impl='cuda'): + """Connects mapping network and synthesis network. + + This forward function will also update the average `w_code`, perform + style mixing as a training regularizer, and do truncation trick, which + is specially designed for inference. + + Concretely, the truncation trick acts as follows: + + For layers in range [0, truncation_layers), the truncated w-code is + computed as + + w_new = w_avg + (w - w_avg) * truncation_psi + + To disable truncation, please set + + (1) truncation_psi = 1.0 (None) OR + (2) truncation_layers = 0 (None) + """ + + mapping_results = self.mapping(z, label, impl=impl) + + w = mapping_results['w'] + if self.training and update_ema and w_moving_decay is not None: + if sync_w_avg: + batch_w_avg = all_gather(w.detach()).mean(dim=0) + else: + batch_w_avg = w.detach().mean(dim=0) + self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) + + wp = mapping_results.pop('wp') + if self.training and style_mixing_prob is not None: + if np.random.uniform() < style_mixing_prob: + new_z = torch.randn_like(z) + new_wp = self.mapping(new_z, label, impl=impl)['wp'] + mixing_cutoff = np.random.randint(1, self.num_layers) + wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] + + if not self.training: + trunc_psi = 1.0 if trunc_psi is None else trunc_psi + trunc_layers = 0 if trunc_layers is None else trunc_layers + if trunc_psi < 1.0 and trunc_layers > 0: + w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] + wp[:, :trunc_layers] = w_avg.lerp( + wp[:, :trunc_layers], trunc_psi) + + synthesis_results = self.synthesis( + wp, + magnitude_moving_decay=magnitude_moving_decay, + update_ema=update_ema, + fp16_res=fp16_res, + impl=impl) + + return {**mapping_results, **synthesis_results} + + +class MappingNetwork(nn.Module): + """Implements the latent space mapping network. + + Basically, this network executes several dense layers in sequence, and the + label embedding if needed. + """ + + def __init__(self, + input_dim, + output_dim, + num_outputs, + repeat_output, + normalize_input, + num_layers, + hidden_dim, + lr_mul, + label_dim, + embedding_dim, + embedding_bias, + embedding_lr_mul, + normalize_embedding, + normalize_embedding_latent, + eps): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_outputs = num_outputs + self.repeat_output = repeat_output + self.normalize_input = normalize_input + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.lr_mul = lr_mul + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.embedding_bias = embedding_bias + self.embedding_lr_mul = embedding_lr_mul + self.normalize_embedding = normalize_embedding + self.normalize_embedding_latent = normalize_embedding_latent + self.eps = eps + + self.var_mapping = {} + + self.norm = PixelNormLayer(dim=1, eps=eps) + + if self.label_dim > 0: + input_dim = input_dim + embedding_dim + self.embedding = DenseLayer(in_channels=label_dim, + out_channels=embedding_dim, + init_weight_std=1.0, + add_bias=embedding_bias, + init_bias=0.0, + lr_mul=embedding_lr_mul, + activation_type='linear') + self.var_mapping['embedding.weight'] = 'embed.weight' + if self.embedding_bias: + self.var_mapping['embedding.bias'] = 'embed.bias' + + if num_outputs is not None and not repeat_output: + output_dim = output_dim * num_outputs + for i in range(num_layers): + in_channels = (input_dim if i == 0 else hidden_dim) + out_channels = (output_dim if i == (num_layers - 1) else hidden_dim) + self.add_module(f'dense{i}', + DenseLayer(in_channels=in_channels, + out_channels=out_channels, + init_weight_std=1.0, + add_bias=True, + init_bias=0.0, + lr_mul=lr_mul, + activation_type='lrelu')) + self.var_mapping[f'dense{i}.weight'] = f'fc{i}.weight' + self.var_mapping[f'dense{i}.bias'] = f'fc{i}.bias' + + def forward(self, z, label=None, impl='cuda'): + if z.ndim != 2 or z.shape[1] != self.input_dim: + raise ValueError(f'Input latent code should be with shape ' + f'[batch_size, input_dim], where ' + f'`input_dim` equals to {self.input_dim}!\n' + f'But `{z.shape}` is received!') + if self.normalize_input: + z = self.norm(z) + + if self.label_dim > 0: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'latent codes ({z.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + embedding = self.embedding(label, impl=impl) + if self.normalize_embedding: + embedding = self.norm(embedding) + w = torch.cat((z, embedding), dim=1) + else: + w = z + + if self.label_dim > 0 and self.normalize_embedding_latent: + w = self.norm(w) + + for i in range(self.num_layers): + w = getattr(self, f'dense{i}')(w, impl=impl) + + wp = None + if self.num_outputs is not None: + if self.repeat_output: + wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1)) + else: + wp = w.reshape(-1, self.num_outputs, self.output_dim) + + results = { + 'z': z, + 'label': label, + 'w': w, + 'wp': wp, + } + if self.label_dim > 0: + results['embedding'] = embedding + return results + + +class SynthesisNetwork(nn.Module): + """Implements the image synthesis network. + + Basically, this network executes several convolutional layers in sequence. + """ + + def __init__(self, + resolution, + w_dim, + image_channels, + final_tanh, + output_scale, + num_layers, + num_critical, + fmaps_base, + fmaps_max, + kernel_size, + conv_clamp, + first_cutoff, + first_stopband, + last_stopband_rel, + margin_size, + filter_size, + act_upsampling, + use_radial_filter, + eps): + super().__init__() + + self.resolution = resolution + self.w_dim = w_dim + self.image_channels = image_channels + self.final_tanh = final_tanh + self.output_scale = output_scale + self.num_layers = num_layers + self.num_critical = num_critical + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.kernel_size = kernel_size + self.conv_clamp = conv_clamp + self.first_cutoff = first_cutoff + self.first_stopband = first_stopband + self.last_stopband_rel = last_stopband_rel + self.margin_size = margin_size + self.filter_size = filter_size + self.act_upsampling = act_upsampling + self.use_radial_filter = use_radial_filter + self.eps = eps + + self.var_mapping = {} + + # Get layer settings. + last_cutoff = resolution / 2 + last_stopband = last_cutoff * last_stopband_rel + layer_indices = np.arange(num_layers + 1) + exponents = np.minimum(layer_indices / (num_layers - num_critical), 1) + cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents + stopbands = ( + first_stopband * (last_stopband / first_stopband) ** exponents) + sampling_rates = np.exp2(np.ceil(np.log2( + np.minimum(stopbands * 2, self.resolution)))) + sampling_rates = np.int64(sampling_rates) + half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs + sizes = sampling_rates + margin_size * 2 + sizes[-2:] = resolution + sizes = np.int64(sizes) + channels = np.rint(np.minimum((fmaps_base / 2) / cutoffs, fmaps_max)) + channels[-1] = image_channels + channels = np.int64(channels) + + self.cutoffs = cutoffs + self.stopbands = stopbands + self.sampling_rates = sampling_rates + self.half_widths = half_widths + self.sizes = sizes + self.channels = channels + + # Input layer, with positional encoding. + self.early_layer = InputLayer(w_dim=w_dim, + channels=channels[0], + size=sizes[0], + sampling_rate=sampling_rates[0], + cutoff=cutoffs[0]) + self.var_mapping['early_layer.weight'] = 'input.weight' + self.var_mapping['early_layer.affine.weight'] = 'input.affine.weight' + self.var_mapping['early_layer.affine.bias'] = 'input.affine.bias' + self.var_mapping['early_layer.transform'] = 'input.transform' + self.var_mapping['early_layer.frequency'] = 'input.freqs' + self.var_mapping['early_layer.phase'] = 'input.phases' + + # Convolutional layers. + for idx in range(num_layers + 1): + # Position related settings. + if idx < num_layers: + kernel_size = self.kernel_size + demodulate = True + act_upsampling = self.act_upsampling + else: # ToRGB layer. + kernel_size = 1 + demodulate = False + act_upsampling = 1 + if idx < num_layers - num_critical: # Non-critical sampling. + use_radial_filter = self.use_radial_filter + else: # Critical sampling. + use_radial_filter = False + + prev_idx = max(idx - 1, 0) + layer_name = f'layer{idx}' + official_layer_name = f'L{idx}_{sizes[idx]}_{channels[idx]}' + self.add_module( + layer_name, + SynthesisLayer(in_channels=channels[prev_idx], + out_channels=channels[idx], + w_dim=w_dim, + kernel_size=kernel_size, + demodulate=demodulate, + eps=eps, + conv_clamp=conv_clamp, + in_size=sizes[prev_idx], + out_size=sizes[idx], + in_sampling_rate=sampling_rates[prev_idx], + out_sampling_rate=sampling_rates[idx], + in_cutoff=cutoffs[prev_idx], + out_cutoff=cutoffs[idx], + in_half_width=half_widths[prev_idx], + out_half_width=half_widths[idx], + filter_size=filter_size, + use_radial_filter=use_radial_filter, + act_upsampling=act_upsampling)) + + self.var_mapping[f'{layer_name}.magnitude_ema'] = ( + f'{official_layer_name}.magnitude_ema') + self.var_mapping[f'{layer_name}.conv.weight'] = ( + f'{official_layer_name}.weight') + self.var_mapping[f'{layer_name}.conv.style.weight'] = ( + f'{official_layer_name}.affine.weight') + self.var_mapping[f'{layer_name}.conv.style.bias'] = ( + f'{official_layer_name}.affine.bias') + self.var_mapping[f'{layer_name}.filter.bias'] = ( + f'{official_layer_name}.bias') + if idx < num_layers: # ToRGB layer does not need filters. + self.var_mapping[f'{layer_name}.filter.up_filter'] = ( + f'{official_layer_name}.up_filter') + self.var_mapping[f'{layer_name}.filter.down_filter'] = ( + f'{official_layer_name}.down_filter') + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + This function is particularly used for choosing how to inject the latent + code into the convolutional layers. The original generator will take a + W-Space code and apply it for style modulation after an affine + transformation. But, sometimes, it may need to directly feed an already + affine-transformed code into the convolutional layer, e.g., when + training an encoder for GAN inversion. We term the transformed space as + Style Space (or Y-Space). This function is designed to tell the + convolutional layers how to use the input code. + + Args: + space_of_latent: The space to which the latent code belong. Case + insensitive. Support `W` and `Y`. + """ + space_of_latent = space_of_latent.upper() + for module in self.modules(): + if isinstance(module, ModulateConvLayer): + setattr(module, 'space_of_latent', space_of_latent) + + def forward(self, + wp, + magnitude_moving_decay=0.999, + update_ema=False, + fp16_res=None, + impl='cuda'): + results = {'wp': wp} + + x = self.early_layer(wp[:, 0]) + for idx, sampling_rate in enumerate(self.sampling_rates): + if fp16_res is not None and sampling_rate >= fp16_res: + x = x.to(torch.float16) + layer = getattr(self, f'layer{idx}') + x, style = layer(x, wp[:, idx + 1], + magnitude_moving_decay=magnitude_moving_decay, + update_ema=update_ema, + impl=impl) + results[f'style{idx}'] = style + + if self.output_scale != 1: + x = x * self.output_scale + x = x.to(torch.float32) + if self.final_tanh: + x = torch.tanh(x) + results['image'] = x + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class InputLayer(nn.Module): + """Implements the input layer with positional encoding. + + Basically, this block outputs a feature map with shape + `(channels, size, size)` based on the coordinate information. + `sampling_rate` and `cutoff` are used to control the coordinate range and + strength respectively. + + For a low-pass filter, `cutoff` is the same as the `bandwidth`. + The initial frequency of the starting feature map is controlled by the + positional encoding `sin(2 * pi * x)`, where + `x = trans(coord) * frequency + phase`. We would like to introduce rich + information (i.e. frequencies), but keep all frequencies lower than + stopband, which is `sampling_rate / 2`. + + Besides, this layer also supports learning a transformation from the latent + code w, and providing a customized transformation for inference. Please + use the buffer `transform`. + + NOTE: `size` is different from `sampling_rate`. `sampling_rate` is the + actual size of the current stage, which determines the maximum frequency + that the feature maps can hold. `size` is the actual height and width of the + current feature map, including the extended border. + """ + + def __init__(self, w_dim, channels, size, sampling_rate, cutoff): + super().__init__() + + self.w_dim = w_dim + self.channels = channels + self.size = size + self.sampling_rate = sampling_rate + self.cutoff = cutoff + + # Coordinate of the entire feature map, with resolution (size, size). + # The coordinate range for the central (sampling_rate, sampling_rate) + # region is set as (-0.0, 0.5), which extends to the remaining region. + theta = torch.eye(2, 3) + theta[0, 0] = 0.5 / sampling_rate * size + theta[1, 1] = 0.5 / sampling_rate * size + grid = F.affine_grid(theta=theta.unsqueeze(0), + size=(1, 1, size, size), + align_corners=False) + self.register_buffer('grid', grid) + + # Draw random frequency from a uniform 2D disc for each channel + # regarding X and Y dimension. And also draw a random phase for each + # channel. Accordingly, each channel has three pre-defined parameters, + # which are X-frequency, Y-frequency, and phase. + frequency = torch.randn(channels, 2) + radius = frequency.square().sum(dim=1, keepdim=True).sqrt() + frequency = frequency / (radius * radius.square().exp().pow(0.25)) + frequency = frequency * cutoff + self.register_buffer('frequency', frequency) + phase = torch.rand(channels) - 0.5 + self.register_buffer('phase', phase) + + # This layer is used to map the latent code w to transform factors, + # with order: cos(angle), sin(angle), transpose_x, transpose_y. + self.affine = DenseLayer(in_channels=w_dim, + out_channels=4, + init_weight_std=0.0, + add_bias=True, + init_bias=(1, 0, 0, 0), + lr_mul=1.0, + activation_type='linear') + + # It is possible to use this buffer to customize the transform of the + # output synthesis. + self.register_buffer('transform', torch.eye(3)) + + # Use 1x1 conv to convert positional encoding to features. + self.weight = nn.Parameter(torch.randn(channels, channels)) + self.weight_scale = 1 / np.sqrt(channels) + + def extra_repr(self): + return (f'channels={self.channels}, ' + f'size={self.size}, ' + f'sampling_rate={self.sampling_rate}, ' + f'cutoff={self.cutoff:.3f}, ') + + def forward(self, w): + batch = w.shape[0] + + # Get transformation matrix. + # Factor controlled by latent code. + transformation_factor = self.affine(w) + # Ensure the range of cosine and sine value (first two dimension). + _norm = transformation_factor[:, :2].norm(dim=1, keepdim=True) + transformation_factor = transformation_factor / _norm + # Rotation. + rotation = torch.eye(3, device=w.device).unsqueeze(0) + rotation = rotation.repeat((batch, 1, 1)) + rotation[:, 0, 0] = transformation_factor[:, 0] + rotation[:, 0, 1] = -transformation_factor[:, 1] + rotation[:, 1, 0] = transformation_factor[:, 1] + rotation[:, 1, 1] = transformation_factor[:, 0] + # Translation. + translation = torch.eye(3, device=w.device).unsqueeze(0) + translation = translation.repeat((batch, 1, 1)) + translation[:, 0, 2] = -transformation_factor[:, 2] + translation[:, 1, 2] = -transformation_factor[:, 3] + # Customized transformation. + transform = rotation @ translation @ self.transform.unsqueeze(0) + + # Transform frequency and shift, which is equivalent to transforming + # the coordinate. For example, given a coordinate, X, we would like to + # first transform it with the rotation matrix, R, and the translation + # matrix, T, as X' = RX + T. Then, we will apply frequency, f, and + # phase, p, with sin(2 * pi * (fX' + p)). Natively, we have + # fX' + p = f(RX + T) + p = (fR)X + (fT + p) + frequency = self.frequency.unsqueeze(0) @ transform[:, :2, :2] # [NC2] + phase = self.frequency.unsqueeze(0) @ transform[:, :2, 2:] # [NC] + phase = phase.squeeze(2) + self.phase.unsqueeze(0) # [NC] + + # Positional encoding. + x = self.grid # [NHW2] + x = x.unsqueeze(3) # [NHW12] + x = x @ frequency.transpose(1, 2).unsqueeze(1).unsqueeze(2) # [NHW1C] + x = x.squeeze(3) # [NHWC] + x = x + phase.unsqueeze(1).unsqueeze(2) # [NHWC] + x = torch.sin(2 * np.pi * x) # [NHWC] + + # Dampen out-of-band frequency that may be introduced by the customized + # transform `self.transform`. + frequency_norm = frequency.norm(dim=2) + stopband = self.sampling_rate / 2 + factor = (frequency_norm - self.cutoff) / (stopband - self.cutoff) + amplitude = (1 - factor).clamp(0, 1) # [NC] + x = x * amplitude.unsqueeze(1).unsqueeze(2) # [NHWC] + + # Project positional encoding to features. + weight = self.weight * self.weight_scale + x = x @ weight.t() + + return x.permute(0, 3, 1, 2).contiguous() + + +class SynthesisLayer(nn.Module): + """Implements the synthesis layer. + + Each synthesis layer (including ToRGB layer) consists of a + `ModulateConvLayer` and a `FilteringActLayer`. Besides, this layer will + trace the magnitude (norm) of the input feature map, and update the + statistic with `magnitude_moving_decay`. + """ + + def __init__(self, + # Settings for modulated convolution. + in_channels, + out_channels, + w_dim, + kernel_size, + demodulate, + eps, + conv_clamp, + # Settings for filtering activation. + in_size, + out_size, + in_sampling_rate, + out_sampling_rate, + in_cutoff, + out_cutoff, + in_half_width, + out_half_width, + filter_size, + use_radial_filter, + act_upsampling): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + w_dim: Dimension of W space for style modulation. + kernel_size: Size of the convolutional kernels. + demodulate: Whether to perform style demodulation. + eps: A small value to avoid divide overflow. + conv_clamp: A threshold to clamp the output of convolution layers to + avoid overflow under FP16 training. + in_size: Size of the input feature map, i.e., height and width. + out_size: Size of the output feature map, i.e., height and width. + in_sampling_rate: Sampling rate of the input feature map. Different + from `in_size` that includes extended border, this field + controls the actual maximum frequency that can be represented + by the feature map. + out_sampling_rate: Sampling rate of the output feature map. + in_cutoff: Cutoff frequency of the input feature map. + out_cutoff: Cutoff frequency of the output feature map. + in_half_width: Half-width of the transition band of the input + feature map. + out_half_width: Half-width of the transition band of the output + feature map. + filter_size: Size of the filter used in this layer. + use_radial_filter: Whether to use radial filter. + act_upsampling: Upsampling factor used before the activation. + `1` means do not wrap upsampling and downsampling around the + activation. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.kernel_size = kernel_size + self.demodulate = demodulate + self.eps = eps + self.conv_clamp = conv_clamp + + self.in_size = in_size + self.out_size = out_size + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + self.filter_size = filter_size + self.use_radial_filter = use_radial_filter + self.act_upsampling = act_upsampling + + self.conv = ModulateConvLayer(in_channels=in_channels, + out_channels=out_channels, + w_dim=w_dim, + kernel_size=kernel_size, + demodulate=demodulate, + eps=eps) + self.register_buffer('magnitude_ema', torch.ones(())) + self.filter = FilteringActLayer(out_channels=out_channels, + in_size=in_size, + out_size=out_size, + in_sampling_rate=in_sampling_rate, + out_sampling_rate=out_sampling_rate, + in_cutoff=in_cutoff, + out_cutoff=out_cutoff, + in_half_width=in_half_width, + out_half_width=out_half_width, + filter_size=filter_size, + use_radial_filter=use_radial_filter, + conv_padding=self.conv.padding, + act_upsampling=act_upsampling) + + def extra_repr(self): + return f'conv_clamp={self.conv_clamp}' + + def forward(self, + x, + w, + magnitude_moving_decay=0.999, + update_ema=False, + impl='cuda'): + if self.training and update_ema and magnitude_moving_decay is not None: + magnitude = x.detach().to(torch.float32).square().mean() + self.magnitude_ema.copy_( + magnitude.lerp(self.magnitude_ema, magnitude_moving_decay)) + + input_gain = self.magnitude_ema.rsqrt() + x, style = self.conv(x, w, gain=input_gain, impl=impl) + if self.act_upsampling > 1: + x = self.filter(x, np.sqrt(2), 0.2, self.conv_clamp, impl=impl) + else: + x = self.filter(x, 1, 1, self.conv_clamp, impl=impl) + + return x, style + + +class ModulateConvLayer(nn.Module): + """Implements the convolutional layer with style modulation. + + Different from the one introduced in StyleGAN2, this layer has following + changes: + + (1) fusing `conv` and `style modulation` into one op by default + (2) NOT adding a noise onto the output feature map. + (3) NOT activating the feature map, which is moved to `FilteringActLayer`. + """ + + def __init__(self, + in_channels, + out_channels, + w_dim, + kernel_size, + demodulate, + eps): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + w_dim: Dimension of W space for style modulation. + kernel_size: Size of the convolutional kernels. + demodulate: Whether to perform style demodulation. + eps: A small value to avoid divide overflow. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.w_dim = w_dim + self.kernel_size = kernel_size + self.demodulate = demodulate + self.eps = eps + + self.space_of_latent = 'W' + + # Set up weight. + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + self.weight = nn.Parameter(torch.randn(*weight_shape)) + self.wscale = 1.0 / np.sqrt(kernel_size * kernel_size * in_channels) + self.padding = kernel_size - 1 + + # Set up style. + self.style = DenseLayer(in_channels=w_dim, + out_channels=in_channels, + init_weight_std=1.0, + add_bias=True, + init_bias=1.0, + lr_mul=1.0, + activation_type='linear') + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'demodulate={self.demodulate}') + + def forward_style(self, w, impl='cuda'): + """Gets style code from the given input. + + More specifically, if the input is from W-Space, it will be projected by + an affine transformation. If it is from the Style Space (Y-Space), no + operation is required. + + NOTE: For codes from Y-Space, we use slicing to make sure the dimension + is correct, in case that the code is padded before fed into this layer. + """ + space_of_latent = self.space_of_latent.upper() + if space_of_latent == 'W': + if w.ndim != 2 or w.shape[1] != self.w_dim: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, w_dim], where ' + f'`w_dim` equals to {self.w_dim}!\n' + f'But `{w.shape}` is received!') + style = self.style(w, impl=impl) + elif space_of_latent == 'Y': + if w.ndim != 2 or w.shape[1] < self.in_channels: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, y_dim], where ' + f'`y_dim` equals to {self.in_channels}!\n' + f'But `{w.shape}` is received!') + style = w[:, :self.in_channels] + else: + raise NotImplementedError(f'Not implemented `space_of_latent`: ' + f'`{space_of_latent}`!') + return style + + def forward(self, x, w, gain=None, impl='cuda'): + dtype = x.dtype + N, C, H, W = x.shape + + # Affine on `w`. + style = self.forward_style(w, impl=impl) + if not self.demodulate: + _style = style * self.wscale # Equivalent to scaling weight. + else: + _style = style + + weight = self.weight + out_ch, in_ch, kh, kw = weight.shape + assert in_ch == C + + # Pre-normalize inputs. + if self.demodulate: + weight = (weight * + weight.square().mean(dim=(1, 2, 3), keepdim=True).rsqrt()) + _style = _style * _style.square().mean().rsqrt() + + weight = weight.unsqueeze(0) + weight = weight * _style.reshape(N, 1, in_ch, 1, 1) # modulation + if self.demodulate: + decoef = (weight.square().sum(dim=(2, 3, 4)) + self.eps).rsqrt() + weight = weight * decoef.reshape(N, out_ch, 1, 1, 1) # demodulation + + if gain is not None: + gain = gain.expand(N, in_ch) + weight = weight * gain.reshape(N, 1, in_ch, 1, 1) + + # Fuse `conv` and `style modulation` as one op, using group convolution. + x = x.reshape(1, N * in_ch, H, W) + w = weight.reshape(N * out_ch, in_ch, kh, kw).to(dtype) + x = conv2d_gradfix.conv2d( + x, w, padding=self.padding, groups=N, impl=impl) + x = x.reshape(N, out_ch, x.shape[2], x.shape[3]) + + assert x.dtype == dtype + assert style.dtype == torch.float32 + return x, style + + +class FilteringActLayer(nn.Module): + """Implements the activation, wrapped with upsampling and downsampling. + + Basically, this layer executes the following operations in order: + + (1) Apply bias. + (2) Upsample the feature map to increase sampling rate. + (3) Apply non-linearity as activation. + (4) Downsample the feature map to target size. + + This layer is mostly borrowed from the official implementation: + + https://github.com/NVlabs/stylegan3/blob/main/training/networks_stylegan3.py + """ + + def __init__(self, + out_channels, + in_size, + out_size, + in_sampling_rate, + out_sampling_rate, + in_cutoff, + out_cutoff, + in_half_width, + out_half_width, + filter_size, + use_radial_filter, + conv_padding, + act_upsampling): + """Initializes with layer settings. + + Args: + out_channels: Number of output channels, which is used for `bias`. + in_size: Size of the input feature map, i.e., height and width. + out_size: Size of the output feature map, i.e., height and width. + in_sampling_rate: Sampling rate of the input feature map. Different + from `in_size` that includes extended border, this field + controls the actual maximum frequency that can be represented + by the feature map. + out_sampling_rate: Sampling rate of the output feature map. + in_cutoff: Cutoff frequency of the input feature map. + out_cutoff: Cutoff frequency of the output feature map. + in_half_width: Half-width of the transition band of the input + feature map. + out_half_width: Half-width of the transition band of the output + feature map. + filter_size: Size of the filter used in this layer. + use_radial_filter: Whether to use radial filter. + conv_padding: The padding used in the previous convolutional layer. + act_upsampling: Upsampling factor used before the activation. + `1` means do not wrap upsampling and downsampling around the + activation. + """ + super().__init__() + + self.out_channels = out_channels + self.in_size = in_size + self.out_size = out_size + self.in_sampling_rate = in_sampling_rate + self.out_sampling_rate = out_sampling_rate + self.in_cutoff = in_cutoff + self.out_cutoff = out_cutoff + self.in_half_width = in_half_width + self.out_half_width = out_half_width + self.filter_size = filter_size + self.use_radial_filter = use_radial_filter + self.conv_padding = conv_padding + self.act_upsampling = act_upsampling + + # Define bias. + self.bias = nn.Parameter(torch.zeros(out_channels)) + + # This sampling rate describes the upsampled feature map before + # activation. + temp_sampling_rate = max(in_sampling_rate, out_sampling_rate) + temp_sampling_rate = temp_sampling_rate * act_upsampling + + # Design upsampling filter. + up_factor = int(np.rint(temp_sampling_rate / in_sampling_rate)) + assert in_sampling_rate * up_factor == temp_sampling_rate + if up_factor > 1: + self.up_factor = up_factor + self.up_taps = filter_size * up_factor + else: + self.up_factor = 1 + self.up_taps = 1 # No filtering. + self.register_buffer( + 'up_filter', + self.design_lowpass_filter(numtaps=self.up_taps, + cutoff=in_cutoff, + width=in_half_width * 2, + fs=temp_sampling_rate, + radial=False)) + + # Design downsampling filter. + down_factor = int(np.rint(temp_sampling_rate / out_sampling_rate)) + assert out_sampling_rate * down_factor == temp_sampling_rate + if down_factor > 1: + self.down_factor = down_factor + self.down_taps = filter_size * down_factor + else: + self.down_factor = 1 + self.down_taps = 1 # No filtering. + self.register_buffer( + 'down_filter', + self.design_lowpass_filter(numtaps=self.down_taps, + cutoff=out_cutoff, + width=out_half_width * 2, + fs=temp_sampling_rate, + radial=use_radial_filter)) + + # Compute padding. + # Desired output size before downsampling. + pad_total = (out_size - 1) * self.down_factor + 1 + # Input size after upsampling. + pad_total = pad_total - (in_size + conv_padding) * self.up_factor + # Size reduction caused by the filters. + pad_total = pad_total + self.up_taps + self.down_taps - 2 + # Shift sample locations according to the symmetric interpretation. + pad_lo = (pad_total + self.up_factor) // 2 + pad_hi = pad_total - pad_lo + self.padding = list(map(int, (pad_lo, pad_hi, pad_lo, pad_hi))) + + def extra_repr(self): + return (f'in_size={self.in_size}, ' + f'out_size={self.out_size}, ' + f'in_srate={self.in_sampling_rate}, ' + f'out_srate={self.out_sampling_rate}, ' + f'in_cutoff={self.in_cutoff:.3f}, ' + f'out_cutoff={self.out_cutoff:.3f}, ' + f'in_half_width={self.in_half_width:.3f}, ' + f'out_half_width={self.out_half_width:.3f}, ' + f'up_factor={self.up_factor}, ' + f'up_taps={self.up_taps}, ' + f'down_factor={self.down_factor}, ' + f'down_taps={self.down_taps}, ' + f'filter_size={self.filter_size}, ' + f'radial_filter={self.use_radial_filter}, ' + f'conv_padding={self.conv_padding}, ' + f'act_upsampling={self.act_upsampling}') + + @staticmethod + def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): + """Designs a low-pass filter. + + Args: + numtaps: Length of the filter (number of coefficients, i.e., the + filter order + 1). + cutoff: Cutoff frequency of the output filter. + width: Width of the transition region. + fs: Sampling frequency. + radial: Whether to use radially symmetric jinc-based filter. + (default: False) + """ + if numtaps == 1: + return None + + assert numtaps > 1 + + if not radial: # Separable Kaiser low-pass filter. + f = scipy.signal.firwin(numtaps=numtaps, + cutoff=cutoff, + width=width, + fs=fs) + else: # Radially symmetric jinc-based filter. + x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs + r = np.hypot(*np.meshgrid(x, x)) + f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r) + beta = scipy.signal.kaiser_beta( + scipy.signal.kaiser_atten(numtaps, width / (fs / 2))) + w = np.kaiser(numtaps, beta) + f = f * np.outer(w, w) + f = f / np.sum(f) + return torch.as_tensor(f, dtype=torch.float32) + + def forward(self, x, gain, slope, clamp, impl='cuda'): + dtype = x.dtype + + x = filtered_lrelu.filtered_lrelu(x=x, + fu=self.up_filter, + fd=self.down_filter, + b=self.bias.to(dtype), + up=self.up_factor, + down=self.down_factor, + padding=self.padding, + gain=gain, + slope=slope, + clamp=clamp, + impl=impl) + + assert x.dtype == dtype + return x + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + init_weight_std, + add_bias, + init_bias, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + init_weight_std: The initial standard deviation of weight. + add_bias: Whether to add bias onto the fully-connected result. + init_bias: The initial bias value before training. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.init_weight_std = init_weight_std + self.add_bias = add_bias + self.init_bias = init_bias + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + self.weight = nn.Parameter( + torch.randn(*weight_shape) * init_weight_std / lr_mul) + self.wscale = lr_mul / np.sqrt(in_channels) + + if add_bias: + init_bias = np.float32(np.float32(init_bias) / lr_mul) + if isinstance(init_bias, np.float32): + self.bias = nn.Parameter(torch.full([out_channels], init_bias)) + else: + assert isinstance(init_bias, np.ndarray) + self.bias = nn.Parameter(torch.from_numpy(init_bias)) + self.bscale = lr_mul + else: + self.bias = None + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'init_weight_std={self.init_weight_std}, ' + f'bias={self.add_bias}, ' + f'init_bias={self.init_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x, impl='cuda'): + dtype = x.dtype + + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight.to(dtype) * self.wscale + bias = None + if self.bias is not None: + bias = self.bias.to(dtype) + if self.bscale != 1.0: + bias = bias * self.bscale + + # Fast pass for linear activation. + if self.activation_type == 'linear' and bias is not None: + x = torch.addmm(bias.unsqueeze(0), x, weight.t()) + else: + x = x.matmul(weight.t()) + x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl) + + assert x.dtype == dtype + return x + +# pylint: enable=missing-function-docstring diff --git a/models/stylegan_discriminator.py b/models/stylegan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..c76ff3430af76839a259035fe37783ac3aa807de --- /dev/null +++ b/models/stylegan_discriminator.py @@ -0,0 +1,624 @@ +# python3.7 +"""Contains the implementation of discriminator described in StyleGAN. + +Paper: https://arxiv.org/pdf/1812.04948.pdf + +Official TensorFlow implementation: https://github.com/NVlabs/stylegan +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast + +__all__ = ['StyleGANDiscriminator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Fused-scale options allowed. +_FUSED_SCALE_ALLOWED = [True, False, 'auto'] + +# pylint: disable=missing-function-docstring + +class StyleGANDiscriminator(nn.Module): + """Defines the discriminator network in StyleGAN. + + NOTE: The discriminator takes images with `RGB` channel order and pixel + range [-1, 1] as inputs. + + Settings for the backbone: + + (1) resolution: The resolution of the input image. (default: -1) + (2) init_res: Smallest resolution of the convolutional backbone. + (default: 4) + (3) image_channels: Number of channels of the input image. (default: 3) + (4) fused_scale: The strategy of fusing `conv2d` and `downsample` as one + operator. `True` means blocks from all resolutions will fuse. `False` + means blocks from all resolutions will not fuse. `auto` means blocks + from resolutions higher than (or equal to) `fused_scale_res` will fuse. + (default: `auto`) + (5) fused_scale_res: Minimum resolution to fuse `conv2d` and `downsample` + as one operator. This field only takes effect if `fused_scale` is set + as `auto`. (default: 128) + (6) use_wscale: Whether to use weight scaling. (default: True) + (7) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0)) + (8) lr_mul: Learning rate multiplier for backbone. (default: 1.0) + (9) mbstd_groups: Group size for the minibatch standard deviation layer. + `0` means disable. (default: 4) + (10) mbstd_channels: Number of new channels (appended to the original + feature map) after the minibatch standard deviation layer. (default: 1) + (11) fmaps_base: Factor to control number of feature maps for each layer. + (default: 16 << 10) + (12) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (13) filter_kernel: Kernel used for filtering (e.g., downsampling). + (default: (1, 2, 1)) + (14) eps: A small value to avoid divide overflow. (default: 1e-8) + + Settings for conditional model: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + + Runtime settings: + + (1) enable_amp: Whether to enable automatic mixed precision training. + (default: False) + """ + + def __init__(self, + # Settings for backbone. + resolution=-1, + init_res=4, + image_channels=3, + fused_scale='auto', + fused_scale_res=128, + use_wscale=True, + wscale_gain=np.sqrt(2.0), + lr_mul=1.0, + mbstd_groups=4, + mbstd_channels=1, + fmaps_base=16 << 10, + fmaps_max=512, + filter_kernel=(1, 2, 1), + eps=1e-8, + # Settings for conditional model. + label_dim=0): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported, or `fused_scale` + is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + if fused_scale not in _FUSED_SCALE_ALLOWED: + raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n' + f'Options allowed: {_FUSED_SCALE_ALLOWED}.') + + self.init_res = init_res + self.init_res_log2 = int(np.log2(init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(resolution)) + self.image_channels = image_channels + self.fused_scale = fused_scale + self.fused_scale_res = fused_scale_res + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.mbstd_groups = mbstd_groups + self.mbstd_channels = mbstd_channels + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.eps = eps + self.label_dim = label_dim + + # Level-of-details (used for progressive training). + self.register_buffer('lod', torch.zeros(())) + self.pth_to_tf_var_mapping = {'lod': 'lod'} + + for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): + res = 2 ** res_log2 + in_channels = self.get_nf(res) + out_channels = self.get_nf(res // 2) + block_idx = self.final_res_log2 - res_log2 + + # Input convolution layer for each resolution. + self.add_module( + f'input{block_idx}', + ConvLayer(in_channels=image_channels, + out_channels=in_channels, + kernel_size=1, + add_bias=True, + scale_factor=1, + fused_scale=False, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = ( + f'FromRGB_lod{block_idx}/weight') + self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = ( + f'FromRGB_lod{block_idx}/bias') + + # Convolution block for each resolution (except the last one). + if res != self.init_res: + # First layer (kernel 3x3) without downsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module( + layer_name, + ConvLayer(in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + scale_factor=1, + fused_scale=False, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv0/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv0/bias') + + # Second layer (kernel 3x3) with downsampling + layer_name = f'layer{2 * block_idx + 1}' + self.add_module( + layer_name, + ConvLayer(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + add_bias=True, + scale_factor=2, + fused_scale=(res >= fused_scale_res + if fused_scale == 'auto' + else fused_scale), + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv1_down/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv1_down/bias') + + # Convolution block for last resolution. + else: + self.mbstd = MiniBatchSTDLayer(groups=mbstd_groups, + new_channels=mbstd_channels, + eps=eps) + + # First layer (kernel 3x3) without downsampling. + layer_name = f'layer{2 * block_idx}' + self.add_module( + layer_name, + ConvLayer(in_channels=in_channels + mbstd_channels, + out_channels=in_channels, + kernel_size=3, + add_bias=True, + scale_factor=1, + fused_scale=False, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Conv/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Conv/bias') + + # Second layer, as a fully-connected layer. + layer_name = f'layer{2 * block_idx + 1}' + self.add_module( + f'layer{2 * block_idx + 1}', + DenseLayer(in_channels=in_channels * res * res, + out_channels=in_channels, + add_bias=True, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/Dense0/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/Dense0/bias') + + # Final dense layer to output score. + self.output = DenseLayer(in_channels=in_channels, + out_channels=max(label_dim, 1), + add_bias=True, + use_wscale=use_wscale, + wscale_gain=1.0, + lr_mul=lr_mul, + activation_type='linear') + self.pth_to_tf_var_mapping['output.weight'] = ( + f'{res}x{res}/Dense1/weight') + self.pth_to_tf_var_mapping['output.bias'] = ( + f'{res}x{res}/Dense1/bias') + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def forward(self, image, label=None, lod=None, enable_amp=False): + expected_shape = (self.image_channels, self.resolution, self.resolution) + if image.ndim != 4 or image.shape[1:] != expected_shape: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, channel, height, width], where ' + f'`channel` equals to {self.image_channels}, ' + f'`height`, `width` equal to {self.resolution}!\n' + f'But `{image.shape}` is received!') + + lod = self.lod.item() if lod is None else lod + if lod + self.init_res_log2 > self.final_res_log2: + raise ValueError(f'Maximum level-of-details (lod) is ' + f'{self.final_res_log2 - self.init_res_log2}, ' + f'but `{lod}` is received!') + + if self.label_dim: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + batch = image.shape[0] + if (label.ndim != 2 or label.shape != (batch, self.label_dim)): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to {batch}, and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + + with autocast(enabled=enable_amp): + for res_log2 in range( + self.final_res_log2, self.init_res_log2 - 1, -1): + block_idx = current_lod = self.final_res_log2 - res_log2 + if current_lod <= lod < current_lod + 1: + x = getattr(self, f'input{block_idx}')(image) + elif current_lod - 1 < lod < current_lod: + alpha = lod - np.floor(lod) + y = getattr(self, f'input{block_idx}')(image) + x = y * alpha + x * (1 - alpha) + if lod < current_lod + 1: + if res_log2 == self.init_res_log2: + x = self.mbstd(x) + x = getattr(self, f'layer{2 * block_idx}')(x) + x = getattr(self, f'layer{2 * block_idx + 1}')(x) + if lod > current_lod: + image = F.avg_pool2d( + image, kernel_size=2, stride=2, padding=0) + x = self.output(x) + + if self.label_dim: + x = (x * label).sum(dim=1, keepdim=True) + + results = { + 'score': x, + 'label': label + } + return results + + +class MiniBatchSTDLayer(nn.Module): + """Implements the minibatch standard deviation layer.""" + + def __init__(self, groups, new_channels, eps): + super().__init__() + self.groups = groups + self.new_channels = new_channels + self.eps = eps + + def extra_repr(self): + return (f'groups={self.groups}, ' + f'new_channels={self.new_channels}, ' + f'epsilon={self.eps}') + + def forward(self, x): + if self.groups <= 1 or self.new_channels < 1: + return x + + N, C, H, W = x.shape + G = min(self.groups, N) # Number of groups. + nC = self.new_channels # Number of channel groups. + c = C // nC # Channels per channel group. + + y = x.reshape(G, -1, nC, c, H, W) # [GnFcHW] + y = y - y.mean(dim=0) # [GnFcHW] + y = y.square().mean(dim=0) # [nFcHW] + y = (y + self.eps).sqrt() # [nFcHW] + y = y.mean(dim=(2, 3, 4)) # [nF] + y = y.reshape(-1, nC, 1, 1) # [nF11] + y = y.repeat(G, 1, H, W) # [NFHW] + x = torch.cat((x, y), dim=1) # [N(C+F)HW] + + return x + + +class Blur(torch.autograd.Function): + """Defines blur operation with customized gradient computation.""" + + @staticmethod + def forward(ctx, x, kernel): + assert kernel.shape[2] == 3 and kernel.shape[3] == 3 + ctx.save_for_backward(kernel) + y = F.conv2d(input=x, + weight=kernel, + bias=None, + stride=1, + padding=1, + groups=x.shape[1]) + return y + + @staticmethod + def backward(ctx, dy): + kernel, = ctx.saved_tensors + dx = BlurBackPropagation.apply(dy, kernel) + return dx, None, None + + +class BlurBackPropagation(torch.autograd.Function): + """Defines the back propagation of blur operation. + + NOTE: This is used to speed up the backward of gradient penalty. + """ + + @staticmethod + def forward(ctx, dy, kernel): + ctx.save_for_backward(kernel) + dx = F.conv2d(input=dy, + weight=kernel.flip((2, 3)), + bias=None, + stride=1, + padding=1, + groups=dy.shape[1]) + return dx + + @staticmethod + def backward(ctx, ddx): + kernel, = ctx.saved_tensors + ddy = F.conv2d(input=ddx, + weight=kernel, + bias=None, + stride=1, + padding=1, + groups=ddx.shape[1]) + return ddy, None, None + + +class ConvLayer(nn.Module): + """Implements the convolutional layer. + + If downsampling is needed (i.e., `scale_factor = 2`), the feature map will + be filtered with `filter_kernel` first. If `fused_scale` is set as `True`, + `conv2d` and `downsample` will be fused as one operator, using stride + convolution. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + add_bias, + scale_factor, + fused_scale, + filter_kernel, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for downsampling. `1` means skip + downsampling. + fused_scale: Whether to fuse `conv2d` and `downsample` as one + operator, using stride convolution. + filter_kernel: Kernel used for filtering. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.fused_scale = fused_scale + self.filter_kernel = filter_kernel + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + + if scale_factor > 1: + assert filter_kernel is not None + kernel = np.array(filter_kernel, dtype=np.float32).reshape(1, -1) + kernel = kernel.T.dot(kernel) + kernel = kernel / np.sum(kernel) + kernel = kernel[np.newaxis, np.newaxis] + self.register_buffer('filter', torch.from_numpy(kernel)) + + if scale_factor > 1 and fused_scale: # use stride convolution. + self.stride = scale_factor + else: + self.stride = 1 + self.padding = kernel_size // 2 + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'downsample={self.scale_factor}, ' + f'fused_scale={self.fused_scale}, ' + f'downsample_filter={self.filter_kernel}, ' + f'act={self.activation_type}') + + def forward(self, x): + if self.scale_factor > 1: + # Disable `autocast` for customized autograd function. + # Please check reference: + # https://pytorch.org/docs/stable/notes/amp_examples.html#autocast-and-custom-autograd-functions + with autocast(enabled=False): + f = self.filter.repeat(self.in_channels, 1, 1, 1) + x = Blur.apply(x.float(), f) # Always use FP32. + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias + if self.bscale != 1.0: + bias = bias * self.bscale + + if self.scale_factor > 1 and self.fused_scale: + weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) + weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + + weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25 + x = F.conv2d(x, + weight=weight, + bias=bias, + stride=self.stride, + padding=self.padding) + if self.scale_factor > 1 and not self.fused_scale: + down = self.scale_factor + x = F.avg_pool2d(x, kernel_size=down, stride=down, padding=0) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + return x + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x): + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias + if self.bscale != 1.0: + bias = bias * self.bscale + + x = F.linear(x, weight=weight, bias=bias) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + return x + +# pylint: enable=missing-function-docstring diff --git a/models/stylegan_generator.py b/models/stylegan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c0034b34a5b72bfe6b305a9f6ff8d772b391c4f5 --- /dev/null +++ b/models/stylegan_generator.py @@ -0,0 +1,999 @@ +# python3.7 +"""Contains the implementation of generator described in StyleGAN. + +Paper: https://arxiv.org/pdf/1812.04948.pdf + +Official TensorFlow implementation: https://github.com/NVlabs/stylegan +""" + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast + +from .utils.ops import all_gather + +__all__ = ['StyleGANGenerator'] + +# Resolutions allowed. +_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] + +# Fused-scale options allowed. +_FUSED_SCALE_ALLOWED = [True, False, 'auto'] + +# pylint: disable=missing-function-docstring + +class StyleGANGenerator(nn.Module): + """Defines the generator network in StyleGAN. + + NOTE: The synthesized images are with `RGB` channel order and pixel range + [-1, 1]. + + Settings for the mapping network: + + (1) z_dim: Dimension of the input latent space, Z. (default: 512) + (2) w_dim: Dimension of the output latent space, W. (default: 512) + (3) repeat_w: Repeat w-code for different layers. (default: True) + (4) normalize_z: Whether to normalize the z-code. (default: True) + (5) mapping_layers: Number of layers of the mapping network. (default: 8) + (6) mapping_fmaps: Number of hidden channels of the mapping network. + (default: 512) + (7) mapping_use_wscale: Whether to use weight scaling for the mapping + network. (default: True) + (8) mapping_wscale_gain: The factor to control weight scaling for the + mapping network (default: sqrt(2.0)) + (9) mapping_lr_mul: Learning rate multiplier for the mapping network. + (default: 0.01) + + Settings for conditional generation: + + (1) label_dim: Dimension of the additional label for conditional generation. + In one-hot conditioning case, it is equal to the number of classes. If + set to 0, conditioning training will be disabled. (default: 0) + (2) embedding_dim: Dimension of the embedding space, if needed. + (default: 512) + + Settings for the synthesis network: + + (1) resolution: The resolution of the output image. (default: -1) + (2) init_res: The initial resolution to start with convolution. (default: 4) + (3) image_channels: Number of channels of the output image. (default: 3) + (4) final_tanh: Whether to use `tanh` to control the final pixel range. + (default: False) + (5) fused_scale: The strategy of fusing `upsample` and `conv2d` as one + operator. `True` means blocks from all resolutions will fuse. `False` + means blocks from all resolutions will not fuse. `auto` means blocks + from resolutions higher than (or equal to) `fused_scale_res` will fuse. + (default: `auto`) + (6) fused_scale_res: Minimum resolution to fuse `conv2d` and `downsample` + as one operator. This field only takes effect if `fused_scale` is set + as `auto`. (default: 128) + (7) use_wscale: Whether to use weight scaling. (default: True) + (8) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0)) + (9) lr_mul: Learning rate multiplier for the synthesis network. + (default: 1.0) + (10) noise_type: Type of noise added to the convolutional results at each + layer. (default: `spatial`) + (11) fmaps_base: Factor to control number of feature maps for each layer. + (default: 16 << 10) + (12) fmaps_max: Maximum number of feature maps in each layer. (default: 512) + (13) filter_kernel: Kernel used for filtering (e.g., downsampling). + (default: (1, 2, 1)) + (14) eps: A small value to avoid divide overflow. (default: 1e-8) + + Runtime settings: + + (1) w_moving_decay: Decay factor for updating `w_avg`, which is used for + training only. Set `None` to disable. (default: None) + (2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set + as `True`, the stats will be more accurate, yet the speed maybe a little + bit slower. (default: False) + (3) style_mixing_prob: Probability to perform style mixing as a training + regularization. Set `None` to disable. (default: None) + (4) trunc_psi: Truncation psi, set `None` to disable. (default: None) + (5) trunc_layers: Number of layers to perform truncation. (default: None) + (6) noise_mode: Mode of the layer-wise noise. Support `none`, `random`, + `const`. (default: `const`) + (7) enable_amp: Whether to enable automatic mixed precision training. + (default: False) + """ + + def __init__(self, + # Settings for mapping network. + z_dim=512, + w_dim=512, + repeat_w=True, + normalize_z=True, + mapping_layers=8, + mapping_fmaps=512, + mapping_use_wscale=True, + mapping_wscale_gain=np.sqrt(2.0), + mapping_lr_mul=0.01, + # Settings for conditional generation. + label_dim=0, + embedding_dim=512, + # Settings for synthesis network. + resolution=-1, + init_res=4, + image_channels=3, + final_tanh=False, + fused_scale='auto', + fused_scale_res=128, + use_wscale=True, + wscale_gain=np.sqrt(2.0), + lr_mul=1.0, + noise_type='spatial', + fmaps_base=16 << 10, + fmaps_max=512, + filter_kernel=(1, 2, 1), + eps=1e-8): + """Initializes with basic settings. + + Raises: + ValueError: If the `resolution` is not supported, or `fused_scale` + is not supported. + """ + super().__init__() + + if resolution not in _RESOLUTIONS_ALLOWED: + raise ValueError(f'Invalid resolution: `{resolution}`!\n' + f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') + if fused_scale not in _FUSED_SCALE_ALLOWED: + raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n' + f'Options allowed: {_FUSED_SCALE_ALLOWED}.') + + self.z_dim = z_dim + self.w_dim = w_dim + self.repeat_w = repeat_w + self.normalize_z = normalize_z + self.mapping_layers = mapping_layers + self.mapping_fmaps = mapping_fmaps + self.mapping_use_wscale = mapping_use_wscale + self.mapping_wscale_gain = mapping_wscale_gain + self.mapping_lr_mul = mapping_lr_mul + + self.label_dim = label_dim + self.embedding_dim = embedding_dim + + self.resolution = resolution + self.init_res = init_res + self.image_channels = image_channels + self.final_tanh = final_tanh + self.fused_scale = fused_scale + self.fused_scale_res = fused_scale_res + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.filter_kernel = filter_kernel + self.eps = eps + + # Dimension of latent space, which is convenient for sampling. + self.latent_dim = (z_dim,) + + # Number of synthesis (convolutional) layers. + self.num_layers = int(np.log2(resolution // init_res * 2)) * 2 + + self.mapping = MappingNetwork(input_dim=z_dim, + output_dim=w_dim, + num_outputs=self.num_layers, + repeat_output=repeat_w, + normalize_input=normalize_z, + num_layers=mapping_layers, + hidden_dim=mapping_fmaps, + use_wscale=mapping_use_wscale, + wscale_gain=mapping_wscale_gain, + lr_mul=mapping_lr_mul, + label_dim=label_dim, + embedding_dim=embedding_dim, + eps=eps) + + # This is used for truncation trick. + if self.repeat_w: + self.register_buffer('w_avg', torch.zeros(w_dim)) + else: + self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) + + self.synthesis = SynthesisNetwork(resolution=resolution, + init_res=init_res, + w_dim=w_dim, + image_channels=image_channels, + final_tanh=final_tanh, + fused_scale=fused_scale, + fused_scale_res=fused_scale_res, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + fmaps_base=fmaps_base, + fmaps_max=fmaps_max, + filter_kernel=filter_kernel, + eps=eps) + + self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'} + for key, val in self.mapping.pth_to_tf_var_mapping.items(): + self.pth_to_tf_var_mapping[f'mapping.{key}'] = val + for key, val in self.synthesis.pth_to_tf_var_mapping.items(): + self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + See `SynthesisNetwork` for more details. + """ + self.synthesis.set_space_of_latent(space_of_latent) + + def forward(self, + z, + label=None, + lod=None, + w_moving_decay=None, + sync_w_avg=False, + style_mixing_prob=None, + trunc_psi=None, + trunc_layers=None, + noise_mode='const', + enable_amp=False): + mapping_results = self.mapping(z, label) + + w = mapping_results['w'] + if self.training and w_moving_decay is not None: + if sync_w_avg: + batch_w_avg = all_gather(w.detach()).mean(dim=0) + else: + batch_w_avg = w.detach().mean(dim=0) + self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) + + wp = mapping_results.pop('wp') + if self.training and style_mixing_prob is not None: + if np.random.uniform() < style_mixing_prob: + new_z = torch.randn_like(z) + new_wp = self.mapping(new_z, label)['wp'] + lod = self.synthesis.lod.item() if lod is None else lod + current_layers = self.num_layers - int(lod) * 2 + mixing_cutoff = np.random.randint(1, current_layers) + wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] + + if not self.training: + trunc_psi = 1.0 if trunc_psi is None else trunc_psi + trunc_layers = 0 if trunc_layers is None else trunc_layers + if trunc_psi < 1.0 and trunc_layers > 0: + w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] + wp[:, :trunc_layers] = w_avg.lerp( + wp[:, :trunc_layers], trunc_psi) + + with autocast(enabled=enable_amp): + synthesis_results = self.synthesis(wp, + lod=lod, + noise_mode=noise_mode) + + return {**mapping_results, **synthesis_results} + + +class MappingNetwork(nn.Module): + """Implements the latent space mapping module. + + Basically, this module executes several dense layers in sequence, and the + label embedding if needed. + """ + + def __init__(self, + input_dim, + output_dim, + num_outputs, + repeat_output, + normalize_input, + num_layers, + hidden_dim, + use_wscale, + wscale_gain, + lr_mul, + label_dim, + embedding_dim, + eps): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_outputs = num_outputs + self.repeat_output = repeat_output + self.normalize_input = normalize_input + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.label_dim = label_dim + self.embedding_dim = embedding_dim + self.eps = eps + + self.pth_to_tf_var_mapping = {} + + if normalize_input: + self.norm = PixelNormLayer(dim=1, eps=eps) + + if self.label_dim > 0: + input_dim = input_dim + embedding_dim + self.embedding = nn.Parameter( + torch.randn(label_dim, embedding_dim)) + self.pth_to_tf_var_mapping['embedding'] = 'LabelConcat/weight' + + if num_outputs is not None and not repeat_output: + output_dim = output_dim * num_outputs + for i in range(num_layers): + in_channels = (input_dim if i == 0 else hidden_dim) + out_channels = (output_dim if i == (num_layers - 1) else hidden_dim) + self.add_module(f'dense{i}', + DenseLayer(in_channels=in_channels, + out_channels=out_channels, + add_bias=True, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + activation_type='lrelu')) + self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight' + self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias' + + def forward(self, z, label=None): + if z.ndim != 2 or z.shape[1] != self.input_dim: + raise ValueError(f'Input latent code should be with shape ' + f'[batch_size, input_dim], where ' + f'`input_dim` equals to {self.input_dim}!\n' + f'But `{z.shape}` is received!') + + if self.label_dim > 0: + if label is None: + raise ValueError(f'Model requires an additional label ' + f'(with dimension {self.label_dim}) as input, ' + f'but no label is received!') + if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim): + raise ValueError(f'Input label should be with shape ' + f'[batch_size, label_dim], where ' + f'`batch_size` equals to that of ' + f'latent codes ({z.shape[0]}) and ' + f'`label_dim` equals to {self.label_dim}!\n' + f'But `{label.shape}` is received!') + label = label.to(dtype=torch.float32) + embedding = torch.matmul(label, self.embedding) + z = torch.cat((z, embedding), dim=1) + + if self.normalize_input: + w = self.norm(z) + else: + w = z + + for i in range(self.num_layers): + w = getattr(self, f'dense{i}')(w) + + wp = None + if self.num_outputs is not None: + if self.repeat_output: + wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1)) + else: + wp = w.reshape(-1, self.num_outputs, self.output_dim) + + results = { + 'z': z, + 'label': label, + 'w': w, + 'wp': wp, + } + if self.label_dim > 0: + results['embedding'] = embedding + return results + + +class SynthesisNetwork(nn.Module): + """Implements the image synthesis module. + + Basically, this module executes several convolutional layers in sequence. + """ + + def __init__(self, + resolution, + init_res, + w_dim, + image_channels, + final_tanh, + fused_scale, + fused_scale_res, + use_wscale, + wscale_gain, + lr_mul, + noise_type, + fmaps_base, + fmaps_max, + filter_kernel, + eps): + super().__init__() + + self.init_res = init_res + self.init_res_log2 = int(np.log2(init_res)) + self.resolution = resolution + self.final_res_log2 = int(np.log2(resolution)) + self.w_dim = w_dim + self.image_channels = image_channels + self.final_tanh = final_tanh + self.fused_scale = fused_scale + self.fused_scale_res = fused_scale_res + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.fmaps_base = fmaps_base + self.fmaps_max = fmaps_max + self.eps = eps + + self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 + + # Level-of-details (used for progressive training). + self.register_buffer('lod', torch.zeros(())) + self.pth_to_tf_var_mapping = {'lod': 'lod'} + + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + res = 2 ** res_log2 + in_channels = self.get_nf(res // 2) + out_channels = self.get_nf(res) + block_idx = res_log2 - self.init_res_log2 + + # First layer (kernel 3x3) with upsampling + layer_name = f'layer{2 * block_idx}' + if res == self.init_res: + self.add_module(layer_name, + ModulateConvLayer(in_channels=0, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=None, + add_bias=True, + scale_factor=None, + fused_scale=None, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + use_style=True, + eps=eps)) + tf_layer_name = 'Const' + self.pth_to_tf_var_mapping[f'{layer_name}.const'] = ( + f'{res}x{res}/{tf_layer_name}/const') + else: + self.add_module( + layer_name, + ModulateConvLayer(in_channels=in_channels, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=3, + add_bias=True, + scale_factor=2, + fused_scale=(res >= fused_scale_res + if fused_scale == 'auto' + else fused_scale), + filter_kernel=filter_kernel, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + use_style=True, + eps=eps)) + tf_layer_name = 'Conv0_up' + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/{tf_layer_name}/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/{tf_layer_name}/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( + f'{res}x{res}/{tf_layer_name}/StyleMod/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( + f'{res}x{res}/{tf_layer_name}/StyleMod/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = ( + f'{res}x{res}/{tf_layer_name}/Noise/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = ( + f'noise{2 * block_idx}') + + # Second layer (kernel 3x3) without upsampling. + layer_name = f'layer{2 * block_idx + 1}' + self.add_module(layer_name, + ModulateConvLayer(in_channels=out_channels, + out_channels=out_channels, + resolution=res, + w_dim=w_dim, + kernel_size=3, + add_bias=True, + scale_factor=1, + fused_scale=False, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=wscale_gain, + lr_mul=lr_mul, + noise_type=noise_type, + activation_type='lrelu', + use_style=True, + eps=eps)) + tf_layer_name = 'Conv' if res == self.init_res else 'Conv1' + self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( + f'{res}x{res}/{tf_layer_name}/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( + f'{res}x{res}/{tf_layer_name}/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = ( + f'{res}x{res}/{tf_layer_name}/StyleMod/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = ( + f'{res}x{res}/{tf_layer_name}/StyleMod/bias') + self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = ( + f'{res}x{res}/{tf_layer_name}/Noise/weight') + self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = ( + f'noise{2 * block_idx + 1}') + + # Output convolution layer for each resolution. + self.add_module(f'output{block_idx}', + ModulateConvLayer(in_channels=out_channels, + out_channels=image_channels, + resolution=res, + w_dim=w_dim, + kernel_size=1, + add_bias=True, + scale_factor=1, + fused_scale=False, + filter_kernel=None, + use_wscale=use_wscale, + wscale_gain=1.0, + lr_mul=lr_mul, + noise_type='none', + activation_type='linear', + use_style=False, + eps=eps)) + self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = ( + f'ToRGB_lod{self.final_res_log2 - res_log2}/weight') + self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = ( + f'ToRGB_lod{self.final_res_log2 - res_log2}/bias') + + def get_nf(self, res): + """Gets number of feature maps according to the given resolution.""" + return min(self.fmaps_base // res, self.fmaps_max) + + def set_space_of_latent(self, space_of_latent): + """Sets the space to which the latent code belong. + + This function is particularly used for choosing how to inject the latent + code into the convolutional layers. The original generator will take a + W-Space code and apply it for style modulation after an affine + transformation. But, sometimes, it may need to directly feed an already + affine-transformed code into the convolutional layer, e.g., when + training an encoder for GAN inversion. We term the transformed space as + Style Space (or Y-Space). This function is designed to tell the + convolutional layers how to use the input code. + + Args: + space_of_latent: The space to which the latent code belong. Case + insensitive. Support `W` and `Y`. + """ + space_of_latent = space_of_latent.upper() + for module in self.modules(): + if isinstance(module, ModulateConvLayer) and module.use_style: + setattr(module, 'space_of_latent', space_of_latent) + + def forward(self, wp, lod=None, noise_mode='const'): + lod = self.lod.item() if lod is None else lod + if lod + self.init_res_log2 > self.final_res_log2: + raise ValueError(f'Maximum level-of-details (lod) is ' + f'{self.final_res_log2 - self.init_res_log2}, ' + f'but `{lod}` is received!') + + results = {'wp': wp} + x = None + for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): + current_lod = self.final_res_log2 - res_log2 + block_idx = res_log2 - self.init_res_log2 + if lod < current_lod + 1: + layer = getattr(self, f'layer{2 * block_idx}') + x, style = layer(x, wp[:, 2 * block_idx], noise_mode) + results[f'style{2 * block_idx}'] = style + layer = getattr(self, f'layer{2 * block_idx + 1}') + x, style = layer(x, wp[:, 2 * block_idx + 1], noise_mode) + results[f'style{2 * block_idx + 1}'] = style + if current_lod - 1 < lod <= current_lod: + image = getattr(self, f'output{block_idx}')(x) + elif current_lod < lod < current_lod + 1: + alpha = np.ceil(lod) - lod + temp = getattr(self, f'output{block_idx}')(x) + image = F.interpolate(image, scale_factor=2, mode='nearest') + image = temp * alpha + image * (1 - alpha) + elif lod >= current_lod + 1: + image = F.interpolate(image, scale_factor=2, mode='nearest') + + if self.final_tanh: + image = torch.tanh(image) + results['image'] = image + return results + + +class PixelNormLayer(nn.Module): + """Implements pixel-wise feature vector normalization layer.""" + + def __init__(self, dim, eps): + super().__init__() + self.dim = dim + self.eps = eps + + def extra_repr(self): + return f'dim={self.dim}, epsilon={self.eps}' + + def forward(self, x): + scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt() + return x * scale + + +class Blur(torch.autograd.Function): + """Defines blur operation with customized gradient computation.""" + + @staticmethod + def forward(ctx, x, kernel): + assert kernel.shape[2] == 3 and kernel.shape[3] == 3 + ctx.save_for_backward(kernel) + y = F.conv2d(input=x, + weight=kernel, + bias=None, + stride=1, + padding=1, + groups=x.shape[1]) + return y + + @staticmethod + def backward(ctx, dy): + kernel, = ctx.saved_tensors + dx = F.conv2d(input=dy, + weight=kernel.flip((2, 3)), + bias=None, + stride=1, + padding=1, + groups=dy.shape[1]) + return dx, None, None + + +class ModulateConvLayer(nn.Module): + """Implements the convolutional layer with style modulation.""" + + def __init__(self, + in_channels, + out_channels, + resolution, + w_dim, + kernel_size, + add_bias, + scale_factor, + fused_scale, + filter_kernel, + use_wscale, + wscale_gain, + lr_mul, + noise_type, + activation_type, + use_style, + eps): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + resolution: Resolution of the output tensor. + w_dim: Dimension of W space for style modulation. + kernel_size: Size of the convolutional kernels. + add_bias: Whether to add bias onto the convolutional result. + scale_factor: Scale factor for upsampling. + fused_scale: Whether to fuse `upsample` and `conv2d` as one + operator, using transpose convolution. + filter_kernel: Kernel used for filtering. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + noise_type: Type of noise added to the feature map after the + convolution (if needed). Support `none`, `spatial` and + `channel`. + activation_type: Type of activation. + use_style: Whether to apply style modulation. + eps: A small value to avoid divide overflow. + """ + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.resolution = resolution + self.w_dim = w_dim + self.kernel_size = kernel_size + self.add_bias = add_bias + self.scale_factor = scale_factor + self.fused_scale = fused_scale + self.filter_kernel = filter_kernel + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.noise_type = noise_type.lower() + self.activation_type = activation_type + self.use_style = use_style + self.eps = eps + + # Set up noise. + if self.noise_type == 'none': + pass + elif self.noise_type == 'spatial': + self.register_buffer( + 'noise', torch.randn(1, 1, resolution, resolution)) + self.noise_strength = nn.Parameter( + torch.zeros(1, out_channels, 1, 1)) + elif self.noise_type == 'channel': + self.register_buffer( + 'noise', torch.randn(1, out_channels, 1, 1)) + self.noise_strength = nn.Parameter( + torch.zeros(1, 1, resolution, resolution)) + else: + raise NotImplementedError(f'Not implemented noise type: ' + f'`{noise_type}`!') + + # Set up bias. + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + + # Set up activation. + assert activation_type in ['linear', 'relu', 'lrelu'] + + # Set up style. + if use_style: + self.space_of_latent = 'W' + self.style = DenseLayer(in_channels=w_dim, + out_channels=out_channels * 2, + add_bias=True, + use_wscale=use_wscale, + wscale_gain=1.0, + lr_mul=1.0, + activation_type='linear') + + if in_channels == 0: # First layer. + self.const = nn.Parameter( + torch.ones(1, out_channels, resolution, resolution)) + return + + # Set up weight. + weight_shape = (out_channels, in_channels, kernel_size, kernel_size) + fan_in = kernel_size * kernel_size * in_channels + wscale = wscale_gain / np.sqrt(fan_in) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + # Set up upsampling filter (if needed). + if scale_factor > 1: + assert filter_kernel is not None + kernel = np.array(filter_kernel, dtype=np.float32).reshape(1, -1) + kernel = kernel.T.dot(kernel) + kernel = kernel / np.sum(kernel) + kernel = kernel[np.newaxis, np.newaxis] + self.register_buffer('filter', torch.from_numpy(kernel)) + + if scale_factor > 1 and fused_scale: # use transpose convolution. + self.stride = scale_factor + else: + self.stride = 1 + self.padding = kernel_size // 2 + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'ksize={self.kernel_size}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'upsample={self.scale_factor}, ' + f'fused_scale={self.fused_scale}, ' + f'upsample_filter={self.filter_kernel}, ' + f'noise_type={self.noise_type}, ' + f'act={self.activation_type}, ' + f'use_style={self.use_style}') + + def forward_style(self, w): + """Gets style code from the given input. + + More specifically, if the input is from W-Space, it will be projected by + an affine transformation. If it is from the Style Space (Y-Space), no + operation is required. + + NOTE: For codes from Y-Space, we use slicing to make sure the dimension + is correct, in case that the code is padded before fed into this layer. + """ + space_of_latent = self.space_of_latent.upper() + if space_of_latent == 'W': + if w.ndim != 2 or w.shape[1] != self.w_dim: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, w_dim], where ' + f'`w_dim` equals to {self.w_dim}!\n' + f'But `{w.shape}` is received!') + style = self.style(w) + elif space_of_latent == 'Y': + if w.ndim != 2 or w.shape[1] < self.out_channels * 2: + raise ValueError(f'The input tensor should be with shape ' + f'[batch_size, y_dim], where ' + f'`y_dim` equals to {self.out_channels * 2}!\n' + f'But `{w.shape}` is received!') + style = w[:, :self.out_channels * 2] + else: + raise NotImplementedError(f'Not implemented `space_of_latent`: ' + f'`{space_of_latent}`!') + return style + + def forward(self, x, w=None, noise_mode='const'): + if self.in_channels == 0: + assert x is None + x = self.const.repeat(w.shape[0], 1, 1, 1) + else: + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + + if self.scale_factor > 1 and self.fused_scale: + weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0) + weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + + weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) + x = F.conv_transpose2d(x, + weight=weight.transpose(0, 1), + bias=None, + stride=self.stride, + padding=self.padding) + else: + if self.scale_factor > 1: + up = self.scale_factor + x = F.interpolate(x, scale_factor=up, mode='nearest') + x = F.conv2d(x, + weight=weight, + bias=None, + stride=self.stride, + padding=self.padding) + + if self.scale_factor > 1: + # Disable `autocast` for customized autograd function. + # Please check reference: + # https://pytorch.org/docs/stable/notes/amp_examples.html#autocast-and-custom-autograd-functions + with autocast(enabled=False): + f = self.filter.repeat(self.out_channels, 1, 1, 1) + x = Blur.apply(x.float(), f) # Always use FP32. + + # Prepare noise. + noise_mode = noise_mode.lower() + if self.noise_type != 'none' and noise_mode != 'none': + if noise_mode == 'random': + noise = torch.randn( + (x.shape[0], *self.noise.shape[1:]), device=x.device) + elif noise_mode == 'const': + noise = self.noise + else: + raise ValueError(f'Unknown noise mode `{noise_mode}`!') + x = x + noise * self.noise_strength + + if self.bias is not None: + bias = self.bias + if self.bscale != 1.0: + bias = bias * self.bscale + x = x + bias.reshape(1, self.out_channels, 1, 1) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + if not self.use_style: + return x + + # Instance normalization. + x = x - x.mean(dim=(2, 3), keepdim=True) + scale = (x.square().mean(dim=(2, 3), keepdim=True) + self.eps).rsqrt() + x = x * scale + # Style modulation. + style = self.forward_style(w) + style_split = style.unsqueeze(2).unsqueeze(3).chunk(2, dim=1) + x = x * (style_split[0] + 1) + style_split[1] + + return x, style + + +class DenseLayer(nn.Module): + """Implements the dense layer.""" + + def __init__(self, + in_channels, + out_channels, + add_bias, + use_wscale, + wscale_gain, + lr_mul, + activation_type): + """Initializes with layer settings. + + Args: + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. + add_bias: Whether to add bias onto the fully-connected result. + use_wscale: Whether to use weight scaling. + wscale_gain: Gain factor for weight scaling. + lr_mul: Learning multiplier for both weight and bias. + activation_type: Type of activation. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_bias = add_bias + self.use_wscale = use_wscale + self.wscale_gain = wscale_gain + self.lr_mul = lr_mul + self.activation_type = activation_type + + weight_shape = (out_channels, in_channels) + wscale = wscale_gain / np.sqrt(in_channels) + if use_wscale: + self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) + self.wscale = wscale * lr_mul + else: + self.weight = nn.Parameter( + torch.randn(*weight_shape) * wscale / lr_mul) + self.wscale = lr_mul + + if add_bias: + self.bias = nn.Parameter(torch.zeros(out_channels)) + self.bscale = lr_mul + else: + self.bias = None + + assert activation_type in ['linear', 'relu', 'lrelu'] + + def extra_repr(self): + return (f'in_ch={self.in_channels}, ' + f'out_ch={self.out_channels}, ' + f'wscale_gain={self.wscale_gain:.3f}, ' + f'bias={self.add_bias}, ' + f'lr_mul={self.lr_mul:.3f}, ' + f'act={self.activation_type}') + + def forward(self, x): + if x.ndim != 2: + x = x.flatten(start_dim=1) + + weight = self.weight + if self.wscale != 1.0: + weight = weight * self.wscale + bias = None + if self.bias is not None: + bias = self.bias + if self.bscale != 1.0: + bias = bias * self.bscale + + x = F.linear(x, weight=weight, bias=bias) + + if self.activation_type == 'linear': + pass + elif self.activation_type == 'relu': + x = F.relu(x, inplace=True) + elif self.activation_type == 'lrelu': + x = F.leaky_relu(x, negative_slope=0.2, inplace=True) + else: + raise NotImplementedError(f'Not implemented activation type ' + f'`{self.activation_type}`!') + + return x + +# pylint: enable=missing-function-docstring diff --git a/models/test.py b/models/test.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1e0239e223537d299a2c52c65928b6c59406da --- /dev/null +++ b/models/test.py @@ -0,0 +1,146 @@ +# python3.7 +"""Unit test for loading pre-trained models. + +Basically, this file tests whether the perceptual model (VGG16) and the +inception model (InceptionV3), which are commonly used for loss computation and +evaluation, have the expected behavior after loading pre-trained weights. In +particular, we compare with the models from repo + +https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +import torch + +from models import build_model +from utils.misc import download_url + +__all__ = ['test_model'] + +_BATCH_SIZE = 4 +# pylint: disable=line-too-long +_PERCEPTUAL_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' +_INCEPTION_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' +# pylint: enable=line-too-long + + +def test_model(): + """Collects all model tests.""" + torch.backends.cudnn.enabled = True + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + print('========== Start Model Test ==========') + test_perceptual() + test_inception() + print('========== Finish Model Test ==========') + + +def test_perceptual(): + """Test the perceptual model.""" + print('===== Testing Perceptual Model =====') + + print('Build test model.') + model = build_model('PerceptualModel', + use_torchvision=False, + no_top=False, + enable_lpips=True) + + print('Build reference model.') + ref_model_path, _, = download_url(_PERCEPTUAL_URL) + with open(ref_model_path, 'rb') as f: + ref_model = torch.jit.load(f).eval().cuda() + + print('Test performance: ') + for size in [224, 128, 256, 512, 1024]: + raw_img = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size)) + raw_img_comp = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size)) + + # The test model requires input images to have range [-1, 1]. + img = raw_img.to(torch.float32).cuda() / 127.5 - 1 + img_comp = raw_img_comp.to(torch.float32).cuda() / 127.5 - 1 + feat = model(img, resize_input=True, return_tensor='feature') + pred = model(img, resize_input=True, return_tensor='prediction') + lpips = model(img, img_comp, resize_input=False, return_tensor='lpips') + assert feat.shape == (_BATCH_SIZE, 4096) + assert pred.shape == (_BATCH_SIZE, 1000) + assert lpips.shape == (_BATCH_SIZE,) + + # The reference model requires input images to have range [0, 255]. + img = raw_img.to(torch.float32).cuda() + img_comp = raw_img_comp.to(torch.float32).cuda() + ref_feat = ref_model(img, resize_images=True, return_features=True) + ref_pred = ref_model(img, resize_images=True, return_features=False) + temp = ref_model(torch.cat([img, img_comp], dim=0), + resize_images=False, return_lpips=True).chunk(2) + ref_lpips = (temp[0] - temp[1]).square().sum(dim=1, keepdim=False) + assert ref_feat.shape == (_BATCH_SIZE, 4096) + assert ref_pred.shape == (_BATCH_SIZE, 1000) + assert ref_lpips.shape == (_BATCH_SIZE,) + + print(f' Size {size}x{size}, feature (with resize):\n ' + f'mean: {(feat - ref_feat).abs().mean().item():.3e}, ' + f'max: {(feat - ref_feat).abs().max().item():.3e}, ' + f'ref_mean: {ref_feat.abs().mean().item():.3e}, ' + f'ref_max: {ref_feat.abs().max().item():.3e}.') + print(f' Size {size}x{size}, prediction (with resize):\n ' + f'mean: {(pred - ref_pred).abs().mean().item():.3e}, ' + f'max: {(pred - ref_pred).abs().max().item():.3e}, ' + f'ref_mean: {ref_pred.abs().mean().item():.3e}, ' + f'ref_max: {ref_pred.abs().max().item():.3e}.') + print(f' Size {size}x{size}, LPIPS (without resize):\n ' + f'mean: {(lpips - ref_lpips).abs().mean().item():.3e}, ' + f'max: {(lpips - ref_lpips).abs().max().item():.3e}, ' + f'ref_mean: {ref_lpips.abs().mean().item():.3e}, ' + f'ref_max: {ref_lpips.abs().max().item():.3e}.') + + +def test_inception(): + """Test the inception model.""" + print('===== Testing Inception Model =====') + + print('Build test model.') + model = build_model('InceptionModel', align_tf=True) + + print('Build reference model.') + ref_model_path, _, = download_url(_INCEPTION_URL) + with open(ref_model_path, 'rb') as f: + ref_model = torch.jit.load(f).eval().cuda() + + print('Test performance: ') + for size in [299, 128, 256, 512, 1024]: + raw_img = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size)) + + # The test model requires input images to have range [-1, 1]. + img = raw_img.to(torch.float32).cuda() / 127.5 - 1 + feat = model(img) + pred = model(img, output_predictions=True) + pred_nb = model(img, output_predictions=True, remove_logits_bias=True) + assert feat.shape == (_BATCH_SIZE, 2048) + assert pred.shape == (_BATCH_SIZE, 1008) + assert pred_nb.shape == (_BATCH_SIZE, 1008) + + # The reference model requires input images to have range [0, 255]. + img = raw_img.to(torch.float32).cuda() + ref_feat = ref_model(img, return_features=True) + ref_pred = ref_model(img) + ref_pred_nb = ref_model(img, no_output_bias=True) + assert ref_feat.shape == (_BATCH_SIZE, 2048) + assert ref_pred.shape == (_BATCH_SIZE, 1008) + assert ref_pred_nb.shape == (_BATCH_SIZE, 1008) + + print(f' Size {size}x{size}, feature:\n ' + f'mean: {(feat - ref_feat).abs().mean().item():.3e}, ' + f'max: {(feat - ref_feat).abs().max().item():.3e}, ' + f'ref_mean: {ref_feat.abs().mean().item():.3e}, ' + f'ref_max: {ref_feat.abs().max().item():.3e}.') + print(f' Size {size}x{size}, prediction:\n ' + f'mean: {(pred - ref_pred).abs().mean().item():.3e}, ' + f'max: {(pred - ref_pred).abs().max().item():.3e}, ' + f'ref_mean: {ref_pred.abs().mean().item():.3e}, ' + f'ref_max: {ref_pred.abs().max().item():.3e}.') + print(f' Size {size}x{size}, prediction (without bias):\n ' + f'mean: {(pred_nb - ref_pred_nb).abs().mean().item():.3e}, ' + f'max: {(pred_nb - ref_pred_nb).abs().max().item():.3e}, ' + f'ref_mean: {ref_pred_nb.abs().mean().item():.3e}, ' + f'ref_max: {ref_pred_nb.abs().max().item():.3e}.') diff --git a/models/utils/__init__.py b/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/utils/ops.py b/models/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6993346e9dc9bbccee150828c2231678c8159000 --- /dev/null +++ b/models/utils/ops.py @@ -0,0 +1,18 @@ +# python3.7 +"""Contains operators for neural networks.""" + +import torch +import torch.distributed as dist + +__all__ = ['all_gather'] + + +def all_gather(tensor): + """Gathers tensor from all devices and executes averaging.""" + if not dist.is_initialized(): + return tensor + + world_size = dist.get_world_size() + tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, async_op=False) + return torch.stack(tensor_list, dim=0).mean(dim=0) diff --git a/requirements/convert.txt b/requirements/convert.txt new file mode 100644 index 0000000000000000000000000000000000000000..3096fa8580734dcc1fc3e81c9abb289ec4ec86a5 --- /dev/null +++ b/requirements/convert.txt @@ -0,0 +1,11 @@ +torch==1.8.1 +tensorflow-gpu==1.15 +ninja==1.10.2 +scikit-video==1.1.11 +pillow==9.0.0 +opencv-python-headless==4.5.5.62 +requests +bs4 +tqdm +rich +easydict diff --git a/requirements/develop.txt b/requirements/develop.txt new file mode 100644 index 0000000000000000000000000000000000000000..e7238e44acc80a94c7052f1148b2631eac48f0d2 --- /dev/null +++ b/requirements/develop.txt @@ -0,0 +1,3 @@ +bpytop # Monitor system resources. +gpustat # Monitor GPU usage. +pylint # Check coding style. diff --git a/requirements/minimal.txt b/requirements/minimal.txt new file mode 100644 index 0000000000000000000000000000000000000000..df29dadc6bb748edc59d05da333ac0e2fd740c6d --- /dev/null +++ b/requirements/minimal.txt @@ -0,0 +1,21 @@ +torch==1.8.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html +torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html +tensorboard==2.7.0 +torch-tb-profiler==0.3.1 +ninja==1.10.2 +numpy==1.21.5 +scipy==1.7.3 +scikit-learn==1.0.2 +scikit-video==1.1.11 +pillow==9.0.0 +opencv-python-headless==4.5.5.62 +requests +bs4 +tqdm +rich +click +cloup +psutil +easydict +lmdb +matplotlib diff --git a/synthesis.py b/synthesis.py new file mode 100644 index 0000000000000000000000000000000000000000..398a65a25cf3c13861155292f03f87ef5ad01c8d --- /dev/null +++ b/synthesis.py @@ -0,0 +1,178 @@ +# python3.7 +"""Script that synthesizes images with pre-trained models. + +Support StyleGAN2 and StyleGAN3. +""" + +import os +import argparse +from tqdm import tqdm +import numpy as np + +import torch +from models import build_model +from utils.visualizers.html_visualizer import HtmlVisualizer +from utils.image_utils import save_image, resize_image +from utils.image_utils import postprocess_image +from utils.custom_utils import to_numpy + + +def parse_args(): + """Parses arguments.""" + parser = argparse.ArgumentParser() + group = parser.add_argument_group('General options.') + group.add_argument('weight_path', type=str, + help='Weight path to the pre-trained model.') + group.add_argument('--save_dir', type=str, default=None, + help='Directory to save the results. If not specified, ' + 'the results will be saved to ' + '`work_dirs/{TASK_SPECIFIC}/` by default.') + group.add_argument('--job', type=str, default='synthesize', + help='Name for the job. (default: synthesize)') + group.add_argument('--seed', type=int, default=4, + help='Seed for sampling. (default: 4)') + group.add_argument('--nums', type=int, default=100, + help='Number of samples to synthesized. (default: 100)') + group.add_argument('--img_size', type=int, default=1024, + help='Size of the synthesized images. (default: 1024)') + group.add_argument('--vis_size', type=int, default=256, + help='Size of the visualize images. (default: 256)') + group.add_argument('--w_dim', type=int, default=512, + help='Dimension of the latent w. (default: 512)') + group.add_argument('--batch_size', type=int, default=4, + help='Batch size. (default: 4)') + group.add_argument('--save_jpg', action='store_true', default=False, + help='Whether to save raw image. (default: False)') + group.add_argument('-d', '--data_name', type=str, default='ffhq', + help='Name of the datasets. (default: ffhq)') + group.add_argument('--latent_path', type=str, default='', + help='Path to the given latent codes. (default: None)') + group.add_argument('--trunc_psi', type=float, default=0.7, + help='Psi factor used for truncation. (default: 0.7)') + group.add_argument('--trunc_layers', type=int, default=8, + help='Number of layers to perform truncation.' + ' (default: 8)') + + group = parser.add_argument_group('StyleGAN2') + group.add_argument('--stylegan2', action='store_true', + help='Whether or not using StyleGAN2. (default: False)') + group.add_argument('--scale_stylegan2', type=float, default=1.0, + help='Scale for the number of channel fro stylegan2.') + group.add_argument('--randomize_noise', type=str, default='const', + help='Noise type when synthesizing. (const or random)') + + group = parser.add_argument_group('StyleGAN3') + group.add_argument('--stylegan3', action='store_true', + help='Whether or not using StyleGAN3. (default: False)') + group.add_argument('--cfg', type=str, default='T', + help='Config of the stylegan3 (T/R).') + group.add_argument('--scale_stylegan3r', type=float, default=2.0, + help='Scale for the number of channel for stylegan3 R.') + group.add_argument('--scale_stylegan3t', type=float, default=1.0, + help='Scale for the number of channel for stylegan3 T.') + group.add_argument('--tx', type=float, default=0, + help='Translate X-coordinate. (default: 0.0)') + group.add_argument('--ty', type=float, default=0, + help='Translate Y-coordinate. (default: 0.0)') + group.add_argument('--rotate', type=float, default=0, + help='Rotation angle in degrees. (default: 0)') + return parser.parse_args() + + +def main(): + """Main function.""" + args = parse_args() + # Parse model configuration. + assert (args.stylegan2 and not args.stylegan3) or \ + (not args.stylegan2 and args.stylegan3) + job_disc = '' + if args.stylegan2: + config = dict(model_type='StyleGAN2Generator', + resolution=args.img_size, + w_dim=args.w_dim, + fmaps_base=int(args.scale_stylegan2 * (32 << 10)), + fmaps_max=512,) + job_disc += 'stylegan2' + else: + if args.stylegan3 and args.cfg == 'R': + config = dict(model_type='StyleGAN3Generator', + resolution=args.img_size, + w_dim=args.w_dim, + fmaps_base=int(args.scale_stylegan3r * (32 << 10)), + fmaps_max=1024, + use_radial_filter=True,) + job_disc += 'stylegan3r' + elif args.stylegan3 and args.cfg == 'T': + config = dict(model_type='StyleGAN3Generator', + resolution=args.img_size, + w_dim=args.w_dim, + fmaps_base=int(args.scale_stylegan3t * (32 << 10)), + fmaps_max=512, + use_radial_filter=False, + kernel_size=3,) + job_disc += 'stylegan3t' + else: + raise TypeError(f'StyleGAN3 config type error, need `R/T`,' + f' but got {args.cfg} instead.') + + # Get work directory and job name. + save_dir = args.save_dir or f'work_dirs/{args.job}/{args.data_name}' + os.makedirs(save_dir, exist_ok=True) + job_name = f'seed_{args.seed}_num_{args.nums}_{job_disc}' + os.makedirs(f'{save_dir}/{job_name}', exist_ok=True) + + # Build generation and get synthesis kwargs. + print('Building generator...') + generator = build_model(**config) + synthesis_kwargs = dict(trunc_psi=args.trunc_psi, + trunc_layers=args.trunc_layers,) + # Load pre-trained weights. + checkpoint_path = args.weight_path + print(f'Loading checkpoint from `{checkpoint_path}` ...') + checkpoint = torch.load(checkpoint_path, map_location='cpu')['models'] + if 'generator_smooth' in checkpoint: + generator.load_state_dict(checkpoint['generator_smooth']) + else: + generator.load_state_dict(checkpoint['generator']) + generator = generator.eval().cuda() + print('Finish loading checkpoint.') + + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if os.path.exists(args.latent_path): + latent_zs = np.load(args.latent_path) + latent_zs = latent_zs[:args.nums] + else: + latent_zs = np.random.randn(args.nums, generator.z_dim) + num_images = latent_zs.shape[0] + latent_zs = torch.from_numpy(latent_zs.astype(np.float32)) + html = HtmlVisualizer(grid_size=num_images) + print(f'Synthesizing {num_images} images ...') + latent_ws = [] + for batch_idx in tqdm(range(0, num_images, args.batch_size)): + latent_z = latent_zs[batch_idx:batch_idx + args.batch_size] + latent_z = latent_z.cuda() + with torch.no_grad(): + g_outputs = generator(latent_z, **synthesis_kwargs) + g_image = to_numpy(g_outputs['image']) + images = postprocess_image(g_image) + for idx in range(images.shape[0]): + sub_idx = batch_idx + idx + img = images[idx] + row_idx, col_idx = divmod(sub_idx, html.num_cols) + image = resize_image(img, (args.vis_size, args.vis_size)) + html.set_cell(row_idx, col_idx, image=image, + text=f'Sample {sub_idx:06d}') + if args.save_jpg: + save_path = f'{save_dir}/{job_name}/{sub_idx:06d}.jpg' + save_image(save_path, img) + latent_ws.append(to_numpy(g_outputs['wp'])) + latent_ws = np.concatenate(latent_ws, axis=0) + print(f'shape of the latent code: {latent_ws.shape}') + np.save(f'{save_dir}/{job_name}/latent_codes.npy', latent_ws) + html.save(f'{save_dir}/{job_name}.html') + print(f'Finish synthesizing {num_images} samples.') + + +if __name__ == '__main__': + main() diff --git a/third_party/__init__.py b/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/stylegan2_official_ops/README.md b/third_party/stylegan2_official_ops/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ef251594d2abbcbbb84f09aa516c3d437283f126 --- /dev/null +++ b/third_party/stylegan2_official_ops/README.md @@ -0,0 +1,28 @@ +# Operators for StyleGAN2 + +All files in this directory are borrowed from repository [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including + +- `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator. +- `upfirdn2d.setup_filter()`: Set up the kernel used for filtering. +- `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel. +- `upfirdn2d.upsample2d()`: Upsampling a 2D feature map. +- `upfirdn2d.downsample2d()`: Downsampling a 2D feature map. +- `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map. +- `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. +- `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. +- `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`) + +We make following slight modifications beyond disabling some lint warnings: + +- Line 25 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). +- Line 35 of file `custom_ops.py`: Disable log message when setting up customized operators. +- Line 53/89 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*) +- Line 24 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). +- Line 32 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default. +- Line 36 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator. +- Line 33 of file `conv2d_gradfix.py`: Enable customized convolution operators by default. +- Line 46/51 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators. +- Line 36/66 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators. +- Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator. + +Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default. diff --git a/third_party/stylegan2_official_ops/__init__.py b/third_party/stylegan2_official_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/stylegan2_official_ops/bias_act.cpp b/third_party/stylegan2_official_ops/bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d2425d8054991a8e8b6f7a940fd0ff7fa0bb330 --- /dev/null +++ b/third_party/stylegan2_official_ops/bias_act.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/bias_act.cu b/third_party/stylegan2_official_ops/bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..dd8fc4756d7d94727f94af738665b68d9c518880 --- /dev/null +++ b/third_party/stylegan2_official_ops/bias_act.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/bias_act.h b/third_party/stylegan2_official_ops/bias_act.h new file mode 100644 index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4 --- /dev/null +++ b/third_party/stylegan2_official_ops/bias_act.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/bias_act.py b/third_party/stylegan2_official_ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..b94dca1fb0a7f3bc13dce952d8e97a211ec94a88 --- /dev/null +++ b/third_party/stylegan2_official_ops/bias_act.py @@ -0,0 +1,227 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom ops to fuse bias and activation as one operator, which is efficient. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=global-statement +# pylint: disable=bare-except + +import os +import warnings +import traceback +from easydict import EasyDict +import numpy as np +import torch + +from . import custom_ops +from . import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _inited, _plugin + if not _inited: + _inited = True + sources = ['bias_act.cpp', 'bias_act.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=global-statement +# pylint: enable=bare-except diff --git a/third_party/stylegan2_official_ops/conv2d_gradfix.py b/third_party/stylegan2_official_ops/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..512702aa29e9877798ac6a24231c6badd4ac7315 --- /dev/null +++ b/third_party/stylegan2_official_ops/conv2d_gradfix.py @@ -0,0 +1,189 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for convolution operators. + +Operators in this file support arbitrarily high order gradients with zero +performance penalty. Please set `impl` as `cuda` to use faster customized +operators, OR as `ref` to use native `torch.nn.functional.conv2d` and +`torch.nn.functional.conv_transpose2d`. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access +# pylint: disable=line-too-long +# pylint: disable=global-statement +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +import warnings +import contextlib +import torch + +enabled = True # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') + return False + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + if not transpose: + output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + else: # transpose + output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + ctx.save_for_backward(input, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) + assert grad_input.shape == input.shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + assert grad_weight.shape == weight_shape + ctx.save_for_backward(grad_output, input) + return grad_weight + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output.shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) + grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input.shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- + +# pylint: enable=redefined-builtin +# pylint: enable=arguments-differ +# pylint: enable=protected-access +# pylint: enable=line-too-long +# pylint: enable=global-statement +# pylint: enable=missing-class-docstring +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan2_official_ops/conv2d_resample.py b/third_party/stylegan2_official_ops/conv2d_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..fb76aa245dd4b2c99f79f24c30403c1a1958c90b --- /dev/null +++ b/third_party/stylegan2_official_ops/conv2d_resample.py @@ -0,0 +1,168 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long + +import torch + +from . import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True, impl='cuda'): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + w = w.flip([2, 3]) + + # Workaround performance pitfall in cuDNN 8.0.5, triggered when using + # 1x1 kernel + memory_format=channels_last + less than 64 channels. + if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: + if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: + if out_channels <= 4 and groups == 1: + in_shape = x.shape + x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) + x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) + else: + x = x.to(memory_format=torch.contiguous_format) + w = w.to(memory_format=torch.contiguous_format) + x = conv2d_gradfix.conv2d(x, w, groups=groups, impl=impl) + return x.to(memory_format=torch.channels_last) + + # Otherwise => execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups, impl=impl) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False, impl='cuda'): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + impl: Implementation mode of customized ops. 'ref' for native PyTorch + implementation, 'cuda' for `.cu` implementation + (default: 'cuda'). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight, impl=impl) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight), impl=impl) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter, impl=impl) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight, impl=impl) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) + return x + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long diff --git a/third_party/stylegan2_official_ops/custom_ops.py b/third_party/stylegan2_official_ops/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9a8ef3ec71d144eed7584378546d7ccc183748 --- /dev/null +++ b/third_party/stylegan2_official_ops/custom_ops.py @@ -0,0 +1,159 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Utility functions to setup customized operators. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-function-docstring +# pylint: disable=useless-suppression +# pylint: disable=inconsistent-quotes + +import os +import glob +import importlib +import hashlib +import shutil +from pathlib import Path + +import torch +from torch.utils.file_baton import FileBaton +import torch.utils.cpp_extension + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'none' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +def _find_compiler_bindir_posix(): + patterns = [ + '/usr/local/cuda/bin' + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + elif os.name == 'posix': + compiler_bindir = _find_compiler_bindir_posix() + if compiler_bindir is None: + raise RuntimeError(f'Could not find NVCC installation on this computer. Check _find_compiler_bindir_posix() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Compile and load. + verbose_build = (verbosity == 'full') + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + source_dirs_set = set(os.path.dirname(source) for source in sources) + if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): + all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) + + # Compute a combined hash digest for all source files in the same + # custom op directory (usually .cu, .cpp, .py and .h files). + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) + + if not os.path.isdir(digest_build_dir): + os.makedirs(digest_build_dir, exist_ok=True) + baton = FileBaton(os.path.join(digest_build_dir, 'lock')) + if baton.try_acquire(): + try: + for src in all_source_files: + shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) + finally: + baton.release() + else: + # Someone else is copying source files under the digest dir, + # wait until done and continue. + baton.wait() + digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, + verbose=verbose_build, sources=digest_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring +# pylint: enable=useless-suppression +# pylint: enable=inconsistent-quotes diff --git a/third_party/stylegan2_official_ops/fma.py b/third_party/stylegan2_official_ops/fma.py new file mode 100644 index 0000000000000000000000000000000000000000..7304d85825d16612eec488242b220c2dbd83b6d7 --- /dev/null +++ b/third_party/stylegan2_official_ops/fma.py @@ -0,0 +1,73 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-function-docstring + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c, impl='cuda'): # => a * b + c + if impl == 'cuda': + return _FusedMultiplyAdd.apply(a, b, c) + return torch.addcmul(c, a, b) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan2_official_ops/grid_sample_gradfix.py b/third_party/stylegan2_official_ops/grid_sample_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..a41c14dc415b3a991973f3d30ca0bc6dd0b84423 --- /dev/null +++ b/third_party/stylegan2_official_ops/grid_sample_gradfix.py @@ -0,0 +1,98 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample`. + +This is useful for differentiable augmentation. This customized operator +supports arbitrarily high order gradients between the input and output. Only +works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and +`align_corners=False`. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access +# pylint: disable=line-too-long +# pylint: disable=missing-function-docstring + +import warnings +import torch + +#---------------------------------------------------------------------------- + +enabled = True # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + if not enabled: + return False + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): + return True + warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') + return False + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- + +# pylint: enable=redefined-builtin +# pylint: enable=arguments-differ +# pylint: enable=protected-access +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan2_official_ops/misc.py b/third_party/stylegan2_official_ops/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7973619f8db5a41f18ef42c83c4c5e5e013e7ff7 --- /dev/null +++ b/third_party/stylegan2_official_ops/misc.py @@ -0,0 +1,281 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Misc functions for customized operations. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=use-maxsplit-arg +# pylint: disable=unnecessary-comprehension + +import re +import contextlib +import warnings +from easydict import EasyDict +import numpy as np +import torch + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to suppress known warnings in torch.jit.trace(). + +class suppress_tracer_warnings(warnings.catch_warnings): + def __enter__(self): + super().__enter__() + warnings.simplefilter('ignore', category=torch.jit.TracerWarning) + return self + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=missing-function-docstring +# pylint: enable=use-maxsplit-arg +# pylint: enable=unnecessary-comprehension diff --git a/third_party/stylegan2_official_ops/upfirdn2d.cpp b/third_party/stylegan2_official_ops/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d7177fc60040751d20e9a8da0301fa3ab64968a --- /dev/null +++ b/third_party/stylegan2_official_ops/upfirdn2d.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/upfirdn2d.cu b/third_party/stylegan2_official_ops/upfirdn2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..ebdd9879f4bb16fc57a23cbc81f9de8ef54e4916 --- /dev/null +++ b/third_party/stylegan2_official_ops/upfirdn2d.cu @@ -0,0 +1,350 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/upfirdn2d.h b/third_party/stylegan2_official_ops/upfirdn2d.h new file mode 100644 index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd --- /dev/null +++ b/third_party/stylegan2_official_ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan2_official_ops/upfirdn2d.py b/third_party/stylegan2_official_ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..1d154bc17430366f375f6e7263854f7063285250 --- /dev/null +++ b/third_party/stylegan2_official_ops/upfirdn2d.py @@ -0,0 +1,401 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom operators for efficient resampling of 2D images. + +`upfirdn` means executing upsampling, FIR filtering, downsampling in sequence. + +Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=global-variable-not-assigned +# pylint: disable=bare-except + +import os +import warnings +import traceback +import numpy as np +import torch + +from . import custom_ops +from . import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_inited = False +_plugin = None + +def _init(): + global _inited, _plugin + if not _inited: + sources = ['upfirdn2d.cpp', 'upfirdn2d.cu'] + sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] + try: + _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) + except: + warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) + return _plugin is not None + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain)) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain)) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=global-variable-not-assigned +# pylint: enable=bare-except diff --git a/third_party/stylegan3_official_ops/README.md b/third_party/stylegan3_official_ops/README.md new file mode 100644 index 0000000000000000000000000000000000000000..417b50ba3dde490f4a4c5c6dfc6afbc28ba640d0 --- /dev/null +++ b/third_party/stylegan3_official_ops/README.md @@ -0,0 +1,30 @@ +# Operators for StyleGAN2 + +All files in this directory are borrowed from repository [stylegan3](https://github.com/NVlabs/stylegan3). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including + +- `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator. +- `upfirdn2d.setup_filter()`: Set up the kernel used for filtering. +- `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel. +- `upfirdn2d.upsample2d()`: Upsampling a 2D feature map. +- `upfirdn2d.downsample2d()`: Downsampling a 2D feature map. +- `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map. +- `filtered_lrelu.filtered_lrelu()`: Leaky ReLU layer, wrapped with upsampling and downsampling for anti-aliasing. +- `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. +- `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. +- `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`) + +We make following slight modifications beyond disabling some lint warnings: + +- Line 24 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3). +- Line 36 of file `custom_ops.py`: Disable log message when setting up customized operators. +- Line 54/109 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*) +- Line 21 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3). +- Line 162-165 of file `filtered_lrelu.py`: Change some implementations in `_filtered_lrelu_ref()` to `ref`. +- Line 31 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default. +- Line 35 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator. +- Line 34 of file `conv2d_gradfix.py`: Enable customized convolution operators by default. +- Line 48/53 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators. +- Line 36/53 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators. +- Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator. + +Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default. diff --git a/third_party/stylegan3_official_ops/__init__.py b/third_party/stylegan3_official_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/stylegan3_official_ops/bias_act.cpp b/third_party/stylegan3_official_ops/bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3adaeee2ae44e96655d354c2bdfb81de8ebfe6c6 --- /dev/null +++ b/third_party/stylegan3_official_ops/bias_act.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act", &bias_act); +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/bias_act.cu b/third_party/stylegan3_official_ops/bias_act.cu new file mode 100644 index 0000000000000000000000000000000000000000..ed1d16f14eadd1344939e074ace1375cfd936cea --- /dev/null +++ b/third_party/stylegan3_official_ops/bias_act.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/bias_act.h b/third_party/stylegan3_official_ops/bias_act.h new file mode 100644 index 0000000000000000000000000000000000000000..60b81c6058d54638a6d74a13046fa388442d767d --- /dev/null +++ b/third_party/stylegan3_official_ops/bias_act.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/bias_act.py b/third_party/stylegan3_official_ops/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..c90e4f0fcc22b2eeb0e5b6a10d1d3f700f808e00 --- /dev/null +++ b/third_party/stylegan3_official_ops/bias_act.py @@ -0,0 +1,222 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom ops to fuse bias and activation as one operator, which is efficient. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=global-statement + +import os +from easydict import EasyDict +import numpy as np +import torch + +from . import custom_ops +from . import misc + +#---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), + 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), + 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), + 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), + 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), + 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), +} + +#---------------------------------------------------------------------------- + +_plugin = None +_null_tensor = torch.empty([0]) + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='bias_act_plugin', + sources=['bias_act.cpp', 'bias_act.cu'], + headers=['bias_act.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +#---------------------------------------------------------------------------- + +def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): + r"""Fused bias and activation function. + + Adds bias `b` to activation tensor `x`, evaluates activation function `act`, + and scales the result by `gain`. Each of the steps is optional. In most cases, + the fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports first and second order gradients, + but not third order gradients. + + Args: + x: Input activation tensor. Can be of any shape. + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The shape must be known, and it must match the dimension of `x` + corresponding to `dim`. + dim: The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + act: Name of the activation function to evaluate, or `"linear"` to disable. + Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. + See `activation_funcs` for a full list. `None` is not allowed. + alpha: Shape parameter for the activation function, or `None` to use the default. + gain: Scaling factor for the output tensor, or `None` to use default. + See `activation_funcs` for the default scaling of each activation function. + If unsure, consider specifying 1. + clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable + the clamping (default). + impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). + + Returns: + Tensor of the same shape and datatype as `x`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops. + """ + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) + return x + +#---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: + y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format + dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=global-statement diff --git a/third_party/stylegan3_official_ops/conv2d_gradfix.py b/third_party/stylegan3_official_ops/conv2d_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..19aba5ca78f1228e4b8e3aafccbbe072c747f007 --- /dev/null +++ b/third_party/stylegan3_official_ops/conv2d_gradfix.py @@ -0,0 +1,219 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for convolution operators. + +Operators in this file support arbitrarily high order gradients with zero +performance penalty. Please set `impl` as `cuda` to use faster customized +operators, OR as `ref` to use native `torch.nn.functional.conv2d` and +`torch.nn.functional.conv_transpose2d`. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access +# pylint: disable=line-too-long +# pylint: disable=global-statement +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring + +import contextlib +import torch + +#---------------------------------------------------------------------------- + +enabled = True # Enable the custom op by setting this to true. +weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. + +@contextlib.contextmanager +def no_weight_gradients(disable=True): + global weight_gradients_disabled + old = weight_gradients_disabled + if disable: + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + +#---------------------------------------------------------------------------- + +def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(input): + return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(input): + return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + return True + +def _tuple_of_ints(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + +#---------------------------------------------------------------------------- + +_conv2d_gradfix_cache = dict() +_null_tensor = torch.empty([0]) + +def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _tuple_of_ints(stride, ndim) + padding = _tuple_of_ints(padding, ndim) + output_padding = _tuple_of_ints(output_padding, ndim) + dilation = _tuple_of_ints(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) + assert all(padding[i] >= 0 for i in range(ndim)) + assert all(dilation[i] >= 0 for i in range(ndim)) + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) + else: # transpose + assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) + + # Helpers. + common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] + - (output_shape[i + 2] - 1) * stride[i] + - (1 - 2 * padding[i]) + - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + ctx.save_for_backward( + input if weight.requires_grad else _null_tensor, + weight if input.requires_grad else _null_tensor, + ) + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): + a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) + c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) + c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) + c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + if transpose: + return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) + return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + input_shape = ctx.input_shape + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad_input = op.apply(grad_output, weight, None) + assert grad_input.shape == input_shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input): + ctx.save_for_backward( + grad_output if input.requires_grad else _null_tensor, + input if grad_output.requires_grad else _null_tensor, + ) + ctx.grad_output_shape = grad_output.shape + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). + if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): + a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) + c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) + return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) + + # General case => cuDNN. + name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' + flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] + return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad_output_shape = ctx.grad_output_shape + input_shape = ctx.input_shape + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) + assert grad2_grad_output.shape == grad_output_shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) + op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) + grad2_input = op.apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input_shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d + +#---------------------------------------------------------------------------- + +# pylint: enable=redefined-builtin +# pylint: enable=arguments-differ +# pylint: enable=protected-access +# pylint: enable=line-too-long +# pylint: enable=global-statement +# pylint: enable=missing-class-docstring +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan3_official_ops/conv2d_resample.py b/third_party/stylegan3_official_ops/conv2d_resample.py new file mode 100644 index 0000000000000000000000000000000000000000..dfde81ee19204a7993fd1c3cd21055a51418231b --- /dev/null +++ b/third_party/stylegan3_official_ops/conv2d_resample.py @@ -0,0 +1,154 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""2D convolution with optional up/downsampling. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long + +import torch + +from . import misc +from . import conv2d_gradfix +from . import upfirdn2d +from .upfirdn2d import _parse_padding +from .upfirdn2d import _get_filter_size + +#---------------------------------------------------------------------------- + +def _get_weight_shape(w): + with misc.suppress_tracer_warnings(): # this value will be treated as a constant + shape = [int(sz) for sz in w.shape] + misc.assert_shape(w, shape) + return shape + +#---------------------------------------------------------------------------- + +def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True, impl='cuda'): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. + """ + _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) + + # Flip weight if requested. + # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + if not flip_weight and (kw > 1 or kh > 1): + w = w.flip([2, 3]) + + # Execute using conv2d_gradfix. + op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d + return op(x, w, stride=stride, padding=padding, groups=groups, impl=impl) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False, impl='cuda'): + r"""2D convolution with optional up/downsampling. + + Padding is performed only once at the beginning, not between the operations. + + Args: + x: Input tensor of shape + `[batch_size, in_channels, in_height, in_width]`. + w: Weight tensor of shape + `[out_channels, in_channels//groups, kernel_height, kernel_width]`. + f: Low-pass filter for up/downsampling. Must be prepared beforehand by + calling upfirdn2d.setup_filter(). None = identity (default). + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + groups: Split input channels into N groups (default: 1). + flip_weight: False = convolution, True = correlation (default: True). + flip_filter: False = convolution, True = correlation (default: False). + impl: Implementation mode, 'cuda' for CUDA implementation, and 'ref' for + native PyTorch implementation (default: 'cuda'). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and (x.ndim == 4) + assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) + assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert isinstance(up, int) and (up >= 1) + assert isinstance(down, int) and (down >= 1) + assert isinstance(groups, int) and (groups >= 1) + out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) + fw, fh = _get_filter_size(f) + px0, px1, py0, py1 = _parse_padding(padding) + + # Adjust padding to account for up/downsampling. + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + + # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. + if kw == 1 and kh == 1 and (down > 1 and up == 1): + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + return x + + # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. + if kw == 1 and kh == 1 and (up > 1 and down == 1): + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) + return x + + # Fast path: downsampling only => use strided convolution. + if down > 1 and up == 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight, impl=impl) + return x + + # Fast path: upsampling with optional downsampling => use transpose strided convolution. + if up > 1: + if groups == 1: + w = w.transpose(0, 1) + else: + w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) + w = w.transpose(1, 2) + w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + px0 -= kw - 1 + px1 -= kw - up + py0 -= kh - 1 + py1 -= kh - up + pxt = max(min(-px0, -px1), 0) + pyt = max(min(-py0, -py1), 0) + x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight), impl=impl) + x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter, impl=impl) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) + return x + + # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. + if up == 1 and down == 1: + if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: + return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight, impl=impl) + + # Fallback: Generic reference implementation. + x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) + x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) + if down > 1: + x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) + return x + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long diff --git a/third_party/stylegan3_official_ops/custom_ops.py b/third_party/stylegan3_official_ops/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c5853ac187e6e3ae522b0ef1aabefc7b188f7083 --- /dev/null +++ b/third_party/stylegan3_official_ops/custom_ops.py @@ -0,0 +1,191 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Utility functions to setup customized operators. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=multiple-statements +# pylint: disable=missing-function-docstring +# pylint: disable=useless-suppression +# pylint: disable=inconsistent-quotes + +import glob +import hashlib +import importlib +import os +import re +import shutil +import uuid + +import torch +import torch.utils.cpp_extension + +#---------------------------------------------------------------------------- +# Global options. + +verbosity = 'none' # Verbosity level: 'none', 'brief', 'full' + +#---------------------------------------------------------------------------- +# Internal helper funcs. + +def _find_compiler_bindir(): + patterns = [ + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', + 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +def _find_compiler_bindir_posix(): + patterns = [ + '/usr/local/cuda/bin' + ] + for pattern in patterns: + matches = sorted(glob.glob(pattern)) + if len(matches): + return matches[-1] + return None + +#---------------------------------------------------------------------------- + +def _get_mangled_gpu_name(): + name = torch.cuda.get_device_name().lower() + out = [] + for c in name: + if re.match('[a-z0-9_-]+', c): + out.append(c) + else: + out.append('-') + return ''.join(out) + +#---------------------------------------------------------------------------- +# Main entry point for compiling and loading C++/CUDA plugins. + +_cached_plugins = dict() + +def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): + assert verbosity in ['none', 'brief', 'full'] + if headers is None: + headers = [] + if source_dir is not None: + sources = [os.path.join(source_dir, fname) for fname in sources] + headers = [os.path.join(source_dir, fname) for fname in headers] + + # Already cached? + if module_name in _cached_plugins: + return _cached_plugins[module_name] + + # Print status. + if verbosity == 'full': + print(f'Setting up PyTorch plugin "{module_name}"...') + elif verbosity == 'brief': + print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) + verbose_build = (verbosity == 'full') + + # Compile and load. + try: # pylint: disable=too-many-nested-blocks + # Make sure we can find the necessary compiler binaries. + if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: + compiler_bindir = _find_compiler_bindir() + if compiler_bindir is None: + raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + elif os.name == 'posix': + compiler_bindir = _find_compiler_bindir_posix() + if compiler_bindir is None: + raise RuntimeError(f'Could not find NVCC installation on this computer. Check _find_compiler_bindir_posix() in "{__file__}".') + os.environ['PATH'] += ';' + compiler_bindir + + # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either + # break the build or unnecessarily restrict what's available to nvcc. + # Unset it to let nvcc decide based on what's available on the + # machine. + os.environ['TORCH_CUDA_ARCH_LIST'] = '' + + # Incremental build md5sum trickery. Copies all the input source files + # into a cached build directory under a combined md5 digest of the input + # source files. Copying is done only if the combined digest has changed. + # This keeps input file timestamps and filenames the same as in previous + # extension builds, allowing for fast incremental rebuilds. + # + # This optimization is done only in case all the source files reside in + # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR + # environment variable is set (we take this as a signal that the user + # actually cares about this.) + # + # EDIT: We now do it regardless of TORCH_EXTENSIONS_DIR, in order to work + # around the *.cu dependency bug in ninja config. + # + all_source_files = sorted(sources + headers) + all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) + if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): + + # Compute combined hash digest for all source files. + hash_md5 = hashlib.md5() + for src in all_source_files: + with open(src, 'rb') as f: + hash_md5.update(f.read()) + + # Select cached build directory name. + source_digest = hash_md5.hexdigest() + build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access + cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') + + if not os.path.isdir(cached_build_dir): + tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' + os.makedirs(tmpdir) + for src in all_source_files: + shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) + try: + os.replace(tmpdir, cached_build_dir) # atomic + except OSError: + # source directory already exists, delete tmpdir and its contents. + shutil.rmtree(tmpdir) + if not os.path.isdir(cached_build_dir): raise + + # Compile. + cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] + torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, + verbose=verbose_build, sources=cached_sources, **build_kwargs) + else: + torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) + + # Load. + module = importlib.import_module(module_name) + + except: + if verbosity == 'brief': + print('Failed!') + raise + + # Print status and add to cache dict. + if verbosity == 'full': + print(f'Done setting up PyTorch plugin "{module_name}".') + elif verbosity == 'brief': + print('Done.') + _cached_plugins[module_name] = module + return module + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=multiple-statements +# pylint: enable=missing-function-docstring +# pylint: enable=useless-suppression +# pylint: enable=inconsistent-quotes diff --git a/third_party/stylegan3_official_ops/filtered_lrelu.cpp b/third_party/stylegan3_official_ops/filtered_lrelu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ff4149b8b46b54d2f400ae10e44d19f20503ba1f --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu.cpp @@ -0,0 +1,300 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "filtered_lrelu.h" + +//------------------------------------------------------------------------ + +static std::tuple filtered_lrelu( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, + int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) +{ + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); + TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); + TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); + TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); + TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); + TORCH_CHECK(fu.numel() > 0, "fu is empty"); + TORCH_CHECK(fd.numel() > 0, "fd is empty"); + TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); + TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); + + // Figure out how much shared memory is available on the device. + int maxSharedBytes = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); + int sharedKB = maxSharedBytes >> 10; + + // Populate enough launch parameters to check if a CUDA kernel exists. + filtered_lrelu_kernel_params p; + p.up = up; + p.down = down; + p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. + p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); + filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB); + if (!test_spec.exec) + { + // No kernel found - return empty tensors and indicate missing kernel with return code of -1. + return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); + } + + // Input/output element size. + int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; + + // Input sizes. + int64_t xw = (int)x.size(3); + int64_t xh = (int)x.size(2); + int64_t fut_w = (int)fu.size(-1) - 1; + int64_t fut_h = (int)fu.size(0) - 1; + int64_t fdt_w = (int)fd.size(-1) - 1; + int64_t fdt_h = (int)fd.size(0) - 1; + + // Logical size of upsampled buffer. + int64_t cw = xw * up + (px0 + px1) - fut_w; + int64_t ch = xh * up + (py0 + py1) - fut_h; + TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); + TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); + + // Compute output size and allocate. + int64_t yw = (cw - fdt_w + (down - 1)) / down; + int64_t yh = (ch - fdt_h + (down - 1)) / down; + TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); + TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); + + // Allocate sign tensor. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + int64_t sw_active = 0; // Active width of sign tensor. + if (writeSigns) + { + sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. + int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. + int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. + TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); + s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); + } + else if (readSigns) + sw_active = s.size(3) << 2; + + // Validate sign tensor if in use. + if (readSigns || writeSigns) + { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); + } + + // Populate rest of CUDA kernel parameters. + p.x = x.data_ptr(); + p.y = y.data_ptr(); + p.b = b.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.fu = fu.data_ptr(); + p.fd = fd.data_ptr(); + p.pad0 = make_int2(px0, py0); + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.flip = (flip_filters) ? 1 : 0; + p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. + p.sOfs = make_int2(sx, sy); + p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. + + // x, y, b strides are in bytes. + p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); + p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); + p.bStride = sz * b.stride(0); + + // fu, fd strides are in elements. + p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); + p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); + + // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. + bool index64b = false; + if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; + if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; + if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; + if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; + if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; + if (s.numel() > INT_MAX) index64b = true; + + // Choose CUDA kernel. + filtered_lrelu_kernel_spec spec = { 0 }; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] + { + if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. + { + // Choose kernel based on index type, datatype and sign read/write modes. + if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); + } + }); + TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = spec.numWarps * 32; + int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; + int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; + int gz = p.yShape.z * p.yShape.w; + + // Repeat multiple horizontal tiles in a CTA? + if (spec.xrep) + { + p.tilesXrep = spec.xrep; + p.tilesXdim = gx; + + gx = (gx + p.tilesXrep - 1) / p.tilesXrep; + std::swap(gx, gy); + } + else + { + p.tilesXrep = 0; + p.tilesXdim = 0; + } + + // Launch filter setup kernel. + AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); + + // Copy kernels to constant memory. + if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + + // Set cache and shared memory configurations for main kernel. + AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); + if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? + AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); + AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); + + // Launch main kernel. + const int maxSubGz = 65535; // CUDA maximum for block z dimension. + for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. + { + p.blockZofs = zofs; + int subGz = std::min(maxSubGz, gz - zofs); + AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); + } + + // Done. + return std::make_tuple(y, so, 0); +} + +//------------------------------------------------------------------------ + +static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) +{ + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); + + // Output signs if we don't have sign input. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + if (writeSigns) + { + int64_t sw = x.size(3); + sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. + s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); + } + + // Validate sign tensor if in use. + if (readSigns || writeSigns) + { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); + } + + // Initialize CUDA kernel parameters. + filtered_lrelu_act_kernel_params p; + p.x = x.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); + p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. + p.sOfs = make_int2(sx, sy); + + // Choose CUDA kernel. + void* func = 0; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] + { + if (writeSigns) + func = choose_filtered_lrelu_act_kernel(); + else if (readSigns) + func = choose_filtered_lrelu_act_kernel(); + else + func = choose_filtered_lrelu_act_kernel(); + }); + TORCH_CHECK(func, "internal error - CUDA kernel not found"); + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = 128; // 4 warps per block. + + // Logical size of launch = writeSigns ? p.s : p.x + uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; + uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; + uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. + gx = (gx - 1) / bx + 1; + + // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. + const uint32_t gmax = 65535; + gy = std::min(gy, gmax); + gz = std::min(gz, gmax); + + // Launch. + AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); + return so; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. + m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/filtered_lrelu.cu b/third_party/stylegan3_official_ops/filtered_lrelu.cu new file mode 100644 index 0000000000000000000000000000000000000000..8e6f47f873d42f7181a0faf64779377e70be3012 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu.cu @@ -0,0 +1,1284 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "filtered_lrelu.h" +#include + +//------------------------------------------------------------------------ +// Helpers. + +enum // Filter modes. +{ + MODE_SUSD = 0, // Separable upsampling, separable downsampling. + MODE_FUSD = 1, // Full upsampling, separable downsampling. + MODE_SUFD = 2, // Separable upsampling, full downsampling. + MODE_FUFD = 3, // Full upsampling, full downsampling. +}; + +template struct InternalType; +template <> struct InternalType +{ + typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); } + __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); } +}; +template <> struct InternalType +{ + typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } + __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } +}; +template <> struct InternalType +{ + typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } + __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } + __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } +}; + +#define MIN(A, B) ((A) < (B) ? (A) : (B)) +#define MAX(A, B) ((A) > (B) ? (A) : (B)) +#define CEIL_DIV(A, B) (((B)==1) ? (A) : \ + ((B)==2) ? ((int)((A)+1) >> 1) : \ + ((B)==4) ? ((int)((A)+3) >> 2) : \ + (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B))) + +// This works only up to blocks of size 256 x 256 and for all N that are powers of two. +template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) +{ + if ((N & (N-1)) && N <= 256) + y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256. + else + y = i/N; + + x = i - y*N; +} + +// Type cast stride before reading it. +template __device__ __forceinline__ T get_stride(const int64_t& x) +{ + return *reinterpret_cast(&x); +} + +//------------------------------------------------------------------------ +// Filters, setup kernel, copying function. + +#define MAX_FILTER_SIZE 32 + +// Combined up/down filter buffers so that transfer can be done with one copy. +__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel. +__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel. + +// Accessors to combined buffers to index up/down filters individually. +#define c_fu (c_fbuf) +#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) +#define g_fu (g_fbuf) +#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) + +// Set up filters into global memory buffer. +static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) +{ + for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x) + { + int x, y; + fast_div_mod(x, y, idx); + + int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); + int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); + if (p.fuShape.y > 0) + g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; + else + g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; + + int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); + int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); + if (p.fdShape.y > 0) + g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; + else + g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; + } +} + +// Host function to copy filters written by setup kernel into constant buffer for main kernel. +template static cudaError_t copy_filters(cudaStream_t stream) +{ + void* src = 0; + cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf); + if (err) return err; + return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream); +} + +//------------------------------------------------------------------------ +// Coordinate spaces: +// - Relative to input tensor: inX, inY, tileInX, tileInY +// - Relative to input tile: relInX, relInY, tileInW, tileInH +// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH +// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH +// - Relative to output tensor: outX, outY, tileOutX, tileOutY +// +// Relationships between coordinate spaces: +// - inX = tileInX + relInX +// - inY = tileInY + relInY +// - relUpX = relInX * up + phaseInX +// - relUpY = relInY * up + phaseInY +// - relUpX = relOutX * down +// - relUpY = relOutY * down +// - outX = tileOutX + relOutX +// - outY = tileOutY + relOutY + +extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer. + +template +static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) +{ + // Check that we don't try to support non-existing filter modes. + static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported"); + static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported"); + static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor"); + static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor"); + static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor"); + static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor"); + static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE"); + static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters"); + static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters"); + static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4"); + static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4"); + + // Static definitions. + typedef typename InternalType::scalar_t scalar_t; + typedef typename InternalType::vec2_t vec2_t; + typedef typename InternalType::vec4_t vec4_t; + const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4. + const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. + const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. + const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. + const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up. + const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4. + + // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. + const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD)); + + // Sizes of logical buffers. + const int szIn = tileInH_up * tileInW; + const int szUpX = tileInH_up * tileUpW; + const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); + const int szDownX = tileUpH * tileOutW; + + // Sizes for shared memory arrays. + const int s_buf0_size_base = + (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) : + (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) : + (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) : + (filterMode == MODE_FUFD) ? szIn : + -1; + const int s_buf1_size_base = + (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) : + (filterMode == MODE_FUSD) ? szUpXY : + (filterMode == MODE_SUFD) ? szUpX : + (filterMode == MODE_FUFD) ? szUpXY : + -1; + + // Ensure U128 alignment. + const int s_buf0_size = (s_buf0_size_base + 3) & ~3; + const int s_buf1_size = (s_buf1_size_base + 3) & ~3; + + // Check at compile time that we don't use too much shared memory. + static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow"); + + // Declare shared memory arrays. + scalar_t* s_buf0; + scalar_t* s_buf1; + if (sharedKB <= 48) + { + // Allocate shared memory arrays here. + __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused. + s_buf0 = s_buf0_st; + s_buf1 = s_buf0 + s_buf0_size; + } + else + { + // Use the dynamically allocated shared memory array. + s_buf0 = (scalar_t*)s_buf_raw; + s_buf1 = s_buf0 + s_buf0_size; + } + + // Pointers to the buffers. + scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY] + scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX] + scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX] + scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX] + if (filterMode == MODE_SUSD) + { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + s_tileDownX = s_buf1; + } + else if (filterMode == MODE_FUSD) + { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + s_tileDownX = s_buf0; + } + else if (filterMode == MODE_SUFD) + { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + } + else if (filterMode == MODE_FUFD) + { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + } + + // Allow large grids in z direction via per-launch offset. + int channelIdx = blockIdx.z + p.blockZofs; + int batchIdx = channelIdx / p.yShape.z; + channelIdx -= batchIdx * p.yShape.z; + + // Offset to output feature map. In bytes. + index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w); + + // Sign shift amount. + uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; + + // Inner tile loop. + #pragma unroll 1 + for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++) + { + // Locate output tile. + int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; + int tileOutX = tileX * tileOutW; + int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; + + // Locate input tile. + int tmpX = tileOutX * down - p.pad0.x; + int tmpY = tileOutY * down - p.pad0.y; + int tileInX = CEIL_DIV(tmpX, up); + int tileInY = CEIL_DIV(tmpY, up); + const int phaseInX = tileInX * up - tmpX; + const int phaseInY = tileInY * up - tmpY; + + // Extra sync if input and output buffers are the same and we are not on first tile. + if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline))) + __syncthreads(); + + // Load input tile & apply bias. Unrolled. + scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride))); + index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w); + int idx = threadIdx.x; + const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); + #pragma unroll + for (int loop = 0; loop < loopCountIN; loop++) + { + int relInX, relInY; + fast_div_mod(relInX, relInY, idx); + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + + if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) + v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b; + + bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH); + if (!skip) + s_tileIn[idx] = v; + + idx += threadsPerBlock; + } + + if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter. + { + // Horizontal upsampling. + __syncthreads(); + if (up == 4) + { + for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) + { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + scalar_t a = s_tileIn[src0]; + if (phaseInX == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } + else if (phaseInX == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } + else if (phaseInX == 2) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } + else // (phaseInX == 3) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst+0] = v.x; + s_tileUpX[dst+1] = v.y; + s_tileUpX[dst+2] = v.z; + s_tileUpX[dst+3] = v.w; + } + } + else if (up == 2) + { + bool p0 = (phaseInX == 0); + for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) + { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + scalar_t a = s_tileIn[src0]; + if (p0) // (phaseInX == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } + else // (phaseInX == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst+0] = v.x; + s_tileUpX[dst+1] = v.y; + } + } + + // Vertical upsampling & nonlinearity. + + __syncthreads(); + int groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. + int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. + if (up == 4) + { + minY -= 3; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) + { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec4_t v = InternalType::zero_vec4(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } + else if (phaseInY == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } + else if (phaseInY == 2) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } + else // (phaseInY == 3) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + index_t si2 = si0 + p.sShape.x * 2; + index_t si3 = si0 + p.sShape.x * 3; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } + if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } + if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } + if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } + + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } + if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + } + } + else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) + { + int ss = (signX & 3) << 1; + if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } + if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } + if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; } + if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; } + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[dst + 0 * tileUpW] = v.x; + if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; + if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; + if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; + } + } + else if (up == 2) + { + minY -= 1; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) + { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec2_t v = InternalType::zero_vec2(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } + else // (phaseInY == 1) + { + #pragma unroll + for (int step = 0; step < fuSize / up; step++) + { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } + if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } + + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } + if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + } + } + } + else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) + { + if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } + if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + } + + if (!downInline) + { + // Write into temporary buffer. + s_tileUpXY[dst] = v.x; + if (relUpY0 < tileUpH - 1) + s_tileUpXY[dst + tileUpW] = v.y; + } + else + { + // Write directly into output buffer. + if ((uint32_t)x < p.yShape.x) + { + int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); + index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut; + if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); + if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]); + } + } + } + } + } + else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) + { + // Full upsampling filter. + + if (up == 2) + { + // 2 x 2-wide. + __syncthreads(); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs. + for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4) + { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); + int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); + int src0 = relInX0 + tileInW * relInY0; + int tap0y = (relInY0 * up + phaseInY - relUpY0); + + #define X_LOOP(TAPY, PX) \ + for (int sx = 0; sx < fuSize / up; sx++) \ + { \ + v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ + v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ + } + + vec4_t v = InternalType::zero_vec4(); + if (tap0y == 0 && phaseInX == 0) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(0, 0) } + if (tap0y == 0 && phaseInX == 1) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(0, 1) } + if (tap0y == 1 && phaseInX == 0) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(1, 0) } + if (tap0y == 1 && phaseInX == 1) + #pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; + #pragma unroll + X_LOOP(1, 1) } + + #undef X_LOOP + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } + if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } + if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } + if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } + + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + } + else + { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } + if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } + if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } + if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } + + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + else + { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + } + } + else if (signRead) // Read sign and apply. + { + if ((uint32_t)signY < p.sShape.y) + { + int s = 0; + if ((uint32_t)signXb < p.swLimit) s = p.s[si]; + if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; + s >>= (signX & 3) << 1; + if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f; + if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f; + if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f; + if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f; + } + } + else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[idx + 0] = v.x; + s_tileUpXY[idx + 1] = v.y; + s_tileUpXY[idx + 2] = v.z; + s_tileUpXY[idx + 3] = v.w; + } + } + else if (up == 1) + { + __syncthreads(); + uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x) + { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + v *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) + { + if (!enableWriteSkip) + { + // Determine and write sign. + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) + { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) + { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + } + else + { + // Determine and write sign. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) + { + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) + { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) + { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + else + { + // Just compute the value. + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + } + } + else if (signRead) + { + // Read sign and apply if within sign tensor bounds. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) + { + int s = p.s[si]; + s >>= signXo; + if (s & 1) v *= p.slope; + if (s & 2) v = 0.f; + } + } + else // Forward pass with no sign write. + { + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + + if (!downInline) // Write into temporary buffer. + s_tileUpXY[idx] = v; + else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer + *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); + } + } + } + + // Downsampling. + if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) + { + // Horizontal downsampling. + __syncthreads(); + if (down == 4 && tileOutW % 4 == 0) + { + // Calculate 4 pixels at a time. + for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + #pragma unroll + for (int step = 0; step < fdSize; step++) + { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; + v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; + v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx+0] = v.x; + s_tileDownX[idx+1] = v.y; + s_tileDownX[idx+2] = v.z; + s_tileDownX[idx+3] = v.w; + } + } + else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) + { + // Calculate 2 pixels at a time. + for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + #pragma unroll + for (int step = 0; step < fdSize; step++) + { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx+0] = v.x; + s_tileDownX[idx+1] = v.y; + } + } + else + { + // Calculate 1 pixel at a time. + for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x) + { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src = relUpY * tileUpW + relUpX0; + scalar_t v = 0.f; + #pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; + s_tileDownX[idx] = v; + } + } + + // Vertical downsampling & store output tile. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) + { + int relOutX, relOutY0; + fast_div_mod(relOutX, relOutY0, idx); + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileOutW + relOutX; + scalar_t v = 0; + #pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; + + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY0; + + if (outX < p.yShape.x & outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } + else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) + { + // Full downsampling filter. + if (down == 2) + { + // 2-wide. + __syncthreads(); + for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2) + { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + int relUpX0 = relOutX0 * down; + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + #pragma unroll + for (int sy = 0; sy < fdSize; sy++) + #pragma unroll + for (int sx = 0; sx < fdSize; sx++) + { + v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + } + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outY < p.yShape.y) + { + index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut; + if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x; + if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y; + } + } + } + else if (down == 1 && !downInline) + { + // Thread per pixel. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) + { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } + } + + if (!enableXrep) + break; + } +} + +//------------------------------------------------------------------------ +// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant. +// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used. + +template +static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Indexing. + int32_t x = threadIdx.x + blockIdx.x * blockDim.x; + int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; + int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. + + // Loop to accommodate oversized tensors. + for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) + for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) + { + // Extract z and w (channel, minibatch index). + int32_t w = q / p.xShape.z; + int32_t z = q - w * p.xShape.z; + + // Choose behavior based on sign read/write mode. + if (signWrite) + { + // Process value if in p.x. + uint32_t s = 0; + if (x < p.xShape.x && y < p.xShape.y) + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + + // Gain, LReLU, clamp. + v *= p.gain; + if (v < 0.f) + { + v *= p.slope; + s = 1; // Sign. + } + if (fabsf(v) > p.clamp) + { + v = InternalType::clamp(v, p.clamp); + s = 2; // Clamp. + } + + *pv = (T)v; // Write value. + } + + // Coalesce into threads 0 and 16 of warp. + uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; + s <<= ((threadIdx.x & 15) << 1); // Shift into place. + s |= __shfl_xor_sync(m, s, 1); // Distribute. + s |= __shfl_xor_sync(m, s, 2); + s |= __shfl_xor_sync(m, s, 4); + s |= __shfl_xor_sync(m, s, 8); + + // Write signs if leader and in p.s. + if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. + { + uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. + ((uint32_t*)p.s)[is >> 4] = s; + } + } + else if (signRead) + { + // Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + + // Apply sign buffer offset. + uint32_t sx = x + p.sOfs.x; + uint32_t sy = y + p.sOfs.y; + + // Read and apply signs if we land inside valid region of sign buffer. + if (sx < p.sShape.x && sy < p.sShape.y) + { + uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous. + unsigned char s = p.s[is]; + s >>= (sx & 3) << 1; // Shift into place. + if (s & 1) // Sign? + v *= p.slope; + if (s & 2) // Clamp? + v = 0.f; + } + + *pv = (T)v; // Write value. + } + } + else + { + // Forward pass with no sign write. Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + if (v < 0.f) + v *= p.slope; + if (fabsf(v) > p.clamp) + v = InternalType::clamp(v, p.clamp); + *pv = (T)v; // Write value. + } + } + } +} + +template void* choose_filtered_lrelu_act_kernel(void) +{ + return (void*)filtered_lrelu_act_kernel; +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB) +{ + filtered_lrelu_kernel_spec s = { 0 }; + + // Return the first matching kernel. +#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ + if (sharedKB >= SH) \ + if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ + if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ + if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \ + { \ + static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \ + static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \ + static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \ + s.setup = (void*)setup_filters_kernel; \ + s.exec = (void*)filtered_lrelu_kernel; \ + s.tileOut = make_int2(TW, TH); \ + s.numWarps = W; \ + s.xrep = XR; \ + s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ + return s; \ + } + + // Launch parameters for various kernel specializations. + // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first. + // Kernels that use more shared memory must be listed before those that use less, for the same reason. + + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2 + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4 + CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1 + CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2 + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2 + CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2 + CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2 + CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4 + CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB + CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4 + + #undef CASE + return s; // No kernel found. +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/filtered_lrelu.h b/third_party/stylegan3_official_ops/filtered_lrelu.h new file mode 100644 index 0000000000000000000000000000000000000000..2c403e3f275f472315662321cad54dd0dbc56d00 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu.h @@ -0,0 +1,90 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct filtered_lrelu_kernel_params +{ + // These parameters decide which kernel to use. + int up; // upsampling ratio (1, 2, 4) + int down; // downsampling ratio (1, 2, 4) + int2 fuShape; // [size, 1] | [size, size] + int2 fdShape; // [size, 1] | [size, size] + + int _dummy; // Alignment. + + // Rest of the parameters. + const void* x; // Input tensor. + void* y; // Output tensor. + const void* b; // Bias tensor. + unsigned char* s; // Sign tensor in/out. NULL if unused. + const float* fu; // Upsampling filter. + const float* fd; // Downsampling filter. + + int2 pad0; // Left/top padding. + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + int flip; // Filter kernel flip for gradient computation. + + int tilesXdim; // Original number of horizontal output tiles. + int tilesXrep; // Number of horizontal tiles per CTA. + int blockZofs; // Block z offset to support large minibatch, channel dimensions. + + int4 xShape; // [width, height, channel, batch] + int4 yShape; // [width, height, channel, batch] + int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. + int swLimit; // Active width of sign tensor in bytes. + + longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. + longlong4 yStride; // + int64_t bStride; // + longlong3 fuStride; // + longlong3 fdStride; // +}; + +struct filtered_lrelu_act_kernel_params +{ + void* x; // Input/output, modified in-place. + unsigned char* s; // Sign tensor in/out. NULL if unused. + + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + + int4 xShape; // [width, height, channel, batch] + longlong4 xStride; // Input/output tensor strides, same order as in shape. + int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct filtered_lrelu_kernel_spec +{ + void* setup; // Function for filter kernel setup. + void* exec; // Function for main operation. + int2 tileOut; // Width/height of launch tile. + int numWarps; // Number of warps per thread block, determines launch block size. + int xrep; // For processing multiple horizontal tiles per thread block. + int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template void* choose_filtered_lrelu_act_kernel(void); +template cudaError_t copy_filters(cudaStream_t stream); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/filtered_lrelu.py b/third_party/stylegan3_official_ops/filtered_lrelu.py new file mode 100644 index 0000000000000000000000000000000000000000..ec924b630622f9e945baa2d3c674cf158b524005 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu.py @@ -0,0 +1,297 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom operators for Leaky ReLU, wrapped with upsampling and downsampling. + +Leaky ReLU will introduce an extremely high frequency into the source feature +map. To solve this problem, an upsampling layer and a downsampling layer are +wrapped around the Leaky ReLU operator. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=global-statement +# pylint: disable=multiple-statements +# pylint: disable=inconsistent-quotes + +import os +import warnings +import numpy as np +import torch + +from . import custom_ops +from . import misc +from . import upfirdn2d +from . import bias_act + +#---------------------------------------------------------------------------- + +_plugin = None + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='filtered_lrelu_plugin', + sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], + headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) + assert 1 <= f.ndim <= 2 + return f.shape[-1], f.shape[0] # width, height + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, (int, np.integer)) for x in padding) + padding = [int(x) for x in padding] + if len(padding) == 2: + px, py = padding + padding = [px, px, py, py] + px0, px1, py0, py1 = padding + return px0, px1, py0, py1 + +#---------------------------------------------------------------------------- + +def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): + r"""Filtered leaky ReLU for a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Add channel-specific bias if provided (`b`). + + 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 3. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 5. Multiply each value by the provided gain factor (`gain`). + + 6. Apply leaky ReLU activation function to each value. + + 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. + + 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking + it so that the footprint of all output pixels lies within the input image. + + 9. Downsample the image by keeping every Nth pixel (`down`). + + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float16/float64 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + fu: Float32 upsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + fd: Float32 downsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type + as `x`. The length of vector must must match the channel dimension of `x`. + up: Integer upsampling factor (default: 1). + down: Integer downsampling factor. (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + gain: Overall scaling factor for signal magnitude (default: sqrt(2)). + slope: Slope on the negative side of leaky ReLU (default: 0.2). + clamp: Maximum magnitude for leaky ReLU output (default: None). + flip_filter: False = convolution, True = correlation (default: False). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) + return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using + existing `upfirdn2n()` and `bias_act()` ops. + """ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + fu_w, fu_h = _get_filter_size(fu) + fd_w, fd_h = _get_filter_size(fd) + if b is not None: + assert isinstance(b, torch.Tensor) and b.dtype == x.dtype + misc.assert_shape(b, [x.shape[1]]) + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + assert slope == float(slope) and slope >= 0 + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + + # Calculate output size. + batch_size, channels, in_h, in_w = x.shape + in_dtype = x.dtype + out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down + out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down + + # Compute using existing ops. + x = bias_act.bias_act(x=x, b=b, impl='ref') # Apply bias. + x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter, impl='ref') # Upsample. + x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp, impl='ref') # Bias, leaky ReLU, clamp. + x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter, impl='ref') # Downsample. + + # Check output shape & dtype. + misc.assert_shape(x, [batch_size, channels, out_h, out_w]) + assert x.dtype == in_dtype + return x + +#---------------------------------------------------------------------------- + +_filtered_lrelu_cuda_cache = dict() + +def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): + """Fast CUDA implementation of `filtered_lrelu()` using custom ops. + """ + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + gain = float(gain) + assert slope == float(slope) and slope >= 0 + slope = float(slope) + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + clamp = float(clamp if clamp is not None else 'inf') + + # Lookup from cache. + key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) + if key in _filtered_lrelu_cuda_cache: + return _filtered_lrelu_cuda_cache[key] + + # Forward op. + class FilteredLReluCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + + # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). + if fu is None: + fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if fd is None: + fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert 1 <= fu.ndim <= 2 + assert 1 <= fd.ndim <= 2 + + # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. + if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: + fu = fu.square()[None] + if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: + fd = fd.square()[None] + + # Missing sign input tensor. + if si is None: + si = torch.empty([0]) + + # Missing bias tensor. + if b is None: + b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) + + # Construct internal sign tensor only if gradients are needed. + write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) + + # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. + strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] + if any(a < b for a, b in zip(strides[:-1], strides[1:])): + warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) + + # Call C++/Cuda plugin if datatype is supported. + if x.dtype in [torch.float16, torch.float32]: + if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): + warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) + y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) + else: + return_code = -1 + + # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because + # only the bit-packed sign tensor is retained for gradient computation. + if return_code < 0: + warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) + + y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. + y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. + so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. + y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. + + # Prepare for gradient computation. + ctx.save_for_backward(fu, fd, (si if si.numel() else so)) + ctx.x_shape = x.shape + ctx.y_shape = y.shape + ctx.s_ofs = sx, sy + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + fu, fd, si = ctx.saved_tensors + _, _, xh, xw = ctx.x_shape + _, _, yh, yw = ctx.y_shape + sx, sy = ctx.s_ofs + dx = None # 0 + dfu = None; assert not ctx.needs_input_grad[1] + dfd = None; assert not ctx.needs_input_grad[2] + db = None # 3 + dsi = None; assert not ctx.needs_input_grad[4] + dsx = None; assert not ctx.needs_input_grad[5] + dsy = None; assert not ctx.needs_input_grad[6] + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: + pp = [ + (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, + xw * up - yw * down + px0 - (up - 1), + (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, + xh * up - yh * down + py0 - (up - 1), + ] + gg = gain * (up ** 2) / (down ** 2) + ff = (not flip_filter) + sx = sx - (fu.shape[-1] - 1) + px0 + sy = sy - (fu.shape[0] - 1) + py0 + dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) + + if ctx.needs_input_grad[3]: + db = dx.sum([0, 2, 3]) + + return dx, dfu, dfd, db, dsi, dsx, dsy + + # Add to cache. + _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda + return FilteredLReluCuda + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=global-statement +# pylint: enable=multiple-statements +# pylint: enable=inconsistent-quotes diff --git a/third_party/stylegan3_official_ops/filtered_lrelu_ns.cu b/third_party/stylegan3_official_ops/filtered_lrelu_ns.cu new file mode 100644 index 0000000000000000000000000000000000000000..ef5d948c4fdf9cb0fe8a42f6268c61aeef6b2000 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu_ns.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for no signs mode (no gradients required). + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/third_party/stylegan3_official_ops/filtered_lrelu_rd.cu b/third_party/stylegan3_official_ops/filtered_lrelu_rd.cu new file mode 100644 index 0000000000000000000000000000000000000000..968347882e9aebd36204f67e201cd16226dd9132 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu_rd.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for sign read mode. + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/third_party/stylegan3_official_ops/filtered_lrelu_wr.cu b/third_party/stylegan3_official_ops/filtered_lrelu_wr.cu new file mode 100644 index 0000000000000000000000000000000000000000..a4c6a24aae908bc07248f7ff710cbd1a11a38bb1 --- /dev/null +++ b/third_party/stylegan3_official_ops/filtered_lrelu_wr.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include "filtered_lrelu.cu" + +// Template/kernel specializations for sign write mode. + +// Full op, 32-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Full op, 64-bit indexing. +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); +template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); + +// Activation/signs only for generic variant. 64-bit indexing. +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); +template void* choose_filtered_lrelu_act_kernel(void); + +// Copy filters to constant memory. +template cudaError_t copy_filters(cudaStream_t stream); diff --git a/third_party/stylegan3_official_ops/fma.py b/third_party/stylegan3_official_ops/fma.py new file mode 100644 index 0000000000000000000000000000000000000000..26195fdb5d4e0329703b7d6e5578f4d17ec57cde --- /dev/null +++ b/third_party/stylegan3_official_ops/fma.py @@ -0,0 +1,73 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-function-docstring + +import torch + +#---------------------------------------------------------------------------- + +def fma(a, b, c, impl='cuda'): # => a * b + c + if impl == 'cuda': + return _FusedMultiplyAdd.apply(a, b, c) + return torch.addcmul(c, a, b) + +#---------------------------------------------------------------------------- + +class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c + @staticmethod + def forward(ctx, a, b, c): # pylint: disable=arguments-differ + out = torch.addcmul(c, a, b) + ctx.save_for_backward(a, b) + ctx.c_shape = c.shape + return out + + @staticmethod + def backward(ctx, dout): # pylint: disable=arguments-differ + a, b = ctx.saved_tensors + c_shape = ctx.c_shape + da = None + db = None + dc = None + + if ctx.needs_input_grad[0]: + da = _unbroadcast(dout * b, a.shape) + + if ctx.needs_input_grad[1]: + db = _unbroadcast(dout * a, b.shape) + + if ctx.needs_input_grad[2]: + dc = _unbroadcast(dout, c_shape) + + return da, db, dc + +#---------------------------------------------------------------------------- + +def _unbroadcast(x, shape): + extra_dims = x.ndim - len(shape) + assert extra_dims >= 0 + dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] + if len(dim): + x = x.sum(dim=dim, keepdim=True) + if extra_dims: + x = x.reshape(-1, *x.shape[extra_dims+1:]) + assert x.shape == shape + return x + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan3_official_ops/grid_sample_gradfix.py b/third_party/stylegan3_official_ops/grid_sample_gradfix.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d9cd591a13e146eeeedddbef28871d7c3a0742 --- /dev/null +++ b/third_party/stylegan3_official_ops/grid_sample_gradfix.py @@ -0,0 +1,92 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom replacement for `torch.nn.functional.grid_sample`. + +This is useful for differentiable augmentation. This customized operator +supports arbitrarily high order gradients between the input and output. Only +works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and +`align_corners=False`. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access +# pylint: disable=line-too-long +# pylint: disable=missing-function-docstring + +import torch + +#---------------------------------------------------------------------------- + +enabled = True # Enable the custom op by setting this to true. + +#---------------------------------------------------------------------------- + +def grid_sample(input, grid, impl='cuda'): + if impl == 'cuda' and _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + +#---------------------------------------------------------------------------- + +def _should_use_custom_op(): + return enabled + +#---------------------------------------------------------------------------- + +class _GridSample2dForward(torch.autograd.Function): + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) + return grad_input, grad_grid + +#---------------------------------------------------------------------------- + +class _GridSample2dBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid + +#---------------------------------------------------------------------------- + +# pylint: enable=redefined-builtin +# pylint: enable=arguments-differ +# pylint: enable=protected-access +# pylint: enable=line-too-long +# pylint: enable=missing-function-docstring diff --git a/third_party/stylegan3_official_ops/misc.py b/third_party/stylegan3_official_ops/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..1acfb7ea16904c07e362aeaae7337920d06fe5ca --- /dev/null +++ b/third_party/stylegan3_official_ops/misc.py @@ -0,0 +1,283 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Misc functions for customized operations. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +# pylint: disable=use-maxsplit-arg + +import re +import contextlib +import warnings +from easydict import EasyDict +import numpy as np +import torch + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ('ignore', None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=missing-function-docstring +# pylint: enable=use-maxsplit-arg diff --git a/third_party/stylegan3_official_ops/upfirdn2d.cpp b/third_party/stylegan3_official_ops/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..44fa337d8d4c34dfa010a59cd27d86857db671aa --- /dev/null +++ b/third_party/stylegan3_official_ops/upfirdn2d.cpp @@ -0,0 +1,107 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.numel() > 0, "x has zero size"); + TORCH_CHECK(f.numel() > 0, "f has zero size"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d", &upfirdn2d); +} + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/upfirdn2d.cu b/third_party/stylegan3_official_ops/upfirdn2d.cu new file mode 100644 index 0000000000000000000000000000000000000000..3a33e31bbb1bbc1cd02ee7d2ede3943917f3906e --- /dev/null +++ b/third_party/stylegan3_official_ops/upfirdn2d.cu @@ -0,0 +1,384 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + // No up/downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x upsampling. + if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 2x downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; + if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + + // 4x upsampling. + if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + + // 4x downsampling (inefficient). + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) + { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) + { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/upfirdn2d.h b/third_party/stylegan3_official_ops/upfirdn2d.h new file mode 100644 index 0000000000000000000000000000000000000000..2793daf874492af01e8634a7863c036e17b6731f --- /dev/null +++ b/third_party/stylegan3_official_ops/upfirdn2d.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/third_party/stylegan3_official_ops/upfirdn2d.py b/third_party/stylegan3_official_ops/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..b4cf0bb8fc299e66997b28cd517b8252619d3f26 --- /dev/null +++ b/third_party/stylegan3_official_ops/upfirdn2d.py @@ -0,0 +1,404 @@ +# python3.7 + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom operators for efficient resampling of 2D images. + +`upfirdn` means executing upsampling, FIR filtering, downsampling in sequence. + +Please refer to https://github.com/NVlabs/stylegan3 +""" + +# pylint: disable=line-too-long +# pylint: disable=missing-class-docstring +# pylint: disable=global-statement + +import os +import numpy as np +import torch + +from . import custom_ops +from . import misc +from . import conv2d_gradfix + +#---------------------------------------------------------------------------- + +_plugin = None + +def _init(): + global _plugin + if _plugin is None: + _plugin = custom_ops.get_plugin( + module_name='upfirdn2d_plugin', + sources=['upfirdn2d.cpp', 'upfirdn2d.cu'], + headers=['upfirdn2d.h'], + source_dir=os.path.dirname(__file__), + extra_cuda_cflags=['--use_fast_math'], + ) + return True + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + with misc.suppress_tracer_warnings(): + fw = int(fw) + fh = int(fh) + misc.assert_shape(f, [fh, fw][:f.ndim]) + assert fw >= 1 and fh >= 1 + return fw, fh + +#---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + +#---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda' and _init(): + return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Check that upsampled buffer is not smaller than the filter. + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= f.shape[-1] and upH >= f.shape[0] + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) + x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) + else: + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + +#---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if f.ndim == 1 and f.shape[0] == 1: + f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1. + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + else: + y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0) + y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + +#---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) + +#---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +#---------------------------------------------------------------------------- + +# pylint: enable=line-too-long +# pylint: enable=missing-class-docstring +# pylint: enable=global-statement diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/custom_utils.py b/utils/custom_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf0208d76adc4c5dcd7dfdbbb1b44e4e29c30f47 --- /dev/null +++ b/utils/custom_utils.py @@ -0,0 +1,86 @@ +# python3.7 +"""Utility functions for image editing.""" + +import numpy as np +import cv2 +import torch + + +__all__ = ['to_numpy', 'linear_interpolate', 'make_transform', + 'get_ind', 'mask2image'] + + +def to_numpy(data): + """Converts the input data to `numpy.ndarray`.""" + if isinstance(data, (int, float)): + return np.array(data) + if isinstance(data, np.ndarray): + return data + if isinstance(data, torch.Tensor): + return data.detach().cpu().numpy() + raise TypeError(f'Not supported data type `{type(data)}` for ' + f'converting to `numpy.ndarray`!') + + +def linear_interpolate(latent_code, + boundary, + layer_index=None, + start_distance=-10.0, + end_distance=10.0, + steps=21): + """Interpolate between the latent code and boundary.""" + assert (len(latent_code.shape) == 3 and len(boundary.shape) == 3 and + latent_code.shape[0] == 1 and boundary.shape[0] == 1 and + latent_code.shape[1] == boundary.shape[1]) + linspace = np.linspace(start_distance, end_distance, steps) + linspace = linspace.reshape([-1, 1, 1]).astype(np.float32) + inter_code = linspace * boundary + is_manipulatable = np.zeros(inter_code.shape, dtype=bool) + is_manipulatable[:, layer_index, :] = True + mani_code = np.where(is_manipulatable, latent_code+inter_code, latent_code) + return mani_code + + +def make_transform(tx, ty, angle): + """Transform the input feature maps with given + coordinates and rotation angle. + + cos(theta) -sin(theta) tx + sin(theta) cos(theta) ty + 0 0 1 + + """ + m = np.eye(3) + s = np.sin(angle/360.0*np.pi*2) + c = np.cos(angle/360.0*np.pi*2) + m[0][0] = c + m[0][1] = s + m[0][2] = tx + m[1][0] = -s + m[1][1] = c + m[1][2] = ty + return m + + +def get_ind(seg_mask, label): + """Get the index of the masked and unmasked region.""" + mask = np.where(seg_mask == label, + np.ones_like(seg_mask), + np.zeros_like(seg_mask)) + f_ind = np.where(mask == 1) + b_ind = np.where((1 - mask) == 1) + return f_ind, b_ind, mask + + +def mask2image(image, mask, r=3, g=255, b=118): + """Show the mask on the given image.""" + assert image.shape[0] == image.shape[1] + r_c = np.ones([256, 256, 1]) * r + g_c = np.ones([256, 256, 1]) * g + b_c = np.ones([256, 256, 1]) * b + img1 = np.concatenate([r_c, g_c, b_c], axis=2).astype(np.uint8) + mask = np.expand_dims(mask, axis=2).astype(np.uint8) + img1 = img1 * mask + image = cv2.addWeighted(image, 0.4, img1, 0.6, 0) + mask_i = np.tile(mask, [1, 1, 3]) * 255 + return image, mask_i diff --git a/utils/dist_utils.py b/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c792ba8b34fed30942b79c8054141b98c6273d8f --- /dev/null +++ b/utils/dist_utils.py @@ -0,0 +1,67 @@ +# python3.7 +"""Contains utility functions used for distribution.""" + +import contextlib +import os +import subprocess + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +__all__ = ['init_dist', 'exit_dist', 'ddp_sync', 'get_ddp_module'] + + +def init_dist(launcher, backend='nccl', **kwargs): + """Initializes distributed environment.""" + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + elif launcher == 'slurm': + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + port = os.environ.get('PORT', 29500) + os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + else: + raise NotImplementedError(f'Not implemented launcher type: ' + f'`{launcher}`!') + + +def exit_dist(): + """Exits the distributed environment.""" + if dist.is_initialized(): + dist.destroy_process_group() + + +@contextlib.contextmanager +def ddp_sync(model, sync): + """Controls whether the `DistributedDataParallel` model should be synced.""" + assert isinstance(model, torch.nn.Module) + is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel) + if sync or not is_ddp: + yield + else: + with model.no_sync(): + yield + + +def get_ddp_module(model): + """Gets the module from `DistributedDataParallel`.""" + assert isinstance(model, torch.nn.Module) + is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel) + if is_ddp: + return model.module + return model diff --git a/utils/file_transmitters/__init__.py b/utils/file_transmitters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..027642bc8f3bcd6bc992daf4aa419fd4ab464372 --- /dev/null +++ b/utils/file_transmitters/__init__.py @@ -0,0 +1,30 @@ +# python3.7 +"""Collects all file transmitters.""" + +from .local_file_transmitter import LocalFileTransmitter +from .dummy_file_transmitter import DummyFileTransmitter + +__all__ = ['build_file_transmitter'] + +_TRANSMITTERS = { + 'local': LocalFileTransmitter, + 'dummy': DummyFileTransmitter, +} + + +def build_file_transmitter(transmitter_type='local', **kwargs): + """Builds a file transmitter. + + Args: + transmitter_type: Type of the file transmitter_type, which is case + insensitive. (default: `normal`) + **kwargs: Additional arguments to build the file transmitter. + + Raises: + ValueError: If the `transmitter_type` is not supported. + """ + transmitter_type = transmitter_type.lower() + if transmitter_type not in _TRANSMITTERS: + raise ValueError(f'Invalid transmitter type: `{transmitter_type}`!\n' + f'Types allowed: {list(_TRANSMITTERS)}.') + return _TRANSMITTERS[transmitter_type](**kwargs) diff --git a/utils/file_transmitters/base_file_transmitter.py b/utils/file_transmitters/base_file_transmitter.py new file mode 100644 index 0000000000000000000000000000000000000000..5e93d2e85581378d01a8227feca29219f5c51417 --- /dev/null +++ b/utils/file_transmitters/base_file_transmitter.py @@ -0,0 +1,92 @@ +# python3.7 +"""Contains the base class to transmit files across file systems. + +Basically, a file transmitter connects the local file system, on which the +programme runs, to a remote file system. This is particularly used for +(1) pulling files that are required by the programme from remote, and +(2) pushing results that are produced by the programme to remote. In this way, +the programme can focus on local file system only. + +NOTE: The remote file system can be the same as the local file system, since +users may want to transmit files across directories. +""" + +import warnings + +__all__ = ['BaseFileTransmitter'] + + +class BaseFileTransmitter(object): + """Defines the base file transmitter. + + A transmitter should have the following functions: + + (1) pull(): The function to pull a file/directory from remote to local. + (2) push(): The function to push a file/directory from local to remote. + (3) remove(): The function to remove a file/directory. + (4) make_remote_dir(): Make directory remotely. + + + To simplify, each derived class just need to implement the following helper + functions: + + (1) download_hard(): Hard download a file/directory from remote to local. + (2) download_soft(): Soft download a file/directory from remote to local. + This is especially used to save space (e.g., soft link). + (3) upload(): Upload a file/directory from local to remote. + (4) delete(): Delete a file/directory according to given path. + """ + + def __init__(self): + pass + + @property + def name(self): + """Returns the class name of the file transmitter.""" + return self.__class__.__name__ + + @staticmethod + def download_hard(src, dst): + """Downloads (in hard mode) a file/directory from remote to local.""" + raise NotImplementedError('Should be implemented in derived class!') + + @staticmethod + def download_soft(src, dst): + """Downloads (in soft mode) a file/directory from local to remote.""" + raise NotImplementedError('Should be implemented in derived class!') + + @staticmethod + def upload(src, dst): + """Uploads a file/directory from local to remote.""" + raise NotImplementedError('Should be implemented in derived class!') + + @staticmethod + def delete(path): + """Deletes the given path.""" + # TODO: should we secure the path to avoid mis-removing / attacks? + raise NotImplementedError('Should be implemented in derived class!') + + def pull(self, src, dst, hard=False): + """Pulls a file/directory from remote to local. + + The argument `hard` is to control the download mode (hard or soft). + For example, the hard mode may hardly copy the file while the soft mode + may softly link the file. + """ + if hard: + self.download_hard(src, dst) + else: + self.download_soft(src, dst) + + def push(self, src, dst): + """Pushes a file/directory from local to remote.""" + self.upload(src, dst) + + def remove(self, path): + """Removes the given path.""" + warnings.warn(f'`{path}` will be removed!') + self.delete(path) + + def make_remote_dir(self, directory): + """Makes a directory on the remote system.""" + raise NotImplementedError('Should be implemented in derived class!') diff --git a/utils/file_transmitters/dummy_file_transmitter.py b/utils/file_transmitters/dummy_file_transmitter.py new file mode 100644 index 0000000000000000000000000000000000000000..c553f4082061da9e6d8194dbbc2ce16f7a122554 --- /dev/null +++ b/utils/file_transmitters/dummy_file_transmitter.py @@ -0,0 +1,34 @@ +# python3.7 +"""Contains the class of dummy file transmitter. + +This file transmitter has all expected data transmission functions but behaves +silently, which is very useful in multi-processing mode. Only the chief process +can have the file transmitter with normal behavior. +""" + +from .base_file_transmitter import BaseFileTransmitter + +__all__ = ['DummyFileTransmitter'] + + +class DummyFileTransmitter(BaseFileTransmitter): + """Implements a dummy transmitter which transmits nothing.""" + + @staticmethod + def download_hard(src, dst): + return + + @staticmethod + def download_soft(src, dst): + return + + @staticmethod + def upload(src, dst): + return + + @staticmethod + def delete(path): + return + + def make_remote_dir(self, directory): + return diff --git a/utils/file_transmitters/local_file_transmitter.py b/utils/file_transmitters/local_file_transmitter.py new file mode 100644 index 0000000000000000000000000000000000000000..562becf65ce0052559109300557c8c8de2e142b6 --- /dev/null +++ b/utils/file_transmitters/local_file_transmitter.py @@ -0,0 +1,35 @@ +# python3.7 +"""Contains the class of local file transmitter. + +The transmitter builds the connection between the local file system and itself. +This can be used to transmit files from one directory to another. Consequently, +`remote` in this file also means `local`. +""" + +from utils.misc import print_and_execute +from .base_file_transmitter import BaseFileTransmitter + +__all__ = ['LocalFileTransmitter'] + + +class LocalFileTransmitter(BaseFileTransmitter): + """Implements the transmitter connecting local file system to itself.""" + + @staticmethod + def download_hard(src, dst): + print_and_execute(f'cp {src} {dst}') + + @staticmethod + def download_soft(src, dst): + print_and_execute(f'ln -s {src} {dst}') + + @staticmethod + def upload(src, dst): + print_and_execute(f'cp {src} {dst}') + + @staticmethod + def delete(path): + print_and_execute(f'rm -r {path}') + + def make_remote_dir(self, directory): + print_and_execute(f'mkdir -p {directory}') diff --git a/utils/formatting_utils.py b/utils/formatting_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..20f9f14050da889b7b9be0867e9373ff54ebe42d --- /dev/null +++ b/utils/formatting_utils.py @@ -0,0 +1,178 @@ +# python3.7 +"""Contains utility functions used for formatting.""" + +import cv2 +import numpy as np + +__all__ = [ + 'format_time', 'format_range', 'format_image_size', 'format_image', + 'raw_label_to_one_hot', 'one_hot_to_raw_label' +] + + +def format_time(seconds): + """Formats seconds to readable time string. + + Args: + seconds: Number of seconds to format. + + Returns: + The formatted time string. + + Raises: + ValueError: If the input `seconds` is less than 0. + """ + if seconds < 0: + raise ValueError(f'Input `seconds` should be greater than or equal to ' + f'0, but `{seconds}` is received!') + + # Returns seconds as float if less than 1 minute. + if seconds < 10: + return f'{seconds:7.3f} s' + if seconds < 60: + return f'{seconds:7.2f} s' + + seconds = int(seconds + 0.5) + days, seconds = divmod(seconds, 86400) + hours, seconds = divmod(seconds, 3600) + minutes, seconds = divmod(seconds, 60) + if days: + return f'{days:2d} d {hours:02d} h' + if hours: + return f'{hours:2d} h {minutes:02d} m' + return f'{minutes:2d} m {seconds:02d} s' + + +def format_range(obj, min_val=None, max_val=None): + """Formats the given object to a valid range. + + If `min_val` or `max_val` is provided, both the starting value and the end + value will be clamped to range `[min_val, max_val]`. + + NOTE: (a, b) is regarded as a valid range if and only if `a <= b`. + + Args: + obj: The input object to format. + min_val: The minimum value to cut off the input range. If not provided, + the default minimum value is negative infinity. (default: None) + max_val: The maximum value to cut off the input range. If not provided, + the default maximum value is infinity. (default: None) + + Returns: + A two-elements tuple, indicating the start and the end of the range. + + Raises: + ValueError: If the input object is an invalid range. + """ + if not isinstance(obj, (tuple, list)): + raise ValueError(f'Input object must be a tuple or a list, ' + f'but `{type(obj)}` received!') + if len(obj) != 2: + raise ValueError(f'Input object is expected to contain two elements, ' + f'but `{len(obj)}` received!') + if obj[0] > obj[1]: + raise ValueError(f'The second element is expected to be equal to or ' + f'greater than the first one, ' + f'but `({obj[0]}, {obj[1]})` received!') + + obj = list(obj) + if min_val is not None: + obj[0] = max(obj[0], min_val) + obj[1] = max(obj[1], min_val) + if max_val is not None: + obj[0] = min(obj[0], max_val) + obj[1] = min(obj[1], max_val) + return tuple(obj) + + +def format_image_size(size): + """Formats the given image size to a two-element tuple. + + A valid image size can be an integer, indicating both the height and the + width, OR can be a two-element list or tuple. Both height and width are + assumed to be positive integer. + + Args: + size: The input size to format. + + Returns: + A two-elements tuple, indicating the height and the width, respectively. + + Raises: + ValueError: If the input size is invalid. + """ + if not isinstance(size, (int, tuple, list)): + raise ValueError(f'Input size must be an integer, a tuple, or a list, ' + f'but `{type(size)}` received!') + if isinstance(size, int): + size = (size, size) + else: + if len(size) == 1: + size = (size[0], size[0]) + if not len(size) == 2: + raise ValueError(f'Input size is expected to have two numbers at ' + f'most, but `{len(size)}` numbers received!') + if not isinstance(size[0], int) or size[0] < 0: + raise ValueError(f'The height is expected to be a non-negative ' + f'integer, but `{size[0]}` received!') + if not isinstance(size[1], int) or size[1] < 0: + raise ValueError(f'The width is expected to be a non-negative ' + f'integer, but `{size[1]}` received!') + return tuple(size) + + +def format_image(image): + """Formats an image read from `cv2`. + + NOTE: This function will always return a 3-dimensional image (i.e., with + shape [H, W, C]) in pixel range [0, 255]. For color images, the channel + order of the input is expected to be with `BGR` or `BGRA`, which is the + raw image decoded by `cv2`; while the channel order of the output is set to + `RGB` or `RGBA` by default. + + Args: + image: `np.ndarray`, an image read by `cv2.imread()` or + `cv2.imdecode()`. + + Returns: + An image with shape [H, W, C] (where `C = 1` for grayscale image). + """ + if image.ndim == 2: # add additional axis if given a grayscale image + image = image[:, :, np.newaxis] + + assert isinstance(image, np.ndarray) + assert image.dtype == np.uint8 + assert image.ndim == 3 and image.shape[2] in [1, 3, 4] + + if image.shape[2] == 3: # BGR image + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if image.shape[2] == 4: # BGRA image + return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + return image + + +def raw_label_to_one_hot(raw_label, num_classes): + """Converts a single label into one-hot vector. + + Args: + raw_label: The raw label. + num_classes: Total number of classes. + + Returns: + one-hot vector of the given raw label. + """ + one_hot = np.zeros(num_classes, dtype=np.float32) + one_hot[raw_label] = 1.0 + return one_hot + + +def one_hot_to_raw_label(one_hot): + """Converts a one-hot vector to a single value label. + + Args: + one_hot: `np.ndarray`, a one-hot encoded vector. + + Returns: + A single integer to represent the category. + """ + return np.argmax(one_hot) diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c640ac5ef977e3a7824dabbc43e5d56e733d0d76 --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,332 @@ +# python3.7 +"""Contains utility functions for image processing. + +The module is primarily built on `cv2`. But, differently, we assume all colorful +images are with `RGB` channel order by default. Also, we assume all gray-scale +images to be with shape [height, width, 1]. +""" + +import os +import cv2 +import numpy as np + +from .misc import IMAGE_EXTENSIONS +from .misc import check_file_ext + +__all__ = [ + 'get_blank_image', 'load_image', 'save_image', 'resize_image', + 'add_text_to_image', 'preprocess_image', 'postprocess_image', + 'parse_image_size', 'get_grid_shape', 'list_images_from_dir' +] + + +def _check_2d_image(image): + """Checks whether a given image is valid. + + A valid image is expected to be with dtype `uint8`. Also, it should have + shape like: + + (1) (height, width, 1) # gray-scale image. + (2) (height, width, 3) # colorful image. + (3) (height, width, 4) # colorful image with transparency (RGBA) + """ + assert isinstance(image, np.ndarray) + assert image.dtype == np.uint8 + assert image.ndim == 3 and image.shape[2] in [1, 3, 4] + + +def get_blank_image(height, width, channels=3, use_black=True): + """Gets a blank image, either white of black. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + height: Height of the returned image. + width: Width of the returned image. + channels: Number of channels. (default: 3) + use_black: Whether to return a black image. (default: True) + """ + shape = (height, width, channels) + if use_black: + return np.zeros(shape, dtype=np.uint8) + return np.ones(shape, dtype=np.uint8) * 255 + + +def load_image(path): + """Loads an image from disk. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + path: Path to load the image from. + + Returns: + An image with dtype `np.ndarray`, or `None` if `path` does not exist. + """ + image = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if image is None: + return None + + if image.ndim == 2: + image = image[:, :, np.newaxis] + _check_2d_image(image) + if image.shape[2] == 3: + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if image.shape[2] == 4: + return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + return image + + +def save_image(path, image): + """Saves an image to disk. + + NOTE: The input image (if colorful) is assumed to be with `RGB` channel + order and pixel range [0, 255]. + + Args: + path: Path to save the image to. + image: Image to save. + """ + if image is None: + return + + _check_2d_image(image) + if image.shape[2] == 1: + cv2.imwrite(path, image) + elif image.shape[2] == 3: + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + elif image.shape[2] == 4: + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)) + + +def resize_image(image, *args, **kwargs): + """Resizes image. + + This is a wrap of `cv2.resize()`. + + NOTE: The channel order of the input image will not be changed. + + Args: + image: Image to resize. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + An image with dtype `np.ndarray`, or `None` if `image` is empty. + """ + if image is None: + return None + + _check_2d_image(image) + if image.shape[2] == 1: # Re-expand the squeezed dim of gray-scale image. + return cv2.resize(image, *args, **kwargs)[:, :, np.newaxis] + return cv2.resize(image, *args, **kwargs) + + +def add_text_to_image(image, + text='', + position=None, + font=cv2.FONT_HERSHEY_TRIPLEX, + font_size=1.0, + line_type=cv2.LINE_8, + line_width=1, + color=(255, 255, 255)): + """Overlays text on given image. + + NOTE: The input image is assumed to be with `RGB` channel order. + + Args: + image: The image to overlay text on. + text: Text content to overlay on the image. (default: empty) + position: Target position (bottom-left corner) to add text. If not set, + center of the image will be used by default. (default: None) + font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) + font_size: Font size of the text added. (default: 1.0) + line_type: Line type used to depict the text. (default: cv2.LINE_8) + line_width: Line width used to depict the text. (default: 1) + color: Color of the text added in `RGB` channel order. (default: + (255, 255, 255)) + + Returns: + An image with target text overlaid on. + """ + if image is None or not text: + return image + + _check_2d_image(image) + cv2.putText(img=image, + text=text, + org=position, + fontFace=font, + fontScale=font_size, + color=color, + thickness=line_width, + lineType=line_type, + bottomLeftOrigin=False) + return image + + +def preprocess_image(image, min_val=-1.0, max_val=1.0): + """Pre-processes image by adjusting the pixel range and to dtype `float32`. + + This function is particularly used to convert an image or a batch of images + to `NCHW` format, which matches the data type commonly used in deep models. + + NOTE: The input image is assumed to be with pixel range [0, 255] and with + format `HWC` or `NHWC`. The returned image will be always be with format + `NCHW`. + + Args: + image: The input image for pre-processing. + min_val: Minimum value of the output image. + max_val: Maximum value of the output image. + + Returns: + The pre-processed image. + """ + assert isinstance(image, np.ndarray) + + image = image.astype(np.float64) + image = image / 255.0 * (max_val - min_val) + min_val + + if image.ndim == 3: + image = image[np.newaxis] + assert image.ndim == 4 and image.shape[3] in [1, 3, 4] + return image.transpose(0, 3, 1, 2) + + +def postprocess_image(image, min_val=-1.0, max_val=1.0): + """Post-processes image to pixel range [0, 255] with dtype `uint8`. + + This function is particularly used to handle the results produced by deep + models. + + NOTE: The input image is assumed to be with format `NCHW`, and the returned + image will always be with format `NHWC`. + + Args: + image: The input image for post-processing. + min_val: Expected minimum value of the input image. + max_val: Expected maximum value of the input image. + + Returns: + The post-processed image. + """ + assert isinstance(image, np.ndarray) + + image = image.astype(np.float64) + image = (image - min_val) / (max_val - min_val) * 255 + image = np.clip(image + 0.5, 0, 255).astype(np.uint8) + + assert image.ndim == 4 and image.shape[1] in [1, 3, 4] + return image.transpose(0, 2, 3, 1) + + +def parse_image_size(obj): + """Parses an object to a pair of image size, i.e., (height, width). + + Args: + obj: The input object to parse image size from. + + Returns: + A two-element tuple, indicating image height and width respectively. + + Raises: + If the input is invalid, i.e., neither a list or tuple, nor a string. + """ + if obj is None or obj == '': + height = 0 + width = 0 + elif isinstance(obj, int): + height = obj + width = obj + elif isinstance(obj, (list, tuple, str, np.ndarray)): + if isinstance(obj, str): + splits = obj.replace(' ', '').split(',') + numbers = tuple(map(int, splits)) + else: + numbers = tuple(obj) + if len(numbers) == 0: + height = 0 + width = 0 + elif len(numbers) == 1: + height = int(numbers[0]) + width = int(numbers[0]) + elif len(numbers) == 2: + height = int(numbers[0]) + width = int(numbers[1]) + else: + raise ValueError('At most two elements for image size.') + else: + raise ValueError(f'Invalid type of input: `{type(obj)}`!') + + return (max(0, height), max(0, width)) + + +def get_grid_shape(size, height=0, width=0, is_portrait=False): + """Gets the shape of a grid based on the size. + + This function makes greatest effort on making the output grid square if + neither `height` nor `width` is set. If `is_portrait` is set as `False`, the + height will always be equal to or smaller than the width. For example, if + input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, + output shape will be (3, 5). Otherwise, the height will always be equal to + or larger than the width. + + Args: + size: Size (height * width) of the target grid. + height: Expected height. If `size % height != 0`, this field will be + ignored. (default: 0) + width: Expected width. If `size % width != 0`, this field will be + ignored. (default: 0) + is_portrait: Whether to return a portrait size of a landscape size. + (default: False) + + Returns: + A two-element tuple, representing height and width respectively. + """ + assert isinstance(size, int) + assert isinstance(height, int) + assert isinstance(width, int) + if size <= 0: + return (0, 0) + + if height > 0 and width > 0 and height * width != size: + height = 0 + width = 0 + + if height > 0 and width > 0 and height * width == size: + return (height, width) + if height > 0 and size % height == 0: + return (height, size // height) + if width > 0 and size % width == 0: + return (size // width, width) + + height = int(np.sqrt(size)) + while height > 0: + if size % height == 0: + width = size // height + break + height = height - 1 + + return (width, height) if is_portrait else (height, width) + + +def list_images_from_dir(directory): + """Lists all images from the given directory. + + NOTE: Do NOT support finding images recursively. + + Args: + directory: The directory to find images from. + + Returns: + A list of sorted filenames, with the directory as prefix. + """ + image_list = [] + for filename in os.listdir(directory): + if check_file_ext(filename, *IMAGE_EXTENSIONS): + image_list.append(os.path.join(directory, filename)) + return sorted(image_list) diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..665fd01dc34ae7a520dadfe4581c97e59dd6affe --- /dev/null +++ b/utils/loggers/__init__.py @@ -0,0 +1,32 @@ +# python3.7 +"""Collects all loggers.""" + +from .normal_logger import NormalLogger +from .rich_logger import RichLogger +from .dummy_logger import DummyLogger + +__all__ = ['build_logger'] + +_LOGGERS = { + 'normal': NormalLogger, + 'rich': RichLogger, + 'dummy': DummyLogger +} + + +def build_logger(logger_type='normal', **kwargs): + """Builds a logger. + + Args: + logger_type: Type of logger, which is case insensitive. + (default: `normal`) + **kwargs: Additional arguments to build the logger. + + Raises: + ValueError: If the `logger_type` is not supported. + """ + logger_type = logger_type.lower() + if logger_type not in _LOGGERS: + raise ValueError(f'Invalid logger type: `{logger_type}`!\n' + f'Types allowed: {list(_LOGGERS)}.') + return _LOGGERS[logger_type](**kwargs) diff --git a/utils/loggers/base_logger.py b/utils/loggers/base_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..c08fa7fec115a54fd6fff578af5cf0229e5395e8 --- /dev/null +++ b/utils/loggers/base_logger.py @@ -0,0 +1,258 @@ +# python3.7 +"""Contains the base class for logging. + +Basically, this is an interface bridging the program and the local file system. +A logger is able to log wrapped message onto the screen and a log file. +""" + +import logging + +__all__ = ['BaseLogger'] + + +class BaseLogger(object): + """Defines the base logger. + + A logger should have the following members: + + (1) logger: The logger to record message. + (2) pbar: The progressive bar (shown on the screen only). + (3) pbar_kwargs: The arguments for the progressive bar. + (4) file_stream: The stream to log messages into if needed. + + A logger should have the following functions: + + (1) log(): The base function to log message. + (2) debug(): The function to log message with `DEBUG` level. + (3) info(): The function to log message with `INFO` level. + (4) warning(): The function to log message with `WARNING` level. + (5) warn(): Same as function `warning()`. + (6) error(): The function to log message with `ERROR` level. + (7) exception(): The function to log message with exception information. + (8) critical(): The function to log message with `CRITICAL` level. + (9) fatal(): Same as function `critical()`. + (10) print(): The function to print the message without any decoration. + (11) init_pbar(): The function to initialize the progressive bar. + (12) add_pbar_task(): The function to add a task to the progressive bar. + (13) update_pbar(): The function to update the progressive bar. + (14) close_pbar(): The function to close the progressive bar. + + The logger will record log message both on screen and to file. + + Args: + logger_name: Unique name for the logger. (default: `logger`) + logfile: Path to the log file. If set as `None`, the file stream + will be skipped. (default: `None`) + screen_level: Minimum level of message to log onto screen. + (default: `logging.INFO`) + file_level: Minimum level of message to log into file. + (default: `logging.DEBUG`) + indent_space: Number of spaces between two adjacent indent levels. + (default: 4) + verbose_log: Whether to log verbose message. (default: False) + """ + + def __init__(self, + logger_name='logger', + logfile=None, + screen_level=logging.INFO, + file_level=logging.DEBUG, + indent_space=4, + verbose_log=False): + self.logger_name = logger_name + self.logfile = logfile + self.screen_level = screen_level + self.file_level = file_level + self.indent_space = indent_space + self.verbose_log = verbose_log + + self.logger = None + self.pbar = None + self.pbar_kwargs = None + self.file_stream = None + + self.warn = self.warning + self.fatal = self.critical + + def __del__(self): + self.close() + + def close(self): + """Closes the logger.""" + if self.file_stream is not None: + self.file_stream.close() + + @property + def name(self): + """Returns the class name of the logger.""" + return self.__class__.__name__ + + # Log message. + def wrap_message(self, message, indent_level=0): + """Wraps the message with indent.""" + if message is None: + message = '' + assert isinstance(message, str) + assert isinstance(indent_level, int) and indent_level >= 0 + if message == '': + return '' + return ' ' * (indent_level * self.indent_space) + message + + def _log(self, message, **kwargs): + """Logs wrapped message.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _debug(self, message, **kwargs): + """Logs wrapped message with `DEBUG` level.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _info(self, message, **kwargs): + """Logs wrapped message with `INFO` level.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _warning(self, message, **kwargs): + """Logs wrapped message with `WARNING` level.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _error(self, message, **kwargs): + """Logs wrapped message with `ERROR` level.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _exception(self, message, **kwargs): + """Logs wrapped message with exception information.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _critical(self, message, **kwargs): + """Logs wrapped message with `CRITICAL` level.""" + raise NotImplementedError('Should be implemented in derived class!') + + def _print(self, *messages, **kwargs): + """Prints wrapped message without any decoration.""" + raise NotImplementedError('Should be implemented in derived class!') + + def log(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._log(message, **kwargs) + + def debug(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with `DEBUG` level. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._debug(message, **kwargs) + + def info(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with `INFO` level. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._info(message, **kwargs) + + def warning(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with `WARNING` level. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._warning(message, **kwargs) + + def error(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with `ERROR` level. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._error(message, **kwargs) + + def exception(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with exception information. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._exception(message, **kwargs) + + def critical(self, message, indent_level=0, is_verbose=False, **kwargs): + """Logs message with `CRITICAL` level. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + message = self.wrap_message(message, indent_level=indent_level) + self._critical(message, **kwargs) + + def print(self, *messages, indent_level=0, is_verbose=False, **kwargs): + """Prints message without any decoration. + + The message is wrapped with indent, and will be disabled if `is_verbose` + is set as `True`. + """ + if is_verbose and not self.verbose_log: + return + new_messages = [] + for message in messages: + new_messages.append( + self.wrap_message(message, indent_level=indent_level)) + self._print(*new_messages, **kwargs) + + # Progressive bar. + def init_pbar(self, leave=False): + """Initializes the progressive bar. + + Args: + leave: Whether to leave the trace of the progressive bar. + (default: False) + """ + raise NotImplementedError('Should be implemented in derived class!') + + def add_pbar_task(self, name, total, **kwargs): + """Adds a task to the progressive bar. + + Args: + name: Name of the added task. + total: Total number of steps (samples) contained in the task. + **kwargs: Additional arguments. + + Returns: + Task ID. + """ + raise NotImplementedError('Should be implemented in derived class!') + + def update_pbar(self, task_id, advance=1): + """Updates the progressive bar. + + Args: + task_id: ID of the task to update. + advance: Number of steps advanced onto the target task. (default: 1) + """ + raise NotImplementedError('Should be implemented in derived class!') + + def close_pbar(self): + """Closes the progress bar.""" + raise NotImplementedError('Should be implemented in derived class!') diff --git a/utils/loggers/dummy_logger.py b/utils/loggers/dummy_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6220e6757c6ce4516834f5102cd0957f8669df --- /dev/null +++ b/utils/loggers/dummy_logger.py @@ -0,0 +1,65 @@ +# python3.7 +"""Contains the class of dummy logger. + +This logger has all expected logging functions but behaves silently, which is +very useful in multi-processing mode. Only the chief process can have the logger +with normal behavior. +""" + +from .base_logger import BaseLogger + +__all__ = ['DummyLogger'] + + +class DummyLogger(BaseLogger): + """Implements a dummy logger which logs nothing.""" + + def __init__(self, + logger_name='logger', + logfile=None, + screen_level=None, + file_level=None, + indent_space=4, + verbose_log=False): + super().__init__(logger_name=logger_name, + logfile=logfile, + screen_level=screen_level, + file_level=file_level, + indent_space=indent_space, + verbose_log=verbose_log) + + def _log(self, message, **kwargs): + return + + def _debug(self, message, **kwargs): + return + + def _info(self, message, **kwargs): + return + + def _warning(self, message, **kwargs): + return + + def _error(self, message, **kwargs): + return + + def _exception(self, message, **kwargs): + return + + def _critical(self, message, **kwargs): + return + + def _print(self, *messages, **kwargs): + return + + def init_pbar(self, leave=False): + return + + def add_pbar_task(self, name, total, **kwargs): + return -1 + + def update_pbar(self, task_id, advance=1): + return + + def close_pbar(self): + return diff --git a/utils/loggers/normal_logger.py b/utils/loggers/normal_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..57f66d21afce89d6f4dcb1bfd7fd7ddc3b441e31 --- /dev/null +++ b/utils/loggers/normal_logger.py @@ -0,0 +1,124 @@ +# python3.7 +"""Contains the class of normal logger. + +This class is built based on the built-in function `print()`, the module +`logging` and the module `tqdm` for progressive bar. +""" + +import sys +import logging +from copy import deepcopy +from tqdm import tqdm + +from .base_logger import BaseLogger + +__all__ = ['NormalLogger'] + + +class NormalLogger(BaseLogger): + """Implements the logger based on `logging` module.""" + + def __init__(self, + logger_name='logger', + logfile=None, + screen_level=logging.INFO, + file_level=logging.DEBUG, + indent_space=4, + verbose_log=False): + super().__init__(logger_name=logger_name, + logfile=logfile, + screen_level=screen_level, + file_level=file_level, + indent_space=indent_space, + verbose_log=verbose_log) + + # Get logger and check whether the logger has already been created. + self.logger = logging.getLogger(self.logger_name) + self.logger.propagate = False + if self.logger.hasHandlers(): # Already existed + raise SystemExit(f'Logger `{self.logger_name}` has already ' + f'existed!\n' + f'Please use another name, or otherwise the ' + f'messages may be mixed up.') + + # Set format. + self.logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + '[%(asctime)s][%(levelname)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + # Print log message onto the screen. + terminal_handler = logging.StreamHandler(stream=sys.stdout) + terminal_handler.setLevel(self.screen_level) + terminal_handler.setFormatter(formatter) + self.logger.addHandler(terminal_handler) + + # Save log message into log file if needed. + if self.logfile: + # File will be closed when the logger is closed in `self.close()`. + self.file_stream = open(self.logfile, 'a') # pylint: disable=consider-using-with + file_handler = logging.StreamHandler(stream=self.file_stream) + file_handler.setLevel(self.file_level) + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + self.pbar = [] + self.pbar_kwargs = {} + + def _log(self, message, **kwargs): + self.logger.log(message, **kwargs) + + def _debug(self, message, **kwargs): + self.logger.debug(message, **kwargs) + + def _info(self, message, **kwargs): + self.logger.info(message, **kwargs) + + def _warning(self, message, **kwargs): + self.logger.warning(message, **kwargs) + + def _error(self, message, **kwargs): + self.logger.error(message, **kwargs) + + def _exception(self, message, **kwargs): + self.logger.exception(message, **kwargs) + + def _critical(self, message, **kwargs): + self.logger.critical(message, **kwargs) + + def _print(self, *messages, **kwargs): + for handler in self.logger.handlers: + print(*messages, file=handler.stream) + + def init_pbar(self, leave=False): + columns = [ + '{desc}', + '{bar}', + ' {percentage:5.1f}%', + '[{elapsed}<{remaining}, {rate_fmt}{postfix}]', + ] + self.pbar_kwargs = dict( + leave=leave, + bar_format=' '.join(columns), + unit='', + ) + + def add_pbar_task(self, name, total, **kwargs): + assert isinstance(self.pbar_kwargs, dict) + pbar_kwargs = deepcopy(self.pbar_kwargs) + pbar_kwargs.update(**kwargs) + self.pbar.append(tqdm(desc=name, total=total, **pbar_kwargs)) + return len(self.pbar) - 1 + + def update_pbar(self, task_id, advance=1): + assert len(self.pbar) > task_id and isinstance(self.pbar[task_id], tqdm) + if self.pbar[task_id].n < self.pbar[task_id].total: + self.pbar[task_id].update(advance) + if self.pbar[task_id].n >= self.pbar[task_id].total: + self.pbar[task_id].refresh() + + def close_pbar(self): + for pbar in self.pbar[::-1]: + pbar.close() + self.pbar = [] + self.pbar_kwargs = {} diff --git a/utils/loggers/rich_logger.py b/utils/loggers/rich_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..38c09d2b0c898aff394415a639742a623b07178a --- /dev/null +++ b/utils/loggers/rich_logger.py @@ -0,0 +1,177 @@ +# python3.7 +"""Contains the class of rich logger. + +This class is based on the module `rich`. Please refer to +https://github.com/Textualize/rich for more details. +""" + +import sys +import logging +from copy import deepcopy +from rich.console import Console +from rich.logging import RichHandler +from rich.progress import Progress +from rich.progress import ProgressColumn +from rich.progress import TextColumn +from rich.progress import BarColumn +from rich.text import Text + +from .base_logger import BaseLogger + +__all__ = ['RichLogger'] + + +def _format_time(seconds): + """Formats seconds to readable time string. + + This function is used to display time in progress bar. + """ + if not seconds: + return '--:--' + + seconds = int(seconds) + hours, seconds = divmod(seconds, 3600) + minutes, seconds = divmod(seconds, 60) + if hours: + return f'{hours}:{minutes:02d}:{seconds:02d}' + return f'{minutes:02d}:{seconds:02d}' + + +class TimeColumn(ProgressColumn): + """Renders total time, ETA, and speed in progress bar.""" + + max_refresh = 0.5 # Only refresh twice a second to prevent jitter + + def render(self, task): + elapsed_time = _format_time(task.elapsed) + eta = _format_time(task.time_remaining) + speed = f'{task.speed:.2f}/s' if task.speed else '?/s' + return Text(f'[{elapsed_time}<{eta}, {speed}]', + style='progress.remaining') + + +class RichLogger(BaseLogger): + """Implements the logger based on `rich` module.""" + + def __init__(self, + logger_name='logger', + logfile=None, + screen_level=logging.INFO, + file_level=logging.DEBUG, + indent_space=4, + verbose_log=False): + super().__init__(logger_name=logger_name, + logfile=logfile, + screen_level=screen_level, + file_level=file_level, + indent_space=indent_space, + verbose_log=verbose_log) + + # Get logger and check whether the logger has already been created. + self.logger = logging.getLogger(self.logger_name) + self.logger.propagate = False + if self.logger.hasHandlers(): # Already existed + raise SystemExit(f'Logger `{self.logger_name}` has already ' + f'existed!\n' + f'Please use another name, or otherwise the ' + f'messages may be mixed up.') + + # Set format. + self.logger.setLevel(logging.DEBUG) + + # Print log message onto the screen. + terminal_console = Console( + file=sys.stdout, log_time=False, log_path=False) + terminal_handler = RichHandler( + level=self.screen_level, + console=terminal_console, + show_time=True, + show_level=True, + show_path=False, + log_time_format='[%Y-%m-%d %H:%M:%S] ') + terminal_handler.setFormatter(logging.Formatter('%(message)s')) + self.logger.addHandler(terminal_handler) + + # Save log message into log file if needed. + if self.logfile: + # File will be closed when the logger is closed in `self.close()`. + self.file_stream = open(self.logfile, 'a') # pylint: disable=consider-using-with + file_console = Console( + file=self.file_stream, log_time=False, log_path=False) + file_handler = RichHandler( + level=self.file_level, + console=file_console, + show_time=True, + show_level=True, + show_path=False, + log_time_format='[%Y-%m-%d %H:%M:%S] ') + file_handler.setFormatter(logging.Formatter('%(message)s')) + self.logger.addHandler(file_handler) + + self.pbar = None + self.pbar_kwargs = {} + + def _log(self, message, **kwargs): + self.logger.log(message, **kwargs) + + def _debug(self, message, **kwargs): + self.logger.debug(message, **kwargs) + + def _info(self, message, **kwargs): + self.logger.info(message, **kwargs) + + def _warning(self, message, **kwargs): + self.logger.warning(message, **kwargs) + + def _error(self, message, **kwargs): + self.logger.error(message, **kwargs) + + def _exception(self, message, **kwargs): + self.logger.exception(message, **kwargs) + + def _critical(self, message, **kwargs): + self.logger.critical(message, **kwargs) + + def _print(self, *messages, **kwargs): + for handler in self.logger.handlers: + handler.console.print(*messages, **kwargs) + + def init_pbar(self, leave=False): + assert self.pbar is None + + # Columns shown in the progress bar. + columns = ( + TextColumn('[progress.description]{task.description}'), + BarColumn(bar_width=None), + TextColumn('[progress.percentage]{task.percentage:>5.1f}%'), + TimeColumn(), + ) + + self.pbar = Progress(*columns, + console=self.logger.handlers[0].console, + transient=not leave, + auto_refresh=True, + refresh_per_second=10) + self.pbar.start() + + def add_pbar_task(self, name, total, **kwargs): + assert isinstance(self.pbar, Progress) + assert isinstance(self.pbar_kwargs, dict) + pbar_kwargs = deepcopy(self.pbar_kwargs) + pbar_kwargs.update(**kwargs) + task_id = self.pbar.add_task(name, total=total, **pbar_kwargs) + return task_id + + def update_pbar(self, task_id, advance=1): + assert isinstance(self.pbar, Progress) + if self.pbar.tasks[task_id].finished: + if self.pbar.tasks[task_id].stop_time is None: + self.pbar.stop_task(task_id) + else: + self.pbar.update(task_id, advance=advance) + + def close_pbar(self): + assert isinstance(self.pbar, Progress) + self.pbar.stop() + self.pbar = None + self.pbar_kwargs = {} diff --git a/utils/loggers/test.py b/utils/loggers/test.py new file mode 100644 index 0000000000000000000000000000000000000000..096f7fd9b32458ac88b551f7acaf676cdc13be4f --- /dev/null +++ b/utils/loggers/test.py @@ -0,0 +1,63 @@ +# python3.7 +"""Unit test for logger.""" + +import os +import time + +from . import build_logger + +__all__ = ['test_logger'] + +_TEST_DIR = 'logger_test' + + +def test_logger(test_dir=_TEST_DIR): + """Tests loggers.""" + print('========== Start Logger Test ==========') + + os.makedirs(test_dir, exist_ok=True) + + for logger_type in ['normal', 'rich', 'dummy']: + for indent_space in [2, 4]: + for verbose_log in [False, True]: + if logger_type == 'normal': + class_name = 'Logger' + elif logger_type == 'rich': + class_name = 'RichLogger' + elif logger_type == 'dummy': + class_name = 'DummyLogger' + + print(f'===== ' + f'Testing `utils.logger.{class_name}` ' + f' (indent: {indent_space}, verbose: {verbose_log}) ' + f'=====') + logger_name = (f'{logger_type}_logger_' + f'indent_{indent_space}_' + f'verbose_{verbose_log}') + logger = build_logger( + logger_type, + logger_name=logger_name, + logfile=os.path.join(test_dir, f'test_{logger_name}.log'), + verbose_log=verbose_log, + indent_space=indent_space) + logger.print('print log') + logger.print('print log,', 'log 2') + logger.print('print log (indent level 0)', indent_level=0) + logger.print('print log (indent level 1)', indent_level=1) + logger.print('print log (indent level 2)', indent_level=2) + logger.print('print log (verbose `False`)', is_verbose=False) + logger.print('print log (verbose `True`)', is_verbose=True) + logger.debug('debug log') + logger.info('info log') + logger.warning('warning log') + logger.init_pbar() + task_1 = logger.add_pbar_task('Task 1', 500) + task_2 = logger.add_pbar_task('Task 2', 1000) + for _ in range(1000): + logger.update_pbar(task_1, 1) + logger.update_pbar(task_2, 1) + time.sleep(0.002) + logger.close_pbar() + print('Success!') + + print('========== Finish Logger Test ==========') diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..36198dff3a4b3e1f7b5e6a21a17418d0e04eb6f3 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,227 @@ +# python3.7 +"""Misc utility functions.""" + +import os +import hashlib + +from torch.hub import download_url_to_file + +__all__ = [ + 'REPO_NAME', 'Infix', 'print_and_execute', 'check_file_ext', + 'IMAGE_EXTENSIONS', 'VIDEO_EXTENSIONS', 'MEDIA_EXTENSIONS', + 'parse_file_format', 'set_cache_dir', 'get_cache_dir', 'download_url' +] + +REPO_NAME = 'Hammer' # Name of the repository (project). + + +class Infix(object): + """Helper class to create custom infix operators. + + When using it, make sure to put the operator between `<<` and `>>`. + `<< INFIX_OP_NAME >>` should be considered as a whole operator. + + Examples: + + # Use `Infix` to create infix operators directly. + add = Infix(lambda a, b: a + b) + 1 << add >> 2 # gives 3 + 1 << add >> 2 << add >> 3 # gives 6 + + # Use `Infix` as a decorator. + @Infix + def mul(a, b): + return a * b + 2 << mul >> 4 # gives 8 + 2 << mul >> 3 << mul >> 7 # gives 42 + """ + + def __init__(self, function): + self.function = function + self.left_value = None + + def __rlshift__(self, left_value): # override `<<` before `Infix` instance + assert self.left_value is None # make sure left is only called once + self.left_value = left_value + return self + + def __rshift__(self, right_value): # override `>>` after `Infix` instance + result = self.function(self.left_value, right_value) + self.left_value = None # reset to None + return result + + +def print_and_execute(cmd): + """Prints and executes a system command. + + Args: + cmd: Command to be executed. + """ + print(cmd) + os.system(cmd) + + +def check_file_ext(filename, *ext_list): + """Checks whether the given filename is with target extension(s). + + NOTE: If `ext_list` is empty, this function will always return `False`. + + Args: + filename: Filename to check. + *ext_list: A list of extensions. + + Returns: + `True` if the filename is with one of extensions in `ext_list`, + otherwise `False`. + """ + if len(ext_list) == 0: + return False + ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list] + ext_list = [ext.lower() for ext in ext_list] + basename = os.path.basename(filename) + ext = os.path.splitext(basename)[1].lower() + return ext in ext_list + + +# File extensions regarding images (not including GIFs). +IMAGE_EXTENSIONS = ( + '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp', + '.tiff', '.tif' +) +# File extensions regarding videos. +VIDEO_EXTENSIONS = ( + '.avi', '.mkv', '.mp4', '.m4v', '.mov', '.webm', '.flv', '.rmvb', '.rm', + '.3gp' +) +# File extensions regarding media, i.e., images, videos, GIFs. +MEDIA_EXTENSIONS = ('.gif', *IMAGE_EXTENSIONS, *VIDEO_EXTENSIONS) + + +def parse_file_format(path): + """Parses the file format of a given path. + + This function basically parses the file format according to its extension. + It will also return `dir` is the given path is a directory. + + Parable file formats: + + - zip: with `.zip` extension. + - tar: with `.tar` / `.tgz` / `.tar.gz` extension. + - lmdb: a folder ending with `lmdb`. + - txt: with `.txt` / `.text` extension, OR without extension (e.g. LICENSE). + - json: with `.json` extension. + - jpg: with `.jpeg` / `jpg` / `jpe` extension. + - png: with `.png` extension. + + Args: + path: The path to the file to parse format from. + + Returns: + A lower-case string, indicating the file format, or `None` if the format + cannot be successfully parsed. + """ + # Handle directory. + if os.path.isdir(path) or path.endswith('/'): + if path.rstrip('/').lower().endswith('lmdb'): + return 'lmdb' + return 'dir' + # Handle file. + if os.path.isfile(path) and os.path.splitext(path)[1] == '': + return 'txt' + path = path.lower() + if path.endswith('.tar.gz'): # Cannot parse accurate extension. + return 'tar' + ext = os.path.splitext(path)[1] + if ext == '.zip': + return 'zip' + if ext in ['.tar', '.tgz']: + return 'tar' + if ext in ['.txt', '.text']: + return 'txt' + if ext == '.json': + return 'json' + if ext in ['.jpeg', '.jpg', '.jpe']: + return 'jpg' + if ext == '.png': + return 'png' + # Unparsable. + return None + + +_cache_dir = None + + +def set_cache_dir(directory=None): + """Sets the global cache directory. + + The cache directory can be used to save some files that will be shared + across jobs. The default cache directory is set as `~/.cache/${REPO_NAME}/`. + This function can be used to redirect the cache directory. Or, users can use + `None` to reset the cache directory back to default. + + Args: + directory: The target directory used to cache files. If set as `None`, + the cache directory will be reset back to default. (default: None) + """ + assert directory is None or isinstance(directory, str), 'Invalid directory!' + global _cache_dir # pylint: disable=global-statement + _cache_dir = directory + + +def get_cache_dir(): + """Gets the global cache directory. + + The global cache directory is primarily set as `~/.cache/${REPO_NAME}/` by + default, and can be redirected with `set_cache_dir()`. + + Returns: + A string, representing the global cache directory. + """ + if _cache_dir is None: + home = os.path.expanduser('~') + return os.path.join(home, '.cache', REPO_NAME) + return _cache_dir + + +def download_url(url, path=None, filename=None, sha256=None): + """Downloads file from URL. + + This function downloads a file from given URL, and executes Hash check if + needed. + + Args: + url: The URL to download file from. + path: Path (directory) to save the downloaded file. If set as `None`, + the cache directory will be used. Please see `get_cache_dir()` for + more details. (default: None) + filename: The name to save the file. If set as `None`, this name will be + automatically parsed from the given URL. (default: None) + sha256: The expected sha256 of the downloaded file. If set as `None`, + the hash check will be skipped. Otherwise, this function will check + whether the sha256 of the downloaded file matches this field. + + Returns: + A two-element tuple, where the first term is the full path of the + downloaded file, and the second term indicate the hash check result. + `True` means hash check passes, `False` means hash check fails, + while `None` means no hash check is executed. + """ + # Handle file path. + if path is None: + path = get_cache_dir() + if filename is None: + filename = os.path.basename(url) + save_path = os.path.join(path, filename) + # Download file if needed. + if not os.path.exists(save_path): + print(f'Downloading URL `{url}` to path `{save_path}` ...') + os.makedirs(path, exist_ok=True) + download_url_to_file(url, save_path, hash_prefix=None, progress=True) + # Check hash if needed. + check_result = None + if sha256 is not None: + with open(save_path, 'rb') as f: + file_hash = hashlib.sha256(f.read()) + check_result = (file_hash.hexdigest() == sha256) + + return save_path, check_result diff --git a/utils/parsing_utils.py b/utils/parsing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50d8b20da87413a2730d12e712ef8d5dddbdad65 --- /dev/null +++ b/utils/parsing_utils.py @@ -0,0 +1,213 @@ +# python3.7 +"""Contains the utility functions for parsing arguments.""" + +import json +import argparse +import click + +__all__ = [ + 'parse_int', 'parse_float', 'parse_bool', 'parse_index', 'parse_json', + 'IntegerParamType', 'FloatParamType', 'BooleanParamType', 'IndexParamType', + 'JsonParamType', 'DictAction' +] + + +def parse_int(arg): + """Parses an argument to integer. + + Support converting string `none` and `null` to `None`. + """ + if arg is None: + return None + if isinstance(arg, str) and arg.lower() in ['none', 'null']: + return None + return int(arg) + + +def parse_float(arg): + """Parses an argument to float number. + + Support converting string `none` and `null` to `None`. + """ + if arg is None: + return None + if isinstance(arg, str) and arg.lower() in ['none', 'null']: + return None + return float(arg) + + +def parse_bool(arg): + """Parses an argument to boolean. + + `None` will be converted to `False`. + """ + if isinstance(arg, bool): + return arg + if arg is None: + return False + if arg.lower() in ['1', 'true', 't', 'yes', 'y']: + return True + if arg.lower() in ['0', 'false', 'f', 'no', 'n', 'none', 'null']: + return False + raise ValueError(f'`{arg}` cannot be converted to boolean!') + + +def parse_index(arg, min_val=None, max_val=None): + """Parses indices. + + If the input is a list or tuple, this function has no effect. + + If the input is a string, it can be either a comma separated list of numbers + `1, 3, 5`, or a dash separated range `3 - 10`. Spaces in the string will be + ignored. + + Args: + arg: The input argument to parse indices from. + min_val: If not `None`, this function will check that all indices are + equal to or larger than this value. (default: None) + max_val: If not `None`, this function will check that all indices are + equal to or smaller than this field. (default: None) + + Returns: + A list of integers. + + Raises: + ValueError: If the input is invalid, i.e., neither a list or tuple, nor + a string. + """ + if arg is None or arg == '': + indices = [] + elif isinstance(arg, int): + indices = [arg] + elif isinstance(arg, (list, tuple)): + indices = list(arg) + elif isinstance(arg, str): + indices = [] + if arg.lower() not in ['none', 'null']: + splits = arg.replace(' ', '').split(',') + for split in splits: + numbers = list(map(int, split.split('-'))) + if len(numbers) == 1: + indices.append(numbers[0]) + elif len(numbers) == 2: + indices.extend(list(range(numbers[0], numbers[1] + 1))) + else: + raise ValueError(f'Invalid type of input: `{type(arg)}`!') + + assert isinstance(indices, list) + indices = sorted(list(set(indices))) + for idx in indices: + assert isinstance(idx, int) + if min_val is not None: + assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!' + if max_val is not None: + assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!' + + return indices + + +def parse_json(arg): + """Parses a string-like argument following JSON format. + + `None` arguments will be kept. + """ + if arg is None: + return None + try: + return json.loads(arg) + except json.decoder.JSONDecodeError: + return arg + + +class IntegerParamType(click.ParamType): + """Defines a `click.ParamType` to parse integer arguments.""" + + name = 'int' + + def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements + try: + return parse_int(value) + except ValueError: + self.fail(f'`{value}` cannot be parsed as an integer!', param, ctx) + + +class FloatParamType(click.ParamType): + """Defines a `click.ParamType` to parse float arguments.""" + + name = 'float' + + def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements + try: + return parse_float(value) + except ValueError: + self.fail(f'`{value}` cannot be parsed as a float!', param, ctx) + + +class BooleanParamType(click.ParamType): + """Defines a `click.ParamType` to parse boolean arguments.""" + + name = 'bool' + + def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements + try: + return parse_bool(value) + except ValueError: + self.fail(f'`{value}` cannot be parsed as a boolean!', param, ctx) + + +class IndexParamType(click.ParamType): + """Defines a `click.ParamType` to parse indices arguments.""" + + name = 'index' + + def __init__(self, min_val=None, max_val=None): + self.min_val = min_val + self.max_val = max_val + + def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements + try: + return parse_index(value, self.min_val, self.max_val) + except ValueError: + self.fail( + f'`{value}` cannot be parsed as a list of indices!', param, ctx) + + +class JsonParamType(click.ParamType): + """Defines a `click.ParamType` to parse arguments following JSON format.""" + + name = 'json' + + def convert(self, value, param, ctx): + return parse_json(value) + + +class DictAction(argparse.Action): + """Argparse action to split each argument into (key, value) pair. + + Each argument should be with `key=value` format, where `value` should be a + string with JSON format. + + For example, with an argparse: + + parser.add_argument('--options', nargs='+', action=DictAction) + + , you can use following arguments in the command line: + + --options \ + a=1 \ + b=1.5 + c=true \ + d=null \ + e=[1,2,3,4,5] \ + f='{"x":1,"y":2,"z":3}' \ + + NOTE: No space is allowed in each argument. Also, the dictionary-type + argument should be quoted with single quotation marks `'`. + """ + + def __call__(self, parser, namespace, values, option_string=None): + options = {} + for argument in values: + key, val = argument.split('=', maxsplit=1) + options[key] = parse_json(val) + setattr(namespace, self.dest, options) diff --git a/utils/tf_utils.py b/utils/tf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..80e48ed06e614571d920125d4b64fbfefbf804c0 --- /dev/null +++ b/utils/tf_utils.py @@ -0,0 +1,47 @@ +# python3.7 +"""Contains the utility functions to handle import TensorFlow modules. + +Basically, TensorFlow may not be supported in the current environment, or may +cause some warnings. This file provides functions to help ease TensorFlow +related imports, such as TensorBoard. +""" + +import warnings + +__all__ = ['import_tf', 'import_tb_writer'] + + +def import_tf(): + """Imports TensorFlow module if possible. + + If `ImportError` is raised, `None` will be returned. Otherwise, the module + `tensorflow` will be returned. + """ + warnings.filterwarnings('ignore', category=FutureWarning) + try: + import tensorflow as tf # pylint: disable=import-outside-toplevel + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + module = tf + except ImportError: + module = None + warnings.filterwarnings('default', category=FutureWarning) + return module + + +def import_tb_writer(): + """Imports the SummaryWriter of TensorBoard. + + If `ImportError` is raised, `None` will be returned. Otherwise, the class + `SummaryWriter` will be returned. + + NOTE: This function attempts to import `SummaryWriter` from + `torch.utils.tensorboard`. But it does not necessarily mean the import + always succeeds because installing TensorBoard is not a duty of `PyTorch`. + """ + warnings.filterwarnings('ignore', category=FutureWarning) + try: + from torch.utils.tensorboard import SummaryWriter # pylint: disable=import-outside-toplevel + except ImportError: # In case TensorBoard is not supported. + SummaryWriter = None + warnings.filterwarnings('default', category=FutureWarning) + return SummaryWriter diff --git a/utils/visualizers/__init__.py b/utils/visualizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..df9fbaca361c3802c0f8221c053b61bf66a2456f --- /dev/null +++ b/utils/visualizers/__init__.py @@ -0,0 +1,14 @@ +# python3.7 +"""Collects all visualizers.""" + +from .grid_visualizer import GridVisualizer +from .gif_visualizer import GifVisualizer +from .html_visualizer import HtmlVisualizer +from .html_visualizer import HtmlReader +from .video_visualizer import VideoVisualizer +from .video_visualizer import VideoReader + +__all__ = [ + 'GridVisualizer', 'GifVisualizer', 'HtmlVisualizer', 'HtmlReader', + 'VideoVisualizer', 'VideoReader' +] diff --git a/utils/visualizers/gif_visualizer.py b/utils/visualizers/gif_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a5528e8af79fda2e3840c67cbf60ea87ff273f4c --- /dev/null +++ b/utils/visualizers/gif_visualizer.py @@ -0,0 +1,79 @@ +# python3.7 +"""Contains the visualizer to visualize images as a GIF.""" + +from PIL import Image + +from ..image_utils import parse_image_size +from ..image_utils import load_image +from ..image_utils import resize_image +from ..image_utils import list_images_from_dir + +__all__ = ['GifVisualizer'] + + +class GifVisualizer(object): + """Defines the visualizer that visualizes an image collection as GIF.""" + + def __init__(self, image_size=None, duration=100, loop=0): + """Initializes the GIF visualizer. + + Args: + image_size: Size for image visualization. (default: None) + duration: Duration between two frames, in milliseconds. + (default: 100) + loop: How many times to loop the GIF. `0` means infinite. + (default: 0) + """ + self.set_image_size(image_size) + self.set_duration(duration) + self.set_loop(loop) + + def set_image_size(self, image_size=None): + """Sets the image size of the GIF.""" + height, width = parse_image_size(image_size) + self.image_height = height + self.image_width = width + + def set_duration(self, duration=100): + """Sets the GIF duration.""" + self.duration = duration + + def set_loop(self, loop=0): + """Sets how many times the GIF will be looped. `0` means infinite.""" + self.loop = loop + + def visualize_collection(self, images, save_path): + """Visualizes a collection of images one by one.""" + height, width = images[0].shape[0:2] + height = self.image_height or height + width = self.image_width or width + pil_images = [] + for image in images: + if image.shape[0:2] != (height, width): + image = resize_image(image, (width, height)) + pil_images.append(Image.fromarray(image)) + pil_images[0].save(save_path, format='GIF', save_all=True, + append_images=pil_images[1:], + duration=self.duration, + loop=self.loop) + + def visualize_list(self, image_list, save_path): + """Visualizes a list of image files.""" + height, width = load_image(image_list[0]).shape[0:2] + height = self.image_height or height + width = self.image_width or width + pil_images = [] + for filename in image_list: + image = load_image(filename) + if image.shape[0:2] != (height, width): + image = resize_image(image, (width, height)) + pil_images.append(Image.fromarray(image)) + pil_images[0].save(save_path, format='GIF', save_all=True, + append_images=pil_images[1:], + duration=self.duration, + loop=self.loop) + + def visualize_directory(self, directory, save_path): + """Visualizes all images under a directory.""" + image_list = list_images_from_dir(directory) + self.visualize_list(image_list, save_path) diff --git a/utils/visualizers/grid_visualizer.py b/utils/visualizers/grid_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..291e5fee45816a9775242c3a138ebd0f55f1df20 --- /dev/null +++ b/utils/visualizers/grid_visualizer.py @@ -0,0 +1,232 @@ +# python3.7 +"""Contains the visualizer to visualize images by composing them as a gird.""" + +from ..image_utils import get_blank_image +from ..image_utils import get_grid_shape +from ..image_utils import parse_image_size +from ..image_utils import load_image +from ..image_utils import save_image +from ..image_utils import resize_image +from ..image_utils import list_images_from_dir + +__all__ = ['GridVisualizer'] + + +class GridVisualizer(object): + """Defines the visualizer that visualizes images as a grid. + + Basically, given a collection of images, this visualizer stitches them one + by one. Notably, this class also supports adding spaces between images, + adding borders around images, and using white/black background. + + Example: + + grid = GridVisualizer(num_rows, num_cols) + for i in range(num_rows): + for j in range(num_cols): + grid.add(i, j, image) + grid.save('visualize.jpg') + """ + + def __init__(self, + grid_size=0, + num_rows=0, + num_cols=0, + is_portrait=False, + image_size=None, + image_channels=0, + row_spacing=0, + col_spacing=0, + border_left=0, + border_right=0, + border_top=0, + border_bottom=0, + use_black_background=True): + """Initializes the grid visualizer. + + Args: + grid_size: Total number of cells, i.e., height * width. (default: 0) + num_rows: Number of rows. (default: 0) + num_cols: Number of columns. (default: 0) + is_portrait: Whether the grid should be portrait or landscape. + This is only used when it requires to compute `num_rows` and + `num_cols` automatically. See function `get_grid_shape()` in + file `./image_utils.py` for details. (default: False) + image_size: Size to visualize each image. (default: 0) + image_channels: Number of image channels. (default: 0) + row_spacing: Spacing between rows. (default: 0) + col_spacing: Spacing between columns. (default: 0) + border_left: Width of left border. (default: 0) + border_right: Width of right border. (default: 0) + border_top: Width of top border. (default: 0) + border_bottom: Width of bottom border. (default: 0) + use_black_background: Whether to use black background. + (default: True) + """ + self.reset(grid_size, num_rows, num_cols, is_portrait) + self.set_image_size(image_size) + self.set_image_channels(image_channels) + self.set_row_spacing(row_spacing) + self.set_col_spacing(col_spacing) + self.set_border_left(border_left) + self.set_border_right(border_right) + self.set_border_top(border_top) + self.set_border_bottom(border_bottom) + self.set_background(use_black_background) + self.grid = None + + def reset(self, + grid_size=0, + num_rows=0, + num_cols=0, + is_portrait=False): + """Resets the grid shape, i.e., number of rows/columns.""" + if grid_size > 0: + num_rows, num_cols = get_grid_shape(grid_size, + height=num_rows, + width=num_cols, + is_portrait=is_portrait) + self.grid_size = num_rows * num_cols + self.num_rows = num_rows + self.num_cols = num_cols + self.grid = None + + def set_image_size(self, image_size=None): + """Sets the image size of each cell in the grid.""" + height, width = parse_image_size(image_size) + self.image_height = height + self.image_width = width + + def set_image_channels(self, image_channels=0): + """Sets the number of channels of the grid.""" + self.image_channels = image_channels + + def set_row_spacing(self, row_spacing=0): + """Sets the spacing between grid rows.""" + self.row_spacing = row_spacing + + def set_col_spacing(self, col_spacing=0): + """Sets the spacing between grid columns.""" + self.col_spacing = col_spacing + + def set_border_left(self, border_left=0): + """Sets the width of the left border of the grid.""" + self.border_left = border_left + + def set_border_right(self, border_right=0): + """Sets the width of the right border of the grid.""" + self.border_right = border_right + + def set_border_top(self, border_top=0): + """Sets the width of the top border of the grid.""" + self.border_top = border_top + + def set_border_bottom(self, border_bottom=0): + """Sets the width of the bottom border of the grid.""" + self.border_bottom = border_bottom + + def set_background(self, use_black=True): + """Sets the grid background.""" + self.use_black_background = use_black + + def init_grid(self): + """Initializes the grid with a blank image.""" + assert self.num_rows > 0 + assert self.num_cols > 0 + assert self.image_height > 0 + assert self.image_width > 0 + assert self.image_channels > 0 + grid_height = (self.image_height * self.num_rows + + self.row_spacing * (self.num_rows - 1) + + self.border_top + self.border_bottom) + grid_width = (self.image_width * self.num_cols + + self.col_spacing * (self.num_cols - 1) + + self.border_left + self.border_right) + self.grid = get_blank_image(grid_height, grid_width, + channels=self.image_channels, + use_black=self.use_black_background) + + def add(self, i, j, image): + """Adds an image into the grid. + + NOTE: The input image is assumed to be with `RGB` channel order. + """ + if self.grid is None: + height, width = image.shape[0:2] + channels = 1 if image.ndim == 2 else image.shape[2] + height = self.image_height or height + width = self.image_width or width + channels = self.image_channels or channels + self.set_image_size((height, width)) + self.set_image_channels(channels) + self.init_grid() + if image.shape[0:2] != (self.image_height, self.image_width): + image = resize_image(image, (self.image_width, self.image_height)) + y = self.border_top + i * (self.image_height + self.row_spacing) + x = self.border_left + j * (self.image_width + self.col_spacing) + self.grid[y:y + self.image_height, x:x + self.image_width] = image + + def visualize_collection(self, + images, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=False, + is_row_major=True): + """Visualizes a collection of images one by one.""" + self.grid = None + self.reset(grid_size=len(images), + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait) + for idx, image in enumerate(images): + if is_row_major: + row_idx, col_idx = divmod(idx, self.num_cols) + else: + col_idx, row_idx = divmod(idx, self.num_rows) + self.add(row_idx, col_idx, image) + if save_path: + self.save(save_path) + + def visualize_list(self, + image_list, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=False, + is_row_major=True): + """Visualizes a list of image files.""" + self.grid = None + self.reset(grid_size=len(image_list), + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait) + for idx, filename in enumerate(image_list): + image = load_image(filename) + if is_row_major: + row_idx, col_idx = divmod(idx, self.num_cols) + else: + col_idx, row_idx = divmod(idx, self.num_rows) + self.add(row_idx, col_idx, image) + if save_path: + self.save(save_path) + + def visualize_directory(self, + directory, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=False, + is_row_major=True): + """Visualizes all images under a directory.""" + image_list = list_images_from_dir(directory) + self.visualize_list(image_list=image_list, + save_path=save_path, + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait, + is_row_major=is_row_major) + + def save(self, path): + """Saves the grid.""" + save_image(path, self.grid) diff --git a/utils/visualizers/html_visualizer.py b/utils/visualizers/html_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb63385af54021febb4791bdda6534324d2d901 --- /dev/null +++ b/utils/visualizers/html_visualizer.py @@ -0,0 +1,438 @@ +# python3.7 +"""Contains the visualizer to visualize images with HTML page.""" + +import os +import base64 +import cv2 +import numpy as np +from bs4 import BeautifulSoup + +from ..image_utils import get_grid_shape +from ..image_utils import parse_image_size +from ..image_utils import load_image +from ..image_utils import resize_image +from ..image_utils import list_images_from_dir + +__all__ = ['HtmlVisualizer', 'HtmlReader'] + + +def get_sortable_html_header(column_name_list, sort_by_ascending=False): + """Gets header for sortable HTML page. + + Basically, the HTML page contains a sortable table, where user can sort the + rows by a particular column by clicking the column head. + + Example: + + column_name_list = [name_1, name_2, name_3] + header = get_sortable_html_header(column_name_list) + footer = get_sortable_html_footer() + sortable_table = ... + html_page = header + sortable_table + footer + + Args: + column_name_list: List of column header names. + sort_by_ascending: Default sorting order. If set as `True`, the HTML + page will be sorted by ascending order when the header is clicked + for the first time. + + Returns: + A string, which represents for the header for a sortable HTML page. + """ + header = '\n'.join([ + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '', + '']) + for idx, name in enumerate(column_name_list): + header += f' \n' + header += '\n' + header += '\n' + header += '\n' + + return header + + +def get_sortable_html_footer(): + """Gets footer for sortable HTML page. + + Check function `get_sortable_html_header()` for more details. + """ + return '\n
{name}
\n\n\n\n' + + +def encode_image_to_html_str(image, image_size=None): + """Encodes an image to HTML language. + + NOTE: Input image is always assumed to be with `RGB` channel order. + + Args: + image: The input image to encode. Should be with `RGB` channel order. + image_size: This field is used to resize the image before encoding. + `None` disables resizing. (default: None) + + Returns: + A string that represents the encoded image. + """ + if image is None: + return '' + + assert image.ndim == 3 and image.shape[2] in [1, 3, 4] + if image.shape[2] == 3: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + elif image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA) + + # Resize the image if needed. + height, width = parse_image_size(image_size) + height = height or image.shape[0] + width = width or image.shape[1] + if image.shape[0:2] != (height, width): + image = resize_image(image, (width, height)) + + # Encode the image to HTML-format string. + if image.shape[2] == 4: # Use `png` to encoder RGBA image. + encoded = cv2.imencode('.png', image)[1].tostring() + encoded_base64 = base64.b64encode(encoded).decode('utf-8') + html_str = f'' + else: + encoded = cv2.imencode('.jpg', image)[1].tostring() + encoded_base64 = base64.b64encode(encoded).decode('utf-8') + html_str = f'' + + return html_str + + +def decode_html_str_to_image(html_str, image_size=None): + """Decodes an image from HTML string. + + Args: + html_str: An HTML string that represents an image. + image_size: This field is used to resize the image after decoding. + `None` disables resizing. (default: None) + + Returns: + An image with `RGB` channel order. + """ + if not html_str: + return None + + assert isinstance(html_str, str) + image_str = html_str.split(',')[-1].strip() + encoded_image = base64.b64decode(image_str) + encoded_image_numpy = np.frombuffer(encoded_image, dtype=np.uint8) + image = cv2.imdecode(encoded_image_numpy, flags=cv2.IMREAD_UNCHANGED) + + if image.ndim == 2: + image = image[:, :, np.newaxis] + assert image.ndim == 3 and image.shape[2] in [1, 3, 4] + if image.shape[2] == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + + # Resize the image if needed. + height, width = parse_image_size(image_size) + height = height or image.shape[0] + width = width or image.shape[1] + if image.shape[0:2] != (height, width): + image = resize_image(image, (width, height)) + + return image + + +class HtmlVisualizer(object): + """Defines the HTML visualizer that visualizes images on an HTML page. + + This class can be used to visualize image results on an HTML page. + Basically, it is based on an HTML-format sorted table with helper functions + `get_sortable_html_header()`, `get_sortable_html_footer()`, and + `encode_image_to_html_str()`. To simplify the usage, specifying the + following fields are enough to create a visualization page: + + (1) num_rows: Number of rows of the table (header-row exclusive). + (2) num_cols: Number of columns of the table. + (3) header_contents (optional): Title of each column. + + NOTE: `grid_size` can be used to assign `num_rows` and `num_cols` + automatically. + + Example: + + html = HtmlVisualizer(num_rows, num_cols) + html.set_headers([...]) + for i in range(num_rows): + for j in range(num_cols): + html.set_cell(i, j, text=..., image=..., highlight=False) + html.save('visualize.html') + """ + + def __init__(self, + grid_size=0, + num_rows=0, + num_cols=0, + is_portrait=True, + image_size=None): + """Initializes the html visualizer. + + Args: + grid_size: Total number of cells, i.e., height * width. (default: 0) + num_rows: Number of rows. (default: 0) + num_cols: Number of columns. (default: 0) + is_portrait: Whether the HTML page should be portrait or landscape. + This is only used when it requires to compute `num_rows` and + `num_cols` automatically. See function `get_grid_shape()` in + file `./image_utils.py` for details. (default: True) + image_size: Size to visualize each image. (default: None) + """ + self.reset(grid_size, num_rows, num_cols, is_portrait) + self.set_image_size(image_size) + + def reset(self, + grid_size=0, + num_rows=0, + num_cols=0, + is_portrait=True): + """Resets the HTML page with new number of rows and columns.""" + if grid_size > 0: + num_rows, num_cols = get_grid_shape(grid_size, + height=num_rows, + width=num_cols, + is_portrait=is_portrait) + self.grid_size = num_rows * num_cols + self.num_rows = num_rows + self.num_cols = num_cols + self.headers = ['' for _ in range(self.num_cols)] + self.cells = [[{ + 'text': '', + 'image': '', + 'highlight': False, + } for _ in range(self.num_cols)] for _ in range(self.num_rows)] + + def set_image_size(self, image_size=None): + """Sets the image size of each cell in the HTML page.""" + self.image_size = image_size + + def set_header(self, col_idx, content): + """Sets the content of a particular header by column index.""" + self.headers[col_idx] = content + + def set_headers(self, contents): + """Sets the contents of all headers.""" + assert isinstance(contents, (list, tuple)) + assert len(contents) == self.num_cols + for col_idx, content in enumerate(contents): + self.set_header(col_idx, content) + + def set_cell(self, row_idx, col_idx, text='', image=None, highlight=False): + """Sets the content of a particular cell. + + Basically, a cell contains some text as well as an image. Both text and + image can be empty. + + NOTE: The image is assumed to be with `RGB` channel order. + + Args: + row_idx: Row index of the cell to edit. + col_idx: Column index of the cell to edit. + text: Text to add into the target cell. (default: None) + image: Image to show in the target cell. Should be with `RGB` + channel order. (default: None) + highlight: Whether to highlight this cell. (default: False) + """ + self.cells[row_idx][col_idx]['text'] = text + self.cells[row_idx][col_idx]['image'] = encode_image_to_html_str( + image, self.image_size) + self.cells[row_idx][col_idx]['highlight'] = bool(highlight) + + def visualize_collection(self, + images, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=True, + is_row_major=True): + """Visualizes a collection of images one by one.""" + self.reset(grid_size=len(images), + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait) + for idx, image in enumerate(images): + if is_row_major: + row_idx, col_idx = divmod(idx, self.num_cols) + else: + col_idx, row_idx = divmod(idx, self.num_rows) + self.set_cell(row_idx, col_idx, text=f'Index {idx:03d}', + image=image) + if save_path: + self.save(save_path) + + def visualize_list(self, + image_list, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=True, + is_row_major=True): + """Visualizes a list of image files.""" + self.reset(grid_size=len(image_list), + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait) + for idx, filename in enumerate(image_list): + basename = os.path.basename(filename) + image = load_image(filename) + if is_row_major: + row_idx, col_idx = divmod(idx, self.num_cols) + else: + col_idx, row_idx = divmod(idx, self.num_rows) + self.set_cell(row_idx, col_idx, + text=f'{basename} (index {idx:03d})', image=image) + if save_path: + self.save(save_path) + + def visualize_directory(self, + directory, + save_path=None, + num_rows=0, + num_cols=0, + is_portrait=True, + is_row_major=True): + """Visualizes all images under a directory.""" + image_list = list_images_from_dir(directory) + self.visualize_list(image_list=image_list, + save_path=save_path, + num_rows=num_rows, + num_cols=num_cols, + is_portrait=is_portrait, + is_row_major=is_row_major) + + def save(self, path): + """Saves the HTML page.""" + html = '' + for i in range(self.num_rows): + html += '\n' + for j in range(self.num_cols): + text = self.cells[i][j]['text'] + image = self.cells[i][j]['image'] + if self.cells[i][j]['highlight']: + color = ' bgcolor="#FF8888"' + else: + color = '' + if text: + html += f' {text}

{image}\n' + else: + html += f' {image}\n' + html += '\n' + + header = get_sortable_html_header(self.headers) + footer = get_sortable_html_footer() + + with open(path, 'w') as f: + f.write(header + html + footer) + + +class HtmlReader(object): + """Defines the HTML page reader. + + This class can be used to parse results from the visualization page + generated by `HtmlVisualizer`. + + Example: + + html = HtmlReader(html_path) + for j in range(html.num_cols): + header = html.get_header(j) + for i in range(html.num_rows): + for j in range(html.num_cols): + text = html.get_text(i, j) + image = html.get_image(i, j, image_size=None) + """ + def __init__(self, path): + """Initializes by loading the content from file.""" + self.path = path + + # Load content. + with open(path, 'r') as f: + self.html = BeautifulSoup(f, 'html.parser') + + # Parse headers. + thead = self.html.find('thead') + headers = thead.findAll('th') + self.headers = [] + for header in headers: + self.headers.append(header.text) + self.num_cols = len(self.headers) + + # Parse cells. + tbody = self.html.find('tbody') + rows = tbody.findAll('tr') + self.cells = [] + for row in rows: + cells = row.findAll('td') + self.cells.append([]) + for cell in cells: + self.cells[-1].append({ + 'text': cell.text, + 'image': cell.find('img')['src'], + }) + assert len(self.cells[-1]) == self.num_cols + self.num_rows = len(self.cells) + + def get_header(self, j): + """Gets header for a particular column.""" + return self.headers[j] + + def get_text(self, i, j): + """Gets text from a particular cell.""" + return self.cells[i][j]['text'] + + def get_image(self, i, j, image_size=None): + """Gets image from a particular cell.""" + return decode_html_str_to_image(self.cells[i][j]['image'], image_size) diff --git a/utils/visualizers/test.py b/utils/visualizers/test.py new file mode 100644 index 0000000000000000000000000000000000000000..765ebf9c721b0792fb373ecb515ebf188f728df0 --- /dev/null +++ b/utils/visualizers/test.py @@ -0,0 +1,97 @@ +# python3.7 +"""Unit test for visualizer.""" + +import os +import skvideo.datasets + +from ..image_utils import save_image +from . import GridVisualizer +from . import HtmlVisualizer +from . import HtmlReader +from . import GifVisualizer +from . import VideoVisualizer +from . import VideoReader + +__all__ = ['test_visualizer'] + +_TEST_DIR = 'visualizer_test' + + +def test_visualizer(test_dir=_TEST_DIR): + """Tests visualizers.""" + print('========== Start Visualizer Test ==========') + + frame_dir = os.path.join(test_dir, 'test_frames') + os.makedirs(frame_dir, exist_ok=True) + + print('===== Testing `VideoReader` =====') + # Total 132 frames, with size (720, 1080). + video_reader = VideoReader(skvideo.datasets.bigbuckbunny()) + frame_height = video_reader.frame_height + frame_width = video_reader.frame_width + frame_size = (frame_height, frame_width) + half_size = (frame_height // 2, frame_width // 2) + # Save frames as the test set. + for idx in range(80): + frame = video_reader.read() + save_image(os.path.join(frame_dir, f'{idx:02d}.png'), frame) + + print('===== Testing `GirdVisualizer` =====') + grid_visualizer = GridVisualizer() + grid_visualizer.set_row_spacing(30) + grid_visualizer.set_col_spacing(30) + grid_visualizer.set_background(use_black=True) + path = os.path.join(test_dir, 'portrait_row_major_ori_space30_black.png') + grid_visualizer.visualize_directory(frame_dir, path, + is_portrait=True, is_row_major=True) + path = os.path.join( + test_dir, 'landscape_col_major_downsample_space15_white.png') + grid_visualizer.set_image_size(half_size) + grid_visualizer.set_row_spacing(15) + grid_visualizer.set_col_spacing(15) + grid_visualizer.set_background(use_black=False) + grid_visualizer.visualize_directory(frame_dir, path, + is_portrait=False, is_row_major=False) + + print('===== Testing `HtmlVisualizer` =====') + html_visualizer = HtmlVisualizer() + path = os.path.join(test_dir, 'portrait_col_major_ori.html') + html_visualizer.visualize_directory(frame_dir, path, + is_portrait=True, is_row_major=False) + path = os.path.join(test_dir, 'landscape_row_major_downsample.html') + html_visualizer.set_image_size(half_size) + html_visualizer.visualize_directory(frame_dir, path, + is_portrait=False, is_row_major=True) + + print('===== Testing `HtmlReader` =====') + path = os.path.join(test_dir, 'landscape_row_major_downsample.html') + html_reader = HtmlReader(path) + for j in range(html_reader.num_cols): + assert html_reader.get_header(j) == '' + parsed_dir = os.path.join(test_dir, 'parsed_frames') + os.makedirs(parsed_dir, exist_ok=True) + for i in range(html_reader.num_rows): + for j in range(html_reader.num_cols): + idx = i * html_reader.num_cols + j + assert html_reader.get_text(i, j).endswith(f'(index {idx:03d})') + image = html_reader.get_image(i, j, image_size=frame_size) + assert image.shape[0:2] == frame_size + save_image(os.path.join(parsed_dir, f'{idx:02d}.png'), image) + + print('===== Testing `GifVisualizer` =====') + gif_visualizer = GifVisualizer() + path = os.path.join(test_dir, 'gif_ori.gif') + gif_visualizer.visualize_directory(frame_dir, path) + gif_visualizer.set_image_size(half_size) + path = os.path.join(test_dir, 'gif_downsample.gif') + gif_visualizer.visualize_directory(frame_dir, path) + + print('===== Testing `VideoVisualizer` =====') + video_visualizer = VideoVisualizer() + path = os.path.join(test_dir, 'video_ori.mp4') + video_visualizer.visualize_directory(frame_dir, path) + path = os.path.join(test_dir, 'video_downsample.mp4') + video_visualizer.set_frame_size(half_size) + video_visualizer.visualize_directory(frame_dir, path) + + print('========== Finish Visualizer Test ==========') diff --git a/utils/visualizers/video_visualizer.py b/utils/visualizers/video_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c3a5934224edc5d557ad1b1458d4f776148a75 --- /dev/null +++ b/utils/visualizers/video_visualizer.py @@ -0,0 +1,173 @@ +# python3.7 +"""Contains the visualizer to visualize images as a video. + +This file relies on `FFmpeg`. Use `sudo apt-get install ffmpeg` and +`brew install ffmpeg` to install on Ubuntu and MacOS respectively. +""" + +import os.path +from skvideo.io import FFmpegWriter +from skvideo.io import FFmpegReader + +from ..image_utils import parse_image_size +from ..image_utils import load_image +from ..image_utils import resize_image +from ..image_utils import list_images_from_dir + +__all__ = ['VideoVisualizer', 'VideoReader'] + + +class VideoVisualizer(object): + """Defines the video visualizer that presents images as a video.""" + + def __init__(self, + path=None, + frame_size=None, + fps=25.0, + codec='libx264', + pix_fmt='yuv420p', + crf=1): + """Initializes the video visualizer. + + Args: + path: Path to write the video. (default: None) + frame_size: Frame size, i.e., (height, width). (default: None) + fps: Frames per second. (default: 24) + codec: Codec. (default: `libx264`) + pix_fmt: Pixel format. (default: `yuv420p`) + crf: Constant rate factor, which controls the compression. The + larger this field is, the higher compression and lower quality. + `0` means no compression and consequently the highest quality. + To enable QuickTime playing (requires YUV to be 4:2:0, but + `crf = 0` results YUV to be 4:4:4), please set this field as + at least 1. (default: 1) + """ + self.set_path(path) + self.set_frame_size(frame_size) + self.set_fps(fps) + self.set_codec(codec) + self.set_pix_fmt(pix_fmt) + self.set_crf(crf) + self.video = None + + def set_path(self, path=None): + """Sets the path to save the video.""" + self.path = path + + def set_frame_size(self, frame_size=None): + """Sets the video frame size.""" + height, width = parse_image_size(frame_size) + self.frame_height = height + self.frame_width = width + + def set_fps(self, fps=25.0): + """Sets the FPS (frame per second) of the video.""" + self.fps = fps + + def set_codec(self, codec='libx264'): + """Sets the video codec.""" + self.codec = codec + + def set_pix_fmt(self, pix_fmt='yuv420p'): + """Sets the video pixel format.""" + self.pix_fmt = pix_fmt + + def set_crf(self, crf=1): + """Sets the CRF (constant rate factor) of the video.""" + self.crf = crf + + def init_video(self): + """Initializes an empty video with expected settings.""" + assert not os.path.exists(self.path), f'Video `{self.path}` existed!' + assert self.frame_height > 0 + assert self.frame_width > 0 + + video_setting = { + '-r': f'{self.fps:.2f}', + '-s': f'{self.frame_width}x{self.frame_height}', + '-vcodec': f'{self.codec}', + '-crf': f'{self.crf}', + '-pix_fmt': f'{self.pix_fmt}', + } + self.video = FFmpegWriter(self.path, outputdict=video_setting) + + def add(self, frame): + """Adds a frame into the video visualizer. + + NOTE: The input frame is assumed to be with `RGB` channel order. + """ + if self.video is None: + height, width = frame.shape[0:2] + height = self.frame_height or height + width = self.frame_width or width + self.set_frame_size((height, width)) + self.init_video() + if frame.shape[0:2] != (self.frame_height, self.frame_width): + frame = resize_image(frame, (self.frame_width, self.frame_height)) + self.video.writeFrame(frame) + + def visualize_collection(self, images, save_path=None): + """Visualizes a collection of images one by one.""" + if save_path is not None and save_path != self.path: + self.save() + self.set_path(save_path) + for image in images: + self.add(image) + self.save() + + def visualize_list(self, image_list, save_path=None): + """Visualizes a list of image files.""" + if save_path is not None and save_path != self.path: + self.save() + self.set_path(save_path) + for filename in image_list: + image = load_image(filename) + self.add(image) + self.save() + + def visualize_directory(self, directory, save_path=None): + """Visualizes all images under a directory.""" + image_list = list_images_from_dir(directory) + self.visualize_list(image_list, save_path) + + def save(self): + """Saves the video by closing the file.""" + if self.video is not None: + self.video.close() + self.video = None + self.set_path(None) + + +class VideoReader(object): + """Defines the video reader. + + This class can be used to read frames from a given video. + + NOTE: Each frame can be read only once. + TODO: Fix this? + """ + + def __init__(self, path, inputdict=None): + """Initializes the video reader by loading the video from disk.""" + self.path = path + self.video = FFmpegReader(path, inputdict=inputdict) + + self.length = self.video.inputframenum + self.frame_height = self.video.inputheight + self.frame_width = self.video.inputwidth + self.fps = self.video.inputfps + self.pix_fmt = self.video.pix_fmt + + def __del__(self): + """Releases the opened video.""" + self.video.close() + + def read(self, image_size=None): + """Reads the next frame.""" + frame = next(self.video.nextFrame()) + height, width = parse_image_size(image_size) + height = height or frame.shape[0] + width = width or frame.shape[1] + if frame.shape[0:2] != (height, width): + frame = resize_image(frame, (width, height)) + return frame