diff --git a/README.md b/README.md
index 44644c0def69245cf1e7d20e921c928e36e0f6d9..8264706fc779a73624707592d59f3b8257ea32be 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,122 @@
----
-title: Sat3density
-emoji: 🏆
-colorFrom: green
-colorTo: blue
-sdk: gradio
-sdk_version: 3.41.2
-app_file: app.py
-pinned: false
-license: other
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Sat2Density: Faithful Density Learning from Satellite-Ground Image Pairs
+
+> [Ming Qian](https://qianmingduowan.github.io/), Jincheng Xiong, [Gui-Song Xia](http://www.captain-whu.com/xia_En.html), [Nan Xue](https://xuenan.net)
+>
+> IEEE/CVF International Conference on Computer Vision (ICCV), 2023
+>
+> [Project](https://sat2density.github.io/) | [Paper](https://arxiv.org/abs/2303.14672) | [Data]() | [Install.md](docs/INSTALL.md)
+
+>
+>
+>
+>
+>
+
+>
+>
+>
+>
+>
+
+>
+>
+>
+>
+>
+
+>
+>
+>
+>
+>
+
+## Checkpoints Downloading
+> Two checkpoints for CVACT and CVUSA can be found from [thisurl](https://github.com/sat2density/checkpoints/releases). You can also run the following command to download them.
+```
+bash scripts/download_weights.sh
+```
+
+## QuickStart Demo
+### Video Synthesis
+ #### Example Usage
+ ```
+ python test.py --yaml=sat2density_cvact \
+ --test_ckpt_path=2u87bj8w \
+ --task=test_vid \
+ --demo_img=demo_img/case1/satview-input.png \
+ --sty_img=demo_img/case1/groundview.image.png \
+ --save_dir=results/case1
+ ```
+ ####
+
+### Illumination Interpolation
+
+```
+python test.py --task=test_interpolation \
+--yaml=sat2density_cvact \
+--test_ckpt_path=2u87bj8w \
+--sty_img1=demo_img/case9/groundview.image.png \
+--sty_img2=demo_img/case7/groundview.image.png \
+--demo_img=demo_img/case3/satview-input.png \
+--save_dir=results/case2
+```
+
+## Train & Inference
+- *We trained our model using 1 V100 32GB GPU. The training phase will take about 20 hours.*
+- *For data preparation, please check out [data.md](dataset/INSTALL.md).*
+
+
+
+
+### Inference
+
+To test Center Ground-View Synthesis setting
+If you want save results, please add --task=vis_test
+```bash
+# CVACT
+python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w
+# CVUSA
+python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4
+```
+
+To test inference with different illumination
+```bash
+# CVACT
+bash inference/single_style_test_cvact.sh
+# CVUSA
+bash inference/single_style_test_cvusa.sh
+```
+
+To test synthesis ground videos
+```bash
+bash inference/synthesis_video.sh
+```
+
+## Training
+
+### Training command
+
+```bash
+# CVACT
+CUDA_VISIBLE_DEVICES=X python train.py --yaml=sat2density_cvact
+# CVUSA
+CUDA_VISIBLE_DEVICES=X python train.py --yaml=sat2density_cvusa
+```
+
+## Citation
+If you use this code for your research, please cite
+
+```
+@inproceedings{qian2021sat2density,
+ title={Sat2Density: Faithful Density Learning from Satellite-Ground Image Pairs},
+ author={Qian, Ming and Xiong, Jincheng and Xia, Gui-Song and Xue, Nan},
+ booktitle={ICCV},
+ year={2023}
+}
+```
+
+## License
+This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
+For commercial use, please contact [mingqian@whu.edu.cn].
diff --git a/__pycache__/options.cpython-38.pyc b/__pycache__/options.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae86524e6a0305cbf3f915a831f72c3d7e62ade3
Binary files /dev/null and b/__pycache__/options.cpython-38.pyc differ
diff --git a/__pycache__/test.cpython-38.pyc b/__pycache__/test.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c26acf14978dd41cdcadaf20b1fec863948c7f3
Binary files /dev/null and b/__pycache__/test.cpython-38.pyc differ
diff --git a/__pycache__/utils.cpython-38.pyc b/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c490c388bbf3dd8d18683faeebf085e3623d0b76
Binary files /dev/null and b/__pycache__/utils.cpython-38.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..e56f1de13f4647636798b3e51522d193342b9ea4
--- /dev/null
+++ b/app.py
@@ -0,0 +1,244 @@
+import gradio as gr
+import numpy as np
+import os
+from PIL import Image
+import torch
+import torchvision.transforms as transforms
+import options
+import test
+import importlib
+from scipy.interpolate import interp1d, splev, splprep
+import cv2
+
+
+def get_single(sat_img, style_img, x_offset, y_offset):
+ name = ''
+ for i in [name for name in os.listdir('demo_img') if 'case' in name]:
+ style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB')
+ style =np.array(style)
+ if (style == style_img).all():
+ name = i
+ break
+
+ input_dict = {}
+ trans = transforms.ToTensor()
+ input_dict['sat'] = trans(sat_img)
+ input_dict['pano'] = trans(style_img)
+ input_dict['paths'] = "demo.png"
+ sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L"))
+ input_a = input_dict['pano']*sky
+ sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
+ input_dict['sky_histc'] = sky_histc
+ input_dict['sky_mask'] = sky
+
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].unsqueeze(0)
+
+ args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png",
+ "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"]
+ opt_cmd = options.parse_arguments(args=args)
+ opt = options.set(opt_cmd=opt_cmd)
+ opt.isTrain = False
+ opt.name = opt.yaml if opt.name is None else opt.name
+ opt.batch_size = 1
+
+ m = importlib.import_module("model.{}".format(opt.model))
+ model = m.Model(opt)
+
+ # m.load_dataset(opt)
+ model.build_networks(opt)
+ ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
+ model.netG.load_state_dict(ckpt['netG'])
+ model.netG.eval()
+
+ model.set_input(input_dict)
+
+ model.style_temp = model.sky_histc
+ opt.origin_H_W = [-(y_offset*256-128)/128, (x_offset*256-128)/128] # TODO: hard code should be removed in the future
+
+ model.forward(opt)
+
+ rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0))
+ rgb = np.array(rgb*255, dtype=np.uint8)
+ return rgb
+
+def get_video(sat_img, style_img, positions):
+ name = ''
+ for i in [name for name in os.listdir('demo_img') if 'case' in name]:
+ style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB')
+ style =np.array(style)
+ if (style == style_img).all():
+ name = i
+ break
+
+ input_dict = {}
+ trans = transforms.ToTensor()
+ input_dict['sat'] = trans(sat_img)
+ input_dict['pano'] = trans(style_img)
+ input_dict['paths'] = "demo.png"
+ sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L"))
+ input_a = input_dict['pano']*sky
+ sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
+ input_dict['sky_histc'] = sky_histc
+ input_dict['sky_mask'] = sky
+
+ for key in input_dict.keys():
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].unsqueeze(0)
+
+ args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png",
+ "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"]
+ opt_cmd = options.parse_arguments(args=args)
+ opt = options.set(opt_cmd=opt_cmd)
+ opt.isTrain = False
+ opt.name = opt.yaml if opt.name is None else opt.name
+ opt.batch_size = 1
+
+ m = importlib.import_module("model.{}".format(opt.model))
+ model = m.Model(opt)
+
+ # m.load_dataset(opt)
+ model.build_networks(opt)
+ ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
+ model.netG.load_state_dict(ckpt['netG'])
+ model.netG.eval()
+
+ model.set_input(input_dict)
+
+ model.style_temp = model.sky_histc
+
+ unique_lst = list(dict.fromkeys(positions))
+ pixels = []
+ for x in positions:
+ if x in unique_lst:
+ if x not in pixels:
+ pixels.append(x)
+ pixels = np.array(pixels)
+ tck, u = splprep(pixels.T, s=25, per=0)
+ u_new = np.linspace(u.min(), u.max(), 80)
+ x_new, y_new = splev(u_new, tck)
+ smooth_path = np.array([x_new,y_new]).T
+
+ rendered_image_list = []
+ rendered_depth_list = []
+
+
+ for i, (x,y) in enumerate(smooth_path):
+ opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future
+ print('Rendering at ({}, {})'.format(x,y))
+ model.forward(opt)
+
+ rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0))
+ rgb = np.array(rgb*255, dtype=np.uint8)
+ rendered_image_list.append(rgb)
+
+ rendered_depth_list.append(
+ model.out_put.depth[0,0].cpu().detach().numpy()
+ )
+
+ output_video_path = 'output_video.mp4'
+
+ # 设置视频的帧率、宽度和高度
+ frame_rate = 15
+ frame_width = 512
+ frame_height = 128
+
+ # 使用OpenCV创建视频写入对象,选择H.264编码器
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height))
+
+ # 遍历图像列表并将它们写入视频
+ for image_np in rendered_image_list:
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
+ out.write(image_np)
+
+ # 释放视频写入对象
+ out.release()
+
+ return "output_video.mp4"
+
+def copy_image(image):
+ return image
+
+def show_image_and_point(image, x, y):
+ x = int(x*image.shape[1])
+ y = image.shape[0]-int(y*image.shape[0])
+ mask = np.zeros(image.shape[:2])
+ radius = min(image.shape[0], image.shape[1])//60
+ for i in range(x-radius-2, x+radius+2):
+ for j in range(y-radius-2, y+radius+2):
+ if (i-x)**2+(j-y)**2<=radius**2:
+ mask[j, i] = 1
+ return (image, [(mask, 'render point')])
+
+def add_select_point(image, evt: gr.SelectData, state1):
+ if state1 == None:
+ state1 = []
+ x, y = evt.index
+ state1.append((x, y))
+ print(state1)
+ radius = min(image.shape[0], image.shape[1])//60
+ for i in range(x-radius-2, x+radius+2):
+ for j in range(y-radius-2, y+radius+2):
+ if (i-x)**2+(j-y)**2<=radius**2:
+ image[j, i, :] = 0
+ return image, state1
+
+def reset_select_points(image):
+ return image, []
+
+
+
+
+
+
+with gr.Blocks() as demo:
+ gr.Markdown("# Sat2Density Demos")
+ gr.Markdown("### select/upload the satllite image and select the style image")
+ with gr.Row():
+ with gr.Column():
+ sat_img = gr.Image(source='upload', shape=[256, 256], interactive=True)
+ img_examples = gr.Examples(examples=['demo_img/{}/satview-input.png'.format(i) for i in os.listdir('demo_img') if 'case' in i],
+ inputs=sat_img, outputs=None, examples_per_page=20)
+ with gr.Column():
+ style_img = gr.Image()
+ style_examples = gr.Examples(examples=['demo_img/{}/groundview.image.png'.format(i) for i in os.listdir('demo_img') if 'case' in i],
+ inputs=style_img, outputs=None, examples_per_page=20)
+
+
+ gr.Markdown("### select a certain point to generate single groundview image")
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ with gr.Column():
+ slider_x = gr.Slider(0.2, 0.8, 0.5, label="x-axis position")
+ slider_y = gr.Slider(0.2, 0.8, 0.5, label="y-axis position")
+ btn_single = gr.Button(label="demo1")
+
+ annotation_image = gr.AnnotatedImage()
+
+ out_single = gr.Image()
+
+ gr.Markdown("### draw a trajectory on the map to generate video")
+ state_select_points = gr.State()
+ with gr.Row():
+ with gr.Column():
+ draw_img = gr.Image(shape=[256, 256], interactive=True)
+ with gr.Column():
+ out_video = gr.Video()
+ reset_btn =gr.Button(value="Reset")
+ btn_video = gr.Button(label="demo1")
+
+ sat_img.change(copy_image, inputs = sat_img, outputs=draw_img)
+
+ draw_img.select(add_select_point, [draw_img, state_select_points], [draw_img, state_select_points])
+ sat_img.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image)
+ slider_x.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden')
+ slider_y.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden')
+ btn_single.click(get_single, inputs = [sat_img, style_img, slider_x, slider_y], outputs=out_single)
+ reset_btn.click(reset_select_points, [sat_img], [draw_img, state_select_points])
+ btn_video.click(get_video, inputs=[sat_img, style_img, state_select_points], outputs=out_video) # 触发
+
+
+demo.launch()
\ No newline at end of file
diff --git a/data/CVACT_Shi.py b/data/CVACT_Shi.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b4588ccbe61427e4bec55e7d32bb59428c1228f
--- /dev/null
+++ b/data/CVACT_Shi.py
@@ -0,0 +1,119 @@
+import torch,os
+from torch.utils.data.dataset import Dataset
+from PIL import Image
+import scipy.io as sio
+import torchvision.transforms as transforms
+
+def data_list(img_root,mode):
+ exist_aer_list = os.listdir(os.path.join(img_root , 'satview_correct'))
+ exist_grd_list = os.listdir(os.path.join(img_root , 'streetview'))
+ allDataList = os.path.join(img_root, 'ACT_data.mat')
+ anuData = sio.loadmat(allDataList)
+
+ all_data_list = []
+ for i in range(0, len(anuData['panoIds'])):
+ grd_id_align = anuData['panoIds'][i] + '_grdView.png'
+ sat_id_ori = anuData['panoIds'][i] + '_satView_polish.png'
+ all_data_list.append([grd_id_align, sat_id_ori])
+
+ data_list = []
+
+ if mode=='train':
+ training_inds = anuData['trainSet']['trainInd'][0][0] - 1
+ trainNum = len(training_inds)
+ for k in range(trainNum):
+ data_list.append(all_data_list[training_inds[k][0]])
+ else:
+ val_inds = anuData['valSet']['valInd'][0][0] - 1
+ valNum = len(val_inds)
+ for k in range(valNum):
+ data_list.append(all_data_list[val_inds[k][0]])
+
+
+ pano_list = [img_root + 'streetview/' + item[0] for item in data_list if item[0] in exist_grd_list and item[1] in exist_aer_list]
+
+ return pano_list
+
+def img_read(img,size=None,datatype='RGB'):
+ img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
+ if size:
+ if type(size) is int:
+ size = (size,size)
+ img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
+ img = transforms.ToTensor()(img)
+ return img
+
+
+class Dataset(Dataset):
+ def __init__(self, opt,split='train',sub=None,sty_img=None):
+ if sty_img:
+ assert sty_img.endswith('grdView.png')
+ demo_img_path = os.path.join(opt.data.root,'streetview',sty_img)
+ self.pano_list = [demo_img_path]
+
+ elif opt.task in ['test_vid','test_interpolation'] :
+ demo_img_path = os.path.join(opt.data.root,'streetview',opt.demo_img.replace('satView_polish.png','grdView.png'))
+ self.pano_list = [demo_img_path]
+
+ else:
+ self.pano_list = data_list(img_root=opt.data.root,mode=split)
+ if sub:
+ self.pano_list = self.pano_list[:sub]
+
+ # select some ground images to test the influence of different skys.
+ # different skys guide different illumination intensity, colors, and etc.
+ if opt.task == 'test_sty':
+ demo_name = [
+ 'dataset/CVACT/streetview/pPfo7qQ1fP_24rXrJ2Uxog_grdView.png',
+ 'dataset/CVACT/streetview/YL81FiK9PucIvAkr1FHkpA_grdView.png',
+ 'dataset/CVACT/streetview/Tzis1jBKHjbXiVB2oRYwAQ_grdView.png',
+ 'dataset/CVACT/streetview/eqGgeBLGXRhSj6c-0h0KoQ_grdView.png',
+ 'dataset/CVACT/streetview/pdZmLHYEhe2PHj_8-WHMhw_grdView.png',
+ 'dataset/CVACT/streetview/ehsu9Q3iTin5t52DM-MwyQ_grdView.png',
+ 'dataset/CVACT/streetview/agLEcuq3_-qFj7wwGbktVg_grdView.png',
+ 'dataset/CVACT/streetview/HwQIDdMI3GfHyPGtCSo6aA_grdView.png',
+ 'dataset/CVACT/streetview/hV8svb3ZVXcQ0AtTRFE1dQ_grdView.png',
+ 'dataset/CVACT/streetview/fzq2mBfKP3UIczAd9KpMMg_grdView.png',
+ 'dataset/CVACT/streetview/acRP98sACUIlwl2ZIsEyiQ_grdView.png',
+ 'dataset/CVACT/streetview/WSh9tNVryLdupUlU0ri2tQ_grdView.png',
+ 'dataset/CVACT/streetview/FhEuB9NA5o08VJ_TBCbHjw_grdView.png',
+ 'dataset/CVACT/streetview/YHfpn2Mgu1lqgT2OUeBpOg_grdView.png',
+ 'dataset/CVACT/streetview/vNhv7ZP1dUkJ93UwFXagJw_grdView.png',
+ ]
+ self.pano_list = demo_name
+
+ self.opt = opt
+
+ def __len__(self):
+ return len(self.pano_list)
+
+ def __getitem__(self, index):
+ pano = self.pano_list[index]
+ aer = pano.replace('streetview','satview_correct').replace('_grdView','_satView_polish')
+ if self.opt.data.sky_mask:
+ sky = pano.replace('streetview','pano_sky_mask')
+ name = pano
+ aer = img_read(aer, size = self.opt.data.sat_size)
+ pano = img_read(pano,size = self.opt.data.pano_size)
+ if self.opt.data.sky_mask:
+ sky = img_read(sky,size=self.opt.data.pano_size,datatype='L')
+
+ input = {}
+ input['sat']=aer
+ input['pano']=pano
+ input['paths']=name
+ if self.opt.data.sky_mask:
+ input['sky_mask']=sky
+ black_ground = torch.zeros_like(pano)
+ if self.opt.data.histo_mode =='grey':
+ input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:]
+ elif self.opt.data.histo_mode in ['rgb','RGB']:
+ input_a = (pano*sky+black_ground*(1-sky))
+ for idx in range(len(input_a)):
+ if idx == 0:
+ sky_histc = input_a[idx].histc()[10:]
+ else:
+ sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0)
+ input['sky_histc'] = sky_histc
+ return input
+
diff --git a/data/CVUSA.py b/data/CVUSA.py
new file mode 100644
index 0000000000000000000000000000000000000000..2179427bf165335df33ab0761ea28975f7672a45
--- /dev/null
+++ b/data/CVUSA.py
@@ -0,0 +1,86 @@
+import torch,os
+from torch.utils.data.dataset import Dataset
+from PIL import Image
+import torchvision.transforms as transforms
+import re
+from easydict import EasyDict as edict
+
+def data_list(img_root,mode):
+ data_list=[]
+ if mode=='train':
+ split_file=os.path.join(img_root, 'splits/train-19zl.csv')
+ with open(split_file) as f:
+ list = f.readlines()
+ for i in list:
+ aerial_name=re.split(r',', re.split('\n', i)[0])[0]
+ panorama_name = re.split(r',', re.split('\n', i)[0])[1]
+ data_list.append([aerial_name, panorama_name])
+ else:
+ split_file=os.path.join(img_root+'splits/val-19zl.csv')
+ with open(split_file) as f:
+ list = f.readlines()
+ for i in list:
+ aerial_name=re.split(r',', re.split('\n', i)[0])[0]
+ panorama_name = re.split(r',', re.split('\n', i)[0])[1]
+ data_list.append([aerial_name, panorama_name])
+ print('length of dataset is: ', len(data_list))
+ return [os.path.join(img_root, i[1]) for i in data_list]
+
+def img_read(img,size=None,datatype='RGB'):
+ img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
+ if size:
+ if type(size) is int:
+ size = (size,size)
+ img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
+ img = transforms.ToTensor()(img)
+ return img
+
+
+class Dataset(Dataset):
+ def __init__(self, opt,split='train',sub=None,sty_img=None):
+ self.pano_list = data_list(img_root=opt.data.root,mode=split)
+ if sub:
+ self.pano_list = self.pano_list[:sub]
+ if opt.task == 'test_vid':
+ demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.demo_img)
+ self.pano_list = [demo_img_path]
+ if sty_img:
+ assert opt.sty_img.split('.')[-1] == 'jpg'
+ demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.sty_img)
+ self.pano_list = [demo_img_path]
+
+ self.opt = opt
+
+ def __len__(self):
+ return len(self.pano_list)
+
+ def __getitem__(self, index):
+ pano = self.pano_list[index]
+ aer = pano.replace('streetview/panos', 'bingmap/19')
+ if self.opt.data.sky_mask:
+ sky = pano.replace('streetview/panos','sky_mask').replace('jpg', 'png')
+ name = pano
+ aer = img_read(aer, size = self.opt.data.sat_size)
+ pano = img_read(pano,size = self.opt.data.pano_size)
+ if self.opt.data.sky_mask:
+ sky = img_read(sky,size=self.opt.data.pano_size,datatype='L')
+
+ input = {}
+ input['sat']=aer
+ input['pano']=pano
+ input['paths']=name
+ if self.opt.data.sky_mask:
+ input['sky_mask']=sky
+ black_ground = torch.zeros_like(pano)
+ if self.opt.data.histo_mode =='grey':
+ input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:]
+ elif self.opt.data.histo_mode in ['rgb','RGB']:
+ input_a = (pano*sky+black_ground*(1-sky))
+ for idx in range(len(input_a)):
+ if idx == 0:
+ sky_histc = input_a[idx].histc()[10:]
+ else:
+ sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0)
+ input['sky_histc'] = sky_histc
+ return input
+
diff --git a/dataset/INSTALL.md b/dataset/INSTALL.md
new file mode 100644
index 0000000000000000000000000000000000000000..1d0784b008b078f757c910b16b08edc08a7ed6f8
--- /dev/null
+++ b/dataset/INSTALL.md
@@ -0,0 +1,32 @@
+For reproduce our paper,
+
+you should first download 4 zip file:
+
+`
+CVACT/satview_correct.zip ,
+CVACT/streetview.zip ,
+CVUSA/bingmap/19.zip ,
+CVUSA/streetview/panos.zip
+`
+ from [here](https://anu365-my.sharepoint.com/:f:/g/personal/u6293587_anu_edu_au/EuOBUDUQNClJvCpQ8bD1hnoBjdRBWxsHOVp946YVahiMGg?e=F4yRAC), the project page is [Sat2StrPanoramaSynthesis](https://github.com/shiyujiao/Sat2StrPanoramaSynthesis).
+
+Then download the sky mask from [here](https://drive.google.com/drive/folders/1pfzwONg4P-Mzvxvzb2HoCpuZFynElPCk?usp=sharing)
+
+Last,the users should organize the dataset just like:
+```
+├dataset
+├── CVACT
+│ ├── streetview
+│ ├── satview_correct
+│ ├── pano_sky_mask
+│ ├── ACT_data.mat
+└── CVUSA
+│ ├── bingmap
+│ │ ├── 19
+│ └── streetview
+│ │ ├── panos
+│ ├── sky_mask
+│ ├── splits
+```
+
+Tip: The sky masks are processed with [Trans4PASS](https://github.com/jamycheung/Trans4PASS).
diff --git a/demo_img/case1/groundview.image.png b/demo_img/case1/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..66f6050ec9e635a04e58242826af667f894adde0
Binary files /dev/null and b/demo_img/case1/groundview.image.png differ
diff --git a/demo_img/case1/groundview.sky.png b/demo_img/case1/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..413a82e27fd781030a1314513e859610ab48f652
Binary files /dev/null and b/demo_img/case1/groundview.sky.png differ
diff --git a/demo_img/case1/satview-input.png b/demo_img/case1/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..375575a45a1d1b33a87c0133e9237d0b89af9f3e
Binary files /dev/null and b/demo_img/case1/satview-input.png differ
diff --git a/demo_img/case10/groundview.image.png b/demo_img/case10/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..66f6050ec9e635a04e58242826af667f894adde0
Binary files /dev/null and b/demo_img/case10/groundview.image.png differ
diff --git a/demo_img/case10/groundview.sky.png b/demo_img/case10/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..5252275f364573f59f2f3b57a177327c297f523a
Binary files /dev/null and b/demo_img/case10/groundview.sky.png differ
diff --git a/demo_img/case10/satview-input.png b/demo_img/case10/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..efe5929d06fb9be3eeda5deb84733614ce571b80
Binary files /dev/null and b/demo_img/case10/satview-input.png differ
diff --git a/demo_img/case11/groundview.image.png b/demo_img/case11/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..a4c7ad7b5281dd57dc161173699a570592978f6a
Binary files /dev/null and b/demo_img/case11/groundview.image.png differ
diff --git a/demo_img/case11/groundview.sky.png b/demo_img/case11/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..31d697c17321a98ad5838420419a2c8028da20b4
Binary files /dev/null and b/demo_img/case11/groundview.sky.png differ
diff --git a/demo_img/case11/satview-input.png b/demo_img/case11/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..78c725dc36987de0b11c1f044294deeb86d7e20c
Binary files /dev/null and b/demo_img/case11/satview-input.png differ
diff --git a/demo_img/case12/groundview.image.png b/demo_img/case12/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e74053765745892dc5847955d14fc8210efcc94
Binary files /dev/null and b/demo_img/case12/groundview.image.png differ
diff --git a/demo_img/case12/groundview.sky.png b/demo_img/case12/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..70c7652da7c488dccc1393b3abf7f916e00662bb
Binary files /dev/null and b/demo_img/case12/groundview.sky.png differ
diff --git a/demo_img/case12/satview-input.png b/demo_img/case12/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..28858ab35e3c67162cf717e297867e44f545c2ab
Binary files /dev/null and b/demo_img/case12/satview-input.png differ
diff --git a/demo_img/case13/groundview.image.png b/demo_img/case13/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..9fd2f5da5ba5e7947074df32bb723d2631115cf5
Binary files /dev/null and b/demo_img/case13/groundview.image.png differ
diff --git a/demo_img/case13/groundview.sky.png b/demo_img/case13/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..b0c9f795448f2c9a4cb4253c4036a31f66356ae2
Binary files /dev/null and b/demo_img/case13/groundview.sky.png differ
diff --git a/demo_img/case13/satview-input.png b/demo_img/case13/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..53b78fd4a1252d9c5f730ce27bd485810ced4db1
Binary files /dev/null and b/demo_img/case13/satview-input.png differ
diff --git a/demo_img/case2/groundview.image.png b/demo_img/case2/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..9fd2f5da5ba5e7947074df32bb723d2631115cf5
Binary files /dev/null and b/demo_img/case2/groundview.image.png differ
diff --git a/demo_img/case2/groundview.sky.png b/demo_img/case2/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..b0c9f795448f2c9a4cb4253c4036a31f66356ae2
Binary files /dev/null and b/demo_img/case2/groundview.sky.png differ
diff --git a/demo_img/case2/satview-input.png b/demo_img/case2/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..53b78fd4a1252d9c5f730ce27bd485810ced4db1
Binary files /dev/null and b/demo_img/case2/satview-input.png differ
diff --git a/demo_img/case3/groundview.image.png b/demo_img/case3/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..ee07e66c811c49961b5b4cf83e2bde17c1c4ab96
Binary files /dev/null and b/demo_img/case3/groundview.image.png differ
diff --git a/demo_img/case3/groundview.sky.png b/demo_img/case3/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..bf1580c092c0e7b80f8199d845f71acc8bf23ee3
Binary files /dev/null and b/demo_img/case3/groundview.sky.png differ
diff --git a/demo_img/case3/satview-input.png b/demo_img/case3/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..d270d4f7ae9a135b0df74e35c7f5ba5440bd07cd
Binary files /dev/null and b/demo_img/case3/satview-input.png differ
diff --git a/demo_img/case4/groundview.image.png b/demo_img/case4/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..bdd86213b5efdf40e457a5fbead6d2592a40dc0d
Binary files /dev/null and b/demo_img/case4/groundview.image.png differ
diff --git a/demo_img/case4/groundview.sky.png b/demo_img/case4/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..996b7120fbb55c5987c9b58ba1ccf28d16747960
Binary files /dev/null and b/demo_img/case4/groundview.sky.png differ
diff --git a/demo_img/case4/satview-input.png b/demo_img/case4/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..e16179fe548be5115eb5906ae7408c56fdae62f8
Binary files /dev/null and b/demo_img/case4/satview-input.png differ
diff --git a/demo_img/case5/groundview.image.png b/demo_img/case5/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..a9266ebdaead7935d8ea0f834cf33e351da0dd5a
Binary files /dev/null and b/demo_img/case5/groundview.image.png differ
diff --git a/demo_img/case5/groundview.sky.png b/demo_img/case5/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..8598b9e878bcbc54d1b5eaa90e729d2dd77308b0
Binary files /dev/null and b/demo_img/case5/groundview.sky.png differ
diff --git a/demo_img/case5/satview-input.png b/demo_img/case5/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c12242298ef0f2721ec629e780dc75d3af8e810
Binary files /dev/null and b/demo_img/case5/satview-input.png differ
diff --git a/demo_img/case6/groundview.image.png b/demo_img/case6/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..71aab5b4aeca63f04b8d0f4eb8e29f5075612c48
Binary files /dev/null and b/demo_img/case6/groundview.image.png differ
diff --git a/demo_img/case6/groundview.sky.png b/demo_img/case6/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d00e485d6cf4ce46aba8274c527598c54e9b970
Binary files /dev/null and b/demo_img/case6/groundview.sky.png differ
diff --git a/demo_img/case6/satview-input.png b/demo_img/case6/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0837bce0790d1795f0079f8c0c76b53a3fa36b9
Binary files /dev/null and b/demo_img/case6/satview-input.png differ
diff --git a/demo_img/case7/groundview.image.png b/demo_img/case7/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..fcbfd8e23a6411d0cbbac43c15efcea4d6a6dd20
Binary files /dev/null and b/demo_img/case7/groundview.image.png differ
diff --git a/demo_img/case7/groundview.sky.png b/demo_img/case7/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..60ce35217b3faf8894d9cbd3e8ce58c77c29d28b
Binary files /dev/null and b/demo_img/case7/groundview.sky.png differ
diff --git a/demo_img/case7/satview-input.png b/demo_img/case7/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..708349179ddfdbeff7fae9ce516a5e49323bc2f8
Binary files /dev/null and b/demo_img/case7/satview-input.png differ
diff --git a/demo_img/case8/groundview.image.png b/demo_img/case8/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..18d52022a5becfb48728b8ea23679dffeee870ea
Binary files /dev/null and b/demo_img/case8/groundview.image.png differ
diff --git a/demo_img/case8/groundview.sky.png b/demo_img/case8/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..94731c15c751e9df883806ac0d43fbfedb270ab7
Binary files /dev/null and b/demo_img/case8/groundview.sky.png differ
diff --git a/demo_img/case8/satview-input.png b/demo_img/case8/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..981ce9f97d7d182ff441a3f925fb1d03598f9ae2
Binary files /dev/null and b/demo_img/case8/satview-input.png differ
diff --git a/demo_img/case9/groundview.image.png b/demo_img/case9/groundview.image.png
new file mode 100644
index 0000000000000000000000000000000000000000..0084c60d4618d184db5e0b383403a20ce1644555
Binary files /dev/null and b/demo_img/case9/groundview.image.png differ
diff --git a/demo_img/case9/groundview.sky.png b/demo_img/case9/groundview.sky.png
new file mode 100644
index 0000000000000000000000000000000000000000..649db36fe97c73ea18f7712941de6c6b0f7bad19
Binary files /dev/null and b/demo_img/case9/groundview.sky.png differ
diff --git a/demo_img/case9/satview-input.png b/demo_img/case9/satview-input.png
new file mode 100644
index 0000000000000000000000000000000000000000..18d0484df8b32f566532c6345abb98eec6e9195e
Binary files /dev/null and b/demo_img/case9/satview-input.png differ
diff --git a/demo_img/runall.sh b/demo_img/runall.sh
new file mode 100644
index 0000000000000000000000000000000000000000..975924bab8dab46fb3e42c0454985aa93afbdb1e
--- /dev/null
+++ b/demo_img/runall.sh
@@ -0,0 +1,30 @@
+# for case in `ls -d demo_img/case*`
+for case_id in 1 2 3 4
+do
+ case=demo_img/case$case_id
+ echo $case
+ python test.py --yaml=sat2density_cvact \
+ --test_ckpt_path=2u87bj8w \
+ --task=test_vid \
+ --demo_img=$case/satview-input.png \
+ --sty_img=$case/groundview.image.png \
+ --save_dir=results/$case
+ # ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png results/$case/render.gif
+ ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png -vf "palettegen" results/$case-palette.png
+ ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png -i results/$case-palette.png -filter_complex "paletteuse" results/$case/render.gif
+
+ ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png -vf "palettegen" results/$case-palette.png
+ ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png -i results/$case-palette.png -filter_complex "paletteuse" results/$case/sat.gif
+ # ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png results/$case/sat.gif
+done
+
+# for case in `ls -d demo_img/case*`
+for case_id in 1 2 3 4
+do
+ case=demo_img/case$case_id
+ sat_gif=results/$case/sat.gif
+ render_gif=results/$case/render.gif
+ # echo $sat_gif
+ cp $sat_gif docs/figures/demo/case$case_id.sat.gif
+ cp $render_gif docs/figures/demo/case$case_id.render.gif
+done
\ No newline at end of file
diff --git a/imaginaire/__init__.py b/imaginaire/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/__pycache__/__init__.cpython-38.pyc b/imaginaire/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e39ab8b1c6c47a307883db29f98f32adce2f4ea6
Binary files /dev/null and b/imaginaire/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/config.py b/imaginaire/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a728a5aaee8d040288ff9ffd17a4fa83a7e2ca7
--- /dev/null
+++ b/imaginaire/config.py
@@ -0,0 +1,238 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""Config utilities for yml file."""
+
+import collections
+import functools
+import os
+import re
+
+import yaml
+from imaginaire.utils.distributed import master_only_print as print
+
+DEBUG = False
+USE_JIT = False
+
+
+class AttrDict(dict):
+ """Dict as attribute trick."""
+
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+ for key, value in self.__dict__.items():
+ if isinstance(value, dict):
+ self.__dict__[key] = AttrDict(value)
+ elif isinstance(value, (list, tuple)):
+ if isinstance(value[0], dict):
+ self.__dict__[key] = [AttrDict(item) for item in value]
+ else:
+ self.__dict__[key] = value
+
+ def yaml(self):
+ """Convert object to yaml dict and return."""
+ yaml_dict = {}
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttrDict):
+ yaml_dict[key] = value.yaml()
+ elif isinstance(value, list):
+ if isinstance(value[0], AttrDict):
+ new_l = []
+ for item in value:
+ new_l.append(item.yaml())
+ yaml_dict[key] = new_l
+ else:
+ yaml_dict[key] = value
+ else:
+ yaml_dict[key] = value
+ return yaml_dict
+
+ def __repr__(self):
+ """Print all variables."""
+ ret_str = []
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttrDict):
+ ret_str.append('{}:'.format(key))
+ child_ret_str = value.__repr__().split('\n')
+ for item in child_ret_str:
+ ret_str.append(' ' + item)
+ elif isinstance(value, list):
+ if isinstance(value[0], AttrDict):
+ ret_str.append('{}:'.format(key))
+ for item in value:
+ # Treat as AttrDict above.
+ child_ret_str = item.__repr__().split('\n')
+ for item in child_ret_str:
+ ret_str.append(' ' + item)
+ else:
+ ret_str.append('{}: {}'.format(key, value))
+ else:
+ ret_str.append('{}: {}'.format(key, value))
+ return '\n'.join(ret_str)
+
+
+class Config(AttrDict):
+ r"""Configuration class. This should include every human specifiable
+ hyperparameter values for your training."""
+
+ def __init__(self, filename=None, verbose=False):
+ super(Config, self).__init__()
+ self.source_filename = filename
+ # Set default parameters.
+ # Logging.
+ large_number = 1000000000
+ self.snapshot_save_iter = large_number
+ self.snapshot_save_epoch = large_number
+ self.metrics_iter = None
+ self.metrics_epoch = None
+ self.snapshot_save_start_iter = 0
+ self.snapshot_save_start_epoch = 0
+ self.image_save_iter = large_number
+ self.image_display_iter = large_number
+ self.max_epoch = large_number
+ self.max_iter = large_number
+ self.logging_iter = 100
+ self.speed_benchmark = False
+
+ # Trainer.
+ self.trainer = AttrDict(
+ model_average_config=AttrDict(enabled=False,
+ beta=0.9999,
+ start_iteration=1000,
+ num_batch_norm_estimation_iterations=30,
+ remove_sn=True),
+ # model_average=False,
+ # model_average_beta=0.9999,
+ # model_average_start_iteration=1000,
+ # model_average_batch_norm_estimation_iteration=30,
+ # model_average_remove_sn=True,
+ image_to_tensorboard=False,
+ hparam_to_tensorboard=False,
+ distributed_data_parallel='pytorch',
+ distributed_data_parallel_params=AttrDict(
+ find_unused_parameters=False),
+ delay_allreduce=True,
+ gan_relativistic=False,
+ gen_step=1,
+ dis_step=1,
+ gan_decay_k=1.,
+ gan_min_k=1.,
+ gan_separate_topk=False,
+ aug_policy='',
+ channels_last=False,
+ strict_resume=True,
+ amp_gp=False,
+ amp_config=AttrDict(init_scale=65536.0,
+ growth_factor=2.0,
+ backoff_factor=0.5,
+ growth_interval=2000,
+ enabled=False))
+
+ # Networks.
+ self.gen = AttrDict(type='imaginaire.generators.dummy')
+ self.dis = AttrDict(type='imaginaire.discriminators.dummy')
+
+ # Optimizers.
+ self.gen_opt = AttrDict(type='adam',
+ fused_opt=False,
+ lr=0.0001,
+ adam_beta1=0.0,
+ adam_beta2=0.999,
+ eps=1e-8,
+ lr_policy=AttrDict(iteration_mode=False,
+ type='step',
+ step_size=large_number,
+ gamma=1))
+ self.dis_opt = AttrDict(type='adam',
+ fused_opt=False,
+ lr=0.0001,
+ adam_beta1=0.0,
+ adam_beta2=0.999,
+ eps=1e-8,
+ lr_policy=AttrDict(iteration_mode=False,
+ type='step',
+ step_size=large_number,
+ gamma=1))
+ # Data.
+ self.data = AttrDict(name='dummy',
+ type='imaginaire.datasets.images',
+ num_workers=0)
+ self.test_data = AttrDict(name='dummy',
+ type='imaginaire.datasets.images',
+ num_workers=0,
+ test=AttrDict(is_lmdb=False,
+ roots='',
+ batch_size=1))
+
+
+# Cudnn.
+ self.cudnn = AttrDict(deterministic=False,
+ benchmark=True)
+
+ # Others.
+ self.pretrained_weight = ''
+ self.inference_args = AttrDict()
+
+ # Update with given configurations.
+ assert os.path.exists(filename), 'File {} not exist.'.format(filename)
+ loader = yaml.SafeLoader
+ loader.add_implicit_resolver(
+ u'tag:yaml.org,2002:float',
+ re.compile(u'''^(?:
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
+ |[-+]?\\.(?:inf|Inf|INF)
+ |\\.(?:nan|NaN|NAN))$''', re.X),
+ list(u'-+0123456789.'))
+ try:
+ with open(filename, 'r') as f:
+ cfg_dict = yaml.load(f, Loader=loader)
+ except EnvironmentError:
+ print('Please check the file with name of "%s"', filename)
+ recursive_update(self, cfg_dict)
+
+ # Put common opts in both gen and dis.
+ if 'common' in cfg_dict:
+ self.common = AttrDict(**cfg_dict['common'])
+ self.gen.common = self.common
+ self.dis.common = self.common
+
+ if verbose:
+ print(' imaginaire config '.center(80, '-'))
+ print(self.__repr__())
+ print(''.center(80, '-'))
+
+
+def rsetattr(obj, attr, val):
+ """Recursively find object and set value"""
+ pre, _, post = attr.rpartition('.')
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
+
+
+def rgetattr(obj, attr, *args):
+ """Recursively find object and return value"""
+
+ def _getattr(obj, attr):
+ r"""Get attribute."""
+ return getattr(obj, attr, *args)
+
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
+
+
+def recursive_update(d, u):
+ """Recursively update AttrDict d with AttrDict u"""
+ for key, value in u.items():
+ if isinstance(value, collections.abc.Mapping):
+ d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
+ elif isinstance(value, (list, tuple)):
+ if isinstance(value[0], dict):
+ d.__dict__[key] = [AttrDict(item) for item in value]
+ else:
+ d.__dict__[key] = value
+ else:
+ d.__dict__[key] = value
+ return d
diff --git a/imaginaire/datasets/__init__.py b/imaginaire/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/datasets/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/datasets/base.py b/imaginaire/datasets/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9192f78d4c3cfb65ac73751b632f897893d8288
--- /dev/null
+++ b/imaginaire/datasets/base.py
@@ -0,0 +1,596 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""All datasets are inherited from this class."""
+
+import importlib
+import json
+import os
+import pickle
+from collections import OrderedDict
+from functools import partial
+from inspect import signature
+
+import numpy as np
+import torch
+import torch.utils.data as data
+import torchvision.transforms as transforms
+
+from imaginaire.datasets.folder import FolderDataset
+from imaginaire.datasets.lmdb import \
+ IMG_EXTENSIONS, HDR_IMG_EXTENSIONS, LMDBDataset
+from imaginaire.datasets.object_store import ObjectStoreDataset
+from imaginaire.utils.data import \
+ (VIDEO_EXTENSIONS, Augmentor,
+ load_from_folder, load_from_lmdb, load_from_object_store)
+from imaginaire.utils.lmdb import create_metadata
+
+
+DATASET_TYPES = ['lmdb', 'folder', 'object_store']
+
+
+class BaseDataset(data.Dataset):
+ r"""Base class for image/video datasets.
+
+ Args:
+ cfg (Config object): Input config.
+ is_inference (bool): Training if False, else validation.
+ is_test (bool): Final test set after training and validation.
+ """
+
+ def __init__(self, cfg, is_inference, is_test):
+ super(BaseDataset, self).__init__()
+
+ self.cfg = cfg
+ self.is_inference = is_inference
+ self.is_test = is_test
+ if self.is_test:
+ self.cfgdata = self.cfg.test_data
+ data_info = self.cfgdata.test
+ else:
+ self.cfgdata = self.cfg.data
+ if self.is_inference:
+ data_info = self.cfgdata.val
+ else:
+ data_info = self.cfgdata.train
+ self.name = self.cfgdata.name
+ self.lmdb_roots = data_info.roots
+ self.dataset_type = getattr(data_info, 'dataset_type', None)
+ self.cache = getattr(self.cfgdata, 'cache', None)
+ self.interpolator = getattr(self.cfgdata, 'interpolator', "INTER_LINEAR")
+
+ # Get AWS secret keys.
+ if self.dataset_type == 'object_store':
+ assert hasattr(cfg, 'aws_credentials_file')
+ self.aws_credentials_file = cfg.aws_credentials_file
+
+ # Legacy lmdb/folder only support.
+ if self.dataset_type is None:
+ self.dataset_is_lmdb = getattr(data_info, 'is_lmdb', False)
+ if self.dataset_is_lmdb:
+ self.dataset_type = 'lmdb'
+ else:
+ self.dataset_type = 'folder'
+ # Legacy support ends.
+
+ assert self.dataset_type in DATASET_TYPES
+ if self.dataset_type == 'lmdb':
+ # Add handle to function to load data from LMDB.
+ self.load_from_dataset = load_from_lmdb
+ elif self.dataset_type == 'folder':
+ # For some unpaired experiments, we would like the dataset to be presented in a paired way
+
+ if hasattr(self.cfgdata, 'paired') is False:
+ self.cfgdata.paired = self.paired
+ # Add handle to function to load data from folder.
+ self.load_from_dataset = load_from_folder
+ # Create metadata for folders.
+ print('Creating metadata')
+ all_filenames, all_metadata = [], []
+ if self.is_test:
+ cfg.data_backup = cfg.data
+ cfg.data = cfg.test_data
+ for root in self.lmdb_roots:
+ filenames, metadata = create_metadata(
+ data_root=root, cfg=cfg, paired=self.cfgdata['paired'])
+ all_filenames.append(filenames)
+ all_metadata.append(metadata)
+ if self.is_test:
+ cfg.data = cfg.data_backup
+ elif self.dataset_type == 'object_store':
+ # Add handle to function to load data from AWS S3.
+ self.load_from_dataset = load_from_object_store
+
+ # Get the types of data stored in dataset, and their extensions.
+ self.data_types = [] # Names of data types.
+ self.dataset_data_types = [] # These data types are in the dataset.
+ self.image_data_types = [] # These types are images.
+ self.hdr_image_data_types = [] # These types are HDR images.
+ self.normalize = {} # Does this data type need normalization?
+ self.extensions = {} # What is this data type's file extension.
+ self.is_mask = {} # Whether this data type is discrete masks?
+ self.num_channels = {} # How many channels does this data type have?
+ self.pre_aug_ops = {} # Ops on data type before augmentation.
+ self.post_aug_ops = {} # Ops on data type after augmentation.
+
+ # Extract info from data types.
+ for data_type in self.cfgdata.input_types:
+ name = list(data_type.keys())
+ assert len(name) == 1
+ name = name[0]
+ info = data_type[name]
+
+ if 'ext' not in info:
+ info['ext'] = None
+ if 'normalize' not in info:
+ info['normalize'] = False
+ if 'is_mask' not in info:
+ info['is_mask'] = False
+ if 'pre_aug_ops' not in info:
+ info['pre_aug_ops'] = 'None'
+ if 'post_aug_ops' not in info:
+ info['post_aug_ops'] = 'None'
+ if 'computed_on_the_fly' not in info:
+ info['computed_on_the_fly'] = False
+ if 'num_channels' not in info:
+ info['num_channels'] = None
+
+ self.data_types.append(name)
+ if not info['computed_on_the_fly']:
+ self.dataset_data_types.append(name)
+
+ self.extensions[name] = info['ext']
+ self.normalize[name] = info['normalize']
+ self.num_channels[name] = info['num_channels']
+ self.pre_aug_ops[name] = [op.strip() for op in
+ info['pre_aug_ops'].split(',')]
+ self.post_aug_ops[name] = [op.strip() for op in
+ info['post_aug_ops'].split(',')]
+ self.is_mask[name] = info['is_mask']
+ if info['ext'] is not None and (info['ext'] in IMG_EXTENSIONS or info['ext'] in VIDEO_EXTENSIONS):
+ self.image_data_types.append(name)
+ if info['ext'] is not None and info['ext'] in HDR_IMG_EXTENSIONS:
+ self.hdr_image_data_types.append(name)
+
+ # Add some info into cfgdata for legacy support.
+ self.cfgdata.data_types = self.data_types
+ self.cfgdata.num_channels = [self.num_channels[name]
+ for name in self.data_types]
+
+ # Augmentations which need full dict.
+ self.full_data_post_aug_ops, self.full_data_ops = [], []
+ if hasattr(self.cfgdata, 'full_data_ops'):
+ ops = self.cfgdata.full_data_ops
+ self.full_data_ops.extend([op.strip() for op in ops.split(',')])
+ if hasattr(self.cfgdata, 'full_data_post_aug_ops'):
+ ops = self.cfgdata.full_data_post_aug_ops
+ self.full_data_post_aug_ops.extend(
+ [op.strip() for op in ops.split(',')])
+
+ # These are the labels which will be concatenated for generator input.
+ self.input_labels = []
+ if hasattr(self.cfgdata, 'input_labels'):
+ self.input_labels = self.cfgdata.input_labels
+
+ # These are the keypoints which also need to be augmented.
+ self.keypoint_data_types = []
+ if hasattr(self.cfgdata, 'keypoint_data_types'):
+ self.keypoint_data_types = self.cfgdata.keypoint_data_types
+
+ # Create augmentation operations.
+ aug_list = data_info.augmentations
+ individual_video_frame_aug_list = getattr(data_info, 'individual_video_frame_augmentations', dict())
+ self.augmentor = Augmentor(
+ aug_list, individual_video_frame_aug_list, self.image_data_types, self.is_mask,
+ self.keypoint_data_types, self.interpolator)
+ self.augmentable_types = self.image_data_types + \
+ self.keypoint_data_types
+
+ # Create torch transformations.
+ self.transform = {}
+ for data_type in self.image_data_types:
+ normalize = self.normalize[data_type]
+ self.transform[data_type] = self._get_transform(
+ normalize, self.num_channels[data_type])
+
+ # Create torch transformations for HDR images.
+ for data_type in self.hdr_image_data_types:
+ normalize = self.normalize[data_type]
+ self.transform[data_type] = self._get_transform(
+ normalize, self.num_channels[data_type])
+
+ # Initialize handles.
+ self.sequence_lists = [] # List of sequences per dataset root.
+ self.lmdbs = {} # Dict for list of lmdb handles per data type.
+ for data_type in self.dataset_data_types:
+ self.lmdbs[data_type] = []
+ self.dataset_probability = None
+ self.additional_lists = []
+
+ # Load each dataset.
+ for idx, root in enumerate(self.lmdb_roots):
+ if self.dataset_type == 'lmdb':
+ self._add_dataset(root)
+ elif self.dataset_type == 'folder':
+ self._add_dataset(root, filenames=all_filenames[idx],
+ metadata=all_metadata[idx])
+ elif self.dataset_type == 'object_store':
+ self._add_dataset(
+ root, aws_credentials_file=self.aws_credentials_file)
+
+ # Compute dataset statistics and create whatever self.variables required
+ # for the specific dataloader.
+ self._compute_dataset_stats()
+
+ # Build index of data to sample.
+ self.mapping, self.epoch_length = self._create_mapping()
+
+ def _create_mapping(self):
+ r"""Creates mapping from data sample idx to actual LMDB keys.
+ All children need to implement their own.
+
+ Returns:
+ self.mapping (list): List of LMDB keys.
+ """
+ raise NotImplementedError
+
+ def _compute_dataset_stats(self):
+ r"""Computes required statistics about dataset.
+ All children need to implement their own.
+ """
+ pass
+
+ def __getitem__(self, index):
+ r"""Entry function for dataset."""
+ raise NotImplementedError
+
+ def _get_transform(self, normalize, num_channels):
+ r"""Convert numpy to torch tensor.
+
+ Args:
+ normalize (bool): Normalize image i.e. (x - 0.5) * 2.
+ Goes from [0, 1] -> [-1, 1].
+ Returns:
+ Composed list of torch transforms.
+ """
+ transform_list = [transforms.ToTensor()]
+ if normalize:
+ transform_list.append(
+ transforms.Normalize((0.5, ) * num_channels,
+ (0.5, ) * num_channels, inplace=True))
+ return transforms.Compose(transform_list)
+
+ def _add_dataset(self, root, filenames=None, metadata=None,
+ aws_credentials_file=None):
+ r"""Adds an LMDB dataset to a list of datasets.
+
+ Args:
+ root (str): Path to LMDB or folder dataset.
+ filenames: List of filenames for folder dataset.
+ metadata: Metadata for folder dataset.
+ aws_credentials_file: Path to file containing AWS credentials.
+ """
+ if aws_credentials_file and self.dataset_type == 'object_store':
+ object_store_dataset = ObjectStoreDataset(
+ root, aws_credentials_file, cache=self.cache)
+ sequence_list = object_store_dataset.sequence_list
+ else:
+ # Get sequences associated with this dataset.
+ if filenames is None:
+ list_path = 'all_filenames.json'
+ with open(os.path.join(root, list_path)) as fin:
+ sequence_list = OrderedDict(json.load(fin))
+ else:
+ sequence_list = filenames
+
+ additional_path = 'all_indices.json'
+ if os.path.exists(os.path.join(root, additional_path)):
+ print('Using additional list for object indices.')
+ with open(os.path.join(root, additional_path)) as fin:
+ additional_list = OrderedDict(json.load(fin))
+ self.additional_lists.append(additional_list)
+ self.sequence_lists.append(sequence_list)
+
+ # Get LMDB dataset handles.
+ for data_type in self.dataset_data_types:
+ if self.dataset_type == 'lmdb':
+ self.lmdbs[data_type].append(
+ LMDBDataset(os.path.join(root, data_type)))
+ elif self.dataset_type == 'folder':
+ self.lmdbs[data_type].append(
+ FolderDataset(os.path.join(root, data_type), metadata))
+ elif self.dataset_type == 'object_store':
+ # All data types use the same handle.
+ self.lmdbs[data_type].append(object_store_dataset)
+
+ def perform_individual_video_frame(self, data, augment_ops):
+ r"""Perform data augmentation on images only.
+
+ Args:
+ data (dict): Keys are from data types. Values can be numpy.ndarray
+ or list of numpy.ndarray (image or list of images).
+ augment_ops (list): The augmentation operations for individual frames.
+ Returns:
+ (tuple):
+ - data (dict): Augmented data, with same keys as input data.
+ - is_flipped (bool): Flag which tells if images have been
+ left-right flipped.
+ """
+ if augment_ops:
+ all_data = dict()
+ for ix, key in enumerate(data.keys()):
+ if ix == 0:
+ num = len(data[key])
+ for j in range(num):
+ all_data['%d' % j] = dict()
+ for j in range(num):
+ all_data['%d' % j][key] = data[key][j:(j+1)]
+ for j in range(num):
+ all_data['%d' % j], _ = self.perform_augmentation(
+ all_data['%d' % j], paired=True, augment_ops=augment_ops)
+ for key in data.keys():
+ tmp = []
+ for j in range(num):
+ tmp += all_data['%d' % j][key]
+ data[key] = tmp
+ return data
+
+ def perform_augmentation(self, data, paired, augment_ops=None):
+ r"""Perform data augmentation on images only.
+
+ Args:
+ data (dict): Keys are from data types. Values can be numpy.ndarray
+ or list of numpy.ndarray (image or list of images).
+ paired (bool): Apply same augmentation to all input keys?
+ augment_ops (list): The augmentation operations.
+ Returns:
+ (tuple):
+ - data (dict): Augmented data, with same keys as input data.
+ - is_flipped (bool): Flag which tells if images have been
+ left-right flipped.
+ """
+ aug_inputs = {}
+ for data_type in self.augmentable_types:
+ aug_inputs[data_type] = data[data_type]
+
+ augmented, is_flipped = self.augmentor.perform_augmentation(
+ aug_inputs, paired=paired, augment_ops=augment_ops)
+
+ for data_type in self.augmentable_types:
+ data[data_type] = augmented[data_type]
+
+ return data, is_flipped
+
+ def flip_hdr(self, data, is_flipped=False):
+ r"""Flip hdr images.
+
+ Args:
+ data (dict): Keys are from data types. Values can be numpy.ndarray
+ or list of numpy.ndarray (image or list of images).
+ is_flipped (bool): Applying left-right flip to the hdr images
+ Returns:
+ (tuple):
+ - data (dict): Augmented data, with same keys as input data.
+ """
+ if is_flipped is False:
+ return data
+
+ for data_type in self.hdr_image_data_types:
+ # print('Length of data: {}'.format(len(data[data_type])))
+ data[data_type][0] = data[data_type][0][:, ::-1, :].copy()
+ return data
+
+ def to_tensor(self, data):
+ r"""Convert all images to tensor.
+
+ Args:
+ data (dict): Dict containing data_type as key, with each value
+ as a list of numpy.ndarrays.
+ Returns:
+ data (dict): Dict containing data_type as key, with each value
+ as a list of torch.Tensors.
+ """
+ for data_type in self.image_data_types:
+ for idx in range(len(data[data_type])):
+ if data[data_type][idx].dtype == np.uint16:
+ data[data_type][idx] = data[data_type][idx].astype(
+ np.float32)
+ data[data_type][idx] = self.transform[data_type](
+ data[data_type][idx])
+ for data_type in self.hdr_image_data_types:
+ for idx in range(len(data[data_type])):
+ data[data_type][idx] = self.transform[data_type](
+ data[data_type][idx])
+ return data
+
+ def apply_ops(self, data, op_dict, full_data=False):
+ r"""Apply any ops from op_dict to data types.
+
+ Args:
+ data (dict): Dict containing data_type as key, with each value
+ as a list of numpy.ndarrays.
+ op_dict (dict): Dict containing data_type as key, with each value
+ containing string of operations to apply.
+ full_data (bool): Do these ops require access to the full data?
+ Returns:
+ data (dict): Dict containing data_type as key, with each value
+ modified by the op if any.
+ """
+ if full_data:
+ # op needs entire data dict.
+ for op in op_dict:
+ if op == 'None':
+ continue
+ op, op_type = self.get_op(op)
+ assert op_type == 'full_data'
+ data = op(data)
+ else:
+ # op per data type.
+ if not op_dict:
+ return data
+ for data_type in data:
+ for op in op_dict[data_type]:
+ if op == 'None':
+ continue
+ op, op_type = self.get_op(op)
+ data[data_type] = op(data[data_type])
+
+ if op_type == 'vis':
+ # We have converted this data type to an image. Enter it
+ # in self.image_data_types and give it a torch
+ # transform.
+ if data_type not in self.image_data_types:
+ self.image_data_types.append(data_type)
+ normalize = self.normalize[data_type]
+ num_channels = self.num_channels[data_type]
+ self.transform[data_type] = \
+ self._get_transform(normalize, num_channels)
+ elif op_type == 'convert':
+ continue
+ elif op_type is None:
+ continue
+ else:
+ raise NotImplementedError
+ return data
+
+ def get_op(self, op):
+ r"""Get function to apply for specific op.
+
+ Args:
+ op (str): Name of the op.
+ Returns:
+ function handle.
+ """
+ def list_to_tensor(data):
+ r"""Convert list of numeric values to tensor."""
+ assert isinstance(data, list)
+ return torch.from_numpy(np.array(data, dtype=np.float32))
+
+ def decode_json_list(data):
+ r"""Decode list of strings in json to objects."""
+ assert isinstance(data, list)
+ return [json.loads(item) for item in data]
+
+ def decode_pkl_list(data):
+ r"""Decode list of pickled strings to objects."""
+ assert isinstance(data, list)
+ return [pickle.loads(item) for item in data]
+
+ def list_to_numpy(data):
+ r"""Convert list of numeric values to numpy array."""
+ assert isinstance(data, list)
+ return np.array(data)
+
+ def l2_normalize(data):
+ r"""L2 normalization."""
+ assert isinstance(data, torch.Tensor)
+ import torch.nn.functional as F
+ return F.normalize(data, dim=1)
+
+ if op == 'to_tensor':
+ return list_to_tensor, None
+ elif op == 'decode_json':
+ return decode_json_list, None
+ elif op == 'decode_pkl':
+ return decode_pkl_list, None
+ elif op == 'to_numpy':
+ return list_to_numpy, None
+ elif op == 'l2_norm':
+ return l2_normalize, None
+ elif '::' in op:
+ parts = op.split('::')
+ if len(parts) == 2:
+ module, function = parts
+ module = importlib.import_module(module)
+ function = getattr(module, function)
+ sig = signature(function)
+ num_params = len(sig.parameters)
+ assert num_params in [3, 4], \
+ 'Full data functions take in (cfgdata, is_inference, ' \
+ 'full_data) or (cfgdata, is_inference, self, full_data) ' \
+ 'as input.'
+ if num_params == 3:
+ function = partial(
+ function, self.cfgdata, self.is_inference)
+ elif num_params == 4:
+ function = partial(
+ function, self.cfgdata, self.is_inference, self)
+ function_type = 'full_data'
+ elif len(parts) == 3:
+ function_type, module, function = parts
+ module = importlib.import_module(module)
+
+ # Get function inputs, if provided.
+ partial_fn = False
+ if '(' in function and ')' in function:
+ partial_fn = True
+ function, params = self._get_fn_params(function)
+
+ function = getattr(module, function)
+
+ # Create partial function.
+ if partial_fn:
+ function = partial(function, **params)
+
+ # Get function signature.
+ sig = signature(function)
+ num_params = 0
+ for param in sig.parameters.values():
+ if param.kind == param.POSITIONAL_OR_KEYWORD:
+ num_params += 1
+
+ if function_type == 'vis':
+ if num_params != 9:
+ raise ValueError(
+ 'vis function type needs to take ' +
+ '(resize_h, resize_w, crop_h, crop_w, ' +
+ 'original_h, original_w, is_flipped, cfgdata, ' +
+ 'data) as input.')
+ function = partial(function,
+ self.augmentor.resize_h,
+ self.augmentor.resize_w,
+ self.augmentor.crop_h,
+ self.augmentor.crop_w,
+ self.augmentor.original_h,
+ self.augmentor.original_w,
+ self.augmentor.is_flipped,
+ self.cfgdata)
+ elif function_type == 'convert':
+ if num_params != 1:
+ raise ValueError(
+ 'convert function type needs to take ' +
+ '(data) as input.')
+ else:
+ raise ValueError('Unknown op: %s' % (op))
+ else:
+ raise ValueError('Unknown op: %s' % (op))
+ return function, function_type
+ else:
+ raise ValueError('Unknown op: %s' % (op))
+
+ def _get_fn_params(self, function_string):
+ r"""Find key-value inputs to function from string definition.
+
+ Args:
+ function_string (str): String with function name and args. e.g.
+ my_function(a=10, b=20).
+ Returns:
+ function (str): Name of function.
+ params (dict): Key-value params for function.
+ """
+ start = function_string.find('(')
+ end = function_string.find(')')
+ function = function_string[:start]
+ params_str = function_string[start+1:end]
+ params = {}
+ for item in params_str.split(':'):
+ key, value = item.split('=')
+ try:
+ params[key] = float(value)
+ except: # noqa
+ params[key] = value
+ return function, params
+
+ def __len__(self):
+ return self.epoch_length
diff --git a/imaginaire/datasets/cache.py b/imaginaire/datasets/cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c28752dc290c9cbf15ffcf6ca2093415082f93e
--- /dev/null
+++ b/imaginaire/datasets/cache.py
@@ -0,0 +1,40 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import diskcache
+
+"""
+INFO:
+Cache objects are thread-safe and may be shared between threads.
+Two Cache objects may also reference the same directory from separate
+threads or processes. In this way, they are also process-safe and support
+cross-process communication.
+"""
+
+
+class Cache(object):
+ r"""This creates an on disk cache, which saves files as bytes.
+ Args:
+ root (str): Path to the cache dir.
+ size_MB (float): Size of cache in MB.
+ """
+
+ def __init__(self, root, size_GB):
+ self.root = root
+ self.size_limit_B = size_GB * 1024 * 1024 * 1024
+ self.cache = diskcache.Cache(root, size_limit=self.size_limit_B)
+ print('Created cache of max size %d GB at %s' %
+ (size_GB, self.cache.directory))
+
+ def read(self, key):
+ if key in self.cache:
+ return self.cache[key]
+ return False
+
+ def write(self, key, value):
+ try:
+ self.cache[key] = value
+ except Exception as e: # noqa
+ print(e)
+ return False
diff --git a/imaginaire/datasets/dummy.py b/imaginaire/datasets/dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..9783eb3f38007652b84dde33f2da9202491686e1
--- /dev/null
+++ b/imaginaire/datasets/dummy.py
@@ -0,0 +1,18 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+
+
+class Dataset(torch.utils.data.Dataset):
+ r"""Dummy dataset, returns nothing."""
+
+ def __init__(self, cfg, is_inference=False, is_test=False):
+ super(Dataset, self).__init__()
+
+ def __getitem__(self, index):
+ return {}
+
+ def __len__(self):
+ return 65535
diff --git a/imaginaire/datasets/folder.py b/imaginaire/datasets/folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd3fcc679834044ac9bf11afe81a9e9fe8697aa8
--- /dev/null
+++ b/imaginaire/datasets/folder.py
@@ -0,0 +1,86 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+import cv2
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+
+from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS
+import imageio
+
+
+class FolderDataset(data.Dataset):
+ r"""This deals with opening, and reading from an Folder dataset.
+
+ Args:
+ root (str): Path to the folder.
+ metadata (dict): Containing extensions.
+ """
+
+ def __init__(self, root, metadata):
+ self.root = os.path.expanduser(root)
+ self.extensions = metadata
+
+ print('Folder at %s opened.' % (root))
+
+ def getitem_by_path(self, path, data_type):
+ r"""Load data item stored for key = path.
+
+ Args:
+ path (str): Key into Folder dataset.
+ data_type (str): Key into self.extensions e.g. data/data_segmaps/...
+ Returns:
+ img (PIL.Image) or buf (str): Contents of file for this key.
+ """
+ # Figure out decoding params.
+ ext = self.extensions[data_type]
+ is_image = False
+ is_hdr = False
+ if ext in IMG_EXTENSIONS:
+ is_image = True
+ if 'tif' in ext:
+ dtype, mode = np.uint16, -1
+ elif 'JPEG' in ext or 'JPG' in ext \
+ or 'jpeg' in ext or 'jpg' in ext:
+ dtype, mode = np.uint8, 3
+ else:
+ dtype, mode = np.uint8, -1
+ elif ext in HDR_IMG_EXTENSIONS:
+ is_hdr = True
+ else:
+ is_image = False
+
+ # Get value from key.
+ filepath = os.path.join(self.root, path.decode() + '.' + ext)
+ assert os.path.exists(filepath), '%s does not exist' % (filepath)
+ with open(filepath, 'rb') as f:
+ buf = f.read()
+
+ # Decode and return.
+ if is_image:
+ try:
+ img = cv2.imdecode(np.fromstring(buf, dtype=dtype), mode)
+ except Exception:
+ print(path)
+ # BGR to RGB if 3 channels.
+ if img.ndim == 3 and img.shape[-1] == 3:
+ img = img[:, :, ::-1]
+ img = Image.fromarray(img)
+ return img
+ elif is_hdr:
+ try:
+ imageio.plugins.freeimage.download()
+ img = imageio.imread(buf)
+ except Exception:
+ print(path)
+ return img # Return a numpy array
+ else:
+ return buf
+
+ def __len__(self):
+ r"""Return number of keys in Folder dataset."""
+ return self.length
diff --git a/imaginaire/datasets/images.py b/imaginaire/datasets/images.py
new file mode 100644
index 0000000000000000000000000000000000000000..943752be11d823025f33956dcdcaf6a35c1fb899
--- /dev/null
+++ b/imaginaire/datasets/images.py
@@ -0,0 +1,168 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import random
+
+from imaginaire.datasets.base import BaseDataset
+
+
+class Dataset(BaseDataset):
+ r"""Image dataset for use in class conditional GAN.
+
+ Args:
+ cfg (Config): Loaded config object.
+ is_inference (bool): In train or inference mode?
+ """
+
+ def __init__(self, cfg, is_inference=False, is_test=False):
+ self.paired = False
+ super(Dataset, self).__init__(cfg, is_inference, is_test)
+ self.num_classes = len(self.class_name_to_idx['images'])
+ self.sample_class_idx = None
+
+ def set_sample_class_idx(self, class_idx):
+ r"""Set sample class idx. This is not used in this class...
+
+ Args:
+ class_idx (int): Which class idx to sample from.
+ """
+ self.sample_class_idx = class_idx
+ self.epoch_length = \
+ max([len(lmdb_keys) for _, lmdb_keys in self.mapping.items()])
+
+ def _create_mapping(self):
+ r"""Creates mapping from idx to key in LMDB.
+
+ Returns:
+ (tuple):
+ - self.mapping (dict): Dict with data type as key mapping idx to
+ LMDB key.
+ - self.epoch_length (int): Number of samples in an epoch.
+ """
+ idx_to_key, class_names = {}, {}
+ for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+ for data_type, data_type_sequence_list in sequence_list.items():
+ class_names[data_type] = []
+ if data_type not in idx_to_key:
+ idx_to_key[data_type] = []
+ for sequence_name, filenames in data_type_sequence_list.items():
+ class_name = sequence_name.split('/')[0]
+ for filename in filenames:
+ idx_to_key[data_type].append({
+ 'lmdb_root': self.lmdb_roots[lmdb_idx],
+ 'lmdb_idx': lmdb_idx,
+ 'sequence_name': sequence_name,
+ 'filename': filename,
+ 'class_name': class_name
+ })
+ class_names[data_type].append(class_name)
+ self.mapping = idx_to_key
+ self.epoch_length = max([len(lmdb_keys)
+ for _, lmdb_keys in self.mapping.items()])
+
+ # Create mapping from class name to class idx.
+ self.class_name_to_idx = {}
+ for data_type, class_names_data_type in class_names.items():
+ self.class_name_to_idx[data_type] = {}
+ class_names_data_type = sorted(list(set(class_names_data_type)))
+ for class_idx, class_name in enumerate(class_names_data_type):
+ self.class_name_to_idx[data_type][class_name] = class_idx
+
+ # Add class idx to mapping.
+ for data_type in self.mapping:
+ for key in self.mapping[data_type]:
+ key['class_idx'] = \
+ self.class_name_to_idx[data_type][key['class_name']]
+
+ # Create a mapping from index to lmdb key for each class.
+ idx_to_key_class = {}
+ for data_type in self.mapping:
+ idx_to_key_class[data_type] = {}
+ for class_idx, class_name in enumerate(class_names[data_type]):
+ idx_to_key_class[data_type][class_idx] = []
+ for key in self.mapping[data_type]:
+ idx_to_key_class[data_type][key['class_idx']].append(key)
+ self.mapping_class = idx_to_key_class
+
+ return self.mapping, self.epoch_length
+
+ def _sample_keys(self, index):
+ r"""Gets files to load for this sample.
+
+ Args:
+ index (int): Index in [0, len(dataset)].
+ Returns:
+ keys (dict): Each key of this dict is a data type.
+ - lmdb_key (dict):
+ - lmdb_idx (int): Chosen LMDB dataset root.
+ - sequence_name (str): Chosen sequence in chosen dataset.
+ - filename (str): Chosen filename in chosen sequence.
+ """
+
+ keys = {}
+ if self.is_inference: # evaluation mode
+ lmdb_keys = self.mapping['images']
+ keys['images'] = lmdb_keys[index % len(lmdb_keys)]
+ else:
+ lmdb_keys = self.mapping['images']
+ keys['images'] = random.choice(lmdb_keys)
+ return keys
+
+ def __getitem__(self, index):
+ r"""Gets selected files.
+
+ Args:
+ index (int): Index into dataset.
+ concat (bool): Concatenate all items in labels?
+ Returns:
+ data (dict): Dict with all chosen data_types.
+ """
+ # Select a sample from the available data.
+ keys_per_data_type = self._sample_keys(index)
+
+ # Get class idx into a list.
+ class_idxs = []
+ for data_type in keys_per_data_type:
+ class_idxs.append(keys_per_data_type[data_type]['class_idx'])
+
+ # Get keys and lmdbs.
+ keys, lmdbs = {}, {}
+ for data_type in self.dataset_data_types:
+ # Unpack keys.
+ lmdb_idx = keys_per_data_type[data_type]['lmdb_idx']
+ sequence_name = keys_per_data_type[data_type]['sequence_name']
+ filename = keys_per_data_type[data_type]['filename']
+ keys[data_type] = '%s/%s' % (sequence_name, filename)
+ lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+ # Load all data for this index.
+ data = self.load_from_dataset(keys, lmdbs)
+
+ # Apply ops pre augmentation.
+ data = self.apply_ops(data, self.pre_aug_ops)
+
+ # Do augmentations for images.
+ data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops)
+
+ # Apply ops post augmentation.
+ data = self.apply_ops(data, self.post_aug_ops)
+ data = self.apply_ops(data, self.full_data_post_aug_ops,
+ full_data=True)
+
+ # Convert images to tensor.
+ for data_type in self.image_data_types:
+ for idx in range(len(data[data_type])):
+ data[data_type][idx] = \
+ data[data_type][idx][:, :, :self.num_channels[data_type]]
+ data = self.to_tensor(data)
+
+ # Remove any extra dimensions.
+ for data_type in self.image_data_types:
+ data[data_type] = data[data_type][0]
+
+ # Package output.
+ data['is_flipped'] = is_flipped
+ data['key'] = keys_per_data_type
+ data['labels'] = class_idxs[0]
+ return data
diff --git a/imaginaire/datasets/lmdb.py b/imaginaire/datasets/lmdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..136642c1e624b886b05aaffe46f14694b0eaa29a
--- /dev/null
+++ b/imaginaire/datasets/lmdb.py
@@ -0,0 +1,92 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import json
+import os
+
+import cv2
+import lmdb
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+
+from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS
+from imaginaire.utils.distributed import master_only_print as print
+import imageio
+
+
+class LMDBDataset(data.Dataset):
+ r"""This deals with opening, and reading from an LMDB dataset.
+ Args:
+ root (str): Path to the LMDB file.
+ """
+
+ def __init__(self, root):
+ self.root = os.path.expanduser(root)
+ self.env = lmdb.open(root, max_readers=126, readonly=True, lock=False,
+ readahead=False, meminit=False)
+ with self.env.begin(write=False) as txn:
+ self.length = txn.stat()['entries']
+
+ # Read metadata.
+ with open(os.path.join(self.root, '..', 'metadata.json')) as fin:
+ self.extensions = json.load(fin)
+
+ print('LMDB file at %s opened.' % (root))
+
+ def getitem_by_path(self, path, data_type):
+ r"""Load data item stored for key = path.
+
+ Args:
+ path (str): Key into LMDB dataset.
+ data_type (str): Key into self.extensions e.g. data/data_segmaps/...
+ Returns:
+ img (PIL.Image) or buf (str): Contents of LMDB value for this key.
+ """
+ # Figure out decoding params.
+ ext = self.extensions[data_type]
+ is_image = False
+ is_hdr = False
+ if ext in IMG_EXTENSIONS:
+ is_image = True
+ if 'tif' in ext:
+ dtype, mode = np.uint16, -1
+ elif 'JPEG' in ext or 'JPG' in ext \
+ or 'jpeg' in ext or 'jpg' in ext:
+ dtype, mode = np.uint8, 3
+ else:
+ dtype, mode = np.uint8, -1
+ elif ext in HDR_IMG_EXTENSIONS:
+ is_hdr = True
+ else:
+ is_image = False
+
+ # Get value from key.
+ with self.env.begin(write=False) as txn:
+ buf = txn.get(path)
+
+ # Decode and return.
+ if is_image:
+ try:
+ img = cv2.imdecode(np.fromstring(buf, dtype=dtype), mode)
+ except Exception:
+ print(path)
+ # BGR to RGB if 3 channels.
+ if img.ndim == 3 and img.shape[-1] == 3:
+ img = img[:, :, ::-1]
+ img = Image.fromarray(img)
+ return img
+ elif is_hdr:
+ try:
+ imageio.plugins.freeimage.download()
+ img = imageio.imread(buf)
+ except Exception:
+ print(path)
+ return img # Return a numpy array
+ else:
+ return buf
+
+ def __len__(self):
+ r"""Return number of keys in LMDB dataset."""
+ return self.length
diff --git a/imaginaire/datasets/object_store.py b/imaginaire/datasets/object_store.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dd4f2d765def17ed52d5a442cc3bb87d31b9ded
--- /dev/null
+++ b/imaginaire/datasets/object_store.py
@@ -0,0 +1,142 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import io
+import json
+
+# import cv2
+import boto3
+from botocore.config import Config
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+import imageio
+from botocore.exceptions import ClientError
+
+from imaginaire.datasets.cache import Cache
+from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS
+
+Image.MAX_IMAGE_PIXELS = None
+
+
+class ObjectStoreDataset(data.Dataset):
+ r"""This deals with opening, and reading from an AWS S3 bucket.
+ Args:
+
+ root (str): Path to the AWS S3 bucket.
+ aws_credentials_file (str): Path to file containing AWS credentials.
+ data_type (str): Which data type should this dataset load?
+ """
+
+ def __init__(self, root, aws_credentials_file, data_type='', cache=None):
+ # Cache.
+ self.cache = False
+ if cache is not None:
+ # raise NotImplementedError
+ self.cache = Cache(cache.root, cache.size_GB)
+
+ # Get bucket info, and keys to info about dataset.
+ with open(aws_credentials_file) as fin:
+ self.credentials = json.load(fin)
+
+ parts = root.split('/')
+ self.bucket = parts[0]
+ self.all_filenames_key = '/'.join(parts[1:]) + '/all_filenames.json'
+ self.metadata_key = '/'.join(parts[1:]) + '/metadata.json'
+
+ # Get list of filenames.
+ filename_info = self._get_object(self.all_filenames_key)
+ self.sequence_list = json.loads(filename_info.decode('utf-8'))
+
+ # Get length.
+ length = 0
+ for _, value in self.sequence_list.items():
+ length += len(value)
+ self.length = length
+
+ # Read metadata.
+ metadata_info = self._get_object(self.metadata_key)
+ self.extensions = json.loads(metadata_info.decode('utf-8'))
+ self.data_type = data_type
+
+ print('AWS S3 bucket at %s opened.' % (root + '/' + self.data_type))
+
+ def _get_object(self, key):
+ r"""Download object from bucket.
+
+ Args:
+ key (str): Key inside bucket.
+ """
+ # Look up value in cache.
+ object_content = self.cache.read(key) if self.cache else False
+ if not object_content:
+ # Either no cache used or key not found in cache.
+ config = Config(connect_timeout=30,
+ signature_version="s3",
+ retries={"max_attempts": 999999})
+ s3 = boto3.client('s3', **self.credentials, config=config)
+ try:
+ s3_response_object = s3.get_object(Bucket=self.bucket, Key=key)
+ object_content = s3_response_object['Body'].read()
+ except Exception as e:
+ print('%s not found' % (key))
+ print(e)
+ # Save content to cache.
+ if self.cache:
+ self.cache.write(key, object_content)
+ return object_content
+
+ def getitem_by_path(self, path, data_type):
+ r"""Load data item stored for key = path.
+
+ Args:
+ path (str): Path into AWS S3 bucket, without data_type prefix.
+ data_type (str): Key into self.extensions e.g. data/data_segmaps/...
+ Returns:
+ img (PIL.Image) or buf (str): Contents of LMDB value for this key.
+ """
+ # Figure out decoding params.
+ ext = self.extensions[data_type]
+ is_image = False
+ is_hdr = False
+ parts = path.split('/')
+ key = parts[0] + '/' + data_type + '/' + '/'.join(parts[1:]) + '.' + ext
+ if ext in IMG_EXTENSIONS:
+ is_image = True
+ if 'tif' in ext:
+ _, mode = np.uint16, -1
+ elif 'JPEG' in ext or 'JPG' in ext \
+ or 'jpeg' in ext or 'jpg' in ext:
+ _, mode = np.uint8, 3
+ else:
+ _, mode = np.uint8, -1
+ elif ext in HDR_IMG_EXTENSIONS:
+ is_hdr = True
+ else:
+ is_image = False
+
+ # Get value from key.
+ buf = self._get_object(key)
+
+ # Decode and return.
+ if is_image:
+ # This is totally a hack.
+ # We should have a better way to handle grayscale images.
+ img = Image.open(io.BytesIO(buf))
+ if mode == 3:
+ img = img.convert('RGB')
+ return img
+ elif is_hdr:
+ try:
+ imageio.plugins.freeimage.download()
+ img = imageio.imread(buf)
+ except Exception:
+ print(path)
+ return img # Return a numpy array
+ else:
+ return buf
+
+ def __len__(self):
+ r"""Return number of keys in LMDB dataset."""
+ return self.length
diff --git a/imaginaire/datasets/paired_few_shot_videos.py b/imaginaire/datasets/paired_few_shot_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b311bc36fbc05aceaf552a47d3d89b118728be
--- /dev/null
+++ b/imaginaire/datasets/paired_few_shot_videos.py
@@ -0,0 +1,308 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import copy
+import random
+import torch
+
+from imaginaire.datasets.paired_videos import Dataset as VideoDataset
+from imaginaire.model_utils.fs_vid2vid import select_object
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Dataset(VideoDataset):
+ r"""Paired video dataset for use in few-shot vid2vid.
+
+ Args:
+ cfg (Config): Loaded config object.
+ is_inference (bool): In train or inference mode?
+ sequence_length (int): What sequence of images to provide?
+ few_shot_K (int): How many images to provide for few-shot?
+ """
+
+ def __init__(self, cfg, is_inference=False, sequence_length=None,
+ few_shot_K=None, is_test=False):
+ self.paired = True
+ # Get initial few shot K.
+ if few_shot_K is None:
+ self.few_shot_K = cfg.data.initial_few_shot_K
+ else:
+ self.few_shot_K = few_shot_K
+ # Initialize.
+ super(Dataset, self).__init__(
+ cfg, is_inference, sequence_length=sequence_length, is_test=is_test)
+
+ def set_inference_sequence_idx(self, index, k_shot_index,
+ k_shot_frame_index):
+ r"""Get frames from this sequence during inference.
+
+ Args:
+ index (int): Index of inference sequence.
+ k_shot_index (int): Index of sequence from which k_shot is sampled.
+ k_shot_frame_index (int): Index of frame to sample.
+ """
+ assert self.is_inference
+ assert index < len(self.mapping)
+ assert k_shot_index < len(self.mapping)
+ assert k_shot_frame_index < len(self.mapping[k_shot_index])
+
+ self.inference_sequence_idx = index
+ self.inference_k_shot_sequence_index = k_shot_index
+ self.inference_k_shot_frame_index = k_shot_frame_index
+ self.epoch_length = len(
+ self.mapping[self.inference_sequence_idx]['filenames'])
+
+ def set_sequence_length(self, sequence_length, few_shot_K=None):
+ r"""Set the length of sequence you want as output from dataloader.
+
+ Args:
+ sequence_length (int): Length of output sequences.
+ few_shot_K (int): Number of few-shot frames.
+ """
+ if few_shot_K is None:
+ few_shot_K = self.few_shot_K
+ assert isinstance(sequence_length, int)
+ assert isinstance(few_shot_K, int)
+ if (sequence_length + few_shot_K) > self.sequence_length_max:
+ error_message = \
+ 'Requested sequence length (%d) ' % (sequence_length) + \
+ '+ few shot K (%d) > ' % (few_shot_K) + \
+ 'max sequence length (%d). ' % (self.sequence_length_max)
+ print(error_message)
+ sequence_length = self.sequence_length_max - few_shot_K
+ print('Reduced sequence length to %s' % (sequence_length))
+ self.sequence_length = sequence_length
+ self.few_shot_K = few_shot_K
+ # Recalculate mapping as some sequences might no longer be useful.
+ self.mapping, self.epoch_length = self._create_mapping()
+ print('Epoch length:', self.epoch_length)
+
+ def _create_mapping(self):
+ r"""Creates mapping from idx to key in LMDB.
+
+ Returns:
+ (tuple):
+ - self.mapping (dict): Dict of seq_len to list of sequences.
+ - self.epoch_length (int): Number of samples in an epoch.
+ """
+ # Create dict mapping length to sequence.
+ length_to_key, num_selected_seq = {}, 0
+ has_additional_lists = len(self.additional_lists) > 0
+ for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+ for sequence_name, filenames in sequence_list.items():
+ if len(filenames) >= (self.sequence_length + self.few_shot_K):
+ if len(filenames) not in length_to_key:
+ length_to_key[len(filenames)] = []
+ if has_additional_lists:
+ obj_indices = self.additional_lists[lmdb_idx][
+ sequence_name]
+ else:
+ obj_indices = [0 for _ in range(len(filenames))]
+ length_to_key[len(filenames)].append({
+ 'lmdb_root': self.lmdb_roots[lmdb_idx],
+ 'lmdb_idx': lmdb_idx,
+ 'sequence_name': sequence_name,
+ 'filenames': filenames,
+ 'obj_indices': obj_indices,
+ })
+ num_selected_seq += 1
+ self.mapping = length_to_key
+ self.epoch_length = num_selected_seq
+
+ # At inference time, we want to use all sequences,
+ # irrespective of length.
+ if self.is_inference:
+ sequence_list = []
+ for key, sequences in self.mapping.items():
+ sequence_list.extend(sequences)
+ self.mapping = sequence_list
+
+ return self.mapping, self.epoch_length
+
+ def _sample_keys(self, index):
+ r"""Gets files to load for this sample.
+
+ Args:
+ index (int): Index in [0, len(dataset)].
+ Returns:
+ key (dict):
+ - lmdb_idx (int): Chosen LMDB dataset root.
+ - sequence_name (str): Chosen sequence in chosen dataset.
+ - filenames (list of str): Chosen filenames in chosen sequence.
+ """
+ if self.is_inference:
+ assert index < self.epoch_length
+ chosen_sequence = self.mapping[self.inference_sequence_idx]
+ chosen_filenames = [chosen_sequence['filenames'][index]]
+ chosen_obj_indices = [chosen_sequence['obj_indices'][index]]
+ k_shot_chosen_sequence = self.mapping[
+ self.inference_k_shot_sequence_index]
+ k_shot_chosen_filenames = [k_shot_chosen_sequence['filenames'][
+ self.inference_k_shot_frame_index]]
+ k_shot_chosen_obj_indices = [k_shot_chosen_sequence['obj_indices'][
+ self.inference_k_shot_frame_index]]
+ # Prepare few shot key.
+ few_shot_key = copy.deepcopy(k_shot_chosen_sequence)
+ few_shot_key['filenames'] = k_shot_chosen_filenames
+ few_shot_key['obj_indices'] = k_shot_chosen_obj_indices
+ else:
+ # Pick a time step for temporal augmentation.
+ time_step = random.randint(1, self.augmentor.max_time_step)
+ required_sequence_length = 1 + \
+ (self.sequence_length - 1) * time_step
+
+ # If step is too large, default to step size of 1.
+ if required_sequence_length + self.few_shot_K > \
+ self.sequence_length_max:
+ required_sequence_length = self.sequence_length
+ time_step = 1
+
+ # Find valid sequences.
+ valid_sequences = []
+ for sequence_length, sequences in self.mapping.items():
+ if sequence_length >= required_sequence_length + \
+ self.few_shot_K:
+ valid_sequences.extend(sequences)
+
+ # Pick a sequence.
+ chosen_sequence = random.choice(valid_sequences)
+
+ # Choose filenames.
+ max_start_idx = len(chosen_sequence['filenames']) - \
+ required_sequence_length
+ start_idx = random.randint(0, max_start_idx)
+ end_idx = start_idx + required_sequence_length
+ chosen_filenames = chosen_sequence['filenames'][
+ start_idx:end_idx:time_step]
+ chosen_obj_indices = chosen_sequence['obj_indices'][
+ start_idx:end_idx:time_step]
+
+ # Find the K few shot filenames.
+ valid_range = list(range(start_idx)) + \
+ list(range(end_idx, len(chosen_sequence['filenames'])))
+ k_shot_chosen = sorted(random.sample(valid_range, self.few_shot_K))
+ k_shot_chosen_filenames = [chosen_sequence['filenames'][idx]
+ for idx in k_shot_chosen]
+ k_shot_chosen_obj_indices = [chosen_sequence['obj_indices'][idx]
+ for idx in k_shot_chosen]
+ assert not (set(chosen_filenames) & set(k_shot_chosen_filenames))
+
+ assert len(chosen_filenames) == self.sequence_length
+ assert len(k_shot_chosen_filenames) == self.few_shot_K
+
+ # Prepare few shot key.
+ few_shot_key = copy.deepcopy(chosen_sequence)
+ few_shot_key['filenames'] = k_shot_chosen_filenames
+ few_shot_key['obj_indices'] = k_shot_chosen_obj_indices
+
+ # Prepre output key.
+ key = copy.deepcopy(chosen_sequence)
+ key['filenames'] = chosen_filenames
+ key['obj_indices'] = chosen_obj_indices
+ return key, few_shot_key
+
+ def _prepare_data(self, keys):
+ r"""Load data and perform augmentation.
+
+ Args:
+ keys (dict): Key into LMDB/folder dataset for this item.
+ Returns:
+ data (dict): Dict with all chosen data_types.
+ """
+ # Unpack keys.
+ lmdb_idx = keys['lmdb_idx']
+ sequence_name = keys['sequence_name']
+ filenames = keys['filenames']
+ obj_indices = keys['obj_indices']
+
+ # Get key and lmdbs.
+ keys, lmdbs = {}, {}
+ for data_type in self.dataset_data_types:
+ keys[data_type] = self._create_sequence_keys(
+ sequence_name, filenames)
+ lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+ # Load all data for this index.
+ data = self.load_from_dataset(keys, lmdbs)
+
+ # Apply ops pre augmentation.
+ data = self.apply_ops(data, self.pre_aug_ops)
+
+ # Select the object in data using the object indices.
+ data = select_object(data, obj_indices)
+
+ # Do augmentations for images.
+ data, is_flipped = self.perform_augmentation(data, paired=True, augment_ops=self.augmentor.augment_ops)
+
+ # Create copy of keypoint data types before post aug.
+ # kp_data = {}
+ # for data_type in self.keypoint_data_types:
+ # new_key = data_type + '_xy'
+ # kp_data[new_key] = copy.deepcopy(data[data_type])
+
+ # Create copy of keypoint data types before post aug.
+ kp_data = {}
+ for data_type in self.keypoint_data_types:
+ new_key = data_type + '_xy'
+ kp_data[new_key] = copy.deepcopy(data[data_type])
+
+ # Apply ops post augmentation.
+ data = self.apply_ops(data, self.post_aug_ops)
+
+ data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True)
+
+ # Convert images to tensor.
+ data = self.to_tensor(data)
+
+ # Pack the sequence of images.
+ for data_type in self.image_data_types:
+ for idx in range(len(data[data_type])):
+ data[data_type][idx] = data[data_type][idx].unsqueeze(0)
+ data[data_type] = torch.cat(data[data_type], dim=0)
+
+ # Add keypoint xy to data.
+ data.update(kp_data)
+
+ data['is_flipped'] = is_flipped
+ data['key'] = keys
+
+ return data
+
+ def _getitem(self, index):
+ r"""Gets selected files.
+
+ Args:
+ index (int): Index into dataset.
+ Returns:
+ data (dict): Dict with all chosen data_types.
+ """
+ # Select a sample from the available data.
+ keys, few_shot_keys = self._sample_keys(index)
+
+ data = self._prepare_data(keys)
+ few_shot_data = self._prepare_data(few_shot_keys)
+
+ # Add few shot data into data.
+ for key, value in few_shot_data.items():
+ data['few_shot_' + key] = few_shot_data[key]
+
+ # Apply full data ops.
+ if self.is_inference:
+ if index == 0:
+ pass
+ elif index < self.cfg.data.num_workers:
+ data_0 = self._getitem(0)
+ if 'common_attr' in data_0:
+ self.common_attr = data['common_attr'] = \
+ data_0['common_attr']
+ else:
+ if hasattr(self, 'common_attr'):
+ data['common_attr'] = self.common_attr
+
+ data = self.apply_ops(data, self.full_data_ops, full_data=True)
+
+ if self.is_inference and index == 0 and 'common_attr' in data:
+ self.common_attr = data['common_attr']
+
+ return data
diff --git a/imaginaire/datasets/paired_few_shot_videos_native.py b/imaginaire/datasets/paired_few_shot_videos_native.py
new file mode 100644
index 0000000000000000000000000000000000000000..17bd23d8b1bd7b8d007f9a9bcf9f13ae6a16bdf0
--- /dev/null
+++ b/imaginaire/datasets/paired_few_shot_videos_native.py
@@ -0,0 +1,233 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import random
+import tempfile
+from collections import OrderedDict
+import warnings
+import numpy as np
+import torch
+# import torchvision.io as io
+import cv2
+from PIL import Image
+
+from imaginaire.datasets.base import BaseDataset
+
+
+class Dataset(BaseDataset):
+ r"""Dataset for paired few shot videos.
+
+ Args:
+ cfg (Config): Loaded config object.
+ is_inference (bool): In train or inference mode?
+ """
+
+ def __init__(self, cfg, is_inference=False, is_test=False):
+ self.paired = True
+ super(Dataset, self).__init__(cfg, is_inference, is_test)
+ self.is_video_dataset = True
+ self.few_shot_K = 1
+ self.first_last_only = getattr(cfg.data, 'first_last_only', False)
+ self.sample_far_frames_more = getattr(cfg.data, 'sample_far_frames_more', False)
+
+ def get_label_lengths(self):
+ r"""Get num channels of all labels to be concated.
+
+ Returns:
+ label_lengths (OrderedDict): Dict mapping image data_type to num
+ channels.
+ """
+ label_lengths = OrderedDict()
+ for data_type in self.input_labels:
+ data_cfg = self.cfgdata
+ if hasattr(data_cfg, 'one_hot_num_classes') and \
+ data_type in data_cfg.one_hot_num_classes:
+ label_lengths[data_type] = data_cfg.one_hot_num_classes[data_type]
+ if getattr(data_cfg, 'use_dont_care', False):
+ label_lengths[data_type] += 1
+ else:
+ label_lengths[data_type] = self.num_channels[data_type]
+ return label_lengths
+
+ def num_inference_sequences(self):
+ r"""Number of sequences available for inference.
+
+ Returns:
+ (int)
+ """
+ assert self.is_inference
+ return len(self.mapping)
+
+ def _create_mapping(self):
+ r"""Creates mapping from idx to key in LMDB.
+
+ Returns:
+ (tuple):
+ - self.mapping (dict): Dict of seq_len to list of sequences.
+ - self.epoch_length (int): Number of samples in an epoch.
+ """
+ # Create dict mapping length to sequence.
+ mapping = []
+ for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+ for sequence_name, filenames in sequence_list.items():
+ for filename in filenames:
+ # This file is corrupt.
+ if filename == 'z-KziTO_5so_0019_start0_end85_h596_w596':
+ continue
+ mapping.append({
+ 'lmdb_root': self.lmdb_roots[lmdb_idx],
+ 'lmdb_idx': lmdb_idx,
+ 'sequence_name': sequence_name,
+ 'filenames': [filename],
+ })
+ self.mapping = mapping
+ self.epoch_length = len(mapping)
+
+ return self.mapping, self.epoch_length
+
+ def _sample_keys(self, index):
+ r"""Gets files to load for this sample.
+
+ Args:
+ index (int): Index in [0, len(dataset)].
+ Returns:
+ (tuple):
+ - key (dict):
+ - lmdb_idx (int): Chosen LMDB dataset root.
+ - sequence_name (str): Chosen sequence in chosen dataset.
+ - filenames (list of str): Chosen filenames in chosen sequence.
+ """
+ if self.is_inference:
+ assert index < self.epoch_length
+ raise NotImplementedError
+ else:
+ # Select a video at random.
+ key = random.choice(self.mapping)
+ return key
+
+ def _create_sequence_keys(self, sequence_name, filenames):
+ r"""Create the LMDB key for this piece of information.
+
+ Args:
+ sequence_name (str): Which sequence from the chosen dataset.
+ filenames (list of str): List of filenames in this sequence.
+ Returns:
+ keys (list): List of full keys.
+ """
+ assert isinstance(filenames, list), 'Filenames should be a list.'
+ keys = []
+ for filename in filenames:
+ keys.append('%s/%s' % (sequence_name, filename))
+ return keys
+
+ def _getitem(self, index):
+ r"""Gets selected files.
+
+ Args:
+ index (int): Index into dataset.
+ concat (bool): Concatenate all items in labels?
+ Returns:
+ data (dict): Dict with all chosen data_types.
+ """
+ # Select a sample from the available data.
+ keys = self._sample_keys(index)
+
+ # Unpack keys.
+ lmdb_idx = keys['lmdb_idx']
+ sequence_name = keys['sequence_name']
+ filenames = keys['filenames']
+
+ # Get key and lmdbs.
+ keys, lmdbs = {}, {}
+ for data_type in self.dataset_data_types:
+ keys[data_type] = self._create_sequence_keys(
+ sequence_name, filenames)
+ lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+ # Load all data for this index.
+ data = self.load_from_dataset(keys, lmdbs)
+
+ # Get frames from video.
+ try:
+ temp = tempfile.NamedTemporaryFile()
+ temp.write(data['videos'][0])
+ temp.seek(0)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ # frames, _, info = io.read_video(temp)
+ # num_frames = frames.size(0)
+ cap = cv2.VideoCapture(temp.name)
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ if self.first_last_only:
+ chosen_idxs = [0, num_frames - 1]
+ else:
+ # chosen_idxs = random.sample(range(frames.size(0)), 2)
+
+ chosen_idx = random.sample(range(num_frames), 1)[0]
+ few_shot_choose_range = list(range(chosen_idx)) + list(range(chosen_idx + 1, num_frames))
+ if self.sample_far_frames_more:
+ choose_weight = list(reversed(range(chosen_idx))) + list(range(num_frames - chosen_idx - 1))
+ few_shot_idx = random.choices(few_shot_choose_range, choose_weight, k=self.few_shot_K)
+ else:
+ few_shot_idx = random.sample(few_shot_choose_range, k=self.few_shot_K)
+ chosen_idxs = few_shot_idx + [chosen_idx]
+
+ chosen_images = []
+ for idx in chosen_idxs:
+ # chosen_images.append(Image.fromarray(frames[idx].numpy()))
+ cap.set(1, idx)
+ _, frame = cap.read()
+ chosen_images.append(Image.fromarray(frame[:, :, ::-1]))
+ except Exception:
+ print('Issue with file:', sequence_name, filenames)
+ blank = np.zeros((512, 512, 3), dtype=np.uint8)
+ chosen_images = [Image.fromarray(blank), Image.fromarray(blank)]
+
+ data['videos'] = chosen_images
+
+ # Apply ops pre augmentation.
+ data = self.apply_ops(data, self.pre_aug_ops)
+
+ # Do augmentations for images.
+ data, is_flipped = self.perform_augmentation(
+ data, paired=True, augment_ops=self.augmentor.augment_ops)
+ # Individual video frame augmentation is used in face-vid2vid.
+ data = self.perform_individual_video_frame(
+ data, self.augmentor.individual_video_frame_augmentation_ops)
+
+ # Apply ops post augmentation.
+ data = self.apply_ops(data, self.post_aug_ops)
+
+ # Convert images to tensor.
+ data = self.to_tensor(data)
+
+ # Pack the sequence of images.
+ for data_type in self.image_data_types:
+ for idx in range(len(data[data_type])):
+ data[data_type][idx] = data[data_type][idx].unsqueeze(0)
+ data[data_type] = torch.cat(data[data_type], dim=0)
+
+ if not self.is_video_dataset:
+ # Remove any extra dimensions.
+ for data_type in self.image_data_types:
+ if data_type in data:
+ data[data_type] = data[data_type].squeeze(0)
+
+ # Prepare output.
+ data['driving_images'] = data['videos'][self.few_shot_K:]
+ data['source_images'] = data['videos'][:self.few_shot_K]
+ data.pop('videos')
+ data['is_flipped'] = is_flipped
+ data['key'] = keys
+ data['original_h_w'] = torch.IntTensor([
+ self.augmentor.original_h, self.augmentor.original_w])
+
+ # Apply full data ops.
+ data = self.apply_ops(data, self.full_data_ops, full_data=True)
+
+ return data
+
+ def __getitem__(self, index):
+ return self._getitem(index)
diff --git a/imaginaire/datasets/paired_images.py b/imaginaire/datasets/paired_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3112a68de57e7c85d031ee690b62ac1fbc4f96d
--- /dev/null
+++ b/imaginaire/datasets/paired_images.py
@@ -0,0 +1,87 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+from imaginaire.datasets.paired_videos import Dataset as VideoDataset
+
+
+class Dataset(VideoDataset):
+ r"""Paired image dataset for use in pix2pixHD, SPADE.
+
+ Args:
+ cfg (Config): Loaded config object.
+ is_inference (bool): In train or inference mode?
+ """
+
+ def __init__(self, cfg, is_inference=False, is_test=False):
+ self.paired = True
+ super(Dataset, self).__init__(cfg, is_inference,
+ sequence_length=1,
+ is_test=is_test)
+ self.is_video_dataset = False
+
+ def _create_mapping(self):
+ r"""Creates mapping from idx to key in LMDB.
+
+ Returns:
+ (tuple):
+ - self.mapping (list): List mapping idx to key.
+ - self.epoch_length (int): Number of samples in an epoch.
+ """
+ idx_to_key = []
+ for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+ for sequence_name, filenames in sequence_list.items():
+ for filename in filenames:
+ idx_to_key.append({
+ 'lmdb_root': self.lmdb_roots[lmdb_idx],
+ 'lmdb_idx': lmdb_idx,
+ 'sequence_name': sequence_name,
+ 'filenames': [filename],
+ })
+ self.mapping = idx_to_key
+ self.epoch_length = len(self.mapping)
+ return self.mapping, self.epoch_length
+
+ def _sample_keys(self, index):
+ r"""Gets files to load for this sample.
+
+ Args:
+ index (int): Index in [0, len(dataset)].
+ Returns:
+ key (dict):
+ - lmdb_idx (int): Chosen LMDB dataset root.
+ - sequence_name (str): Chosen sequence in chosen dataset.
+ - filenames (list of str): Chosen filenames in chosen sequence.
+ """
+ assert self.sequence_length == 1, \
+ 'Image dataset can only have sequence length = 1, not %d' % (
+ self.sequence_length)
+ return self.mapping[index]
+
+ def set_sequence_length(self, sequence_length):
+ r"""Set the length of sequence you want as output from dataloader.
+ Ignore this as this is an image loader.
+
+ Args:
+ sequence_length (int): Length of output sequences.
+ """
+ pass
+
+ def set_inference_sequence_idx(self, index):
+ r"""Get frames from this sequence during inference.
+ Overriden from super as this is not applicable for images.
+
+ Args:
+ index (int): Index of inference sequence.
+ """
+ raise RuntimeError('Image dataset does not have sequences.')
+
+ def num_inference_sequences(self):
+ r"""Number of sequences available for inference.
+ Overriden from super as this is not applicable for images.
+
+ Returns:
+ (int)
+ """
+ raise RuntimeError('Image dataset does not have sequences.')
diff --git a/imaginaire/datasets/paired_videos.py b/imaginaire/datasets/paired_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..38e5c645f9de00184fa6cf7cc0b6b910c5815417
--- /dev/null
+++ b/imaginaire/datasets/paired_videos.py
@@ -0,0 +1,288 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import copy
+import random
+from collections import OrderedDict
+
+import torch
+
+from imaginaire.datasets.base import BaseDataset
+from imaginaire.model_utils.fs_vid2vid import select_object
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Dataset(BaseDataset):
+ r"""Paired video dataset for use in vid2vid, wc_vid2vid.
+
+ Args:
+ cfg (Config): Loaded config object.
+ is_inference (bool): In train or inference mode?
+ sequence_length (int): What sequence of images to provide?
+ """
+
+ def __init__(self, cfg,
+ is_inference=False,
+ sequence_length=None,
+ is_test=False):
+ self.paired = True
+ # Get initial sequence length.
+ if sequence_length is None and not is_inference:
+ self.sequence_length = cfg.data.train.initial_sequence_length
+ elif sequence_length is None and is_inference:
+ self.sequence_length = 2
+ else:
+ self.sequence_length = sequence_length
+ super(Dataset, self).__init__(cfg, is_inference, is_test)
+ self.set_sequence_length(self.sequence_length)
+ self.is_video_dataset = True
+
+ def get_label_lengths(self):
+ r"""Get num channels of all labels to be concated.
+
+ Returns:
+ label_lengths (OrderedDict): Dict mapping image data_type to num
+ channels.
+ """
+ label_lengths = OrderedDict()
+ for data_type in self.input_labels:
+ data_cfg = self.cfgdata
+ if hasattr(data_cfg, 'one_hot_num_classes') and data_type in data_cfg.one_hot_num_classes:
+ label_lengths[data_type] = data_cfg.one_hot_num_classes[data_type]
+ if getattr(data_cfg, 'use_dont_care', False):
+ label_lengths[data_type] += 1
+ else:
+ label_lengths[data_type] = self.num_channels[data_type]
+ return label_lengths
+
+ def num_inference_sequences(self):
+ r"""Number of sequences available for inference.
+
+ Returns:
+ (int)
+ """
+ assert self.is_inference
+ return len(self.mapping)
+
+ def set_inference_sequence_idx(self, index):
+ r"""Get frames from this sequence during inference.
+
+ Args:
+ index (int): Index of inference sequence.
+ """
+ assert self.is_inference
+ assert index < len(self.mapping)
+ self.inference_sequence_idx = index
+ self.epoch_length = len(
+ self.mapping[self.inference_sequence_idx]['filenames'])
+
+ def set_sequence_length(self, sequence_length):
+ r"""Set the length of sequence you want as output from dataloader.
+
+ Args:
+ sequence_length (int): Length of output sequences.
+ """
+ assert isinstance(sequence_length, int)
+ if sequence_length > self.sequence_length_max:
+ print('Requested sequence length (%d) > ' % (sequence_length) +
+ 'max sequence length (%d). ' % (self.sequence_length_max) +
+ 'Limiting sequence length to max sequence length.')
+ sequence_length = self.sequence_length_max
+ self.sequence_length = sequence_length
+ # Recalculate mapping as some sequences might no longer be useful.
+ self.mapping, self.epoch_length = self._create_mapping()
+ print('Epoch length:', self.epoch_length)
+
+ def _compute_dataset_stats(self):
+ r"""Compute statistics of video sequence dataset.
+
+ Returns:
+ sequence_length_max (int): Maximum sequence length.
+ """
+ print('Num datasets:', len(self.sequence_lists))
+
+ if self.sequence_length >= 1:
+ num_sequences, sequence_length_max = 0, 0
+ for sequence in self.sequence_lists:
+ for _, filenames in sequence.items():
+ sequence_length_max = max(
+ sequence_length_max, len(filenames))
+ num_sequences += 1
+ print('Num sequences:', num_sequences)
+ print('Max sequence length:', sequence_length_max)
+ self.sequence_length_max = sequence_length_max
+
+ def _create_mapping(self):
+ r"""Creates mapping from idx to key in LMDB.
+
+ Returns:
+ (tuple):
+ - self.mapping (dict): Dict of seq_len to list of sequences.
+ - self.epoch_length (int): Number of samples in an epoch.
+ """
+ # Create dict mapping length to sequence.
+ length_to_key, num_selected_seq = {}, 0
+ total_num_of_frames = 0
+ for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+ for sequence_name, filenames in sequence_list.items():
+ if len(filenames) >= self.sequence_length:
+ total_num_of_frames += len(filenames)
+ if len(filenames) not in length_to_key:
+ length_to_key[len(filenames)] = []
+ length_to_key[len(filenames)].append({
+ 'lmdb_root': self.lmdb_roots[lmdb_idx],
+ 'lmdb_idx': lmdb_idx,
+ 'sequence_name': sequence_name,
+ 'filenames': filenames,
+ })
+ num_selected_seq += 1
+ self.mapping = length_to_key
+ self.epoch_length = num_selected_seq
+ if not self.is_inference and self.epoch_length < \
+ self.cfgdata.train.batch_size * 8:
+ self.epoch_length = total_num_of_frames
+
+ # At inference time, we want to use all sequences,
+ # irrespective of length.
+ if self.is_inference:
+ sequence_list = []
+ for key, sequences in self.mapping.items():
+ sequence_list.extend(sequences)
+ self.mapping = sequence_list
+
+ return self.mapping, self.epoch_length
+
+ def _sample_keys(self, index):
+ r"""Gets files to load for this sample.
+
+ Args:
+ index (int): Index in [0, len(dataset)].
+ Returns:
+ key (dict):
+ - lmdb_idx (int): Chosen LMDB dataset root.
+ - sequence_name (str): Chosen sequence in chosen dataset.
+ - filenames (list of str): Chosen filenames in chosen sequence.
+ """
+ if self.is_inference:
+ assert index < self.epoch_length
+ chosen_sequence = self.mapping[self.inference_sequence_idx]
+ chosen_filenames = [chosen_sequence['filenames'][index]]
+ else:
+ # Pick a time step for temporal augmentation.
+ time_step = random.randint(1, self.augmentor.max_time_step)
+ required_sequence_length = 1 + \
+ (self.sequence_length - 1) * time_step
+
+ # If step is too large, default to step size of 1.
+ if required_sequence_length > self.sequence_length_max:
+ required_sequence_length = self.sequence_length
+ time_step = 1
+
+ # Find valid sequences.
+ valid_sequences = []
+ for sequence_length, sequences in self.mapping.items():
+ if sequence_length >= required_sequence_length:
+ valid_sequences.extend(sequences)
+
+ # Pick a sequence.
+ chosen_sequence = random.choice(valid_sequences)
+
+ # Choose filenames.
+ max_start_idx = len(chosen_sequence['filenames']) - \
+ required_sequence_length
+ start_idx = random.randint(0, max_start_idx)
+
+ chosen_filenames = chosen_sequence['filenames'][
+ start_idx:start_idx + required_sequence_length:time_step]
+ assert len(chosen_filenames) == self.sequence_length
+
+ # Prepre output key.
+ key = copy.deepcopy(chosen_sequence)
+ key['filenames'] = chosen_filenames
+ return key
+
+ def _create_sequence_keys(self, sequence_name, filenames):
+ r"""Create the LMDB key for this piece of information.
+
+ Args:
+ sequence_name (str): Which sequence from the chosen dataset.
+ filenames (list of str): List of filenames in this sequence.
+ Returns:
+ keys (list): List of full keys.
+ """
+ assert isinstance(filenames, list), 'Filenames should be a list.'
+ keys = []
+ if sequence_name.endswith('___') and sequence_name[-9:-6] == '___':
+ sequence_name = sequence_name[:-9]
+ for filename in filenames:
+ keys.append('%s/%s' % (sequence_name, filename))
+ return keys
+
+ def _getitem(self, index):
+ r"""Gets selected files.
+
+ Args:
+ index (int): Index into dataset.
+ concat (bool): Concatenate all items in labels?
+ Returns:
+ data (dict): Dict with all chosen data_types.
+ """
+ # Select a sample from the available data.
+ keys = self._sample_keys(index)
+
+ # Unpack keys.
+ lmdb_idx = keys['lmdb_idx']
+ sequence_name = keys['sequence_name']
+ filenames = keys['filenames']
+
+ # Get key and lmdbs.
+ keys, lmdbs = {}, {}
+ for data_type in self.dataset_data_types:
+ keys[data_type] = self._create_sequence_keys(
+ sequence_name, filenames)
+ lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+ # Load all data for this index.
+ data = self.load_from_dataset(keys, lmdbs)
+
+ # Apply ops pre augmentation.
+ data = self.apply_ops(data, self.pre_aug_ops)
+
+ # If multiple subjects exist in the data, only pick one to synthesize.
+ data = select_object(data, obj_indices=None)
+
+ # Do augmentations for images.
+ data, is_flipped = self.perform_augmentation(data, paired=True, augment_ops=self.augmentor.augment_ops)
+
+ # Apply ops post augmentation.
+ data = self.apply_ops(data, self.post_aug_ops)
+ data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True)
+
+ # Convert images to tensor.
+ data = self.to_tensor(data)
+
+ # Pack the sequence of images.
+ for data_type in self.image_data_types + self.hdr_image_data_types:
+ for idx in range(len(data[data_type])):
+ data[data_type][idx] = data[data_type][idx].unsqueeze(0)
+ data[data_type] = torch.cat(data[data_type], dim=0)
+
+ if not self.is_video_dataset:
+ # Remove any extra dimensions.
+ for data_type in self.data_types:
+ if data_type in data:
+ data[data_type] = data[data_type].squeeze(0)
+
+ data['is_flipped'] = is_flipped
+ data['key'] = keys
+ data['original_h_w'] = torch.IntTensor([
+ self.augmentor.original_h, self.augmentor.original_w])
+
+ # Apply full data ops.
+ data = self.apply_ops(data, self.full_data_ops, full_data=True)
+
+ return data
+
+ def __getitem__(self, index):
+ return self._getitem(index)
diff --git a/imaginaire/datasets/unpaired_few_shot_images.py b/imaginaire/datasets/unpaired_few_shot_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8fa1e5ec2cd528c72effd065338fa693d1319c8
--- /dev/null
+++ b/imaginaire/datasets/unpaired_few_shot_images.py
@@ -0,0 +1,182 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import random
+
+from imaginaire.datasets.base import BaseDataset
+
+
+class Dataset(BaseDataset):
+ r"""Image dataset for use in FUNIT.
+
+ Args:
+ cfg (Config): Loaded config object.
+ is_inference (bool): In train or inference mode?
+ """
+
+ def __init__(self, cfg, is_inference=False, is_test=False):
+ self.paired = False
+ super(Dataset, self).__init__(cfg, is_inference, is_test)
+ self.num_content_classes = len(self.class_name_to_idx['images_content'])
+ self.num_style_classes = len(self.class_name_to_idx['images_style'])
+ self.sample_class_idx = None
+ self.content_offset = 8888
+ self.content_interval = 100
+
+ def set_sample_class_idx(self, class_idx=None):
+ r"""Set sample class idx.
+
+ Args:
+ class_idx (int): Which class idx to sample from.
+ """
+ self.sample_class_idx = class_idx
+ if class_idx is None:
+ self.epoch_length = \
+ max([len(lmdb_keys) for _, lmdb_keys in self.mapping.items()])
+ else:
+ self.epoch_length = \
+ len(self.mapping_class['images_style'][class_idx])
+
+ def _create_mapping(self):
+ r"""Creates mapping from idx to key in LMDB.
+
+ Returns:
+ (tuple):
+ - self.mapping (dict): Dict with data type as key mapping idx to
+ LMDB key.
+ - self.epoch_length (int): Number of samples in an epoch.
+ """
+ idx_to_key, class_names = {}, {}
+ for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+ for data_type, data_type_sequence_list in sequence_list.items():
+ class_names[data_type] = []
+ if data_type not in idx_to_key:
+ idx_to_key[data_type] = []
+ for sequence_name, filenames in data_type_sequence_list.items():
+ class_name = sequence_name.split('/')[0]
+ for filename in filenames:
+ idx_to_key[data_type].append({
+ 'lmdb_root': self.lmdb_roots[lmdb_idx],
+ 'lmdb_idx': lmdb_idx,
+ 'sequence_name': sequence_name,
+ 'filename': filename,
+ 'class_name': class_name
+ })
+ class_names[data_type].append(class_name)
+ self.mapping = idx_to_key
+ self.epoch_length = max([len(lmdb_keys)
+ for _, lmdb_keys in self.mapping.items()])
+
+ # Create mapping from class name to class idx.
+ self.class_name_to_idx = {}
+ for data_type, class_names_data_type in class_names.items():
+ self.class_name_to_idx[data_type] = {}
+ class_names_data_type = sorted(list(set(class_names_data_type)))
+ for class_idx, class_name in enumerate(class_names_data_type):
+ self.class_name_to_idx[data_type][class_name] = class_idx
+
+ # Add class idx to mapping.
+ for data_type in self.mapping:
+ for key in self.mapping[data_type]:
+ key['class_idx'] = \
+ self.class_name_to_idx[data_type][key['class_name']]
+
+ # Create a mapping from index to lmdb key for each class.
+ idx_to_key_class = {}
+ for data_type in self.mapping:
+ idx_to_key_class[data_type] = {}
+ for class_idx, class_name in enumerate(class_names[data_type]):
+ idx_to_key_class[data_type][class_idx] = []
+ for key in self.mapping[data_type]:
+ idx_to_key_class[data_type][key['class_idx']].append(key)
+ self.mapping_class = idx_to_key_class
+
+ return self.mapping, self.epoch_length
+
+ def _sample_keys(self, index):
+ r"""Gets files to load for this sample.
+
+ Args:
+ index (int): Index in [0, len(dataset)].
+ Returns:
+ (tuple):
+ - keys (dict): Each key of this dict is a data type.
+ - lmdb_key (dict):
+ - lmdb_idx (int): Chosen LMDB dataset root.
+ - sequence_name (str): Chosen sequence in chosen dataset.
+ - filename (str): Chosen filename in chosen sequence.
+ """
+
+ keys = {}
+ if self.is_inference: # evaluation mode
+ lmdb_keys_content = self.mapping['images_content']
+ keys['images_content'] = \
+ lmdb_keys_content[
+ ((index + self.content_offset * self.sample_class_idx) *
+ self.content_interval) % len(lmdb_keys_content)]
+
+ lmdb_keys_style = \
+ self.mapping_class['images_style'][self.sample_class_idx]
+ keys['images_style'] = lmdb_keys_style[index]
+ else:
+ lmdb_keys_content = self.mapping['images_content']
+ lmdb_keys_style = self.mapping['images_style']
+ keys['images_content'] = random.choice(lmdb_keys_content)
+ keys['images_style'] = random.choice(lmdb_keys_style)
+ return keys
+
+ def __getitem__(self, index):
+ r"""Gets selected files.
+
+ Args:
+ index (int): Index into dataset.
+ concat (bool): Concatenate all items in labels?
+ Returns:
+ data (dict): Dict with all chosen data_types.
+ """
+ # Select a sample from the available data.
+ keys_per_data_type = self._sample_keys(index)
+
+ # Get class idx into a list.
+ class_idxs = []
+ for data_type in keys_per_data_type:
+ class_idxs.append(keys_per_data_type[data_type]['class_idx'])
+
+ # Get keys and lmdbs.
+ keys, lmdbs = {}, {}
+ for data_type in self.dataset_data_types:
+ # Unpack keys.
+ lmdb_idx = keys_per_data_type[data_type]['lmdb_idx']
+ sequence_name = keys_per_data_type[data_type]['sequence_name']
+ filename = keys_per_data_type[data_type]['filename']
+ keys[data_type] = '%s/%s' % (sequence_name, filename)
+ lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+ # Load all data for this index.
+ data = self.load_from_dataset(keys, lmdbs)
+
+ # Apply ops pre augmentation.
+ data = self.apply_ops(data, self.pre_aug_ops)
+
+ # Do augmentations for images.
+ data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops)
+
+ # Apply ops post augmentation.
+ data = self.apply_ops(data, self.post_aug_ops)
+ data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True)
+
+ # Convert images to tensor.
+ data = self.to_tensor(data)
+
+ # Remove any extra dimensions.
+ for data_type in self.image_data_types:
+ data[data_type] = data[data_type][0]
+
+ # Package output.
+ data['is_flipped'] = is_flipped
+ data['key'] = keys_per_data_type
+ data['labels_content'] = class_idxs[0]
+ data['labels_style'] = class_idxs[1]
+
+ return data
diff --git a/imaginaire/datasets/unpaired_images.py b/imaginaire/datasets/unpaired_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..68a49a876705771a2d7a3f836bdbee2cdd328c10
--- /dev/null
+++ b/imaginaire/datasets/unpaired_images.py
@@ -0,0 +1,118 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import random
+
+from imaginaire.datasets.base import BaseDataset
+
+
+class Dataset(BaseDataset):
+ r"""Unpaired image dataset for use in MUNIT.
+
+ Args:
+ cfg (Config): Loaded config object.
+ is_inference (bool): In train or inference mode?
+ """
+
+ def __init__(self, cfg, is_inference=False, is_test=False):
+ self.paired = False
+ super(Dataset, self).__init__(cfg, is_inference, is_test)
+
+ def _create_mapping(self):
+ r"""Creates mapping from idx to key in LMDB.
+
+ Returns:
+ (tuple):
+ - self.mapping (dict): Dict with data type as key mapping idx to
+ LMDB key.
+ - self.epoch_length (int): Number of samples in an epoch.
+ """
+ idx_to_key = {}
+ for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
+ for data_type, data_type_sequence_list in sequence_list.items():
+ if data_type not in idx_to_key:
+ idx_to_key[data_type] = []
+ for sequence_name, filenames in data_type_sequence_list.items():
+ for filename in filenames:
+ idx_to_key[data_type].append({
+ 'lmdb_root': self.lmdb_roots[lmdb_idx],
+ 'lmdb_idx': lmdb_idx,
+ 'sequence_name': sequence_name,
+ 'filename': filename,
+ })
+ self.mapping = idx_to_key
+ self.epoch_length = max([len(lmdb_keys)
+ for _, lmdb_keys in self.mapping.items()])
+ return self.mapping, self.epoch_length
+
+ def _sample_keys(self, index):
+ r"""Gets files to load for this sample.
+
+ Args:
+ index (int): Index in [0, len(dataset)].
+ Returns:
+ keys (dict): Each key of this dict is a data type.
+ lmdb_key (dict):
+ lmdb_idx (int): Chosen LMDB dataset root.
+ sequence_name (str): Chosen sequence in chosen dataset.
+ filename (str): Chosen filename in chosen sequence.
+ """
+ keys = {}
+ for data_type in self.data_types:
+ lmdb_keys = self.mapping[data_type]
+ if self.is_inference:
+ # Modulo ensures valid indexing in case A and B have different
+ # number of files.
+ keys[data_type] = lmdb_keys[index % len(lmdb_keys)]
+ else:
+ keys[data_type] = random.choice(lmdb_keys)
+ return keys
+
+ def __getitem__(self, index):
+ r"""Gets selected files.
+
+ Args:
+ index (int): Index into dataset.
+ concat (bool): Concatenate all items in labels?
+ Returns:
+ data (dict): Dict with all chosen data_types.
+ """
+ # Select a sample from the available data.
+ keys_per_data_type = self._sample_keys(index)
+
+ # Get keys and lmdbs.
+ keys, lmdbs = {}, {}
+ for data_type in self.dataset_data_types:
+ # Unpack keys.
+ lmdb_idx = keys_per_data_type[data_type]['lmdb_idx']
+ sequence_name = keys_per_data_type[data_type]['sequence_name']
+ filename = keys_per_data_type[data_type]['filename']
+ keys[data_type] = '%s/%s' % (sequence_name, filename)
+ lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]
+
+ # Load all data for this index.
+ data = self.load_from_dataset(keys, lmdbs)
+
+ # Apply ops pre augmentation.
+ data = self.apply_ops(data, self.pre_aug_ops)
+
+ # Do augmentations for images.
+ data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops)
+
+ # Apply ops post augmentation.
+ data = self.apply_ops(data, self.post_aug_ops)
+ data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True)
+
+ # Convert images to tensor.
+ data = self.to_tensor(data)
+
+ # Remove any extra dimensions.
+ for data_type in self.image_data_types:
+ data[data_type] = data[data_type][0]
+
+ # Package output.
+ data['is_flipped'] = is_flipped
+ data['key'] = keys_per_data_type
+
+ return data
diff --git a/imaginaire/discriminators/__init__.py b/imaginaire/discriminators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/discriminators/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/discriminators/dummy.py b/imaginaire/discriminators/dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..a345806f6f844fbcf0c9da1915f4db5b2fa3d587
--- /dev/null
+++ b/imaginaire/discriminators/dummy.py
@@ -0,0 +1,29 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch.nn as nn
+
+from imaginaire.layers import LinearBlock
+
+
+class Discriminator(nn.Module):
+ """Dummy Discriminator constructor.
+
+ Args:
+ dis_cfg (obj): Discriminator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file
+ """
+
+ def __init__(self, dis_cfg, data_cfg):
+ super(Discriminator, self).__init__()
+ self.dummy_layer = LinearBlock(1, 1)
+ pass
+
+ def forward(self, data):
+ """Dummy discriminator forward.
+
+ Args:
+ data (dict):
+ """
+ return
diff --git a/imaginaire/discriminators/fpse.py b/imaginaire/discriminators/fpse.py
new file mode 100644
index 0000000000000000000000000000000000000000..231b666bfd3970c760df7ce5d1174193ad9a7708
--- /dev/null
+++ b/imaginaire/discriminators/fpse.py
@@ -0,0 +1,132 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.layers import Conv2dBlock
+
+
+class FPSEDiscriminator(nn.Module):
+ r"""# Feature-Pyramid Semantics Embedding Discriminator. This is a copy
+ of the discriminator in https://arxiv.org/pdf/1910.06809.pdf
+ """
+
+ def __init__(self,
+ num_input_channels,
+ num_labels,
+ num_filters,
+ kernel_size,
+ weight_norm_type,
+ activation_norm_type):
+ super().__init__()
+ padding = int(np.ceil((kernel_size - 1.0) / 2))
+ nonlinearity = 'leakyrelu'
+ stride1_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity,
+ # inplace_nonlinearity=True,
+ order='CNA')
+ down_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity,
+ # inplace_nonlinearity=True,
+ order='CNA')
+ latent_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=1,
+ stride=1,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity,
+ # inplace_nonlinearity=True,
+ order='CNA')
+ # bottom-up pathway
+
+ self.enc1 = down_conv2d_block(num_input_channels, num_filters)
+ self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters)
+ self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters)
+ self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters)
+ self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters)
+
+ # top-down pathway
+ self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
+ self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
+ self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
+ self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
+
+ # upsampling
+ self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear',
+ align_corners=False)
+
+ # final layers
+ self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+ self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+ self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+
+ # true/false prediction and semantic alignment prediction
+ self.output = Conv2dBlock(num_filters * 2, 1, kernel_size=1)
+ self.seg = Conv2dBlock(num_filters * 2, num_filters * 2, kernel_size=1)
+ self.embedding = Conv2dBlock(num_labels, num_filters * 2, kernel_size=1)
+
+ def forward(self, images, segmaps):
+ r"""
+
+ Args:
+ images: image tensors.
+ segmaps: segmentation map tensors.
+ """
+ # bottom-up pathway
+ feat11 = self.enc1(images)
+ feat12 = self.enc2(feat11)
+ feat13 = self.enc3(feat12)
+ feat14 = self.enc4(feat13)
+ feat15 = self.enc5(feat14)
+ # top-down pathway and lateral connections
+ feat25 = self.lat5(feat15)
+ feat24 = self.upsample2x(feat25) + self.lat4(feat14)
+ feat23 = self.upsample2x(feat24) + self.lat3(feat13)
+ feat22 = self.upsample2x(feat23) + self.lat2(feat12)
+ # final prediction layers
+ feat32 = self.final2(feat22)
+ feat33 = self.final3(feat23)
+ feat34 = self.final4(feat24)
+ # Patch-based True/False prediction
+ pred2 = self.output(feat32)
+ pred3 = self.output(feat33)
+ pred4 = self.output(feat34)
+ seg2 = self.seg(feat32)
+ seg3 = self.seg(feat33)
+ seg4 = self.seg(feat34)
+
+ # # segmentation map embedding
+ segembs = self.embedding(segmaps)
+ segembs = F.avg_pool2d(segembs, kernel_size=2, stride=2)
+ segembs2 = F.avg_pool2d(segembs, kernel_size=2, stride=2)
+ segembs3 = F.avg_pool2d(segembs2, kernel_size=2, stride=2)
+ segembs4 = F.avg_pool2d(segembs3, kernel_size=2, stride=2)
+
+ # semantics embedding discriminator score
+ pred2 += torch.mul(segembs2, seg2).sum(dim=1, keepdim=True)
+ pred3 += torch.mul(segembs3, seg3).sum(dim=1, keepdim=True)
+ pred4 += torch.mul(segembs4, seg4).sum(dim=1, keepdim=True)
+
+ # concat results from multiple resolutions
+ # results = [pred2, pred3, pred4]
+
+ return pred2, pred3, pred4
diff --git a/imaginaire/discriminators/fs_vid2vid.py b/imaginaire/discriminators/fs_vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..78e29f64a301864f03238db7bd1ad444bee9773e
--- /dev/null
+++ b/imaginaire/discriminators/fs_vid2vid.py
@@ -0,0 +1,318 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import importlib
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator
+from imaginaire.model_utils.fs_vid2vid import get_fg_mask, pick_image
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+ get_paired_input_label_channel_number)
+from imaginaire.utils.misc import get_nested_attr
+
+
+class Discriminator(nn.Module):
+ r"""Image and video discriminator constructor.
+
+ Args:
+ dis_cfg (obj): Discriminator part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file
+ """
+
+ def __init__(self, dis_cfg, data_cfg):
+ super().__init__()
+ self.data_cfg = data_cfg
+ num_input_channels = get_paired_input_label_channel_number(data_cfg)
+ if num_input_channels == 0:
+ num_input_channels = getattr(data_cfg, 'label_channels', 1)
+ num_img_channels = get_paired_input_image_channel_number(data_cfg)
+ self.num_frames_D = data_cfg.num_frames_D
+ self.num_scales = get_nested_attr(dis_cfg, 'temporal.num_scales', 0)
+ num_netD_input_channels = (num_input_channels + num_img_channels)
+ self.use_few_shot = 'few_shot' in data_cfg.type
+ if self.use_few_shot:
+ num_netD_input_channels *= 2
+ self.net_D = MultiPatchDiscriminator(dis_cfg.image,
+ num_netD_input_channels)
+
+ self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None)
+ if self.add_dis_cfg is not None:
+ for name in self.add_dis_cfg:
+ add_dis_cfg = self.add_dis_cfg[name]
+ num_ch = num_img_channels * (2 if self.use_few_shot else 1)
+ setattr(self, 'net_D_' + name,
+ MultiPatchDiscriminator(add_dis_cfg, num_ch))
+
+ # Temporal discriminator.
+ self.num_netDT_input_channels = num_img_channels * self.num_frames_D
+ for n in range(self.num_scales):
+ setattr(self, 'net_DT%d' % n,
+ MultiPatchDiscriminator(dis_cfg.temporal,
+ self.num_netDT_input_channels))
+ self.has_fg = getattr(data_cfg, 'has_foreground', False)
+
+ def forward(self, data, net_G_output, past_frames):
+ r"""Discriminator forward.
+
+ Args:
+ data (dict): Input data.
+ net_G_output (dict): Generator output.
+ past_frames (list of tensors): Past real frames / generator outputs.
+ Returns:
+ (tuple):
+ - output (dict): Discriminator output.
+ - past_frames (list of tensors): New past frames by adding
+ current outputs.
+ """
+ label, real_image = data['label'], data['image']
+ # Only operate on the latest output frame.
+ if label.dim() == 5:
+ label = label[:, -1]
+ if self.use_few_shot:
+ # Pick only one reference image to concat with.
+ ref_idx = net_G_output['ref_idx'] \
+ if 'ref_idx' in net_G_output else 0
+ ref_label = pick_image(data['ref_labels'], ref_idx)
+ ref_image = pick_image(data['ref_images'], ref_idx)
+ # Concat references with label map as discriminator input.
+ label = torch.cat([label, ref_label, ref_image], dim=1)
+ fake_image = net_G_output['fake_images']
+ output = dict()
+
+ # Individual frame loss.
+ pred_real, pred_fake = self.discrminate_image(self.net_D, label,
+ real_image, fake_image)
+ output['indv'] = dict()
+ output['indv']['pred_real'] = pred_real
+ output['indv']['pred_fake'] = pred_fake
+
+ if 'fake_raw_images' in net_G_output and \
+ net_G_output['fake_raw_images'] is not None:
+ # Raw generator output loss.
+ fake_raw_image = net_G_output['fake_raw_images']
+ fg_mask = get_fg_mask(data['label'], self.has_fg)
+ pred_real, pred_fake = self.discrminate_image(
+ self.net_D, label,
+ real_image * fg_mask,
+ fake_raw_image * fg_mask)
+ output['raw'] = dict()
+ output['raw']['pred_real'] = pred_real
+ output['raw']['pred_fake'] = pred_fake
+
+ # Additional GAN loss on specific regions.
+ if self.add_dis_cfg is not None:
+ for name in self.add_dis_cfg:
+ # Crop corresponding regions in the image according to the
+ # crop function.
+ add_dis_cfg = self.add_dis_cfg[name]
+ file, crop_func = add_dis_cfg.crop_func.split('::')
+ file = importlib.import_module(file)
+ crop_func = getattr(file, crop_func)
+
+ real_crop = crop_func(self.data_cfg, real_image, label)
+ fake_crop = crop_func(self.data_cfg, fake_image, label)
+ if self.use_few_shot:
+ ref_crop = crop_func(self.data_cfg, ref_image, label)
+ if ref_crop is not None:
+ real_crop = torch.cat([real_crop, ref_crop], dim=1)
+ fake_crop = torch.cat([fake_crop, ref_crop], dim=1)
+
+ # Feed the crops to specific discriminator.
+ if fake_crop is not None:
+ net_D = getattr(self, 'net_D_' + name)
+ pred_real, pred_fake = \
+ self.discrminate_image(net_D, None,
+ real_crop, fake_crop)
+ else:
+ pred_real = pred_fake = None
+ output[name] = dict()
+ output[name]['pred_real'] = pred_real
+ output[name]['pred_fake'] = pred_fake
+
+ # Temporal loss.
+ past_frames, skipped_frames = \
+ get_all_skipped_frames(past_frames, [real_image, fake_image],
+ self.num_scales, self.num_frames_D)
+
+ for scale in range(self.num_scales):
+ real_image, fake_image = \
+ [skipped_frame[scale] for skipped_frame in skipped_frames]
+ pred_real, pred_fake = self.discriminate_video(real_image,
+ fake_image, scale)
+ output['temporal_%d' % scale] = dict()
+ output['temporal_%d' % scale]['pred_real'] = pred_real
+ output['temporal_%d' % scale]['pred_fake'] = pred_fake
+
+ return output, past_frames
+
+ def discrminate_image(self, net_D, real_A, real_B, fake_B):
+ r"""Discriminate individual images.
+
+ Args:
+ net_D (obj): Discriminator network.
+ real_A (NxC1xHxW tensor): Input label map.
+ real_B (NxC2xHxW tensor): Real image.
+ fake_B (NxC2xHxW tensor): Fake image.
+ Returns:
+ (tuple):
+ - pred_real (NxC3xH2xW2 tensor): Output of net_D for real images.
+ - pred_fake (NxC3xH2xW2 tensor): Output of net_D for fake images.
+ """
+ if real_A is not None:
+ real_AB = torch.cat([real_A, real_B], dim=1)
+ fake_AB = torch.cat([real_A, fake_B], dim=1)
+ else:
+ real_AB, fake_AB = real_B, fake_B
+
+ pred_real = net_D.forward(real_AB)
+ pred_fake = net_D.forward(fake_AB)
+ return pred_real, pred_fake
+
+ def discriminate_video(self, real_B, fake_B, scale):
+ r"""Discriminate a sequence of images.
+
+ Args:
+ real_B (NxCxHxW tensor): Real image.
+ fake_B (NxCxHxW tensor): Fake image.
+ scale (int): Temporal scale.
+ Returns:
+ (tuple):
+ - pred_real (NxC2xH2xW2 tensor): Output of net_D for real images.
+ - pred_fake (NxC2xH2xW2 tensor): Output of net_D for fake images.
+ """
+ if real_B is None:
+ return None, None
+ net_DT = getattr(self, 'net_DT%d' % scale)
+ height, width = real_B.shape[-2:]
+ real_B = real_B.view(-1, self.num_netDT_input_channels, height, width)
+ fake_B = fake_B.view(-1, self.num_netDT_input_channels, height, width)
+
+ pred_real = net_DT.forward(real_B)
+ pred_fake = net_DT.forward(fake_B)
+ return pred_real, pred_fake
+
+
+def get_all_skipped_frames(past_frames, new_frames, t_scales, tD):
+ r"""Get temporally skipped frames from the input frames.
+
+ Args:
+ past_frames (list of tensors): Past real frames / generator outputs.
+ new_frames (list of tensors): Current real frame / generated output.
+ t_scales (int): Temporal scale.
+ tD (int): Number of frames as input to the temporal discriminator.
+ Returns:
+ (tuple):
+ - new_past_frames (list of tensors): Past + current frames.
+ - skipped_frames (list of tensors): Temporally skipped frames using
+ the given t_scales.
+ """
+ new_past_frames, skipped_frames = [], []
+ for past_frame, new_frame in zip(past_frames, new_frames):
+ skipped_frame = None
+ if t_scales > 0:
+ past_frame, skipped_frame = \
+ get_skipped_frames(past_frame, new_frame.unsqueeze(1),
+ t_scales, tD)
+ new_past_frames.append(past_frame)
+ skipped_frames.append(skipped_frame)
+ return new_past_frames, skipped_frames
+
+
+def get_skipped_frames(all_frames, frame, t_scales, tD):
+ r"""Get temporally skipped frames from the input frames.
+
+ Args:
+ all_frames (NxTxCxHxW tensor): All past frames.
+ frame (Nx1xCxHxW tensor): Current frame.
+ t_scales (int): Temporal scale.
+ tD (int): Number of frames as input to the temporal discriminator.
+ Returns:
+ (tuple):
+ - all_frames (NxTxCxHxW tensor): Past + current frames.
+ - skipped_frames (list of NxTxCxHxW tensors): Temporally skipped
+ frames.
+ """
+ all_frames = torch.cat([all_frames.detach(), frame], dim=1) \
+ if all_frames is not None else frame
+ skipped_frames = [None] * t_scales
+ for s in range(t_scales):
+ # Number of skipped frames between neighboring frames (e.g. 1, 3, 9,...)
+ t_step = tD ** s
+ # Number of frames the final triplet frames span before skipping
+ # (e.g., 2, 6, 18, ...).
+ t_span = t_step * (tD-1)
+ if all_frames.size(1) > t_span:
+ skipped_frames[s] = all_frames[:, -(t_span+1)::t_step].contiguous()
+
+ # Maximum number of past frames we need to keep track of.
+ max_num_prev_frames = (tD ** (t_scales-1)) * (tD-1)
+ # Remove past frames that are older than this number.
+ if all_frames.size()[1] > max_num_prev_frames:
+ all_frames = all_frames[:, -max_num_prev_frames:]
+ return all_frames, skipped_frames
+
+
+class MultiPatchDiscriminator(nn.Module):
+ r"""Multi-resolution patch discriminator.
+
+ Args:
+ dis_cfg (obj): Discriminator part of the yaml config file.
+ num_input_channels (int): Number of input channels.
+ """
+
+ def __init__(self, dis_cfg, num_input_channels):
+ super(MultiPatchDiscriminator, self).__init__()
+ kernel_size = getattr(dis_cfg, 'kernel_size', 4)
+ num_filters = getattr(dis_cfg, 'num_filters', 64)
+ max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
+ num_discriminators = getattr(dis_cfg, 'num_discriminators', 3)
+ num_layers = getattr(dis_cfg, 'num_layers', 3)
+ activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
+ weight_norm_type = getattr(dis_cfg, 'weight_norm_type',
+ 'spectral_norm')
+ self.nets_discriminator = []
+ for i in range(num_discriminators):
+ net_discriminator = NLayerPatchDiscriminator(
+ kernel_size,
+ num_input_channels,
+ num_filters,
+ num_layers,
+ max_num_filters,
+ activation_norm_type,
+ weight_norm_type)
+ self.add_module('discriminator_%d' % i, net_discriminator)
+ self.nets_discriminator.append(net_discriminator)
+
+ def forward(self, input_x):
+ r"""Multi-resolution patch discriminator forward.
+
+ Args:
+ input_x (N x C x H x W tensor) : Concatenation of images and
+ semantic representations.
+ Returns:
+ (dict):
+ - output (list): list of output tensors produced by individual
+ patch discriminators.
+ - features (list): list of lists of features produced by
+ individual patch discriminators.
+ """
+ output_list = []
+ features_list = []
+ input_downsampled = input_x
+ for name, net_discriminator in self.named_children():
+ if not name.startswith('discriminator_'):
+ continue
+ output, features = net_discriminator(input_downsampled)
+ output_list.append(output)
+ features_list.append(features)
+ input_downsampled = F.interpolate(
+ input_downsampled, scale_factor=0.5, mode='bilinear',
+ align_corners=True, recompute_scale_factor=True)
+ output_x = dict()
+ output_x['output'] = output_list
+ output_x['features'] = features_list
+ return output_x
diff --git a/imaginaire/discriminators/funit.py b/imaginaire/discriminators/funit.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e74238bf723058dfb680989a8839ce54bf98e61
--- /dev/null
+++ b/imaginaire/discriminators/funit.py
@@ -0,0 +1,117 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import warnings
+
+import torch
+from torch import nn
+
+from imaginaire.layers import Conv2dBlock, Res2dBlock
+
+
+class Discriminator(nn.Module):
+ r"""Discriminator in the improved FUNIT baseline in the COCO-FUNIT paper.
+
+ Args:
+ dis_cfg (obj): Discriminator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, dis_cfg, data_cfg):
+ super().__init__()
+ self.model = ResDiscriminator(**vars(dis_cfg))
+
+ def forward(self, data, net_G_output, recon=True):
+ r"""Improved FUNIT discriminator forward function.
+
+ Args:
+ data (dict): Training data at the current iteration.
+ net_G_output (dict): Fake data generated at the current iteration.
+ recon (bool): If ``True``, also classifies reconstructed images.
+ """
+ source_labels = data['labels_content']
+ target_labels = data['labels_style']
+ fake_out_trans, fake_features_trans = \
+ self.model(net_G_output['images_trans'], target_labels)
+ output = dict(fake_out_trans=fake_out_trans,
+ fake_features_trans=fake_features_trans)
+
+ real_out_style, real_features_style = \
+ self.model(data['images_style'], target_labels)
+ output.update(dict(real_out_style=real_out_style,
+ real_features_style=real_features_style))
+ if recon:
+ fake_out_recon, fake_features_recon = \
+ self.model(net_G_output['images_recon'], source_labels)
+ output.update(dict(fake_out_recon=fake_out_recon,
+ fake_features_recon=fake_features_recon))
+ return output
+
+
+class ResDiscriminator(nn.Module):
+ r"""Residual discriminator architecture used in the FUNIT paper."""
+
+ def __init__(self,
+ image_channels=3,
+ num_classes=119,
+ num_filters=64,
+ max_num_filters=1024,
+ num_layers=6,
+ padding_mode='reflect',
+ weight_norm_type='',
+ **kwargs):
+ super().__init__()
+ for key in kwargs:
+ if key != 'type':
+ warnings.warn(
+ "Discriminator argument {} is not used".format(key))
+
+ conv_params = dict(padding_mode=padding_mode,
+ activation_norm_type='none',
+ weight_norm_type=weight_norm_type,
+ bias=[True, True, True],
+ nonlinearity='leakyrelu',
+ order='NACNAC')
+
+ first_kernel_size = 7
+ first_padding = (first_kernel_size - 1) // 2
+ model = [Conv2dBlock(image_channels, num_filters,
+ first_kernel_size, 1, first_padding,
+ padding_mode=padding_mode,
+ weight_norm_type=weight_norm_type)]
+ for i in range(num_layers):
+ num_filters_prev = num_filters
+ num_filters = min(num_filters * 2, max_num_filters)
+ model += [Res2dBlock(num_filters_prev, num_filters_prev,
+ **conv_params),
+ Res2dBlock(num_filters_prev, num_filters,
+ **conv_params)]
+ if i != num_layers - 1:
+ model += [nn.ReflectionPad2d(1),
+ nn.AvgPool2d(3, stride=2)]
+ self.model = nn.Sequential(*model)
+ self.classifier = Conv2dBlock(num_filters, 1, 1, 1, 0,
+ nonlinearity='leakyrelu',
+ weight_norm_type=weight_norm_type,
+ order='NACNAC')
+
+ self.embedder = nn.Embedding(num_classes, num_filters)
+
+ def forward(self, images, labels=None):
+ r"""Forward function of the projection discriminator.
+
+ Args:
+ images (image tensor): Images inputted to the discriminator.
+ labels (long int tensor): Class labels of the images.
+ """
+ assert (images.size(0) == labels.size(0))
+ features = self.model(images)
+ outputs = self.classifier(features)
+ features_1x1 = features.mean(3).mean(2)
+ if labels is None:
+ return features_1x1
+ embeddings = self.embedder(labels)
+ outputs += torch.sum(embeddings * features_1x1, dim=1,
+ keepdim=True).view(images.size(0), 1, 1, 1)
+ return outputs, features_1x1
diff --git a/imaginaire/discriminators/gancraft.py b/imaginaire/discriminators/gancraft.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bc070cb46ac5c6ae287231ddd0144bedd6d55a2
--- /dev/null
+++ b/imaginaire/discriminators/gancraft.py
@@ -0,0 +1,278 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import functools
+from imaginaire.layers import Conv2dBlock
+
+from imaginaire.utils.data import get_paired_input_label_channel_number, get_paired_input_image_channel_number
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Discriminator(nn.Module):
+ r"""Multi-resolution patch discriminator. Based on FPSE discriminator but with N+1 labels.
+
+ Args:
+ dis_cfg (obj): Discriminator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, dis_cfg, data_cfg):
+ super(Discriminator, self).__init__()
+ # We assume the first datum is the ground truth image.
+ image_channels = get_paired_input_image_channel_number(data_cfg)
+ # Calculate number of channels in the input label.
+ num_labels = get_paired_input_label_channel_number(data_cfg)
+
+ self.use_label = getattr(dis_cfg, 'use_label', True)
+ # Override number of input channels
+ if hasattr(dis_cfg, 'image_channels'):
+ image_channels = dis_cfg.image_channels
+ if hasattr(dis_cfg, 'num_labels'):
+ num_labels = dis_cfg.num_labels
+ else:
+ # We assume the first datum is the ground truth image.
+ image_channels = get_paired_input_image_channel_number(data_cfg)
+ # Calculate number of channels in the input label.
+ num_labels = get_paired_input_label_channel_number(data_cfg)
+
+ if not self.use_label:
+ num_labels = 2 # ignore + true
+
+ # Build the discriminator.
+ num_filters = getattr(dis_cfg, 'num_filters', 128)
+ weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
+
+ fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3)
+ fpse_activation_norm_type = getattr(dis_cfg,
+ 'fpse_activation_norm_type',
+ 'none')
+ do_multiscale = getattr(dis_cfg, 'do_multiscale', False)
+ smooth_resample = getattr(dis_cfg, 'smooth_resample', False)
+ no_label_except_largest_scale = getattr(dis_cfg, 'no_label_except_largest_scale', False)
+
+ self.fpse_discriminator = FPSEDiscriminator(
+ image_channels,
+ num_labels,
+ num_filters,
+ fpse_kernel_size,
+ weight_norm_type,
+ fpse_activation_norm_type,
+ do_multiscale,
+ smooth_resample,
+ no_label_except_largest_scale)
+
+ def _single_forward(self, input_label, input_image, weights):
+ output_list, features_list = self.fpse_discriminator(input_image, input_label, weights)
+ return output_list, [features_list]
+
+ def forward(self, data, net_G_output, weights=None, incl_real=False, incl_pseudo_real=False):
+ r"""GANcraft discriminator forward.
+
+ Args:
+ data (dict):
+ - data (N x C1 x H x W tensor) : Ground truth images.
+ - label (N x C2 x H x W tensor) : Semantic representations.
+ - z (N x style_dims tensor): Gaussian random noise.
+ net_G_output (dict):
+ - fake_images (N x C1 x H x W tensor) : Fake images.
+ Returns:
+ output_x (dict):
+ - real_outputs (list): list of output tensors produced by
+ individual patch discriminators for real images.
+ - real_features (list): list of lists of features produced by
+ individual patch discriminators for real images.
+ - fake_outputs (list): list of output tensors produced by
+ individual patch discriminators for fake images.
+ - fake_features (list): list of lists of features produced by
+ individual patch discriminators for fake images.
+ """
+ output_x = dict()
+
+ # Fake.
+ fake_images = net_G_output['fake_images']
+ if self.use_label:
+ fake_labels = data['fake_masks']
+ else:
+ fake_labels = torch.zeros([fake_images.size(0), 2, fake_images.size(
+ 2), fake_images.size(3)], device=fake_images.device, dtype=fake_images.dtype)
+ fake_labels[:, 1, :, :] = 1
+ output_x['fake_outputs'], output_x['fake_features'] = \
+ self._single_forward(fake_labels, fake_images, None)
+
+ # Real.
+ if incl_real:
+ real_images = data['images']
+ if self.use_label:
+ real_labels = data['real_masks']
+ else:
+ real_labels = torch.zeros([real_images.size(0), 2, real_images.size(
+ 2), real_images.size(3)], device=real_images.device, dtype=real_images.dtype)
+ real_labels[:, 1, :, :] = 1
+ output_x['real_outputs'], output_x['real_features'] = \
+ self._single_forward(real_labels, real_images, None)
+
+ # pseudo-Real.
+ if incl_pseudo_real:
+ preal_images = data['pseudo_real_img']
+ preal_labels = data['fake_masks']
+ if not self.use_label:
+ preal_labels = torch.zeros([preal_images.size(0), 2, preal_images.size(
+ 2), preal_images.size(3)], device=preal_images.device, dtype=preal_images.dtype)
+ preal_labels[:, 1, :, :] = 1
+ output_x['pseudo_real_outputs'], output_x['pseudo_real_features'] = \
+ self._single_forward(preal_labels, preal_images, None)
+
+ return output_x
+
+
+class FPSEDiscriminator(nn.Module):
+ def __init__(self,
+ num_input_channels,
+ num_labels,
+ num_filters,
+ kernel_size,
+ weight_norm_type,
+ activation_norm_type,
+ do_multiscale,
+ smooth_resample,
+ no_label_except_largest_scale):
+ super().__init__()
+
+ self.do_multiscale = do_multiscale
+ self.no_label_except_largest_scale = no_label_except_largest_scale
+
+ padding = int(np.ceil((kernel_size - 1.0) / 2))
+ nonlinearity = 'leakyrelu'
+ stride1_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity,
+ # inplace_nonlinearity=True,
+ order='CNA')
+ down_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity,
+ # inplace_nonlinearity=True,
+ order='CNA')
+ latent_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=1,
+ stride=1,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity,
+ # inplace_nonlinearity=True,
+ order='CNA')
+ # bottom-up pathway
+ self.enc1 = down_conv2d_block(num_input_channels, num_filters) # 3
+ self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) # 7
+ self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) # 15
+ self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) # 31
+ self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) # 63
+
+ # top-down pathway
+ # self.lat1 = latent_conv2d_block(num_filters, 2 * num_filters) # Zekun
+ self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
+ self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
+ self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
+ self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
+
+ # upsampling
+ self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear',
+ align_corners=False)
+
+ # final layers
+ self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+ self.output = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
+
+ if self.do_multiscale:
+ self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+ self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
+ if self.no_label_except_largest_scale:
+ self.output3 = Conv2dBlock(num_filters * 2, 2, kernel_size=1)
+ self.output4 = Conv2dBlock(num_filters * 2, 2, kernel_size=1)
+ else:
+ self.output3 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
+ self.output4 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1)
+
+ self.interpolator = functools.partial(F.interpolate, mode='nearest')
+ if smooth_resample:
+ self.interpolator = self.smooth_interp
+
+ @staticmethod
+ def smooth_interp(x, size):
+ r"""Smooth interpolation of segmentation maps.
+
+ Args:
+ x (4D tensor): Segmentation maps.
+ size(2D list): Target size (H, W).
+ """
+ x = F.interpolate(x, size=size, mode='area')
+ onehot_idx = torch.argmax(x, dim=-3, keepdims=True)
+ x.fill_(0.0)
+ x.scatter_(1, onehot_idx, 1.0)
+ return x
+
+ # Weights: [N C]
+ def forward(self, images, segmaps, weights=None):
+ # Assume images 256x256
+ # bottom-up pathway
+ feat11 = self.enc1(images) # 128
+ feat12 = self.enc2(feat11) # 64
+ feat13 = self.enc3(feat12) # 32
+ feat14 = self.enc4(feat13) # 16
+ feat15 = self.enc5(feat14) # 8
+ # top-down pathway and lateral connections
+ feat25 = self.lat5(feat15) # 8
+ feat24 = self.upsample2x(feat25) + self.lat4(feat14) # 16
+ feat23 = self.upsample2x(feat24) + self.lat3(feat13) # 32
+ feat22 = self.upsample2x(feat23) + self.lat2(feat12) # 64
+
+ # final prediction layers
+ feat32 = self.final2(feat22)
+
+ results = []
+ label_map = self.interpolator(segmaps, size=feat32.size()[2:])
+ pred2 = self.output(feat32) # N, num_labels+1, H//4, W//4
+
+ features = [feat11, feat12, feat13, feat14, feat15, feat25, feat24, feat23, feat22]
+ if weights is not None:
+ label_map = label_map * weights[..., None, None]
+ results.append({'pred': pred2, 'label': label_map})
+
+ if self.do_multiscale:
+ feat33 = self.final3(feat23)
+ pred3 = self.output3(feat33)
+
+ feat34 = self.final4(feat24)
+ pred4 = self.output4(feat34)
+
+ if self.no_label_except_largest_scale:
+ label_map3 = torch.ones([pred3.size(0), 1, pred3.size(2), pred3.size(3)], device=pred3.device)
+ label_map4 = torch.ones([pred4.size(0), 1, pred4.size(2), pred4.size(3)], device=pred4.device)
+ else:
+ label_map3 = self.interpolator(segmaps, size=pred3.size()[2:])
+ label_map4 = self.interpolator(segmaps, size=pred4.size()[2:])
+
+ if weights is not None:
+ label_map3 = label_map3 * weights[..., None, None]
+ label_map4 = label_map4 * weights[..., None, None]
+
+ results.append({'pred': pred3, 'label': label_map3})
+ results.append({'pred': pred4, 'label': label_map4})
+
+ return results, features
diff --git a/imaginaire/discriminators/mlp_multiclass.py b/imaginaire/discriminators/mlp_multiclass.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f7d1d27e2d21a7b4fd5f23646b545a44a47a783
--- /dev/null
+++ b/imaginaire/discriminators/mlp_multiclass.py
@@ -0,0 +1,63 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+
+import numpy as np
+import torch.nn as nn
+
+from imaginaire.layers import LinearBlock
+
+
+class Discriminator(nn.Module):
+ r"""Multi-layer Perceptron Classifier constructor.
+
+ Args:
+ dis_cfg (obj): Discriminator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file
+ """
+
+ def __init__(self, dis_cfg, data_cfg):
+ super(Discriminator, self).__init__()
+ num_input_channels = dis_cfg.input_dims
+ num_labels = dis_cfg.num_labels
+ num_layers = getattr(dis_cfg, 'num_layers', 5)
+ num_filters = getattr(dis_cfg, 'num_filters', 512)
+ activation_norm_type = getattr(dis_cfg,
+ 'activation_norm_type',
+ 'batch_norm')
+ nonlinearity = getattr(dis_cfg, 'nonlinearity', 'leakyrelu')
+ base_linear_block = \
+ functools.partial(LinearBlock,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity,
+ order='CNA')
+ dropout_ratio = 0.1
+ layers = [base_linear_block(num_input_channels, num_filters),
+ nn.Dropout(dropout_ratio)]
+ for n in range(num_layers):
+ dropout_ratio *= 1.5
+ dropout_ratio = np.min([dropout_ratio, 0.5])
+ layers += [base_linear_block(num_filters, num_filters),
+ nn.Dropout(dropout_ratio)]
+ layers += [LinearBlock(num_filters, num_labels)]
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, data):
+ r"""Patch Discriminator forward.
+
+ Args:
+ data (dict):
+ - data (N x -1 tensor): We will reshape the tensor to this format.
+ Returns:
+ (dict):
+ - results (N x C tensor): Output scores before softmax.
+ """
+ input_x = data['data']
+ bs = input_x.size()[0]
+ input_x = input_x.view(bs, -1)
+ pre_softmax_scores = self.model(input_x)
+ outputs = dict()
+ outputs['results'] = pre_softmax_scores
+ return outputs
diff --git a/imaginaire/discriminators/multires_patch.py b/imaginaire/discriminators/multires_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a31d25c689a3d5b26a3b92216715d3ebd62dc91
--- /dev/null
+++ b/imaginaire/discriminators/multires_patch.py
@@ -0,0 +1,313 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# Copyright (C) 2020 NVIDIA Corporation. All rights reserved
+import functools
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from imaginaire.layers import Conv2dBlock
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+ get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Discriminator(nn.Module):
+ r"""Multi-resolution patch discriminator.
+
+ Args:
+ dis_cfg (obj): Discriminator definition part of the yaml config
+ file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, dis_cfg, data_cfg):
+ super(Discriminator, self).__init__()
+ print('Multi-resolution patch discriminator initialization.')
+ # We assume the first datum is the ground truth image.
+ image_channels = get_paired_input_image_channel_number(data_cfg)
+ # Calculate number of channels in the input label.
+ num_labels = get_paired_input_label_channel_number(data_cfg)
+
+ # Build the discriminator.
+ kernel_size = getattr(dis_cfg, 'kernel_size', 3)
+ num_filters = getattr(dis_cfg, 'num_filters', 128)
+ max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
+ num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
+ num_layers = getattr(dis_cfg, 'num_layers', 5)
+ activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
+ weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
+ print('\tBase filter number: %d' % num_filters)
+ print('\tNumber of discriminators: %d' % num_discriminators)
+ print('\tNumber of layers in a discriminator: %d' % num_layers)
+ print('\tWeight norm type: %s' % weight_norm_type)
+ num_input_channels = image_channels + num_labels
+ self.model = MultiResPatchDiscriminator(num_discriminators,
+ kernel_size,
+ num_input_channels,
+ num_filters,
+ num_layers,
+ max_num_filters,
+ activation_norm_type,
+ weight_norm_type)
+ print('Done with the Multi-resolution patch '
+ 'discriminator initialization.')
+
+ def forward(self, data, net_G_output, real=True):
+ r"""SPADE Generator forward.
+
+ Args:
+ data (dict):
+ - data (N x C1 x H x W tensor) : Ground truth images.
+ - label (N x C2 x H x W tensor) : Semantic representations.
+ - z (N x style_dims tensor): Gaussian random noise.
+ net_G_output (dict):
+ fake_images (N x C1 x H x W tensor) : Fake images.
+ real (bool): If ``True``, also classifies real images. Otherwise it
+ only classifies generated images to save computation during the
+ generator update.
+ Returns:
+ (tuple):
+ - real_outputs (list): list of output tensors produced by
+ - individual patch discriminators for real images.
+ - real_features (list): list of lists of features produced by
+ individual patch discriminators for real images.
+ - fake_outputs (list): list of output tensors produced by
+ individual patch discriminators for fake images.
+ - fake_features (list): list of lists of features produced by
+ individual patch discriminators for fake images.
+ """
+ output_x = dict()
+ if 'label' in data:
+ fake_input_x = torch.cat(
+ (data['label'], net_G_output['fake_images']), 1)
+ else:
+ fake_input_x = net_G_output['fake_images']
+ output_x['fake_outputs'], output_x['fake_features'], _ = \
+ self.model.forward(fake_input_x)
+ if real:
+ if 'label' in data:
+ real_input_x = torch.cat(
+ (data['label'], data['images']), 1)
+ else:
+ real_input_x = data['images']
+ output_x['real_outputs'], output_x['real_features'], _ = \
+ self.model.forward(real_input_x)
+ return output_x
+
+
+class MultiResPatchDiscriminator(nn.Module):
+ r"""Multi-resolution patch discriminator.
+
+ Args:
+ num_discriminators (int): Num. of discriminators (one per scale).
+ kernel_size (int): Convolution kernel size.
+ num_image_channels (int): Num. of channels in the real/fake image.
+ num_filters (int): Num. of base filters in a layer.
+ num_layers (int): Num. of layers for the patch discriminator.
+ max_num_filters (int): Maximum num. of filters in a layer.
+ activation_norm_type (str): batch_norm/instance_norm/none/....
+ weight_norm_type (str): none/spectral_norm/weight_norm
+ """
+
+ def __init__(self,
+ num_discriminators=3,
+ kernel_size=3,
+ num_image_channels=3,
+ num_filters=64,
+ num_layers=4,
+ max_num_filters=512,
+ activation_norm_type='',
+ weight_norm_type='',
+ **kwargs):
+ super().__init__()
+ for key in kwargs:
+ if key != 'type' and key != 'patch_wise':
+ warnings.warn(
+ "Discriminator argument {} is not used".format(key))
+
+ self.discriminators = nn.ModuleList()
+ for i in range(num_discriminators):
+ net_discriminator = NLayerPatchDiscriminator(
+ kernel_size,
+ num_image_channels,
+ num_filters,
+ num_layers,
+ max_num_filters,
+ activation_norm_type,
+ weight_norm_type)
+ self.discriminators.append(net_discriminator)
+ print('Done with the Multi-resolution patch '
+ 'discriminator initialization.')
+
+ def forward(self, input_x):
+ r"""Multi-resolution patch discriminator forward.
+
+ Args:
+ input_x (tensor) : Input images.
+ Returns:
+ (tuple):
+ - output_list (list): list of output tensors produced by
+ individual patch discriminators.
+ - features_list (list): list of lists of features produced by
+ individual patch discriminators.
+ - input_list (list): list of downsampled input images.
+ """
+ input_list = []
+ output_list = []
+ features_list = []
+ input_downsampled = input_x
+ for net_discriminator in self.discriminators:
+ input_list.append(input_downsampled)
+ output, features = net_discriminator(input_downsampled)
+ output_list.append(output)
+ features_list.append(features)
+ input_downsampled = nn.functional.interpolate(
+ input_downsampled, scale_factor=0.5, mode='bilinear',
+ align_corners=True, recompute_scale_factor=True)
+ return output_list, features_list, input_list
+
+
+class WeightSharedMultiResPatchDiscriminator(nn.Module):
+ r"""Multi-resolution patch discriminator with shared weights.
+
+ Args:
+ num_discriminators (int): Num. of discriminators (one per scale).
+ kernel_size (int): Convolution kernel size.
+ num_image_channels (int): Num. of channels in the real/fake image.
+ num_filters (int): Num. of base filters in a layer.
+ num_layers (int): Num. of layers for the patch discriminator.
+ max_num_filters (int): Maximum num. of filters in a layer.
+ activation_norm_type (str): batch_norm/instance_norm/none/....
+ weight_norm_type (str): none/spectral_norm/weight_norm
+ """
+
+ def __init__(self,
+ num_discriminators=3,
+ kernel_size=3,
+ num_image_channels=3,
+ num_filters=64,
+ num_layers=4,
+ max_num_filters=512,
+ activation_norm_type='',
+ weight_norm_type='',
+ **kwargs):
+ super().__init__()
+ for key in kwargs:
+ if key != 'type' and key != 'patch_wise':
+ warnings.warn(
+ "Discriminator argument {} is not used".format(key))
+ self.num_discriminators = num_discriminators
+ self.discriminator = NLayerPatchDiscriminator(
+ kernel_size,
+ num_image_channels,
+ num_filters,
+ num_layers,
+ max_num_filters,
+ activation_norm_type,
+ weight_norm_type)
+ print('Done with the Weight-Shared Multi-resolution patch '
+ 'discriminator initialization.')
+
+ def forward(self, input_x):
+ r"""Multi-resolution patch discriminator forward.
+
+ Args:
+ input_x (tensor) : Input images.
+ Returns:
+ (tuple):
+ - output_list (list): list of output tensors produced by
+ individual patch discriminators.
+ - features_list (list): list of lists of features produced by
+ individual patch discriminators.
+ - input_list (list): list of downsampled input images.
+ """
+ input_list = []
+ output_list = []
+ features_list = []
+ input_downsampled = input_x
+ for i in range(self.num_discriminators):
+ input_list.append(input_downsampled)
+ output, features = self.discriminator(input_downsampled)
+ output_list.append(output)
+ features_list.append(features)
+ input_downsampled = nn.functional.interpolate(
+ input_downsampled, scale_factor=0.5, mode='bilinear',
+ align_corners=True)
+ return output_list, features_list, input_list
+
+
+class NLayerPatchDiscriminator(nn.Module):
+ r"""Patch Discriminator constructor.
+
+ Args:
+ kernel_size (int): Convolution kernel size.
+ num_input_channels (int): Num. of channels in the real/fake image.
+ num_filters (int): Num. of base filters in a layer.
+ num_layers (int): Num. of layers for the patch discriminator.
+ max_num_filters (int): Maximum num. of filters in a layer.
+ activation_norm_type (str): batch_norm/instance_norm/none/....
+ weight_norm_type (str): none/spectral_norm/weight_norm
+ """
+
+ def __init__(self,
+ kernel_size,
+ num_input_channels,
+ num_filters,
+ num_layers,
+ max_num_filters,
+ activation_norm_type,
+ weight_norm_type):
+ super(NLayerPatchDiscriminator, self).__init__()
+ self.num_layers = num_layers
+ padding = int(np.floor((kernel_size - 1.0) / 2))
+ nonlinearity = 'leakyrelu'
+ base_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity,
+ # inplace_nonlinearity=True,
+ order='CNA')
+ layers = [[base_conv2d_block(
+ num_input_channels, num_filters, stride=2)]]
+ for n in range(num_layers):
+ num_filters_prev = num_filters
+ num_filters = min(num_filters * 2, max_num_filters)
+ stride = 2 if n < (num_layers - 1) else 1
+ layers += [[base_conv2d_block(num_filters_prev, num_filters,
+ stride=stride)]]
+ layers += [[Conv2dBlock(num_filters, 1,
+ 3, 1,
+ padding,
+ weight_norm_type=weight_norm_type)]]
+ for n in range(len(layers)):
+ setattr(self, 'layer' + str(n), nn.Sequential(*layers[n]))
+
+ def forward(self, input_x):
+ r"""Patch Discriminator forward.
+
+ Args:
+ input_x (N x C x H1 x W2 tensor): Concatenation of images and
+ semantic representations.
+ Returns:
+ (tuple):
+ - output (N x 1 x H2 x W2 tensor): Discriminator output value.
+ Before the sigmoid when using NSGAN.
+ - features (list): lists of tensors of the intermediate
+ activations.
+ """
+ res = [input_x]
+ for n in range(self.num_layers + 2):
+ layer = getattr(self, 'layer' + str(n))
+ x = res[-1]
+ res.append(layer(x))
+ output = res[-1]
+ features = res[1:-1]
+ return output, features
diff --git a/imaginaire/discriminators/multires_patch_pano.py b/imaginaire/discriminators/multires_patch_pano.py
new file mode 100644
index 0000000000000000000000000000000000000000..97763da3bbfc990be1755bbb62383869fd6708da
--- /dev/null
+++ b/imaginaire/discriminators/multires_patch_pano.py
@@ -0,0 +1,247 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# Copyright (C) 2020 NVIDIA Corporation. All rights reserved
+import functools
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from imaginaire.layers import Conv2dBlock
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+ get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+from model.sample import Equirectangular
+
+class Discriminator(nn.Module):
+ r"""Multi-resolution patch discriminator.
+
+ Args:
+ dis_cfg (obj): Discriminator definition part of the yaml config
+ file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, dis_cfg):
+ super(Discriminator, self).__init__()
+ print('Multi-resolution patch discriminator initialization.')
+ # We assume the first datum is the ground truth image.
+ num_input_channels = getattr(dis_cfg, 'input_channels', 3)
+ # Calculate number of channels in the input label.
+
+ # Build the discriminator.
+ kernel_size = getattr(dis_cfg, 'kernel_size', 3)
+ num_filters = getattr(dis_cfg, 'num_filters', 128)
+ max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
+ num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
+ num_layers = getattr(dis_cfg, 'num_layers', 5)
+ activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
+ weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
+ print('\tBase filter number: %d' % num_filters)
+ print('\tNumber of discriminators: %d' % num_discriminators)
+ print('\tNumber of layers in a discriminator: %d' % num_layers)
+ print('\tWeight norm type: %s' % weight_norm_type)
+ self.condition = getattr(dis_cfg, 'condition', None)
+ # self.condition = dis_cfg.condition
+ self.model = MultiResPatchDiscriminator(num_discriminators,
+ kernel_size,
+ num_input_channels,
+ num_filters,
+ num_layers,
+ max_num_filters,
+ activation_norm_type,
+ weight_norm_type)
+ print('Done with the Multi-resolution patch '
+ 'discriminator initialization.')
+
+ def forward(self, data, net_G_output, real=True):
+ r"""SPADE Generator forward.
+
+ Args:
+ data (N x C1 x H x W tensor) : Ground truth images.
+ net_G_output (dict):
+ fake_images (N x C1 x H x W tensor) : Fake images.
+ real (bool): If ``True``, also classifies real images. Otherwise it
+ only classifies generated images to save computation during the
+ generator update.
+ Returns:
+ (tuple):
+ - real_outputs (list): list of output tensors produced by
+ - individual patch discriminators for real images.
+ - real_features (list): list of lists of features produced by
+ individual patch discriminators for real images.
+ - fake_outputs (list): list of output tensors produced by
+ individual patch discriminators for fake images.
+ - fake_features (list): list of lists of features produced by
+ individual patch discriminators for fake images.
+ """
+ output_x = dict()
+ if self.condition:
+ fake_input_x = torch.cat([net_G_output['pred'],net_G_output['generator_inputs']],dim=1)
+ else:
+ fake_input_x = net_G_output['pred']
+ output_x['fake_outputs'], output_x['fake_features'], _ = \
+ self.model.forward(fake_input_x)
+ if real:
+ if self.condition:
+ real_input_x = torch.cat([net_G_output['pred'],net_G_output['generator_inputs']],dim=1)
+ else:
+ real_input_x = data
+ output_x['real_outputs'], output_x['real_features'], _ = \
+ self.model.forward(real_input_x)
+ return output_x
+
+
+class MultiResPatchDiscriminator(nn.Module):
+ r"""Multi-resolution patch discriminator.
+
+ Args:
+ num_discriminators (int): Num. of discriminators (one per scale).
+ kernel_size (int): Convolution kernel size.
+ num_image_channels (int): Num. of channels in the real/fake image.
+ num_filters (int): Num. of base filters in a layer.
+ num_layers (int): Num. of layers for the patch discriminator.
+ max_num_filters (int): Maximum num. of filters in a layer.
+ activation_norm_type (str): batch_norm/instance_norm/none/....
+ weight_norm_type (str): none/spectral_norm/weight_norm
+ """
+
+ def __init__(self,
+ num_discriminators=3,
+ kernel_size=3,
+ num_image_channels=3,
+ num_filters=64,
+ num_layers=4,
+ max_num_filters=512,
+ activation_norm_type='',
+ weight_norm_type='',
+ **kwargs):
+ super().__init__()
+ for key in kwargs:
+ if key != 'type' and key != 'patch_wise':
+ warnings.warn(
+ "Discriminator argument {} is not used".format(key))
+
+ self.discriminators = nn.ModuleList()
+ for i in range(num_discriminators):
+ net_discriminator = NLayerPatchDiscriminator(
+ kernel_size,
+ num_image_channels,
+ num_filters,
+ num_layers,
+ max_num_filters,
+ activation_norm_type,
+ weight_norm_type)
+ self.discriminators.append(net_discriminator)
+ print('Done with the Multi-resolution patch '
+ 'discriminator initialization.')
+ self.e = Equirectangular(theta=[-40., 40.],width = 128, height = 128,FovX = 100)
+
+ def forward(self, input_x):
+ r"""Multi-resolution patch discriminator forward.
+
+ Args:
+ input_x (tensor) : Input images.
+ Returns:
+ (tuple):
+ - output_list (list): list of output tensors produced by
+ individual patch discriminators.
+ - features_list (list): list of lists of features produced by
+ individual patch discriminators.
+ - input_list (list): list of downsampled input images.
+ """
+ input_list = []
+ output_list = []
+ features_list = []
+ input_N = nn.functional.interpolate(
+ input_x, scale_factor=0.5, mode='bilinear',
+ align_corners=True, recompute_scale_factor=True)
+ equ= self.e(input_x)
+ for i, net_discriminator in enumerate(self.discriminators):
+ input_list.append(input_N)
+ output, features = net_discriminator(input_N)
+ output_list.append(output)
+ features_list.append(features)
+ if i == 0:
+ input_N = torch.nn.functional.grid_sample(input_x, equ.float(), align_corners = True)*0.99
+ elif i == 1:
+ input_N = nn.functional.interpolate(
+ input_N, scale_factor=0.5, mode='bilinear',
+ align_corners=True, recompute_scale_factor=True)
+
+ return output_list, features_list, input_list
+
+class NLayerPatchDiscriminator(nn.Module):
+ r"""Patch Discriminator constructor.
+
+ Args:
+ kernel_size (int): Convolution kernel size.
+ num_input_channels (int): Num. of channels in the real/fake image.
+ num_filters (int): Num. of base filters in a layer.
+ num_layers (int): Num. of layers for the patch discriminator.
+ max_num_filters (int): Maximum num. of filters in a layer.
+ activation_norm_type (str): batch_norm/instance_norm/none/....
+ weight_norm_type (str): none/spectral_norm/weight_norm
+ """
+
+ def __init__(self,
+ kernel_size,
+ num_input_channels,
+ num_filters,
+ num_layers,
+ max_num_filters,
+ activation_norm_type,
+ weight_norm_type):
+ super(NLayerPatchDiscriminator, self).__init__()
+ self.num_layers = num_layers
+ padding = int(np.floor((kernel_size - 1.0) / 2))
+ nonlinearity = 'leakyrelu'
+ base_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity,
+ # inplace_nonlinearity=True,
+ order='CNA')
+ layers = [[base_conv2d_block(
+ num_input_channels, num_filters, stride=2)]]
+ for n in range(num_layers):
+ num_filters_prev = num_filters
+ num_filters = min(num_filters * 2, max_num_filters)
+ stride = 2 if n < (num_layers - 1) else 1
+ layers += [[base_conv2d_block(num_filters_prev, num_filters,
+ stride=stride)]]
+ layers += [[Conv2dBlock(num_filters, 1,
+ 3, 1,
+ padding,
+ weight_norm_type=weight_norm_type)]]
+ for n in range(len(layers)):
+ setattr(self, 'layer' + str(n), nn.Sequential(*layers[n]))
+
+
+ def forward(self, input_x):
+ r"""Patch Discriminator forward.
+
+ Args:
+ input_x (N x C x H1 x W2 tensor): Concatenation of images and
+ semantic representations.
+ Returns:
+ (tuple):
+ - output (N x 1 x H2 x W2 tensor): Discriminator output value.
+ Before the sigmoid when using NSGAN.
+ - features (list): lists of tensors of the intermediate
+ activations.
+ """
+ res = [input_x]
+ for n in range(self.num_layers + 2):
+ layer = getattr(self, 'layer' + str(n))
+ x = res[-1]
+ res.append(layer(x))
+ output = res[-1]
+ features = res[1:-1]
+ return output, features
diff --git a/imaginaire/discriminators/munit.py b/imaginaire/discriminators/munit.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e407569764dced0962e66c56a1a0b2e9106c683
--- /dev/null
+++ b/imaginaire/discriminators/munit.py
@@ -0,0 +1,99 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from torch import nn
+
+from imaginaire.discriminators.multires_patch import MultiResPatchDiscriminator
+from imaginaire.discriminators.residual import ResDiscriminator
+
+
+class Discriminator(nn.Module):
+ r"""MUNIT discriminator. It can be either a multi-resolution patch
+ discriminator like in the original implementation, or a
+ global residual discriminator.
+
+ Args:
+ dis_cfg (obj): Discriminator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file
+ """
+
+ def __init__(self, dis_cfg, data_cfg):
+ super().__init__()
+ if getattr(dis_cfg, 'patch_wise', True):
+ # Use the multi-resolution patch discriminator. It works better for
+ # scene images and when you want to preserve pixel-wise
+ # correspondence during translation.
+ self.discriminator_a = \
+ MultiResPatchDiscriminator(**vars(dis_cfg))
+ self.discriminator_b = \
+ MultiResPatchDiscriminator(**vars(dis_cfg))
+ else:
+ # Use the global residual discriminator. It works better if images
+ # have a single centered object (e.g., animal faces, shoes).
+ self.discriminator_a = ResDiscriminator(**vars(dis_cfg))
+ self.discriminator_b = ResDiscriminator(**vars(dis_cfg))
+
+ def forward(self, data, net_G_output, gan_recon=False, real=True):
+ r"""Returns the output of the discriminator.
+
+ Args:
+ data (dict):
+ - images_a (tensor) : Images in domain A.
+ - images_b (tensor) : Images in domain B.
+ net_G_output (dict):
+ - images_ab (tensor) : Images translated from domain A to B by
+ the generator.
+ - images_ba (tensor) : Images translated from domain B to A by
+ the generator.
+ - images_aa (tensor) : Reconstructed images in domain A.
+ - images_bb (tensor) : Reconstructed images in domain B.
+ gan_recon (bool): If ``True``, also classifies reconstructed images.
+ real (bool): If ``True``, also classifies real images. Otherwise it
+ only classifies generated images to save computation during the
+ generator update.
+
+ Returns:
+ (dict):
+ - out_ab (tensor): Output of the discriminator for images
+ translated from domain A to B by the generator.
+ - out_ab (tensor): Output of the discriminator for images
+ translated from domain B to A by the generator.
+ - fea_ab (tensor): Intermediate features of the discriminator
+ for images translated from domain B to A by the generator.
+ - fea_ba (tensor): Intermediate features of the discriminator
+ for images translated from domain A to B by the generator.
+
+ - out_a (tensor): Output of the discriminator for images
+ in domain A.
+ - out_b (tensor): Output of the discriminator for images
+ in domain B.
+ - fea_a (tensor): Intermediate features of the discriminator
+ for images in domain A.
+ - fea_b (tensor): Intermediate features of the discriminator
+ for images in domain B.
+
+ - out_aa (tensor): Output of the discriminator for
+ reconstructed images in domain A.
+ - out_bb (tensor): Output of the discriminator for
+ reconstructed images in domain B.
+ - fea_aa (tensor): Intermediate features of the discriminator
+ for reconstructed images in domain A.
+ - fea_bb (tensor): Intermediate features of the discriminator
+ for reconstructed images in domain B.
+ """
+ out_ab, fea_ab, _ = self.discriminator_b(net_G_output['images_ab'])
+ out_ba, fea_ba, _ = self.discriminator_a(net_G_output['images_ba'])
+ output = dict(out_ba=out_ba, out_ab=out_ab,
+ fea_ba=fea_ba, fea_ab=fea_ab)
+ if real:
+ out_a, fea_a, _ = self.discriminator_a(data['images_a'])
+ out_b, fea_b, _ = self.discriminator_b(data['images_b'])
+ output.update(dict(out_a=out_a, out_b=out_b,
+ fea_a=fea_a, fea_b=fea_b))
+ if gan_recon:
+ out_aa, fea_aa, _ = self.discriminator_a(net_G_output['images_aa'])
+ out_bb, fea_bb, _ = self.discriminator_b(net_G_output['images_bb'])
+ output.update(dict(out_aa=out_aa, out_bb=out_bb,
+ fea_aa=fea_aa, fea_bb=fea_bb))
+ return output
diff --git a/imaginaire/discriminators/residual.py b/imaginaire/discriminators/residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..f65b41df96f8ea25e82d18adf33cabef746f863c
--- /dev/null
+++ b/imaginaire/discriminators/residual.py
@@ -0,0 +1,96 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import warnings
+
+import torch
+import torch.nn as nn
+
+from imaginaire.layers import Conv2dBlock, Res2dBlock
+from imaginaire.third_party.upfirdn2d import BlurDownsample
+
+
+class ResDiscriminator(nn.Module):
+ r"""Global residual discriminator.
+
+ Args:
+ image_channels (int): Num. of channels in the real/fake image.
+ num_filters (int): Num. of base filters in a layer.
+ max_num_filters (int): Maximum num. of filters in a layer.
+ first_kernel_size (int): Kernel size in the first layer.
+ num_layers (int): Num. of layers in discriminator.
+ padding_mode (str): Padding mode.
+ activation_norm_type (str): Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``.
+ weight_norm_type (str): Type of weight normalization.
+ ``'none'``, ``'spectral'``, or ``'weight'``.
+ aggregation (str): Method to aggregate features across different
+ locations in the final layer. ``'conv'``, or ``'pool'``.
+ order (str): Order of operations in the residual link.
+ anti_aliased (bool): If ``True``, uses anti-aliased pooling.
+ """
+
+ def __init__(self,
+ image_channels=3,
+ num_filters=64,
+ max_num_filters=512,
+ first_kernel_size=1,
+ num_layers=4,
+ padding_mode='zeros',
+ activation_norm_type='',
+ weight_norm_type='',
+ aggregation='conv',
+ order='pre_act',
+ anti_aliased=False,
+ **kwargs):
+ super().__init__()
+ for key in kwargs:
+ if key != 'type' and key != 'patch_wise':
+ warnings.warn(
+ "Discriminator argument {} is not used".format(key))
+
+ conv_params = dict(padding_mode=padding_mode,
+ activation_norm_type=activation_norm_type,
+ weight_norm_type=weight_norm_type,
+ nonlinearity='leakyrelu')
+
+ first_padding = (first_kernel_size - 1) // 2
+ model = [Conv2dBlock(image_channels, num_filters,
+ first_kernel_size, 1, first_padding,
+ **conv_params)]
+ for _ in range(num_layers):
+ num_filters_prev = num_filters
+ num_filters = min(num_filters * 2, max_num_filters)
+ model.append(Res2dBlock(num_filters_prev, num_filters, order=order,
+ **conv_params))
+ if anti_aliased:
+ model.append(BlurDownsample())
+ else:
+ model.append(nn.AvgPool2d(2, stride=2))
+ if aggregation == 'pool':
+ model += [torch.nn.AdaptiveAvgPool2d(1)]
+ elif aggregation == 'conv':
+ model += [Conv2dBlock(num_filters, num_filters, 4, 1, 0,
+ nonlinearity='leakyrelu')]
+ else:
+ raise ValueError('The aggregation mode is not recognized'
+ % self.aggregation)
+ self.model = nn.Sequential(*model)
+ self.classifier = nn.Linear(num_filters, 1)
+
+ def forward(self, images):
+ r"""Multi-resolution patch discriminator forward.
+
+ Args:
+ images (tensor) : Input images.
+ Returns:
+ (tuple):
+ - outputs (tensor): Output of the discriminator.
+ - features (tensor): Intermediate features of the discriminator.
+ - images (tensor): Input images.
+ """
+ batch_size = images.size(0)
+ features = self.model(images)
+ outputs = self.classifier(features.view(batch_size, -1))
+ return outputs, features, images
diff --git a/imaginaire/discriminators/spade.py b/imaginaire/discriminators/spade.py
new file mode 100644
index 0000000000000000000000000000000000000000..d85d1c5926ce92e948501c7afaf1d1324e1ea38c
--- /dev/null
+++ b/imaginaire/discriminators/spade.py
@@ -0,0 +1,119 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+
+from imaginaire.discriminators.fpse import FPSEDiscriminator
+from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+ get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Discriminator(nn.Module):
+ r"""Multi-resolution patch discriminator.
+
+ Args:
+ dis_cfg (obj): Discriminator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, dis_cfg, data_cfg):
+ super(Discriminator, self).__init__()
+ print('Multi-resolution patch discriminator initialization.')
+ image_channels = getattr(dis_cfg, 'image_channels', None)
+ if image_channels is None:
+ image_channels = get_paired_input_image_channel_number(data_cfg)
+ num_labels = getattr(dis_cfg, 'num_labels', None)
+ if num_labels is None:
+ # Calculate number of channels in the input label when not specified.
+ num_labels = get_paired_input_label_channel_number(data_cfg)
+
+ # Build the discriminator.
+ kernel_size = getattr(dis_cfg, 'kernel_size', 3)
+ num_filters = getattr(dis_cfg, 'num_filters', 128)
+ max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
+ num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
+ num_layers = getattr(dis_cfg, 'num_layers', 5)
+ activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
+ weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
+ print('\tBase filter number: %d' % num_filters)
+ print('\tNumber of discriminators: %d' % num_discriminators)
+ print('\tNumber of layers in a discriminator: %d' % num_layers)
+ print('\tWeight norm type: %s' % weight_norm_type)
+ num_input_channels = image_channels + num_labels
+ self.discriminators = nn.ModuleList()
+ for i in range(num_discriminators):
+ net_discriminator = NLayerPatchDiscriminator(
+ kernel_size,
+ num_input_channels,
+ num_filters,
+ num_layers,
+ max_num_filters,
+ activation_norm_type,
+ weight_norm_type)
+ self.discriminators.append(net_discriminator)
+ print('Done with the Multi-resolution patch discriminator initialization.')
+ self.use_fpse = getattr(dis_cfg, 'use_fpse', True)
+ if self.use_fpse:
+ fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3)
+ fpse_activation_norm_type = getattr(dis_cfg,
+ 'fpse_activation_norm_type',
+ 'none')
+ self.fpse_discriminator = FPSEDiscriminator(
+ image_channels,
+ num_labels,
+ num_filters,
+ fpse_kernel_size,
+ weight_norm_type,
+ fpse_activation_norm_type)
+
+ def _single_forward(self, input_label, input_image):
+ # Compute discriminator outputs and intermediate features from input
+ # images and semantic labels.
+ input_x = torch.cat(
+ (input_label, input_image), 1)
+ output_list = []
+ features_list = []
+ if self.use_fpse:
+ pred2, pred3, pred4 = self.fpse_discriminator(input_image, input_label)
+ output_list = [pred2, pred3, pred4]
+ input_downsampled = input_x
+ for net_discriminator in self.discriminators:
+ output, features = net_discriminator(input_downsampled)
+ output_list.append(output)
+ features_list.append(features)
+ input_downsampled = nn.functional.interpolate(
+ input_downsampled, scale_factor=0.5, mode='bilinear',
+ align_corners=True)
+ return output_list, features_list
+
+ def forward(self, data, net_G_output):
+ r"""SPADE discriminator forward.
+
+ Args:
+ data (dict):
+ - data (N x C1 x H x W tensor) : Ground truth images.
+ - label (N x C2 x H x W tensor) : Semantic representations.
+ - z (N x style_dims tensor): Gaussian random noise.
+ net_G_output (dict):
+ fake_images (N x C1 x H x W tensor) : Fake images.
+ Returns:
+ (dict):
+ - real_outputs (list): list of output tensors produced by
+ individual patch discriminators for real images.
+ - real_features (list): list of lists of features produced by
+ individual patch discriminators for real images.
+ - fake_outputs (list): list of output tensors produced by
+ individual patch discriminators for fake images.
+ - fake_features (list): list of lists of features produced by
+ individual patch discriminators for fake images.
+ """
+ output_x = dict()
+ output_x['real_outputs'], output_x['real_features'] = \
+ self._single_forward(data['label'], data['images'])
+ output_x['fake_outputs'], output_x['fake_features'] = \
+ self._single_forward(data['label'], net_G_output['fake_images'])
+ return output_x
diff --git a/imaginaire/discriminators/unit.py b/imaginaire/discriminators/unit.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb537fe47ea17051a012f7e647c5ac6b6ea1b7d9
--- /dev/null
+++ b/imaginaire/discriminators/unit.py
@@ -0,0 +1,99 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from torch import nn
+
+from imaginaire.discriminators.multires_patch import \
+ WeightSharedMultiResPatchDiscriminator
+from imaginaire.discriminators.residual import ResDiscriminator
+
+
+class Discriminator(nn.Module):
+ r"""UNIT discriminator. It can be either a multi-resolution patch
+ discriminator like in the original implementation, or a
+ global residual discriminator.
+
+ Args:
+ dis_cfg (obj): Discriminator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file
+ """
+
+ def __init__(self, dis_cfg, data_cfg):
+ super().__init__()
+ if getattr(dis_cfg, 'patch_dis', True):
+ # Use the multi-resolution patch discriminator. It works better for
+ # scene images and when you want to preserve pixel-wise
+ # correspondence during translation.
+ self.discriminator_a = \
+ WeightSharedMultiResPatchDiscriminator(**vars(dis_cfg))
+ self.discriminator_b = \
+ WeightSharedMultiResPatchDiscriminator(**vars(dis_cfg))
+ else:
+ # Use the global residual discriminator. It works better if images
+ # have a single centered object (e.g., animal faces, shoes).
+ self.discriminator_a = ResDiscriminator(**vars(dis_cfg))
+ self.discriminator_b = ResDiscriminator(**vars(dis_cfg))
+
+ def forward(self, data, net_G_output, gan_recon=False, real=True):
+ r"""Returns the output of the discriminator.
+
+ Args:
+ data (dict):
+ - images_a (tensor) : Images in domain A.
+ - images_b (tensor) : Images in domain B.
+ net_G_output (dict):
+ - images_ab (tensor) : Images translated from domain A to B by
+ the generator.
+ - images_ba (tensor) : Images translated from domain B to A by
+ the generator.
+ - images_aa (tensor) : Reconstructed images in domain A.
+ - images_bb (tensor) : Reconstructed images in domain B.
+ gan_recon (bool): If ``True``, also classifies reconstructed images.
+ real (bool): If ``True``, also classifies real images. Otherwise it
+ only classifies generated images to save computation during the
+ generator update.
+ Returns:
+ (dict):
+ - out_ab (tensor): Output of the discriminator for images
+ translated from domain A to B by the generator.
+ - out_ab (tensor): Output of the discriminator for images
+ translated from domain B to A by the generator.
+ - fea_ab (tensor): Intermediate features of the discriminator
+ for images translated from domain B to A by the generator.
+ - fea_ba (tensor): Intermediate features of the discriminator
+ for images translated from domain A to B by the generator.
+
+ - out_a (tensor): Output of the discriminator for images
+ in domain A.
+ - out_b (tensor): Output of the discriminator for images
+ in domain B.
+ - fea_a (tensor): Intermediate features of the discriminator
+ for images in domain A.
+ - fea_b (tensor): Intermediate features of the discriminator
+ for images in domain B.
+
+ - out_aa (tensor): Output of the discriminator for
+ reconstructed images in domain A.
+ - out_bb (tensor): Output of the discriminator for
+ reconstructed images in domain B.
+ - fea_aa (tensor): Intermediate features of the discriminator
+ for reconstructed images in domain A.
+ - fea_bb (tensor): Intermediate features of the discriminator
+ for reconstructed images in domain B.
+ """
+ out_ab, fea_ab, _ = self.discriminator_b(net_G_output['images_ab'])
+ out_ba, fea_ba, _ = self.discriminator_a(net_G_output['images_ba'])
+ output = dict(out_ba=out_ba, out_ab=out_ab,
+ fea_ba=fea_ba, fea_ab=fea_ab)
+ if real:
+ out_a, fea_a, _ = self.discriminator_a(data['images_a'])
+ out_b, fea_b, _ = self.discriminator_b(data['images_b'])
+ output.update(dict(out_a=out_a, out_b=out_b,
+ fea_a=fea_a, fea_b=fea_b))
+ if gan_recon:
+ out_aa, fea_aa, _ = self.discriminator_a(net_G_output['images_aa'])
+ out_bb, fea_bb, _ = self.discriminator_b(net_G_output['images_bb'])
+ output.update(dict(out_aa=out_aa, out_bb=out_bb,
+ fea_aa=fea_aa, fea_bb=fea_bb))
+ return output
diff --git a/imaginaire/evaluation/__init__.py b/imaginaire/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a379a7be63550921dcd5802e7621196df9de9e1
--- /dev/null
+++ b/imaginaire/evaluation/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .fid import compute_fid, compute_fid_data
+from .kid import compute_kid, compute_kid_data
+from .prdc import compute_prdc
+from .common import compute_all_metrics, compute_all_metrics_data
+
+__all__ = ['compute_fid', 'compute_fid_data', 'compute_kid', 'compute_kid_data',
+ 'compute_prdc', 'compute_all_metrics', 'compute_all_metrics_data']
diff --git a/imaginaire/evaluation/caption/__init__.py b/imaginaire/evaluation/caption/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3197a51d0318c393a6659495d97914b1192f587a
--- /dev/null
+++ b/imaginaire/evaluation/caption/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .r_precision import get_r_precision
+from .common import get_image_encoder
+
+__all__ = ['get_image_encoder', 'get_r_precision']
diff --git a/imaginaire/evaluation/caption/clip.py b/imaginaire/evaluation/caption/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cdcd94e80afcb310564d9d670c3f73f5e065707
--- /dev/null
+++ b/imaginaire/evaluation/caption/clip.py
@@ -0,0 +1,576 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# flake8: noqa
+# https://github.com/openai/CLIP
+import hashlib
+import os
+import urllib
+import warnings
+from time import sleep
+from typing import Union, List
+
+
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+from torch import nn
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, \
+ Normalize
+from tqdm import tqdm
+
+__all__ = ["available_models", "load", 'build_model']
+
+from imaginaire.utils.io import download_file_from_google_drive
+
+_MODELS = {
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+}
+
+
+def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(
+ f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target,
+ "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target,
+ "wb") as output:
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80,
+ unit='iB', unit_scale=True) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(
+ open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError(
+ f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def _transform(n_px):
+ return Compose([
+ Resize(n_px, interpolation=Image.BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert("RGB"),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+
+def available_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list(_MODELS.keys())
+
+
+def load(model_path):
+ if not os.path.exists(model_path):
+ downloaded = False
+ while not downloaded:
+ try:
+ download_file_from_google_drive("1Ri5APYM34A_IjG4F3Admutsf2oUwDjfW", model_path)
+ downloaded = True
+ except Exception as e:
+ print(e)
+ sleep(30)
+ continue
+ model = torch.load(model_path, map_location='cpu')
+ model = build_model(model).cuda()
+ return model, _transform(model.visual.input_resolution)
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1,
+ bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int,
+ output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
+ 2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat(
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, input_resolution=224,
+ width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2,
+ padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.avgpool = nn.AvgPool2d(2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
+ heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ def stem(x):
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
+ (self.conv3, self.bn3)]:
+ x = self.relu(bn(conv(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int,
+ attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype,
+ device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[
+ 0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int,
+ attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(
+ *[ResidualAttentionBlock(width, heads, attn_mask) for _ in
+ range(layers)])
+
+ def forward(self, x: torch.Tensor):
+ return self.resblocks(x)
+
+
+class VisualTransformer(nn.Module):
+ def __init__(self, input_resolution: int, patch_size: int, width: int,
+ layers: int, heads: int, output_dim: int):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width,
+ kernel_size=patch_size, stride=patch_size,
+ bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(
+ scale * torch.randn((input_resolution // patch_size) ** 2 + 1,
+ width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def forward(self, x: torch.Tensor):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1],
+ -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x],
+ dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+class CLIP(nn.Module):
+ def __init__(self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+
+ if isinstance(vision_layers, (tuple, list)):
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width
+ )
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisualTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask()
+ )
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(
+ torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(
+ torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ self.initialize_parameters()
+
+ def initialize_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ if isinstance(self.visual, ModifiedResNet):
+ if self.visual.attnpool is not None:
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.visual.layer1, self.visual.layer2,
+ self.visual.layer3, self.visual.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ proj_std = (self.transformer.width ** -0.5) * (
+ (2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection,
+ std=self.transformer.width ** -0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ return self.visual.conv1.weight.dtype
+
+ def encode_image(self, image):
+ return self.visual(image.type(self.dtype))
+
+ def encode_text(self, text):
+ x = self.token_embedding(text).type(
+ self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(
+ dim=-1)] @ self.text_projection
+
+ return x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+
+ # normalized features
+ image_features = image_features / image_features.norm(dim=-1,
+ keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_image = logit_scale * image_features @ text_features.t()
+ logits_per_text = logit_scale * text_features @ image_features.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_image, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
+ "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len([k for k in state_dict.keys() if
+ k.startswith("visual.") and k.endswith(
+ ".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round(
+ (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if
+ k.startswith(f"visual.layer{b}"))) for b in
+ [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict[
+ "visual.attnpool.positional_embedding"].shape[
+ 0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == \
+ state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if
+ k.startswith(f"transformer.resblocks")))
+
+ model = CLIP(
+ embed_dim,
+ image_resolution, vision_layers, vision_width, vision_patch_size,
+ context_length, vocab_size, transformer_width, transformer_heads,
+ transformer_layers
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ if key in state_dict:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict)
+ return model.eval()
diff --git a/imaginaire/evaluation/caption/common.py b/imaginaire/evaluation/caption/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed54e55335b9b9684eab2fc58cbe14e810597dd6
--- /dev/null
+++ b/imaginaire/evaluation/caption/common.py
@@ -0,0 +1,57 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+import boto3
+import torch
+from torch import nn, distributed as dist
+from torch.nn import functional as F
+from torch.distributed import barrier
+
+from imaginaire.utils.distributed import is_local_master
+from .clip import build_model
+from ...utils.io import download_file_from_google_drive
+
+
+def get_image_encoder(aws_credentials=None):
+ if dist.is_initialized() and not is_local_master():
+ # Make sure only the first process in distributed training downloads the model, and the others use the cache.
+ barrier()
+
+ # Load the CLIP image encoder.
+ print("Loading CLIP image encoder.")
+ model_path = os.path.join(torch.hub.get_dir(), 'checkpoints', 'ViT-B-32.pt')
+ if not os.path.exists(model_path):
+ if aws_credentials is not None:
+ s3 = boto3.client('s3', **aws_credentials)
+ s3.download_file('lpi-poe', 'model_zoo/ViT-B-32.pt', model_path)
+ else:
+ download_file_from_google_drive("1Ri5APYM34A_IjG4F3Admutsf2oUwDjfW", model_path)
+ model = torch.load(model_path, map_location='cpu')
+
+ if dist.is_initialized() and is_local_master():
+ # Make sure only the first process in distributed training downloads the model, and the others use the cache.
+ barrier()
+
+ encoder = build_model(model).cuda()
+ return ImageEncoder(encoder)
+
+
+class ImageEncoder(nn.Module):
+ def __init__(self, encoder):
+ super().__init__()
+ self.model = encoder
+ self.image_size = self.model.visual.input_resolution
+ self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda")
+ self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda")
+
+ @torch.no_grad()
+ def forward(self, data, fake_images, align_corners=True):
+ images = 0.5 * (1 + fake_images)
+ images = F.interpolate(images, (self.image_size, self.image_size), mode='bicubic', align_corners=align_corners)
+ images.clamp_(0, 1)
+ images = (images - self.mean[None, :, None, None]) / (self.std[None, :, None, None])
+ image_code = self.model.encode_image(images)
+ return torch.cat((image_code, data['captions-clip']), dim=1)
diff --git a/imaginaire/evaluation/caption/r_precision.py b/imaginaire/evaluation/caption/r_precision.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec192e67cf38d24f4f238d151661367f9a5fa753
--- /dev/null
+++ b/imaginaire/evaluation/caption/r_precision.py
@@ -0,0 +1,27 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# flake8: noqa
+
+import torch
+import torch.nn.functional as F
+
+
+def get_r_precision(image_text_code, eps=1e-5):
+ all_image_code, all_text_code = torch.chunk(image_text_code, 2, dim=1)
+ P_rates = []
+ num_samples = len(all_image_code)
+ assert num_samples >= 100
+ for i in range(0, num_samples, 100):
+ if i + 100 <= num_samples:
+ cur_image_code = all_image_code[i:i + 100]
+ cur_text_code = all_text_code[i:i + 100]
+ cur_image_code = F.normalize(cur_image_code, dim=1, eps=eps)
+ cur_text_code = F.normalize(cur_text_code, dim=1, eps=eps)
+ cosine_similarities = cur_image_code @ cur_text_code.T
+ top1_indices = torch.topk(cosine_similarities, dim=1, k=1)[1][:, 0]
+ P_rate = torch.sum(top1_indices == torch.arange(100, device=top1_indices.device)).item()
+ P_rates.append(P_rate)
+ A_precision = sum(P_rates) * 1.0 / len(P_rates)
+ return {"caption_rprec": A_precision}
diff --git a/imaginaire/evaluation/common.py b/imaginaire/evaluation/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..046c47da6692f5bd24543cc0e72da2a86b7ef2a3
--- /dev/null
+++ b/imaginaire/evaluation/common.py
@@ -0,0 +1,651 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import math
+import os
+from functools import partial
+import torch
+import torch.distributed as dist
+from torch import nn
+from torch.nn import functional as F
+from torchvision.models import inception_v3
+from cleanfid.features import feature_extractor
+from cleanfid.resize import build_resizer
+
+from imaginaire.evaluation.lpips import get_lpips_model
+from imaginaire.evaluation.segmentation import get_segmentation_hist_model, get_miou
+from imaginaire.evaluation.caption import get_image_encoder, get_r_precision
+from imaginaire.evaluation.pretrained import TFInceptionV3, InceptionV3, Vgg16, SwAV
+from imaginaire.utils.distributed import (dist_all_gather_tensor, get_rank,
+ get_world_size, is_master,
+ is_local_master)
+from imaginaire.utils.distributed import master_only_print
+from imaginaire.utils.misc import apply_imagenet_normalization, to_cuda
+
+
+@torch.no_grad()
+def compute_all_metrics(act_dir,
+ data_loader,
+ net_G,
+ key_real='images',
+ key_fake='fake_images',
+ sample_size=None,
+ preprocess=None,
+ is_video=False,
+ few_shot_video=False,
+ kid_num_subsets=1,
+ kid_subset_size=None,
+ key_prefix='',
+ prdc_k=5,
+ metrics=None,
+ dataset_name='',
+ aws_credentials=None,
+ **kwargs):
+ r"""
+ Args:
+ act_dir (string): Path to a directory to temporarily save feature activations.
+ data_loader (obj): PyTorch dataloader object.
+ net_G (obj): The generator module.
+ key_real (str): Dictionary key value for the real data.
+ key_fake (str): Dictionary key value for the fake data.
+ sample_size (int or None): How many samples to use for FID.
+ preprocess (func or None): Pre-processing function to use.
+ is_video (bool): Whether we are handling video sequences.
+ few_shot_video (bool): If ``True``, uses few-shot video synthesis.
+ kid_num_subsets (int): Number of subsets for KID evaluation.
+ kid_subset_size (int or None): The number of samples in each subset for KID evaluation.
+ key_prefix (string): Add this string before all keys of the output dictionary.
+ prdc_k (int): The K used for computing K-NN when evaluating precision/recall/density/coverage.
+ metrics (list of strings): Which metrics we want to evaluate.
+ dataset_name (string): The name of the dataset, currently only used to determine which segmentation network to
+ use for segmentation evaluation.
+ Returns:
+ batch_y (tensor): Inception features of the current batch. Note that
+ only the master gpu will get it.
+ """
+
+ from imaginaire.evaluation.fid import _calculate_frechet_distance
+ from imaginaire.evaluation.kid import _polynomial_mmd_averages
+ from imaginaire.evaluation.prdc import _get_prdc
+ from imaginaire.evaluation.msid import _get_msid
+ from imaginaire.evaluation.knn import _get_1nn_acc
+ if metrics is None:
+ metrics = []
+ act_path = os.path.join(act_dir, 'activations_real.pt')
+
+ # Get feature activations and other outputs computed from fake images.
+ output_module_dict = nn.ModuleDict()
+ if "seg_mIOU" in metrics:
+ output_module_dict["seg_mIOU"] = get_segmentation_hist_model(dataset_name, aws_credentials)
+ if "caption_rprec" in metrics:
+ output_module_dict["caption_rprec"] = get_image_encoder(aws_credentials)
+ if "LPIPS" in metrics:
+ output_module_dict["LPIPS"] = get_lpips_model()
+
+ fake_outputs = get_outputs(
+ data_loader, key_real, key_fake, net_G, sample_size, preprocess,
+ output_module_dict=output_module_dict, **kwargs
+ )
+ fake_act = fake_outputs["activations"]
+
+ # Get feature activations computed from real images.
+ real_act = load_or_compute_activations(
+ act_path, data_loader, key_real, key_fake, None,
+ sample_size, preprocess, is_video=is_video,
+ few_shot_video=few_shot_video, **kwargs
+ )
+
+ metrics_from_activations = {
+ "1NN": _get_1nn_acc,
+ "MSID": _get_msid,
+ "FID": _calculate_frechet_distance,
+ "KID": partial(_polynomial_mmd_averages,
+ n_subsets=kid_num_subsets,
+ subset_size=kid_subset_size,
+ ret_var=True),
+ "PRDC": partial(_get_prdc, nearest_k=prdc_k)
+ }
+
+ other_metrics = {
+ "seg_mIOU": get_miou,
+ "caption_rprec": get_r_precision,
+ "LPIPS": lambda x: {"LPIPS": torch.mean(x).item()}
+ }
+
+ all_metrics = {}
+ if is_master():
+ for metric in metrics:
+ if metric in metrics_from_activations:
+ metric_function = metrics_from_activations[metric]
+ metric_dict = metric_function(real_act, fake_act)
+ elif metric in other_metrics:
+ metric_function = other_metrics[metric]
+ if fake_outputs[metric] is not None:
+ metric_dict = metric_function(fake_outputs[metric])
+ else:
+ print(f"{metric} is not implemented!")
+ raise NotImplementedError
+ for k, v in metric_dict.items():
+ all_metrics.update({key_prefix + k: v})
+ if dist.is_initialized():
+ dist.barrier()
+ return all_metrics
+
+
+@torch.no_grad()
+def compute_all_metrics_data(data_loader_a,
+ data_loader_b,
+ key_a='images',
+ key_b='images',
+ sample_size=None,
+ preprocess=None,
+ kid_num_subsets=1,
+ kid_subset_size=None,
+ key_prefix='',
+ prdc_k=5,
+ metrics=None,
+ dataset_name='',
+ aws_credentials=None,
+ **kwargs):
+ r"""
+ Args:
+ act_dir (string): Path to a directory to temporarily save feature activations.
+ data_loader (obj): PyTorch dataloader object.
+ net_G (obj): The generator module.
+ key_a (str): Dictionary key value for the real data.
+ key_b (str): Dictionary key value for the fake data.
+ sample_size (int or None): How many samples to use for FID.
+ preprocess (func or None): Pre-processing function to use.
+ is_video (bool): Whether we are handling video sequences.
+ few_shot_video (bool): If ``True``, uses few-shot video synthesis.
+ kid_num_subsets (int): Number of subsets for KID evaluation.
+ kid_subset_size (int or None): The number of samples in each subset for KID evaluation.
+ key_prefix (string): Add this string before all keys of the output dictionary.
+ prdc_k (int): The K used for computing K-NN when evaluating precision/recall/density/coverage.
+ metrics (list of strings): Which metrics we want to evaluate.
+ dataset_name (string): The name of the dataset, currently only used to determine which segmentation network to
+ use for segmentation evaluation.
+ Returns:
+ batch_y (tensor): Inception features of the current batch. Note that
+ only the master gpu will get it.
+ """
+
+ from imaginaire.evaluation.fid import _calculate_frechet_distance
+ from imaginaire.evaluation.kid import _polynomial_mmd_averages
+ from imaginaire.evaluation.prdc import _get_prdc
+ from imaginaire.evaluation.msid import _get_msid
+ from imaginaire.evaluation.knn import _get_1nn_acc
+ if metrics is None:
+ metrics = []
+
+ min_data_size = min(len(data_loader_a.dataset),
+ len(data_loader_b.dataset))
+ if sample_size is None:
+ sample_size = min_data_size
+ else:
+ sample_size = min(sample_size, min_data_size)
+
+ # Get feature activations and other outputs computed from fake images.
+ output_module_dict = nn.ModuleDict()
+ if "seg_mIOU" in metrics:
+ output_module_dict["seg_mIOU"] = get_segmentation_hist_model(dataset_name, aws_credentials)
+ if "caption_rprec" in metrics:
+ output_module_dict["caption_rprec"] = get_image_encoder(aws_credentials)
+ if "LPIPS" in metrics:
+ output_module_dict["LPIPS"] = get_lpips_model()
+
+ fake_outputs = get_outputs(
+ data_loader_b, key_a, key_b, None, sample_size, preprocess,
+ output_module_dict=output_module_dict, **kwargs
+ )
+ act_b = fake_outputs["activations"]
+
+ act_a = load_or_compute_activations(
+ None, data_loader_a, key_a, key_b, None, sample_size, preprocess,
+ output_module_dict=output_module_dict, **kwargs
+ )
+
+ # act_b = load_or_compute_activations(
+ # None, data_loader_b, key_a, key_b, None, sample_size, preprocess,
+ # output_module_dict=output_module_dict, generate_twice=generate_twice, **kwargs
+ # )
+
+ metrics_from_activations = {
+ "1NN": _get_1nn_acc,
+ "MSID": _get_msid,
+ "FID": _calculate_frechet_distance,
+ "KID": partial(_polynomial_mmd_averages,
+ n_subsets=kid_num_subsets,
+ subset_size=kid_subset_size,
+ ret_var=True),
+ "PRDC": partial(_get_prdc, nearest_k=prdc_k)
+ }
+
+ other_metrics = {
+ "seg_mIOU": get_miou,
+ "caption_rprec": get_r_precision,
+ "LPIPS": lambda x: {"LPIPS": torch.mean(x).item()}
+ }
+
+ all_metrics = {}
+ if is_master():
+ for metric in metrics:
+ if metric in metrics_from_activations:
+ metric_function = metrics_from_activations[metric]
+ metric_dict = metric_function(act_a, act_b)
+ elif metric in other_metrics:
+ metric_function = other_metrics[metric]
+ if fake_outputs[metric] is not None:
+ metric_dict = metric_function(fake_outputs[metric])
+ else:
+ print(f"{metric} is not implemented!")
+ raise NotImplementedError
+ for k, v in metric_dict.items():
+ all_metrics.update({key_prefix + k: v})
+ if dist.is_initialized():
+ dist.barrier()
+ return all_metrics
+
+
+@torch.no_grad()
+def get_activations(data_loader, key_real, key_fake,
+ generator=None, sample_size=None, preprocess=None,
+ align_corners=True, network='inception', **kwargs):
+ r"""Compute activation values and pack them in a list.
+
+ Args:
+ data_loader (obj): PyTorch dataloader object.
+ key_real (str): Dictionary key value for the real data.
+ key_fake (str): Dictionary key value for the fake data.
+ generator (obj): PyTorch trainer network.
+ sample_size (int): How many samples to use for FID.
+ preprocess (func): Pre-processing function to use.
+ align_corners (bool): The ``'align_corners'`` parameter to be used for
+ `torch.nn.functional.interpolate`.
+ Returns:
+ batch_y (tensor): Inception features of the current batch. Note that
+ only the master gpu will get it.
+ """
+ if dist.is_initialized() and not is_local_master():
+ # Make sure only the first process in distributed training downloads
+ # the model, and the others will use the cache
+ # noinspection PyUnresolvedReferences
+ torch.distributed.barrier()
+
+ if network == 'tf_inception':
+ model = TFInceptionV3()
+ elif network == 'inception':
+ model = InceptionV3()
+ elif network == 'vgg16':
+ model = Vgg16()
+ elif network == 'swav':
+ model = SwAV()
+ elif network == 'clean_inception':
+ model = CleanInceptionV3()
+ else:
+ raise NotImplementedError(f'Network "{network}" is not supported!')
+
+ if dist.is_initialized() and is_local_master():
+ # Make sure only the first process in distributed training downloads
+ # the model, and the others will use the cache
+ # noinspection PyUnresolvedReferences
+ dist.barrier()
+
+ model = model.to('cuda').eval()
+ world_size = get_world_size()
+ batch_y = []
+
+ # Iterate through the dataset to compute the activation.
+ for it, data in enumerate(data_loader):
+ data = to_cuda(data)
+ # Preprocess the data.
+ if preprocess is not None:
+ data = preprocess(data)
+ # Load real data if the generator is not specified.
+ if generator is None:
+ images = data[key_real]
+ else:
+ # Compute the generated image.
+ net_G_output = generator(data, **kwargs)
+ images = net_G_output[key_fake]
+ # Clamp the image for models that do not set the output to between
+ # -1, 1. For models that employ tanh, this has no effect.
+ images.clamp_(-1, 1)
+ y = model(images, align_corners=align_corners)
+ batch_y.append(y)
+ if sample_size is not None and \
+ data_loader.batch_size * world_size * (it + 1) >= sample_size:
+ # Reach the number of samples we need.
+ break
+
+ batch_y = torch.cat(dist_all_gather_tensor(torch.cat(batch_y)))
+ if sample_size is not None:
+ batch_y = batch_y[:sample_size]
+ print(f"Computed feature activations of size {batch_y.shape}")
+ return batch_y
+
+
+class CleanInceptionV3(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = feature_extractor(name="torchscript_inception", resize_inside=False)
+
+ def forward(self, img_batch, transform=True, **_kwargs):
+ if transform:
+ # Assume the input is (-1, 1). We transform it to (0, 255) and round it to the closest integer.
+ img_batch = torch.round(255 * (0.5 * img_batch + 0.5))
+ resized_batch = clean_resize(img_batch)
+ return self.model(resized_batch)
+
+
+def clean_resize(img_batch):
+ # Resize images from arbitrary resolutions to 299x299.
+ batch_size = img_batch.size(0)
+ img_batch = img_batch.cpu().numpy()
+ fn_resize = build_resizer('clean')
+ resized_batch = torch.zeros(batch_size, 3, 299, 299, device='cuda')
+ for idx in range(batch_size):
+ curr_img = img_batch[idx]
+ img_np = curr_img.transpose((1, 2, 0))
+ img_resize = fn_resize(img_np)
+ resized_batch[idx] = torch.tensor(img_resize.transpose((2, 0, 1)), device='cuda')
+ resized_batch = resized_batch.cuda()
+ return resized_batch
+
+
+@torch.no_grad()
+def get_outputs(data_loader, key_real, key_fake,
+ generator=None, sample_size=None, preprocess=None,
+ align_corners=True, network='inception',
+ output_module_dict=None, **kwargs):
+ r"""Compute activation values and pack them in a list.
+
+ Args:
+ data_loader (obj): PyTorch dataloader object.
+ key_real (str): Dictionary key value for the real data.
+ key_fake (str): Dictionary key value for the fake data.
+ generator (obj): PyTorch trainer network.
+ sample_size (int): How many samples to use for FID.
+ preprocess (func): Pre-processing function to use.
+ align_corners (bool): The ``'align_corners'`` parameter to be used for `torch.nn.functional.interpolate`.
+ Returns:
+ batch_y (tensor): Inception features of the current batch. Note that
+ only the master gpu will get it.
+ """
+ if output_module_dict is None:
+ output_module_dict = nn.ModuleDict()
+ if dist.is_initialized() and not is_local_master():
+ # Make sure only the first process in distributed training downloads
+ # the model, and the others will use the cache
+ # noinspection PyUnresolvedReferences
+ torch.distributed.barrier()
+
+ if network == 'tf_inception':
+ model = TFInceptionV3()
+ elif network == 'inception':
+ model = InceptionV3()
+ elif network == 'vgg16':
+ model = Vgg16()
+ elif network == 'swav':
+ model = SwAV()
+ elif network == 'clean_inception':
+ model = CleanInceptionV3()
+ else:
+ raise NotImplementedError(f'Network "{network}" is not supported!')
+
+ if dist.is_initialized() and is_local_master():
+ # Make sure only the first process in distributed training downloads
+ # the model, and the others will use the cache
+ # noinspection PyUnresolvedReferences
+ dist.barrier()
+
+ model = model.to('cuda').eval()
+ world_size = get_world_size()
+ output = {}
+ for k in output_module_dict.keys():
+ output[k] = []
+ output["activations"] = []
+
+ # Iterate through the dataset to compute the activation.
+ for it, data in enumerate(data_loader):
+ data = to_cuda(data)
+ # Preprocess the data.
+ if preprocess is not None:
+ data = preprocess(data)
+ # Load real data if the generator is not specified.
+ if generator is None:
+ images = data[key_real]
+ else:
+ # Compute the generated image.
+ net_G_output = generator(data, **kwargs)
+ images = net_G_output[key_fake]
+ for metric_name, metric_module in output_module_dict.items():
+ if metric_module is not None:
+ if metric_name == 'LPIPS':
+ assert generator is not None
+ net_G_output_another = generator(data, **kwargs)
+ images_another = net_G_output_another[key_fake]
+ output[metric_name].append(metric_module(images, images_another))
+ else:
+ output[metric_name].append(metric_module(data, images, align_corners=align_corners))
+ # Clamp the image for models that do not set the output to between
+ # -1, 1. For models that employ tanh, this has no effect.
+ images.clamp_(-1, 1)
+ y = model(images, align_corners=align_corners)
+ output["activations"].append(y)
+ if sample_size is not None and data_loader.batch_size * world_size * (it + 1) >= sample_size:
+ # Reach the number of samples we need.
+ break
+
+ for k, v in output.items():
+ if len(v) > 0:
+ output[k] = torch.cat(dist_all_gather_tensor(torch.cat(v)))[:sample_size]
+ else:
+ output[k] = None
+ return output
+
+
+@torch.no_grad()
+def get_video_activations(data_loader, key_real, key_fake, trainer=None,
+ sample_size=None, preprocess=None, few_shot=False):
+ r"""Compute activation values and pack them in a list. We do not do all
+ reduce here.
+
+ Args:
+ data_loader (obj): PyTorch dataloader object.
+ key_real (str): Dictionary key value for the real data.
+ key_fake (str): Dictionary key value for the fake data.
+ trainer (obj): Trainer. Video generation is more involved, we rely on
+ the "reset" and "test" function to conduct the evaluation.
+ sample_size (int): For computing video activation, we will use .
+ preprocess (func): The preprocess function to be applied to the data.
+ few_shot (bool): If ``True``, uses the few-shot setting.
+ Returns:
+ batch_y (tensor): Inception features of the current batch. Note that
+ only the master gpu will get it.
+ """
+ inception = inception_init()
+ batch_y = []
+
+ # We divide video sequences to different GPUs for testing.
+ num_sequences = data_loader.dataset.num_inference_sequences()
+ if sample_size is None:
+ num_videos_to_test = 10
+ num_frames_per_video = 5
+ else:
+ num_videos_to_test, num_frames_per_video = sample_size
+ if num_videos_to_test == -1:
+ num_videos_to_test = num_sequences
+ else:
+ num_videos_to_test = min(num_videos_to_test, num_sequences)
+ master_only_print('Number of videos used for evaluation: {}'.format(num_videos_to_test))
+ master_only_print('Number of frames per video used for evaluation: {}'.format(num_frames_per_video))
+
+ world_size = get_world_size()
+ if num_videos_to_test < world_size:
+ seq_to_run = [get_rank() % num_videos_to_test]
+ else:
+ num_videos_to_test = num_videos_to_test // world_size * world_size
+ seq_to_run = range(get_rank(), num_videos_to_test, world_size)
+
+ for sequence_idx in seq_to_run:
+ data_loader = set_sequence_idx(few_shot, data_loader, sequence_idx)
+ if trainer is not None:
+ trainer.reset()
+ for it, data in enumerate(data_loader):
+ if few_shot and it == 0:
+ continue
+ if it >= num_frames_per_video:
+ break
+
+ # preprocess the data is preprocess is not none.
+ if trainer is not None:
+ data = trainer.pre_process(data)
+ elif preprocess is not None:
+ data = preprocess(data)
+ data = to_cuda(data)
+
+ if trainer is None:
+ images = data[key_real][:, -1]
+ else:
+ net_G_output = trainer.test_single(data)
+ images = net_G_output[key_fake]
+ y = inception_forward(inception, images)
+ batch_y += [y]
+
+ batch_y = torch.cat(batch_y)
+ batch_y = dist_all_gather_tensor(batch_y)
+ if is_local_master():
+ batch_y = torch.cat(batch_y)
+ return batch_y
+
+
+def inception_init():
+ inception = inception_v3(pretrained=True, transform_input=False)
+ inception = inception.to('cuda')
+ inception.eval()
+ inception.fc = torch.nn.Sequential()
+ return inception
+
+
+def inception_forward(inception, images):
+ images.clamp_(-1, 1)
+ images = apply_imagenet_normalization(images)
+ images = F.interpolate(images, size=(299, 299),
+ mode='bicubic', align_corners=True)
+ return inception(images)
+
+
+def gather_tensors(batch_y):
+ batch_y = torch.cat(batch_y)
+ batch_y = dist_all_gather_tensor(batch_y)
+ if is_local_master():
+ batch_y = torch.cat(batch_y)
+ return batch_y
+
+
+def set_sequence_idx(few_shot, data_loader, sequence_idx):
+ r"""Get sequence index
+
+ Args:
+ few_shot (bool): If ``True``, uses the few-shot setting.
+ data_loader: dataloader object
+ sequence_idx (int): which sequence to use.
+ """
+ if few_shot:
+ data_loader.dataset.set_inference_sequence_idx(sequence_idx,
+ sequence_idx,
+ 0)
+ else:
+ data_loader.dataset.set_inference_sequence_idx(sequence_idx)
+ return data_loader
+
+
+def load_or_compute_activations(act_path, data_loader, key_real, key_fake,
+ generator=None, sample_size=None,
+ preprocess=None,
+ is_video=False, few_shot_video=False,
+ **kwargs):
+ r"""Load mean and covariance from saved npy file if exists. Otherwise,
+ compute the mean and covariance.
+
+ Args:
+ act_path (str or None): Location for the numpy file to store or to load
+ the activations.
+ data_loader (obj): PyTorch dataloader object.
+ key_real (str): Dictionary key value for the real data.
+ key_fake (str): Dictionary key value for the fake data.
+ generator (obj): PyTorch trainer network.
+ sample_size (int): How many samples to be used for computing the KID.
+ preprocess (func): The preprocess function to be applied to the data.
+ is_video (bool): Whether we are handling video sequences.
+ few_shot_video (bool): If ``True``, uses few-shot video synthesis.
+ Returns:
+ (torch.Tensor) Feature activations.
+ """
+ if act_path is not None and os.path.exists(act_path):
+ # Loading precomputed activations.
+ print('Load activations from {}'.format(act_path))
+ act = torch.load(act_path, map_location='cpu').cuda()
+ else:
+ # Compute activations.
+ if is_video:
+ act = get_video_activations(
+ data_loader, key_real, key_fake, generator,
+ sample_size, preprocess, few_shot_video, **kwargs
+ )
+ else:
+ act = get_activations(
+ data_loader, key_real, key_fake, generator,
+ sample_size, preprocess, **kwargs
+ )
+ if act_path is not None and is_local_master():
+ print('Save activations to {}'.format(act_path))
+ if not os.path.exists(os.path.dirname(act_path)):
+ os.makedirs(os.path.dirname(act_path), exist_ok=True)
+ torch.save(act, act_path)
+ return act
+
+
+def compute_pairwise_distance(data_x, data_y=None, num_splits=10):
+ r"""
+
+ Args:
+ data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
+ data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
+ Returns:
+ numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
+ """
+ if data_y is None:
+ data_y = data_x
+ num_samples = data_x.shape[0]
+ assert data_x.shape[0] == data_y.shape[0]
+ dists = []
+ for i in range(num_splits):
+ batch_size = math.ceil(num_samples / num_splits)
+ start_idx = i * batch_size
+ end_idx = min((i + 1) * batch_size, num_samples)
+ dists.append(torch.cdist(data_x[start_idx:end_idx],
+ data_y).cpu())
+ dists = torch.cat(dists, dim=0)
+ return dists
+
+
+def compute_nn(input_features, k, num_splits=50):
+ num_samples = input_features.shape[0]
+ all_indices = []
+ all_values = []
+ for i in range(num_splits):
+ batch_size = math.ceil(num_samples / num_splits)
+ start_idx = i * batch_size
+ end_idx = min((i + 1) * batch_size, num_samples)
+ dist = torch.cdist(input_features[start_idx:end_idx],
+ input_features)
+ dist[:, start_idx:end_idx] += torch.diag(
+ float('inf') * torch.ones(dist.size(0), device=dist.device)
+ )
+ k_smallests, indices = torch.topk(dist, k, dim=-1, largest=False)
+ all_indices.append(indices)
+ all_values.append(k_smallests)
+ return torch.cat(all_values, dim=0), torch.cat(all_indices, dim=0)
diff --git a/imaginaire/evaluation/fid.py b/imaginaire/evaluation/fid.py
new file mode 100644
index 0000000000000000000000000000000000000000..793f4b77dd50381abffc82d03fee8bc36e746dab
--- /dev/null
+++ b/imaginaire/evaluation/fid.py
@@ -0,0 +1,143 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+import numpy as np
+import torch
+from scipy import linalg
+
+from imaginaire.evaluation.common import load_or_compute_activations
+from imaginaire.utils.distributed import is_master
+from imaginaire.utils.distributed import master_only_print as print
+
+
+@torch.no_grad()
+def compute_fid(fid_path, data_loader, net_G,
+ key_real='images', key_fake='fake_images',
+ sample_size=None, preprocess=None, return_act=False,
+ is_video=False, few_shot_video=False, **kwargs):
+ r"""Compute the fid score.
+
+ Args:
+ fid_path (str): Location for the numpy file to store or to load the
+ statistics.
+ data_loader (obj): PyTorch dataloader object.
+ net_G (obj): For image generation modes, net_G is the generator network.
+ For video generation models, net_G is the trainer.
+ key_real (str): Dictionary key value for the real data.
+ key_fake (str): Dictionary key value for the fake data.
+ sample_size (int or tuple): How many samples to be used.
+ preprocess (func): The preprocess function to be applied to the data.
+ return_act (bool): If ``True``, also returns feature activations of
+ real and fake data.
+ is_video (bool): Whether we are handling video sequences.
+ few_shot_video (bool): If ``True``, uses few-shot video synthesis.
+ Returns:
+ (float): FID value.
+ """
+ print('Computing FID.')
+ act_path = os.path.join(os.path.dirname(fid_path),
+ 'activations_real.npy')
+ # Get the fake mean and covariance.
+ fake_act = load_or_compute_activations(
+ None, data_loader, key_real, key_fake, net_G,
+ sample_size, preprocess, is_video=is_video,
+ few_shot_video=few_shot_video, **kwargs
+ )
+
+ # Get the ground truth mean and covariance.
+ real_act = load_or_compute_activations(
+ act_path, data_loader, key_real, key_fake, None,
+ sample_size, preprocess, is_video=is_video,
+ few_shot_video=few_shot_video, **kwargs
+ )
+
+ if is_master():
+ fid = _calculate_frechet_distance(
+ fake_act, real_act)["FID"]
+ if return_act:
+ return fid, real_act, fake_act
+ else:
+ return fid
+ elif return_act:
+ return None, None, None
+ else:
+ return None
+
+
+@torch.no_grad()
+def compute_fid_data(fid_path, data_loader_a, data_loader_b,
+ key_a='images', key_b='images', sample_size=None,
+ is_video=False, few_shot_video=False, **kwargs):
+ r"""Compute the fid score between two datasets.
+
+ Args:
+ fid_path (str): Location for the numpy file to store or to load the
+ statistics.
+ data_loader_a (obj): PyTorch dataloader object for dataset a.
+ data_loader_b (obj): PyTorch dataloader object for dataset b.
+ key_a (str): Dictionary key value for images in the dataset a.
+ key_b (str): Dictionary key value for images in the dataset b.
+ sample_size (int): How many samples to be used for computing the FID.
+ is_video (bool): Whether we are handling video sequences.
+ few_shot_video (bool): If ``True``, uses few-shot video synthesis.
+ Returns:
+ (float): FID value.
+ """
+ print('Computing FID.')
+ path_a = os.path.join(os.path.dirname(fid_path),
+ 'activations_a.npy')
+ min_data_size = min(len(data_loader_a.dataset),
+ len(data_loader_b.dataset))
+ if sample_size is None:
+ sample_size = min_data_size
+ else:
+ sample_size = min(sample_size, min_data_size)
+
+ act_a = load_or_compute_activations(
+ path_a, data_loader_a, key_a, key_b, None,
+ sample_size=sample_size, is_video=is_video,
+ few_shot_video=few_shot_video, **kwargs
+ )
+ act_b = load_or_compute_activations(
+ None, data_loader_b, key_a, key_b, None,
+ sample_size=sample_size, is_video=is_video,
+ few_shot_video=few_shot_video, **kwargs
+ )
+
+ if is_master():
+ return _calculate_frechet_distance(act_a, act_b)["FID"]
+
+
+def _calculate_frechet_distance(act_1, act_2, eps=1e-6):
+ mu1 = np.mean(act_1.cpu().numpy(), axis=0)
+ sigma1 = np.cov(act_1.cpu().numpy(), rowvar=False)
+ mu2 = np.mean(act_2.cpu().numpy(), axis=0)
+ sigma2 = np.cov(act_2.cpu().numpy(), rowvar=False)
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+ assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
+ assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'
+ diff = mu1 - mu2
+ # Product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = ('fid calculation produces singular product; '
+ 'adding %s to diagonal of cov estimates') % eps
+ print(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # Numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ print('Imaginary component {}'.format(m))
+ # raise ValueError('Imaginary component {}'.format(m))
+ covmean = covmean.real
+ tr_covmean = np.trace(covmean)
+ return {"FID": (diff.dot(diff) + np.trace(sigma1) + np.trace(
+ sigma2) - 2 * tr_covmean)}
diff --git a/imaginaire/evaluation/kid.py b/imaginaire/evaluation/kid.py
new file mode 100644
index 0000000000000000000000000000000000000000..675b93015b03a8c1f4557e5687de0887b1f5a0a4
--- /dev/null
+++ b/imaginaire/evaluation/kid.py
@@ -0,0 +1,317 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+"""
+Modified from https://github.com/abdulfatir/gan-metrics-pytorch
+Copyright 2018 Institute of Bioinformatics, JKU Linz
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import warnings
+
+import numpy as np
+import torch
+
+from imaginaire.evaluation.common import get_activations, \
+ load_or_compute_activations
+from imaginaire.utils.distributed import is_master
+from imaginaire.utils.distributed import master_only_print as print
+
+
+@torch.no_grad()
+def compute_kid(kid_path, data_loader, net_G,
+ key_real='images', key_fake='fake_images',
+ real_act=None, fake_act=None,
+ sample_size=None, preprocess=None, is_video=False,
+ save_act=True, num_subsets=1, subset_size=None, **kwargs):
+ r"""Compute the kid score.
+
+ Args:
+ kid_path (str): Location for store feature activations.
+ data_loader (obj): PyTorch dataloader object.
+ net_G (obj): For image generation modes, net_G is the PyTorch trainer
+ network. For video generation models, net_G is the trainer
+ because video generation requires more complicated processing.
+ key_real (str): Dictionary key value for the real data.
+ key_fake (str): Dictionary key value for the fake data.
+ real_act (torch.Tensor or None): Feature activations of real data.
+ fake_act (torch.Tensor or None): Feature activations of fake data.
+ sample_size (int): How many samples to be used for computing feature
+ activations.
+ preprocess (func): The preprocess function to be applied to the data.
+ is_video (bool): Whether we are handling video sequences.
+ save_act (bool): If ``True``, saves real activations to the disk and
+ reload them in the future. It might save some computation but will
+ cost storage.
+ num_subsets (int): Number of subsets to sample from all the samples.
+ subset_size (int): Number of samples in each subset.
+ Returns:
+ kid (float): KID value.
+ """
+ print('Computing KID.')
+ act_path = os.path.join(
+ os.path.dirname(kid_path), 'activations_real.npy'
+ ) if save_act else None
+
+ # Get the fake activations.
+ if fake_act is None:
+ fake_act = load_or_compute_activations(None,
+ data_loader,
+ key_real, key_fake, net_G,
+ sample_size, preprocess,
+ is_video=is_video, **kwargs)
+ else:
+ print(f"Using precomputed activations of size {fake_act.shape}.")
+
+ # Get the ground truth activations.
+ if real_act is None:
+ real_act = load_or_compute_activations(act_path,
+ data_loader,
+ key_real, key_fake, None,
+ sample_size, preprocess,
+ is_video=is_video, **kwargs)
+ else:
+ print(f"Using precomputed activations of size {real_act.shape}.")
+
+ if is_master():
+ return _polynomial_mmd_averages(fake_act, real_act,
+ num_subsets,
+ subset_size,
+ ret_var=True)["KID"]
+
+
+@torch.no_grad()
+def compute_kid_data(kid_path, data_loader_a, data_loader_b,
+ key_a='images', key_b='images', sample_size=None,
+ is_video=False, num_subsets=1, subset_size=None,
+ **kwargs):
+ r"""Compute the kid score between two datasets.
+
+ Args:
+ kid_path (str): Location for store feature activations.
+ data_loader_a (obj): PyTorch dataloader object for dataset a.
+ data_loader_b (obj): PyTorch dataloader object for dataset b.
+ key_a (str): Dictionary key value for images in the dataset a.
+ key_b (str): Dictionary key value for images in the dataset b.
+ sample_size (int): How many samples to be used for computing the KID.
+ is_video (bool): Whether we are handling video sequences.
+ num_subsets (int): Number of subsets to sample from the whole data.
+ subset_size (int): Number of samples in each subset.
+ Returns:
+ kid (float): KID value.
+ """
+ min_data_size = min(len(data_loader_a.dataset),
+ len(data_loader_b.dataset))
+ if sample_size is None:
+ sample_size = min_data_size
+ else:
+ sample_size = min(sample_size, min_data_size)
+ print('Computing KID using {} images from both distributions.'.
+ format(sample_size))
+ path_a = os.path.join(os.path.dirname(kid_path),
+ 'activations_a.npy')
+ act_a = load_or_compute_activations(path_a, data_loader_a,
+ key_a, key_a,
+ sample_size=sample_size,
+ is_video=is_video, **kwargs)
+ act_b = get_activations(data_loader_b, key_b, key_b,
+ None, sample_size, None, **kwargs)
+
+ if is_master():
+ return _polynomial_mmd_averages(act_a, act_b,
+ num_subsets,
+ subset_size,
+ ret_var=True)["KID"]
+
+
+def _polynomial_mmd_averages(codes_g, codes_r, n_subsets, subset_size,
+ ret_var=True, **kernel_args):
+ r"""Computes MMD between two sets of features using polynomial kernels. It
+ performs a number of repetitions of subset sampling without replacement.
+
+ Args:
+ codes_g (Tensor): Feature activations of generated images.
+ codes_r (Tensor): Feature activations of real images.
+ n_subsets (int): The number of subsets.
+ subset_size (int): The number of samples in each subset.
+ ret_var (bool): If ``True``, returns both mean and variance of MMDs,
+ otherwise only returns the mean.
+ Returns:
+ (tuple):
+ - mmds (Tensor): Mean of MMDs.
+ - mmd_vars (Tensor): Variance of MMDs.
+ """
+ mmds = np.zeros(n_subsets)
+ if ret_var:
+ mmd_vars = np.zeros(n_subsets)
+ choice = np.random.choice
+
+ if subset_size is None:
+ subset_size = min(len(codes_r), len(codes_r))
+ print("Subset size not provided, "
+ "setting it to the data size ({}).".format(subset_size))
+ if subset_size > len(codes_g) or subset_size > len(codes_r):
+ subset_size = min(len(codes_r), len(codes_r))
+ warnings.warn(
+ "Subset size is large than the actual data size, "
+ "setting it to the data size ({}).".format(subset_size))
+
+ for i in range(n_subsets):
+ g = codes_g[choice(len(codes_g), subset_size, replace=False)]
+ r = codes_r[choice(len(codes_r), subset_size, replace=False)]
+ o = _polynomial_mmd(g, r, **kernel_args, ret_var=ret_var)
+ if ret_var:
+ # noinspection PyUnboundLocalVariable
+ mmds[i], mmd_vars[i] = o
+ else:
+ mmds[i] = o
+ return {'KID': mmds.mean()}
+
+
+def _polynomial_kernel(X, Y=None, degree=3, gamma=None, coef0=1.):
+ r"""Compute the polynomial kernel between X and Y"""
+ if gamma is None:
+ gamma = 1.0 / X.shape[1]
+ if Y is None:
+ Y = X
+
+ # K = safe_sparse_dot(X, Y.T, dense_output=True)
+ K = torch.matmul(X, Y.t())
+ K *= gamma
+ K += coef0
+ K = K ** degree
+ return K
+
+
+def _polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1,
+ ret_var=True):
+ r"""Computes MMD between two sets of features using polynomial kernels. It
+ performs a number of repetitions of subset sampling without replacement.
+
+ Args:
+ codes_g (torch.Tensor): Feature activations of generated images.
+ codes_r (torch.Tensor): Feature activations of real images.
+ degree (int): The degree of the polynomial kernel.
+ gamma (float or None): Scale of the polynomial kernel.
+ coef0 (float or None): Bias of the polynomial kernel.
+ ret_var (bool): If ``True``, returns both mean and variance of MMDs,
+ otherwise only returns the mean.
+ Returns:
+ (tuple):
+ - mmds (torch.Tensor): Mean of MMDs.
+ - mmd_vars (torch.Tensor): Variance of MMDs.
+ """
+ # use k(x, y) = (gamma + coef0)^degree
+ # default gamma is 1 / dim
+ X = codes_g
+ Y = codes_r
+
+ # with warnings.catch_warnings():
+ # warnings.simplefilter('ignore')
+ K_XX = _polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0)
+ K_YY = _polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0)
+ K_XY = _polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0)
+
+ return _mmd2_and_variance(K_XX, K_XY, K_YY, ret_var=ret_var)
+
+
+def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False,
+ mmd_est='unbiased', ret_var=True):
+ r"""Based on
+ https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py
+ but changed to not compute the full kernel matrix at once
+ """
+
+ m = K_XX.shape[0]
+ assert K_XX.shape == (m, m)
+ assert K_XY.shape == (m, m)
+ assert K_YY.shape == (m, m)
+ var_at_m = m
+
+ # Get the various sums of kernels that we'll use
+ # Kts drop the diagonal, but we don't need to compute them explicitly
+ if unit_diagonal:
+ diag_X = diag_Y = 1
+ sum_diag_X = sum_diag_Y = m
+ sum_diag2_X = sum_diag2_Y = m
+ else:
+ diag_X = torch.diagonal(K_XX)
+ diag_Y = torch.diagonal(K_YY)
+
+ sum_diag_X = diag_X.sum()
+ sum_diag_Y = diag_Y.sum()
+
+ sum_diag2_X = _sqn(diag_X)
+ sum_diag2_Y = _sqn(diag_Y)
+
+ Kt_XX_sums = K_XX.sum(dim=1) - diag_X
+ Kt_YY_sums = K_YY.sum(dim=1) - diag_Y
+ K_XY_sums_0 = K_XY.sum(dim=0)
+ K_XY_sums_1 = K_XY.sum(dim=1)
+
+ Kt_XX_sum = Kt_XX_sums.sum()
+ Kt_YY_sum = Kt_YY_sums.sum()
+ K_XY_sum = K_XY_sums_0.sum()
+
+ if mmd_est == 'biased':
+ mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
+ + (Kt_YY_sum + sum_diag_Y) / (m * m)
+ - 2 * K_XY_sum / (m * m))
+ else:
+ assert mmd_est in {'unbiased', 'u-statistic'}
+ mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1))
+ if mmd_est == 'unbiased':
+ mmd2 -= 2 * K_XY_sum / (m * m)
+ else:
+ mmd2 -= 2 * (K_XY_sum - torch.trace(K_XY)) / (m * (m - 1))
+
+ if not ret_var:
+ return mmd2
+
+ Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X
+ Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y
+ K_XY_2_sum = _sqn(K_XY)
+
+ dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1)
+ dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0)
+
+ m1 = m - 1
+ m2 = m - 2
+ zeta1_est = (
+ 1 / (m * m1 * m2) *
+ (_sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum)
+ - 1 / (m * m1) ** 2 * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2)
+ + 1 / (m * m * m1) * (
+ _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum)
+ - 2 / m ** 4 * K_XY_sum ** 2
+ - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
+ + 2 / (m ** 3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
+ )
+ zeta2_est = (
+ 1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum)
+ - 1 / (m * m1) ** 2 * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2)
+ + 2 / (m * m) * K_XY_2_sum
+ - 2 / m ** 4 * K_XY_sum ** 2
+ - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX)
+ + 4 / (m ** 3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
+ )
+ var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est
+ + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est)
+
+ return mmd2.cpu().numpy(), var_est.cpu().numpy()
+
+
+def _sqn(arr):
+ r"""Squared norm."""
+ flat = arr.view(-1)
+ return flat.dot(flat)
diff --git a/imaginaire/evaluation/knn.py b/imaginaire/evaluation/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..48b1cf502dc01e89e67f021572df9e94c71938ef
--- /dev/null
+++ b/imaginaire/evaluation/knn.py
@@ -0,0 +1,35 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+
+from imaginaire.evaluation.common import compute_nn
+
+
+def _get_1nn_acc(data_x, data_y, k=1):
+ device = data_x.device
+ n0 = data_x.size(0)
+ n1 = data_y.size(0)
+ data_all = torch.cat((data_x, data_y), dim=0)
+ val, idx = compute_nn(data_all, k)
+ label = torch.cat((torch.ones(n0, device=device),
+ torch.zeros(n1, device=device)))
+
+ count = torch.zeros(n0 + n1, device=device)
+ for i in range(0, k):
+ count = count + label.index_select(0, idx[:, i])
+ pred = torch.ge(count, (float(k) / 2) *
+ torch.ones(n0 + n1, device=device)).float()
+
+ tp = (pred * label).sum()
+ fp = (pred * (1 - label)).sum()
+ fn = ((1 - pred) * label).sum()
+ tn = ((1 - pred) * (1 - label)).sum()
+ acc_r = (tp / (tp + fn)).item()
+ acc_f = (tn / (tn + fp)).item()
+ acc = torch.eq(label, pred).float().mean().item()
+
+ return {'1NN_acc': acc,
+ '1NN_acc_real': acc_r,
+ '1NN_acc_fake': acc_f}
diff --git a/imaginaire/evaluation/lpips.py b/imaginaire/evaluation/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..37b10d77ad739151347ee2c45e6417145206cdbc
--- /dev/null
+++ b/imaginaire/evaluation/lpips.py
@@ -0,0 +1,153 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from collections import namedtuple
+
+import torch
+from torch import nn, distributed as dist
+import torchvision.models as tv
+from torch.distributed import barrier
+
+from imaginaire.utils.distributed import is_local_master
+
+
+def get_lpips_model():
+ if dist.is_initialized() and not is_local_master():
+ # Make sure only the first process in distributed training downloads the model, and the others use the cache.
+ barrier()
+
+ model = LPIPSNet().cuda()
+
+ if dist.is_initialized() and is_local_master():
+ # Make sure only the first process in distributed training downloads the model, and the others use the cache.
+ barrier()
+ return model
+
+
+# Learned perceptual network, modified from https://github.com/richzhang/PerceptualSimilarity
+
+def normalize_tensor(in_feat, eps=1e-5):
+ norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True) + eps)
+ return in_feat / (norm_factor + eps)
+
+
+class NetLinLayer(nn.Module):
+ """ A single linear layer used as placeholder for LPIPS learnt weights """
+
+ def __init__(self, dim):
+ super(NetLinLayer, self).__init__()
+ self.weight = nn.Parameter(torch.zeros(1, dim, 1, 1))
+
+ def forward(self, inp):
+ out = self.weight * inp
+ return out
+
+
+class ScalingLayer(nn.Module):
+ # For rescaling the input to vgg16
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class LPIPSNet(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = LPNet()
+
+ @torch.no_grad()
+ def forward(self, fake_images, fake_images_another, align_corners=True):
+ features, shape = self._forward_single(fake_images)
+ features_another, _ = self._forward_single(fake_images_another)
+ result = 0
+ for i, g_feat in enumerate(features):
+ cur_diff = torch.sum((g_feat - features_another[i]) ** 2, dim=1) / (shape[i] ** 2)
+ result += cur_diff
+ return result
+
+ def _forward_single(self, images):
+ return self.model(torch.clamp(images, 0, 1))
+
+
+class LPNet(nn.Module):
+ def __init__(self):
+ super(LPNet, self).__init__()
+
+ self.scaling_layer = ScalingLayer()
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.L = 5
+ dims = [64, 128, 256, 512, 512]
+ self.lins = nn.ModuleList([NetLinLayer(dims[i]) for i in range(self.L)])
+
+ weights = torch.hub.load_state_dict_from_url(
+ 'https://github.com/niopeng/CAM-Net/raw/main/code/models/weights/v0.1/vgg.pth'
+ )
+ for i in range(self.L):
+ self.lins[i].weight.data = torch.sqrt(weights["lin%d.model.1.weight" % i])
+
+ def forward(self, in0, avg=False):
+ in0 = 2 * in0 - 1
+ in0_input = self.scaling_layer(in0)
+ outs0 = self.net.forward(in0_input)
+ feats0 = {}
+ shapes = []
+ res = []
+
+ for kk in range(self.L):
+ feats0[kk] = normalize_tensor(outs0[kk])
+
+ if avg:
+ res = [self.lins[kk](feats0[kk]).mean([2, 3], keepdim=False) for kk in range(self.L)]
+ else:
+ for kk in range(self.L):
+ cur_res = self.lins[kk](feats0[kk])
+ shapes.append(cur_res.shape[-1])
+ res.append(cur_res.reshape(cur_res.shape[0], -1))
+
+ return res, shapes
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ h = self.slice1(x)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+
+ return out
diff --git a/imaginaire/evaluation/msid.py b/imaginaire/evaluation/msid.py
new file mode 100644
index 0000000000000000000000000000000000000000..c59a39dd38fe29d8d463de8348eb3749189c6b2b
--- /dev/null
+++ b/imaginaire/evaluation/msid.py
@@ -0,0 +1,375 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+# flake8: noqa
+from scipy.sparse import lil_matrix, diags, eye
+import math
+
+import numpy as np
+import torch
+
+EPSILON = 1e-6
+NORMALIZATION = 1e6
+
+
+def _get_msid(x, y, ts=np.logspace(-1, 1, 256), k=5, m=10, niters=100,
+ rademacher=False, graph_builder='full',
+ msid_mode='max', normalized_laplacian=True, normalize='empty'):
+ """
+ Compute the msid score between two samples, x and y
+ Arguments:
+ x: x samples
+ y: y samples
+ ts: temperature values
+ k: number of neighbours for graph construction
+ m: Lanczos steps in SLQ
+ niters: number of starting random vectors for SLQ
+ rademacher: if True, sample random vectors from Rademacher
+ distributions, else sample standard normal distribution
+ graph_builder: if 'kgraph', uses faster graph construction (options:
+ 'sparse', 'kgraph')
+ msid_mode: 'l2' to compute the l2 norm of the distance between `msid1`
+ and `msid2`; 'max' to find the maximum abosulute difference between
+ two descriptors over temperature
+ normalized_laplacian: if True, use normalized Laplacian
+ normalize: 'empty' for average heat kernel (corresponds to the empty
+ graph normalization of NetLSD), 'complete' for the complete, 'er'
+ for erdos-renyi normalization, 'none' for no normalization
+ Returns:
+ msid_score: the scalar value of the distance between discriptors
+ """
+ normed_msidx = msid_descriptor(x, ts, k, m, niters, rademacher,
+ graph_builder, normalized_laplacian,
+ normalize)
+ normed_msidy = msid_descriptor(y, ts, k, m, niters, rademacher,
+ graph_builder, normalized_laplacian,
+ normalize)
+
+ c = np.exp(-2 * (ts + 1 / ts))
+
+ if msid_mode == 'l2':
+ score = np.linalg.norm(normed_msidx - normed_msidy)
+ elif msid_mode == 'max':
+ score = np.amax(c * np.abs(normed_msidx - normed_msidy))
+ else:
+ raise Exception('Use either l2 or max mode.')
+
+ return {'IMD': score}
+
+
+def msid_descriptor(x, ts=np.logspace(-1, 1, 256), k=5, m=10, niters=100,
+ rademacher=False, graph_builder='full',
+ normalized_laplacian=True, normalize='empty'):
+ """
+ Compute the msid descriptor for a single sample x
+ Arguments:
+ x: x samples
+ ts: temperature values
+ k: number of neighbours for graph construction
+ m: Lanczos steps in SLQ
+ niters: number of starting random vectors for SLQ
+ rademacher: if True, sample random vectors from Rademacher
+ distributions, else sample standard normal distribution
+ graph_builder: if 'kgraph', uses faster graph construction
+ (options: '
+ rse', 'kgraph')
+ normalized_laplacian: if True, use normalized Laplacian
+ normalize: 'empty' for average heat kernel (corresponds to the empty
+ graph normalization of NetLSD), 'complete' for the complete, 'er'
+ for erdos-renyi normalization, 'none' for no normalization
+ Returns:
+ normed_msidx: normalized msid descriptor
+ """
+ Lx = _build_graph(x, k, graph_builder, normalized_laplacian)
+
+ nx = Lx.shape[0]
+ msidx = slq_red_var(Lx, m, niters, ts, rademacher)
+
+ normed_msidx = _normalize_msid(msidx, normalize, nx, k, ts) * NORMALIZATION
+
+ return normed_msidx
+
+
+def _build_graph(data, k=5, graph_builder='full', normalized=True):
+ """
+ Return Laplacian from data or load preconstructed from path
+
+ Arguments:
+ data: samples
+ k: number of neighbours for graph construction
+ graph_builder: if 'kgraph', use faster graph construction
+ normalized: if True, use nnormalized Laplacian
+ Returns:
+ L: Laplacian of the graph constructed with data
+ """
+ if graph_builder == 'sparse':
+ A = construct_graph_sparse(data.cpu().numpy(), k)
+ elif graph_builder == 'kgraph':
+ A = construct_graph_kgraph(data.cpu().numpy(), k)
+ elif graph_builder == 'full':
+ A = lil_matrix(construct_graph(data, k).cpu().numpy()).tocsr()
+ else:
+ raise Exception('Please specify graph builder: sparse or kgraph.')
+ A = (A + A.T) / 2
+ A.data = np.ones(A.data.shape)
+ L = _laplacian_sparse(A, normalized)
+ return L
+
+
+def _normalize_msid(msid, normalization, n, k, ts):
+ normed_msid = msid.copy()
+ if normalization == 'empty':
+ normed_msid /= n
+ elif normalization == 'complete':
+ normed_msid /= (1 + (n - 1) * np.exp(-(1 + 1 / (n - 1)) * ts))
+ elif normalization == 'er':
+ xs = np.linspace(0, 1, n)
+ er_spectrum = 4 / np.sqrt(k) * xs + 1 - 2 / np.sqrt(k)
+ er_msid = np.exp(-np.outer(ts, er_spectrum)).sum(-1)
+ normed_msid = normed_msid / (er_msid + EPSILON)
+ elif normalization == 'none' or normalization is None:
+ pass
+ else:
+ raise ValueError('Unknown normalization parameter!')
+ return normed_msid
+
+
+def _lanczos_m(A, m, nv, rademacher, SV=None):
+ """
+ Lanczos algorithm computes symmetric m x m tridiagonal matrix T and
+ matrix V with orthogonal rows constituting the basis of the Krylov
+ subspace K_m(A, x), where x is an arbitrary starting unit vector. This
+ implementation parallelizes `nv` starting vectors.
+
+ Arguments:
+ m: number of Lanczos steps
+ nv: number of random vectors
+ rademacher: True to use Rademacher distribution,
+ False - standard normal for random vectors
+ SV: specified starting vectors
+
+ Returns: T: a nv x m x m tensor, T[i, :, :] is the ith symmetric
+ tridiagonal matrix V: a n x m x nv tensor, V[:, :, i] is the ith matrix
+ with orthogonal rows
+ """
+ orthtol = 1e-5
+ if type(SV) != np.ndarray:
+ if rademacher:
+ SV = np.sign(np.random.randn(A.shape[0], nv))
+ else:
+ SV = np.random.randn(A.shape[0],
+ nv) # init random vectors in columns: n x nv
+ V = np.zeros((SV.shape[0], m, nv))
+ T = np.zeros((nv, m, m))
+
+ np.divide(SV, np.linalg.norm(SV, axis=0), out=SV) # normalize each column
+ V[:, 0, :] = SV
+
+ w = A.dot(SV)
+ alpha = np.einsum('ij,ij->j', w, SV)
+ w -= alpha[None, :] * SV
+ beta = np.einsum('ij,ij->j', w, w)
+ np.sqrt(beta, beta)
+
+ T[:, 0, 0] = alpha
+ T[:, 0, 1] = beta
+ T[:, 1, 0] = beta
+
+ np.divide(w, beta[None, :], out=w)
+ V[:, 1, :] = w
+ t = np.zeros((m, nv))
+
+ for i in range(1, m):
+ SVold = V[:, i - 1, :]
+ SV = V[:, i, :]
+
+ w = A.dot(SV) # sparse @ dense
+ w -= beta[None, :] * SVold # n x nv
+ np.einsum('ij,ij->j', w, SV, out=alpha)
+
+ T[:, i, i] = alpha
+
+ if i < m - 1:
+ w -= alpha[None, :] * SV # n x nv
+ # reortho
+ np.einsum('ijk,ik->jk', V, w, out=t)
+ w -= np.einsum('ijk,jk->ik', V, t)
+ np.einsum('ij,ij->j', w, w, out=beta)
+ np.sqrt(beta, beta)
+ np.divide(w, beta[None, :], out=w)
+
+ T[:, i, i + 1] = beta
+ T[:, i + 1, i] = beta
+
+ # more reotho
+ innerprod = np.einsum('ijk,ik->jk', V, w)
+ reortho = False
+ for _ in range(100):
+ if not (innerprod > orthtol).sum():
+ reortho = True
+ break
+ np.einsum('ijk,ik->jk', V, w, out=t)
+ w -= np.einsum('ijk,jk->ik', V, t)
+ np.divide(w, np.linalg.norm(w, axis=0)[None, :], out=w)
+ innerprod = np.einsum('ijk,ik->jk', V, w)
+
+ V[:, i + 1, :] = w
+
+ if (np.abs(beta) > 1e-6).sum() == 0 or not reortho:
+ break
+ return T, V
+
+
+def _slq(A, m, niters, rademacher):
+ """
+ Compute the trace of matrix exponential
+
+ Arguments:
+ A: square matrix in trace(exp(A))
+ m: number of Lanczos steps
+ niters: number of quadratures (also, the number of random vectors in the
+ hutchinson trace estimator)
+ rademacher: True to use Rademacher distribution, False - standard normal
+ for random vectors in Hutchinson
+ Returns: trace: estimate of trace of matrix exponential
+ """
+ T, _ = _lanczos_m(A, m, niters, rademacher)
+ eigvals, eigvecs = np.linalg.eigh(T)
+ expeig = np.exp(eigvals)
+ sqeigv1 = np.power(eigvecs[:, 0, :], 2)
+ trace = A.shape[-1] * (expeig * sqeigv1).sum() / niters
+ return trace
+
+
+def _slq_ts(A, m, niters, ts, rademacher):
+ """
+ Compute the trace of matrix exponential
+
+ Arguments:
+ A: square matrix in trace(exp(-t*A)), where t is temperature
+ m: number of Lanczos steps
+ niters: number of quadratures (also, the number of random vectors in the
+ hutchinson trace estimator)
+ ts: an array with temperatures
+ rademacher: True to use Rademacher distribution, False - standard normal
+ for random vectors in Hutchinson
+ Returns:
+ trace: estimate of trace of matrix exponential across temperatures `ts`
+ """
+ T, _ = _lanczos_m(A, m, niters, rademacher)
+ eigvals, eigvecs = np.linalg.eigh(T)
+ expeig = np.exp(-np.outer(ts, eigvals)).reshape(ts.shape[0], niters, m)
+ sqeigv1 = np.power(eigvecs[:, 0, :], 2)
+ traces = A.shape[-1] * (expeig * sqeigv1).sum(-1).mean(-1)
+ return traces
+
+
+def _slq_ts_fs(A, m, niters, ts, rademacher, fs):
+ """
+ Compute the trace of matrix functions
+
+ Arguments:
+ A: square matrix in trace(exp(-t*A)), where t is temperature
+ m: number of Lanczos steps
+ niters: number of quadratures (also, the number of random vectors in the
+ hutchinson trace estimator)
+ ts: an array with temperatures
+ rademacher: True to use Rademacher distribution, else - standard normal
+ for random vectors in Hutchinson
+ fs: a list of functions
+ Returns:
+ traces: estimate of traces for each of the functions in fs
+ """
+ T, _ = _lanczos_m(A, m, niters, rademacher)
+ eigvals, eigvecs = np.linalg.eigh(T)
+ traces = np.zeros((len(fs), len(ts)))
+ for i, f in enumerate(fs):
+ expeig = f(-np.outer(ts, eigvals)).reshape(ts.shape[0], niters, m)
+ sqeigv1 = np.power(eigvecs[:, 0, :], 2)
+ traces[i, :] = A.shape[-1] * (expeig * sqeigv1).sum(-1).mean(-1)
+ return traces
+
+
+def slq_red_var(A, m, niters, ts, rademacher):
+ """
+ Compute the trace of matrix exponential with reduced variance
+
+ Arguments:
+ A: square matrix in trace(exp(-t*A)), where t is temperature
+ m: number of Lanczos steps
+ niters: number of quadratures (also, the number of random vectors in the
+ hutchinson trace estimator)
+ ts: an array with temperatures
+ Returns:
+ traces: estimate of trace for each temperature value in `ts`
+ """
+ fs = [np.exp, lambda x: x]
+
+ traces = _slq_ts_fs(A, m, niters, ts, rademacher, fs)
+ subee = traces[0, :] - traces[1, :] / np.exp(ts)
+ sub = - ts * A.shape[0] / np.exp(ts)
+ return subee + sub
+
+
+def np_euc_cdist(data):
+ dd = np.sum(data * data, axis=1)
+ dist = -2 * np.dot(data, data.T)
+ dist += dd + dd[:, np.newaxis]
+ np.fill_diagonal(dist, 0)
+ np.sqrt(dist, dist)
+ return dist
+
+
+def construct_graph_sparse(data, k):
+ n = len(data)
+ spmat = lil_matrix((n, n))
+ dd = np.sum(data * data, axis=1)
+
+ for i in range(n):
+ dists = dd - 2 * data[i, :].dot(data.T)
+ inds = np.argpartition(dists, k + 1)[:k + 1]
+ inds = inds[inds != i]
+ spmat[i, inds] = 1
+
+ return spmat.tocsr()
+
+
+def construct_graph_kgraph(data, k):
+ raise NotImplementedError
+ #
+ # n = len(data)
+ # spmat = lil_matrix((n, n))
+ # index = pykgraph.KGraph(data, 'euclidean')
+ # index.build(reverse=0, K=2 * k + 1, L=2 * k + 50)
+ # result = index.search(data, K=k + 1)[:, 1:]
+ # spmat[np.repeat(np.arange(n), k, 0), result.ravel()] = 1
+ # return spmat.tocsr()
+
+
+def construct_graph(input_features, k, num_splits=10):
+ num_samples = input_features.shape[0]
+ indices = []
+ for i in range(num_splits):
+ batch_size = math.ceil(num_samples / num_splits)
+ start_idx = i * batch_size
+ end_idx = min((i + 1) * batch_size, num_samples)
+ dist = torch.cdist(input_features[start_idx:end_idx],
+ input_features)
+ indices.append(torch.topk(dist, k + 1, dim=-1, largest=False)[1].cpu())
+ indices = torch.cat(indices, dim=0)
+ graph = torch.zeros(num_samples, num_samples, device=indices.device)
+ graph.scatter_(1, indices, 1)
+ graph -= torch.eye(num_samples, device=graph.device)
+ return graph
+
+
+def _laplacian_sparse(A, normalized=True):
+ D = A.sum(1).A1
+ if normalized:
+ Dsqrt = diags(1 / np.sqrt(D))
+ L = eye(A.shape[0]) - Dsqrt.dot(A).dot(Dsqrt)
+ else:
+ L = diags(D) - A
+ return L
diff --git a/imaginaire/evaluation/prdc.py b/imaginaire/evaluation/prdc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ba2abba01f6e4b8e49b4b979ebe0aa93e4f9b00
--- /dev/null
+++ b/imaginaire/evaluation/prdc.py
@@ -0,0 +1,124 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""
+Modified from https://github.com/clovaai/generative-evaluation-prdc
+Copyright (c) 2020-present NAVER Corp.
+MIT license
+"""
+import os
+
+import torch
+
+from imaginaire.utils.distributed import is_master
+from imaginaire.utils.distributed import master_only_print as print
+
+from .common import load_or_compute_activations, compute_pairwise_distance, \
+ compute_nn
+
+
+@torch.no_grad()
+def compute_prdc(prdc_path, data_loader, net_G,
+ key_real='images', key_fake='fake_images',
+ real_act=None, fake_act=None,
+ sample_size=None, save_act=True, k=10, **kwargs):
+ r"""Compute precision diversity curve
+
+ Args:
+
+ """
+ print('Computing PRDC.')
+ act_path = os.path.join(
+ os.path.dirname(prdc_path), 'activations_real.npy'
+ ) if save_act else None
+
+ # Get the fake activations.
+ if fake_act is None:
+ fake_act = load_or_compute_activations(None,
+ data_loader,
+ key_real, key_fake, net_G,
+ sample_size=sample_size,
+ **kwargs)
+ else:
+ print(f"Using precomputed activations of size {fake_act.shape}.")
+
+ # Get the ground truth activations.
+ if real_act is None:
+ real_act = load_or_compute_activations(act_path,
+ data_loader,
+ key_real, key_fake, None,
+ sample_size=sample_size,
+ **kwargs)
+ else:
+ print(f"Using precomputed activations of size {real_act.shape}.")
+
+ if is_master():
+ prdc_data = _get_prdc(real_act, fake_act, k)
+ return \
+ prdc_data['precision'], prdc_data['recall'], \
+ prdc_data['density'], prdc_data['coverage']
+ else:
+ return None, None, None, None
+
+
+def get_kth_value(unsorted, k, dim=-1):
+ r"""
+
+ Args:
+ unsorted: numpy.ndarray of any dimensionality.
+ k: int
+ Returns:
+ kth values along the designated axis.
+ """
+ indices = torch.topk(unsorted, k, dim=dim, largest=False)[1]
+ k_smallests = torch.gather(unsorted, dim=dim, index=indices)
+ kth_values = k_smallests.max(dim=dim)[0]
+ return kth_values
+
+
+def _get_prdc(real_features, fake_features, nearest_k):
+ r"""
+ Computes precision, recall, density, and coverage given two manifolds.
+
+ Args:
+ real_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
+ fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
+ nearest_k: int.
+ Returns:
+ dict of precision, recall, density, and coverage.
+ """
+ real_nearest_neighbour_distances, _ = compute_nn(
+ real_features, nearest_k)
+ real_nearest_neighbour_distances = \
+ real_nearest_neighbour_distances.max(dim=-1)[0].cpu()
+ fake_nearest_neighbour_distances, _ = compute_nn(
+ fake_features, nearest_k)
+ fake_nearest_neighbour_distances = \
+ fake_nearest_neighbour_distances.max(dim=-1)[0].cpu()
+ distance_real_fake = compute_pairwise_distance(
+ real_features, fake_features)
+
+ precision = (
+ distance_real_fake <
+ torch.unsqueeze(real_nearest_neighbour_distances, dim=1)
+ ).any(dim=0).float().mean().item()
+
+ recall = (
+ distance_real_fake <
+ torch.unsqueeze(fake_nearest_neighbour_distances, dim=0)
+ ).any(dim=1).float().mean().item()
+
+ density = (1. / float(nearest_k)) * (
+ distance_real_fake <
+ torch.unsqueeze(real_nearest_neighbour_distances, dim=1)
+ ).sum(dim=0).float().mean().item()
+
+ # noinspection PyUnresolvedReferences
+ coverage = (
+ distance_real_fake.min(dim=1)[0] <
+ real_nearest_neighbour_distances
+ ).float().mean().item()
+
+ return dict(precision=precision, recall=recall,
+ density=density, coverage=coverage)
diff --git a/imaginaire/evaluation/pretrained.py b/imaginaire/evaluation/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..b99c8f19cbb7687ac5e88b28939e83475e458971
--- /dev/null
+++ b/imaginaire/evaluation/pretrained.py
@@ -0,0 +1,232 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+"""
+Modified from
+https://github.com/mseitzer/pytorch-fid
+
+Code adapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
+of Tensorflow
+Copyright 2018 Institute of Bioinformatics, JKU Linz
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+try:
+ from torchvision.models.utils import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+from torchvision.models import inception, inception_v3, vgg16
+
+# Inception weights ported to Pytorch from
+# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases' \
+ '/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
+
+
+class SwAV(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = torch.hub.load('facebookresearch/swav', 'resnet50',
+ pretrained=True)
+ self.model.fc = torch.nn.Sequential()
+
+ def forward(self, x, align_corners=True):
+ y = self.model(F.interpolate(
+ x, size=(224, 224), mode='bicubic', align_corners=align_corners))
+ return y
+
+
+class Vgg16(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = vgg16(pretrained=True, init_weights=False)
+ self.model.classifier = torch.nn.Sequential(
+ *[self.model.classifier[i] for i in range(4)]
+ )
+
+ def forward(self, x, align_corners=True):
+ y = self.model(F.interpolate(
+ x, size=(224, 224), mode='bicubic', align_corners=align_corners))
+ return y
+
+
+class InceptionV3(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = inception_v3(transform_input=False,
+ pretrained=True,
+ init_weights=False)
+ self.model.fc = torch.nn.Sequential()
+
+ def forward(self, x, align_corners=True):
+ y = self.model(F.interpolate(
+ x, size=(299, 299), mode='bicubic', align_corners=align_corners))
+ return y
+
+
+class TFInceptionV3(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = inception_v3(transform_input=False,
+ num_classes=1008,
+ aux_logits=False,
+ pretrained=False,
+ init_weights=False)
+ self.model.Mixed_5b = FIDInceptionA(192, pool_features=32)
+ self.model.Mixed_5c = FIDInceptionA(256, pool_features=64)
+ self.model.Mixed_5d = FIDInceptionA(288, pool_features=64)
+ self.model.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
+ self.model.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
+ self.model.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
+ self.model.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
+ self.model.Mixed_7b = FIDInceptionE_1(1280)
+ self.model.Mixed_7c = FIDInceptionE_2(2048)
+
+ state_dict = load_state_dict_from_url(
+ FID_WEIGHTS_URL, progress=True, map_location='cpu'
+ )
+ self.model.load_state_dict(state_dict)
+ self.model.fc = torch.nn.Sequential()
+
+ def forward(self, x, align_corners=True):
+ y = self.model(F.interpolate(
+ x, size=(299, 299), mode='bicubic', align_corners=align_corners))
+ return y
+
+
+class FIDInceptionA(inception.InceptionA):
+ """InceptionA block patched for FID computation"""
+
+ def __init__(self, in_channels, pool_features):
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch5x5 = self.branch5x5_1(x)
+ branch5x5 = self.branch5x5_2(branch5x5)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+ count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionC(inception.InceptionC):
+ """InceptionC block patched for FID computation"""
+
+ def __init__(self, in_channels, channels_7x7):
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch7x7 = self.branch7x7_1(x)
+ branch7x7 = self.branch7x7_2(branch7x7)
+ branch7x7 = self.branch7x7_3(branch7x7)
+
+ branch7x7dbl = self.branch7x7dbl_1(x)
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+ count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_1(inception.InceptionE):
+ """First InceptionE block patched for FID computation"""
+
+ def __init__(self, in_channels):
+ super(FIDInceptionE_1, self).__init__(in_channels)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+ count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_2(inception.InceptionE):
+ """Second InceptionE block patched for FID computation"""
+
+ def __init__(self, in_channels):
+ super(FIDInceptionE_2, self).__init__(in_channels)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ # Patch: The FID Inception model uses max pooling instead of average
+ # pooling. This is likely an error in this specific Inception
+ # implementation, as other Inception models use average pooling here
+ # (which matches the description in the paper).
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
diff --git a/imaginaire/evaluation/segmentation/__init__.py b/imaginaire/evaluation/segmentation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..63a71a20ebb0ac2940c091f064de5a029126af42
--- /dev/null
+++ b/imaginaire/evaluation/segmentation/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .common import get_segmentation_hist_model, get_miou,compute_hist
+
+__all__ = ['get_segmentation_hist_model', 'get_miou','compute_hist']
diff --git a/imaginaire/evaluation/segmentation/celebamask_hq.py b/imaginaire/evaluation/segmentation/celebamask_hq.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab95f325eb681bb0d80bac2c626494ee9e81b51b
--- /dev/null
+++ b/imaginaire/evaluation/segmentation/celebamask_hq.py
@@ -0,0 +1,130 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# https://github.com/switchablenorms/CelebAMask-HQ/tree/master/face_parsing
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class Unet(nn.Module):
+ def __init__(
+ self,
+ feature_scale=4,
+ n_classes=19,
+ is_deconv=True,
+ in_channels=3,
+ is_batchnorm=True,
+ image_size=512,
+ use_dont_care=False
+ ):
+ super(Unet, self).__init__()
+ self.is_deconv = is_deconv
+ self.in_channels = in_channels
+ self.is_batchnorm = is_batchnorm
+ self.feature_scale = feature_scale
+ self.image_size = image_size
+ self.n_classes = n_classes
+ self.use_dont_care = use_dont_care
+
+ filters = [64, 128, 256, 512, 1024]
+ filters = [int(x / self.feature_scale) for x in filters]
+
+ # downsampling
+ self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2)
+
+ self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2)
+
+ self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2)
+
+ self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
+ self.maxpool4 = nn.MaxPool2d(kernel_size=2)
+
+ self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
+
+ # upsampling
+ self.up_concat4 = unetUp(
+ filters[4], filters[3], self.is_deconv, self.is_batchnorm)
+ self.up_concat3 = unetUp(
+ filters[3], filters[2], self.is_deconv, self.is_batchnorm)
+ self.up_concat2 = unetUp(
+ filters[2], filters[1], self.is_deconv, self.is_batchnorm)
+ self.up_concat1 = unetUp(
+ filters[1], filters[0], self.is_deconv, self.is_batchnorm)
+
+ # final conv (without any concat)
+ self.final = nn.Conv2d(filters[0], n_classes, 1)
+
+ def forward(self, images, align_corners=True):
+ images = F.interpolate(
+ images, size=(self.image_size, self.image_size), mode='bicubic',
+ align_corners=align_corners
+ )
+
+ conv1 = self.conv1(images)
+ maxpool1 = self.maxpool1(conv1)
+ conv2 = self.conv2(maxpool1)
+ maxpool2 = self.maxpool2(conv2)
+ conv3 = self.conv3(maxpool2)
+ maxpool3 = self.maxpool3(conv3)
+ conv4 = self.conv4(maxpool3)
+ maxpool4 = self.maxpool4(conv4)
+ center = self.center(maxpool4)
+ up4 = self.up_concat4(conv4, center)
+ up3 = self.up_concat3(conv3, up4)
+ up2 = self.up_concat2(conv2, up3)
+ up1 = self.up_concat1(conv1, up2)
+ probs = self.final(up1)
+ pred = torch.argmax(probs, dim=1)
+ return pred
+
+
+class unetConv2(nn.Module):
+ def __init__(self, in_size, out_size, is_batchnorm):
+ super(unetConv2, self).__init__()
+
+ if is_batchnorm:
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_size, out_size, 3, 1, 1),
+ nn.BatchNorm2d(out_size),
+ nn.ReLU(),
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(out_size, out_size, 3, 1, 1),
+ nn.BatchNorm2d(out_size),
+ nn.ReLU(),
+ )
+ else:
+ self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1),
+ nn.ReLU())
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU()
+ )
+
+ def forward(self, inputs):
+ outputs = self.conv1(inputs)
+ outputs = self.conv2(outputs)
+ return outputs
+
+
+class unetUp(nn.Module):
+ def __init__(self, in_size, out_size, is_deconv, is_batchnorm):
+ super(unetUp, self).__init__()
+ self.conv = unetConv2(in_size, out_size, is_batchnorm)
+ if is_deconv:
+ self.up = nn.ConvTranspose2d(
+ in_size, out_size, kernel_size=2, stride=2)
+ else:
+ self.up = nn.UpsamplingBilinear2d(scale_factor=2)
+
+ def forward(self, inputs1, inputs2):
+ outputs2 = self.up(inputs2)
+ offset = outputs2.size()[2] - inputs1.size()[2]
+ padding = 2 * [offset // 2, offset // 2]
+ outputs1 = F.pad(inputs1, padding)
+
+ return self.conv(torch.cat([outputs1, outputs2], 1))
diff --git a/imaginaire/evaluation/segmentation/cocostuff.py b/imaginaire/evaluation/segmentation/cocostuff.py
new file mode 100644
index 0000000000000000000000000000000000000000..601f0c28a3811f994df9a9375bd6c5fc08509d9d
--- /dev/null
+++ b/imaginaire/evaluation/segmentation/cocostuff.py
@@ -0,0 +1,48 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from torch import nn
+from torch.nn import functional as F
+import torch.hub
+
+
+class DeepLabV2(nn.Module):
+ def __init__(self, n_classes=182, image_size=512, use_dont_care=True):
+ super(DeepLabV2, self).__init__()
+ self.model = torch.hub.load(
+ "kazuto1011/deeplab-pytorch", "deeplabv2_resnet101",
+ pretrained=False, n_classes=182
+ )
+ state_dict = torch.hub.load_state_dict_from_url(
+ 'https://github.com/kazuto1011/deeplab-pytorch/releases/download/'
+ 'v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth',
+ map_location="cpu"
+ )
+ self.model.load_state_dict(state_dict)
+
+ self.image_size = image_size
+ # self.mean = torch.tensor([122.675, 116.669, 104.008], device="cuda")
+ self.mean = torch.tensor([104.008, 116.669, 122.675], device="cuda")
+ self.n_classes = n_classes
+ self.use_dont_care = use_dont_care
+
+ def forward(self, images, align_corners=True):
+ scale = self.image_size / max(images.shape[2:])
+ images = F.interpolate(
+ images, scale_factor=scale, mode='bilinear',
+ align_corners=align_corners
+ )
+ images = 255 * 0.5 * (images + 1) # (-1, 1) -> (0, 255)
+ images = images.flip(1) # RGB to BGR
+ images -= self.mean[None, :, None, None]
+ _, _, H, W = images.shape
+
+ logits = self.model(images)
+ logits = F.interpolate(
+ logits, size=(H, W), mode="bilinear",
+ align_corners=align_corners
+ )
+ probs = F.softmax(logits, dim=1)
+ pred = torch.argmax(probs, dim=1)
+ return pred
diff --git a/imaginaire/evaluation/segmentation/common.py b/imaginaire/evaluation/segmentation/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..78d90ae26db4747e5be859a8c1237a288e3146b7
--- /dev/null
+++ b/imaginaire/evaluation/segmentation/common.py
@@ -0,0 +1,92 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+import boto3
+import torch
+from torch import nn, distributed as dist
+from torch.nn import functional as F
+
+from imaginaire.utils.distributed import is_local_master
+from imaginaire.utils.io import download_file_from_google_drive
+
+
+def get_segmentation_hist_model(dataset_name, aws_credentials=None):
+ if dist.is_initialized() and not is_local_master():
+ # Make sure only the first process in distributed training downloads
+ # the model, and the others will use the cache
+ # noinspection PyUnresolvedReferences
+ torch.distributed.barrier()
+
+ # Load the segmentation network.
+ if dataset_name == "celebamask_hq":
+ from imaginaire.evaluation.segmentation.celebamask_hq import Unet
+ seg_network = Unet()
+ os.makedirs(os.path.join(torch.hub.get_dir(), 'checkpoints'), exist_ok=True)
+ model_path = os.path.join(os.path.join(torch.hub.get_dir(), 'checkpoints'), "celebamask_hq.pt")
+ if not os.path.exists(model_path):
+ if aws_credentials is not None:
+ s3 = boto3.client('s3', **aws_credentials)
+ s3.download_file('lpi-poe', 'model_zoo/celebamask_hq.pt', model_path)
+ else:
+ download_file_from_google_drive("1o1m-eT38zNCIFldcRaoWcLvvBtY8S4W3", model_path)
+ state_dict = torch.load(model_path, map_location='cpu')
+ seg_network.load_state_dict(state_dict)
+ elif dataset_name == "cocostuff" or dataset_name == "getty":
+ from imaginaire.evaluation.segmentation.cocostuff import DeepLabV2
+ seg_network = DeepLabV2()
+ else:
+ print(f"No segmentation network for {dataset_name} was found.")
+ return None
+
+ if dist.is_initialized() and is_local_master():
+ # Make sure only the first process in distributed training downloads
+ # the model, and the others will use the cache
+ # noinspection PyUnresolvedReferences
+ torch.distributed.barrier()
+
+ if seg_network is not None:
+ seg_network = seg_network.to('cuda').eval()
+
+ return SegmentationHistModel(seg_network)
+
+
+class SegmentationHistModel(nn.Module):
+ def __init__(self, seg_network):
+ super().__init__()
+ self.seg_network = seg_network
+
+ def forward(self, data, fake_images, align_corners=True):
+ pred = self.seg_network(fake_images, align_corners=align_corners)
+ gt = data["segmaps"]
+ gt = gt * 255.0
+ gt = gt.long()
+ # print(fake_images.shape, fake_images.min(), fake_images.max())
+ # print(gt.shape, gt.min(), gt.max())
+ # exit()
+ return compute_hist(pred, gt, self.seg_network.n_classes, self.seg_network.use_dont_care)
+
+
+def compute_hist(pred, gt, n_classes, use_dont_care):
+ _, H, W = pred.size()
+ gt = F.interpolate(gt.float(), (H, W), mode="nearest").long().squeeze(1)
+ ignore_idx = n_classes if use_dont_care else -1
+ all_hist = []
+ for cur_pred, cur_gt in zip(pred, gt):
+ keep = torch.logical_not(cur_gt == ignore_idx)
+ merge = cur_pred[keep] * n_classes + cur_gt[keep]
+ hist = torch.bincount(merge, minlength=n_classes ** 2)
+ hist = hist.view((n_classes, n_classes))
+ all_hist.append(hist)
+ all_hist = torch.stack(all_hist)
+ return all_hist
+
+
+def get_miou(hist, eps=1e-8):
+ hist = hist.sum(0)
+ IOUs = torch.diag(hist) / (
+ torch.sum(hist, dim=0, keepdim=False) + torch.sum(hist, dim=1, keepdim=False) - torch.diag(hist) + eps)
+ mIOU = 100 * torch.mean(IOUs).item()
+ return {"seg_mIOU": mIOU}
diff --git a/imaginaire/generators/__init__.py b/imaginaire/generators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imaginaire/generators/__pycache__/__init__.cpython-38.pyc b/imaginaire/generators/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c29051fe61ab456c65828f39f1e1130d389be734
Binary files /dev/null and b/imaginaire/generators/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/generators/__pycache__/craft_2stage_add_style.cpython-38.pyc b/imaginaire/generators/__pycache__/craft_2stage_add_style.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5fa5fc572212d24f822b7bcdc1d936b848ea4adb
Binary files /dev/null and b/imaginaire/generators/__pycache__/craft_2stage_add_style.cpython-38.pyc differ
diff --git a/imaginaire/generators/__pycache__/craft_base.cpython-38.pyc b/imaginaire/generators/__pycache__/craft_base.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..205291ec52f52992a82b459886ef1b3455ee39f0
Binary files /dev/null and b/imaginaire/generators/__pycache__/craft_base.cpython-38.pyc differ
diff --git a/imaginaire/generators/coco_funit.py b/imaginaire/generators/coco_funit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1f555fb91c9d8580ccf69e1e785c5b6c5a54aef
--- /dev/null
+++ b/imaginaire/generators/coco_funit.py
@@ -0,0 +1,194 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch import nn
+
+from imaginaire.generators.funit import (MLP, ContentEncoder, Decoder,
+ StyleEncoder)
+
+
+class Generator(nn.Module):
+ r"""COCO-FUNIT Generator.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ r"""COCO-FUNIT Generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+ super().__init__()
+ self.generator = COCOFUNITTranslator(**vars(gen_cfg))
+
+ def forward(self, data):
+ r"""In the FUNIT's forward pass, it generates a content embedding and
+ a style code from the content image, and a style code from the style
+ image. By mixing the content code and the style code from the content
+ image, we reconstruct the input image. By mixing the content code and
+ the style code from the style image, we have a translation output.
+
+ Args:
+ data (dict): Training data at the current iteration.
+ """
+ content_a = self.generator.content_encoder(data['images_content'])
+ style_a = self.generator.style_encoder(data['images_content'])
+ style_b = self.generator.style_encoder(data['images_style'])
+ images_trans = self.generator.decode(content_a, style_b)
+ images_recon = self.generator.decode(content_a, style_a)
+
+ net_G_output = dict(images_trans=images_trans,
+ images_recon=images_recon)
+ return net_G_output
+
+ def inference(self, data, keep_original_size=True):
+ r"""COCO-FUNIT inference.
+
+ Args:
+ data (dict): Training data at the current iteration.
+ - images_content (tensor): Content images.
+ - images_style (tensor): Style images.
+ a2b (bool): If ``True``, translates images from domain A to B,
+ otherwise from B to A.
+ keep_original_size (bool): If ``True``, output image is resized
+ to the input content image size.
+ """
+ content_a = self.generator.content_encoder(data['images_content'])
+ style_b = self.generator.style_encoder(data['images_style'])
+ output_images = self.generator.decode(content_a, style_b)
+ if keep_original_size:
+ height = data['original_h_w'][0][0]
+ width = data['original_h_w'][0][1]
+ # print('( H, W) = ( %d, %d)' % (height, width))
+ output_images = torch.nn.functional.interpolate(
+ output_images, size=[height, width])
+ file_names = data['key']['images_content'][0]
+ return output_images, file_names
+
+
+class COCOFUNITTranslator(nn.Module):
+ r"""COCO-FUNIT Generator architecture.
+
+ Args:
+ num_filters (int): Base filter numbers.
+ num_filters_mlp (int): Base filter number in the MLP module.
+ style_dims (int): Dimension of the style code.
+ usb_dims (int): Dimension of the universal style bias code.
+ num_res_blocks (int): Number of residual blocks at the end of the
+ content encoder.
+ num_mlp_blocks (int): Number of layers in the MLP module.
+ num_downsamples_content (int): Number of times we reduce
+ resolution by 2x2 for the content image.
+ num_downsamples_style (int): Number of times we reduce
+ resolution by 2x2 for the style image.
+ num_image_channels (int): Number of input image channels.
+ weight_norm_type (str): Type of weight normalization.
+ ``'none'``, ``'spectral'``, or ``'weight'``.
+ """
+
+ def __init__(self,
+ num_filters=64,
+ num_filters_mlp=256,
+ style_dims=64,
+ usb_dims=1024,
+ num_res_blocks=2,
+ num_mlp_blocks=3,
+ num_downsamples_style=4,
+ num_downsamples_content=2,
+ num_image_channels=3,
+ weight_norm_type='',
+ **kwargs):
+ super().__init__()
+
+ self.style_encoder = StyleEncoder(num_downsamples_style,
+ num_image_channels,
+ num_filters,
+ style_dims,
+ 'reflect',
+ 'none',
+ weight_norm_type,
+ 'relu')
+
+ self.content_encoder = ContentEncoder(num_downsamples_content,
+ num_res_blocks,
+ num_image_channels,
+ num_filters,
+ 'reflect',
+ 'instance',
+ weight_norm_type,
+ 'relu')
+
+ self.decoder = Decoder(self.content_encoder.output_dim,
+ num_filters_mlp,
+ num_image_channels,
+ num_downsamples_content,
+ 'reflect',
+ weight_norm_type,
+ 'relu')
+
+ self.usb = torch.nn.Parameter(torch.randn(1, usb_dims))
+
+ self.mlp = MLP(style_dims,
+ num_filters_mlp,
+ num_filters_mlp,
+ num_mlp_blocks,
+ 'none',
+ 'relu')
+
+ num_content_mlp_blocks = 2
+ num_style_mlp_blocks = 2
+ self.mlp_content = MLP(self.content_encoder.output_dim,
+ style_dims,
+ num_filters_mlp,
+ num_content_mlp_blocks,
+ 'none',
+ 'relu')
+
+ self.mlp_style = MLP(style_dims + usb_dims,
+ style_dims,
+ num_filters_mlp,
+ num_style_mlp_blocks,
+ 'none',
+ 'relu')
+
+ def forward(self, images):
+ r"""Reconstruct the input image by combining the computer content and
+ style code.
+
+ Args:
+ images (tensor): Input image tensor.
+ """
+ # reconstruct an image
+ content, style = self.encode(images)
+ images_recon = self.decode(content, style)
+ return images_recon
+
+ def encode(self, images):
+ r"""Encoder images to get their content and style codes.
+
+ Args:
+ images (tensor): Input image tensor.
+ """
+ style = self.style_encoder(images)
+ content = self.content_encoder(images)
+ return content, style
+
+ def decode(self, content, style):
+ r"""Generate images by combining their content and style codes.
+
+ Args:
+ content (tensor): Content code tensor.
+ style (tensor): Style code tensor.
+ """
+ content_style_code = content.mean(3).mean(2)
+ content_style_code = self.mlp_content(content_style_code)
+ batch_size = style.size(0)
+ usb = self.usb.repeat(batch_size, 1)
+ style = style.view(batch_size, -1)
+ style_in = self.mlp_style(torch.cat([style, usb], 1))
+ coco_style = style_in * content_style_code
+ coco_style = self.mlp(coco_style)
+ images = self.decoder(content, coco_style)
+ return images
diff --git a/imaginaire/generators/craft_2stage.py b/imaginaire/generators/craft_2stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee9c8bd42feb3fae32a96625f885a6752e5254d1
--- /dev/null
+++ b/imaginaire/generators/craft_2stage.py
@@ -0,0 +1,65 @@
+import torch
+from torch import nn
+import functools
+import torch.nn.functional as F
+
+import sys
+sys.path.append(".")
+from model import geometry_transform
+from imaginaire.utils.distributed import master_only_print as print
+from model.graphs.decoder import DeepLab
+from imaginaire.generators.craft_base import *
+
+
+
+class Generator(nn.Module):
+ def __init__(self,opt):
+ super(Generator, self).__init__()
+ gen_cfg = opt.arch.gen
+ data_cfg = opt.data
+ # self.gen_model = gen_model
+ self.gen_cfg = opt.arch.gen
+ if gen_cfg.transform_mode in ['project_RGB','volum_rendering','proj_like_radus']:
+ self.pano_direction = torch.from_numpy(geometry_transform.get_original_coord(opt)).unsqueeze(0).to(opt.device)
+ if gen_cfg.transform_mode == 'volum_rendering':
+ last_act = 'relu'
+ else:
+ last_act = 'softmax'
+ self.depth_model = inner_Generator(gen_cfg,gen_cfg.depth_arch,data_cfg,num_input_channels=3,last_act=last_act)
+ render_input_channel = 3
+ if gen_cfg.cat_opa:
+ render_input_channel = render_input_channel+1
+ self.denoise_model = inner_Generator(gen_cfg,gen_cfg.render_arch,data_cfg,render_input_channel,last_act='sigmoid')
+
+ self.PE = None
+
+
+
+ def forward(self, inputs, style_img=None,opt=None):
+ estimated_height = self.depth_model(inputs)
+
+ if self.gen_cfg.transform_mode in ['project_RGB','volum_rendering','proj_like_radus']:
+ geo_outputs = geometry_transform.render(opt,inputs,estimated_height,self.pano_direction,PE=self.PE)
+ generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth']
+ if 'voxel' in geo_outputs.keys():
+ voxel = geo_outputs['voxel']
+ # mu, logvar, z = self.style_encode(style_img)
+ # z = self.style_model(z)
+ if self.gen_cfg.cat_opa:
+ generator_inputs = torch.cat((generator_inputs,opacity),dim=1)
+ output_RGB = self.denoise_model(generator_inputs)
+ out_put = {
+ 'pred': output_RGB,
+ # 'inter_RGB': generator_inputs, ### out_feature not for show
+ # 'mu' :mu,
+ # 'logvar' : logvar,
+ }
+ if self.gen_cfg.transform_mode in ['volum_rendering']:
+ out_put['opacity'] = opacity
+ if self.gen_cfg.transform_mode:
+ out_put['estimated_height'] = estimated_height
+ out_put['generator_inputs'] = generator_inputs
+ out_put['voxel'] = voxel
+ out_put['depth'] = depth
+ return out_put
+
diff --git a/imaginaire/generators/craft_2stage_add_style.py b/imaginaire/generators/craft_2stage_add_style.py
new file mode 100644
index 0000000000000000000000000000000000000000..7920abdf110783ea2921e5f8faa98c3a7c284ec9
--- /dev/null
+++ b/imaginaire/generators/craft_2stage_add_style.py
@@ -0,0 +1,75 @@
+import torch
+from torch import nn
+import sys
+sys.path.append(".")
+from model import geometry_transform
+from imaginaire.generators.craft_base import *
+
+
+
+class Generator(nn.Module):
+ def __init__(self,opt):
+ super(Generator, self).__init__()
+ gen_cfg = opt.arch.gen
+ data_cfg = opt.data
+ style_enc_cfg = opt.arch.gen.style_enc_cfg
+ # self.gen_model = gen_model
+ self.style_inject = getattr(gen_cfg, 'style_inject',
+ None)
+ self.gen_cfg = opt.arch.gen
+ self.pano_direction = torch.from_numpy(geometry_transform.get_original_coord(opt)).unsqueeze(0).to(opt.device)
+ last_act = 'relu'
+ self.depth_model = inner_Generator_split(gen_cfg,gen_cfg.depth_arch,data_cfg,num_input_channels=3,last_act=last_act)
+
+
+ render_input_channel = 3
+ if gen_cfg.cat_opa:
+ render_input_channel +=1
+ if gen_cfg.cat_depth:
+ render_input_channel +=1
+
+ self.denoise_model = inner_Generator_split(gen_cfg,gen_cfg.render_arch,data_cfg,render_input_channel,last_act='sigmoid')
+ if self.style_inject:
+ if self.style_inject=='histo':
+ self.style_encode = histo_process(style_enc_cfg)
+ elif self.style_inject=='perspective':
+ self.style_encode = StyleEncoder(style_enc_cfg)
+ else:
+ raise Exception('Unknown style inject')
+ self.style_model = StyleMLP(style_dim=style_enc_cfg.style_dims, out_dim=style_enc_cfg.interm_style_dims, hidden_channels=style_enc_cfg.hidden_channel, leaky_relu=True, num_layers=5, normalize_input=True,
+ output_act=True)
+
+ self.PE = geometry_transform.position_produce(opt) if gen_cfg.cat_PE else None
+
+
+
+ def forward(self, inputs, style_img=None,opt=None):
+ # predicted height of satellite images
+ estimated_height = self.depth_model(inputs)
+ geo_outputs = geometry_transform.render(opt,inputs,estimated_height,self.pano_direction,PE=self.PE)
+ generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth']
+ if 'voxel' in geo_outputs.keys():
+ voxel = geo_outputs['voxel']
+
+ if self.gen_cfg.cat_opa:
+ generator_inputs = torch.cat((generator_inputs,opacity),dim=1)
+ if self.gen_cfg.cat_depth:
+ generator_inputs = torch.cat((generator_inputs,depth),dim=1)
+ if self.style_inject:
+ mu, logvar, z = self.style_encode(style_img)
+ z = self.style_model(z)
+ else:
+ z = None
+ # merge multiple sources(rgb,opacity,depth and sky) and denoise redundancy
+ output_RGB = self.denoise_model(generator_inputs,z)
+ out_put = {'pred': output_RGB}
+ if self.style_inject:
+ out_put['mu'] = mu
+ out_put['logvar'] = logvar
+ out_put['estimated_height'] = estimated_height
+ out_put['generator_inputs'] = generator_inputs
+ out_put['voxel'] = voxel
+ out_put['depth'] = depth
+ out_put['opacity'] = opacity
+ return out_put
+
diff --git a/imaginaire/generators/craft_base.py b/imaginaire/generators/craft_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f283c701269ad8eaa656036ade597d37393ebc
--- /dev/null
+++ b/imaginaire/generators/craft_base.py
@@ -0,0 +1,483 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import Upsample as NearestUpsample
+import torch.nn.functional as F
+from functools import partial
+
+import sys
+sys.path.append(".")
+from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
+
+
+class StyleMLP(nn.Module):
+ r"""MLP converting style code to intermediate style representation."""
+
+ def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True,
+ output_act=True):
+ super(StyleMLP, self).__init__()
+
+ self.normalize_input = normalize_input
+ self.output_act = output_act
+ fc_layers = []
+ fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True))
+ for i in range(num_layers-1):
+ fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True))
+ self.fc_layers = nn.ModuleList(fc_layers)
+
+ self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True)
+
+ if leaky_relu:
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ self.act = partial(F.relu, inplace=True)
+
+ def forward(self, z):
+ r""" Forward network
+
+ Args:
+ z (N x style_dim tensor): Style codes.
+ """
+ if self.normalize_input:
+ z = F.normalize(z, p=2, dim=-1,eps=1e-6)
+ for fc_layer in self.fc_layers:
+ z = self.act(fc_layer(z))
+ z = self.fc_out(z)
+ if self.output_act:
+ z = self.act(z)
+ return z
+
+class histo_process(nn.Module):
+ r"""Histo process to replace Style Encoder constructor.
+
+ Args:
+ style_enc_cfg (obj): Style encoder definition file.
+ """
+ def __init__(self,style_enc_cfg):
+ super().__init__()
+ # if style_enc_cfg.histo.mode in ['RGB','rgb']:
+ input_channel=270
+ # else:
+ # input_channel=90
+ style_dims = style_enc_cfg.style_dims
+ self.no_vae = getattr(style_enc_cfg, 'no_vae', False)
+ num_filters = getattr(style_enc_cfg, 'num_filters', 180)
+ self.process_model = nn.ModuleList()
+ self.layer1 = LinearBlock(input_channel,num_filters)
+ self.layer4 = LinearBlock(num_filters, num_filters)
+ self.fc_mu = LinearBlock(num_filters, style_dims,nonlinearity='tanh')
+ if not self.no_vae:
+ self.fc_var = LinearBlock(num_filters, style_dims,nonlinearity='tanh')
+
+
+ def forward(self,histo):
+ x = self.layer1(histo)
+ x = self.layer4(x)
+ mu = self.fc_mu(x) #[-1,1]
+ if not self.no_vae:
+ logvar = self.fc_var(x) # [-1,1]
+ std = torch.exp(0.5 * logvar) # [0.607,1.624]
+ eps = torch.randn_like(std)
+ z = eps.mul(std) + mu
+ else:
+ z = mu
+ logvar = torch.zeros_like(mu)
+ return mu, logvar, z
+
+
+
+class StyleEncoder(nn.Module):
+ r"""Style Encoder constructor.
+
+ Args:
+ style_enc_cfg (obj): Style encoder definition file.
+ """
+
+ def __init__(self, style_enc_cfg):
+ super(StyleEncoder, self).__init__()
+ input_image_channels = style_enc_cfg.input_image_channels
+ num_filters = style_enc_cfg.num_filters
+ kernel_size = style_enc_cfg.kernel_size
+ padding = int(np.ceil((kernel_size - 1.0) / 2))
+ style_dims = style_enc_cfg.style_dims
+ weight_norm_type = style_enc_cfg.weight_norm_type
+ self.no_vae = getattr(style_enc_cfg, 'no_vae', False)
+ activation_norm_type = 'none'
+ nonlinearity = 'leakyrelu'
+ base_conv2d_block = \
+ partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ # inplace_nonlinearity=True,
+ nonlinearity=nonlinearity)
+ self.layer1 = base_conv2d_block(input_image_channels, num_filters)
+ self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2)
+ self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4)
+ self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8)
+ self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8)
+ self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8)
+ self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims,nonlinearity='tanh')
+ if not self.no_vae:
+ self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims,nonlinearity='tanh')
+
+ def forward(self, input_x):
+ r"""SPADE Style Encoder forward.
+
+ Args:
+ input_x (N x 3 x H x W tensor): input images.
+ Returns:
+ mu (N x C tensor): Mean vectors.
+ logvar (N x C tensor): Log-variance vectors.
+ z (N x C tensor): Style code vectors.
+ """
+ if input_x.size(2) != 256 or input_x.size(3) != 256:
+ input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear')
+ x = self.layer1(input_x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.layer5(x)
+ x = self.layer6(x)
+ x = x.view(x.size(0), -1)
+ mu = self.fc_mu(x)
+ if not self.no_vae:
+ logvar = self.fc_var(x)
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ z = eps.mul(std) + mu
+ else:
+ z = mu
+ logvar = torch.zeros_like(mu)
+ return mu, logvar, z
+
+
+class RenderCNN(nn.Module):
+ r"""CNN converting intermediate feature map to final image."""
+
+ def __init__(self, in_channels, style_dim, hidden_channels=256,
+ leaky_relu=True):
+ super(RenderCNN, self).__init__()
+ self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels)
+
+ self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0)
+ self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
+ self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
+
+ self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
+ self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
+
+ self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
+ self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
+
+ self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0)
+
+ if leaky_relu:
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ self.act = partial(F.relu, inplace=True)
+
+ def modulate(self, x, w_, b_):
+ w_ = w_[..., None, None]
+ b_ = b_[..., None, None]
+ return x * (w_+1) + b_ +1e-9
+
+ def forward(self, x, z):
+ r"""Forward network.
+
+ Args:
+ x (N x in_channels x H x W tensor): Intermediate feature map
+ z (N x style_dim tensor): Style codes.
+ """
+ z = self.fc_z_cond(z)
+ adapt = torch.chunk(z, 2 * 2, dim=-1)
+ y = self.act(self.conv1(x))
+
+ y = y + self.conv2b(self.act(self.conv2a(y)))
+ y = self.act(self.modulate(y, adapt[0], adapt[1]))
+
+ y = y + self.conv3b(self.act(self.conv3a(y)))
+ y = self.act(self.modulate(y, adapt[2], adapt[3]))
+
+ y = y + self.conv4b(self.act(self.conv4a(y)))
+ y = self.act(y)
+
+ y = self.conv4(y)
+ y = torch.sigmoid(y)
+ return y
+
+
+class inner_Generator(nn.Module):
+ r"""Pix2pixHD coarse-to-fine generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ last_act: ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``,default is 'relu'.
+ """
+
+ def __init__(self, gen_cfg,inner_cfg, data_cfg,num_input_channels=3,last_act='relu'):
+ super().__init__()
+ assert last_act in ['none', 'relu', 'leakyrelu', 'prelu',
+ 'tanh' , 'sigmoid' , 'softmax']
+ # pix2pixHD has a global generator.
+ global_gen_cfg = inner_cfg
+ # By default, pix2pixHD using instance normalization.
+ activation_norm_type = getattr(gen_cfg, 'activation_norm_type',
+ 'instance')
+ activation_norm_params = getattr(gen_cfg, 'activation_norm_params',
+ None)
+ weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '')
+ padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect')
+ base_conv_block = partial(Conv2dBlock,
+ padding_mode=padding_mode,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ nonlinearity='relu')
+ base_res_block = partial(Res2dBlock,
+ padding_mode=padding_mode,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ nonlinearity='relu', order='CNACN')
+ # Know what is the number of available segmentation labels.
+
+ # Global generator model.
+ global_model = GlobalGenerator(global_gen_cfg, data_cfg,
+ num_input_channels, padding_mode,
+ base_conv_block, base_res_block,last_act=last_act)
+ self.global_model = global_model
+
+
+ def forward(self, input, random_style=False):
+ r"""Coarse-to-fine generator forward.
+
+ Args:
+ data (dict) : Dictionary of input data.
+ random_style (bool): Always set to false for the pix2pixHD model.
+ Returns:
+ output (dict) : Dictionary of output data.
+ """
+ return self.global_model(input)
+
+
+
+ def load_pretrained_network(self, pretrained_dict):
+ r"""Load a pretrained network."""
+ # print(pretrained_dict.keys())
+ model_dict = self.state_dict()
+ print('Pretrained network has fewer layers; The following are '
+ 'not initialized:')
+
+ not_initialized = set()
+ for k, v in model_dict.items():
+ kp = 'module.' + k.replace('global_model.', 'global_model.model.')
+ if kp in pretrained_dict and v.size() == pretrained_dict[kp].size():
+ model_dict[k] = pretrained_dict[kp]
+ else:
+ not_initialized.add('.'.join(k.split('.')[:2]))
+ print(sorted(not_initialized))
+ self.load_state_dict(model_dict)
+
+ def inference(self, data, **kwargs):
+ r"""Generator inference.
+
+ Args:
+ data (dict) : Dictionary of input data.
+ Returns:
+ fake_images (tensor): Output fake images.
+ file_names (str): Data file name.
+ """
+ output = self.forward(data, **kwargs)
+ return output['fake_images'], data['key']['seg_maps'][0]
+
+
+class GlobalGenerator(nn.Module):
+ r"""Coarse generator constructor. This is the main generator in the
+ pix2pixHD architecture.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ num_input_channels (int): Number of segmentation labels.
+ padding_mode (str): zero | reflect | ...
+ base_conv_block (obj): Conv block with preset attributes.
+ base_res_block (obj): Residual block with preset attributes.
+ last_act (str, optional, default='relu'):
+ Type of nonlinear activation function.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ """
+
+ def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode,
+ base_conv_block, base_res_block,last_act='relu'):
+ super(GlobalGenerator, self).__init__()
+
+ # num_img_channels = get_paired_input_image_channel_number(data_cfg)
+ num_out_put_channels = getattr(gen_cfg, 'output_nc', 64)
+ num_filters = getattr(gen_cfg, 'num_filters', 64)
+ num_downsamples = getattr(gen_cfg, 'num_downsamples', 4)
+ num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9)
+ # First layer.
+ model = [base_conv_block(num_input_channels, num_filters,
+ kernel_size=7, padding=3)]
+ # Downsample.
+ for i in range(num_downsamples):
+ ch = num_filters * (2 ** i)
+ model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)]
+ # ResNet blocks.
+ ch = num_filters * (2 ** num_downsamples)
+ for i in range(num_res_blocks):
+ model += [base_res_block(ch, ch, 3, padding=1)]
+ # Upsample.
+ num_upsamples = num_downsamples
+ for i in reversed(range(num_upsamples)):
+ ch = num_filters * (2 ** i)
+ model += \
+ [NearestUpsample(scale_factor=2),
+ base_conv_block(ch * 2, ch, 3, padding=1)]
+ model += [Conv2dBlock(num_filters, num_out_put_channels, 7, padding=3,
+ padding_mode=padding_mode, nonlinearity=last_act)]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ r"""Coarse-to-fine generator forward.
+
+ Args:
+ input (4D tensor) : Input semantic representations.
+ Returns:
+ output (4D tensor) : Synthesized image by generator.
+ """
+ return self.model(input)
+
+class inner_Generator_split(nn.Module):
+ r"""Pix2pixHD coarse-to-fine generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ last_act: ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``,default is 'relu'.
+ """
+
+ def __init__(self, gen_cfg,inner_cfg, data_cfg,num_input_channels=3,last_act='relu'):
+ super().__init__()
+ assert last_act in ['none', 'relu', 'leakyrelu', 'prelu',
+ 'tanh' , 'sigmoid' , 'softmax']
+ # pix2pixHD has a global generator.
+ # By default, pix2pixHD using instance normalization.
+ print(inner_cfg)
+ style_dim = gen_cfg.style_enc_cfg.interm_style_dims
+ activation_norm_type = getattr(gen_cfg, 'activation_norm_type',
+ 'instance')
+ activation_norm_params = getattr(gen_cfg, 'activation_norm_params',
+ None)
+ weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '')
+ padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect')
+ # num_input_channels = get_paired_input_label_channel_number(data_cfg)
+ # num_input_channels = 3
+ base_conv_block = partial(Conv2dBlock,
+ padding_mode=padding_mode,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ )
+ base_res_block = partial(Res2dBlock,
+ padding_mode=padding_mode,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ nonlinearity='relu', order='CNACN')
+ # Know what is the number of available segmentation labels.
+
+ # Global generator model.
+
+ num_out_put_channels = getattr(inner_cfg, 'output_nc', 64)
+ num_filters = getattr(inner_cfg, 'num_filters', 64)
+ num_downsamples = 4
+ num_res_blocks = getattr(inner_cfg, 'num_res_blocks', 9)
+ # First layer.
+ model = [base_conv_block(num_input_channels, num_filters,
+ kernel_size=7, padding=3)]
+ model += [nn.PReLU()]
+ # Downsample.
+ for i in range(num_downsamples):
+ ch = num_filters * (2 ** i)
+ model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)]
+ model += [nn.PReLU()]
+ # ResNet blocks.
+ ch = num_filters * (2 ** num_downsamples)
+ for i in range(num_res_blocks):
+ model += [base_res_block(ch, ch, 3, padding=1)]
+ self.model = nn.Sequential(*model)
+ # Upsample.
+ assert num_downsamples == 4
+ if not (inner_cfg.name =='render' and gen_cfg.style_inject):
+ list = [16,8,4,2]
+ else:
+ list = [16,6,6,6]
+
+ self.up0_a = NearestUpsample(scale_factor=2)
+ self.up0_b = base_conv_block(num_filters * list[0], num_filters*list[1], 3, padding=1)
+ self.up1_a = NearestUpsample(scale_factor=2)
+ self.up1_b = base_conv_block(num_filters * list[1], num_filters*list[2], 3, padding=1)
+ self.up2_a = NearestUpsample(scale_factor=2)
+ self.up2_b = base_conv_block(num_filters * list[2], num_filters*list[3], 3, padding=1)
+ self.up3_a = NearestUpsample(scale_factor=2)
+ self.up3_b = base_conv_block(num_filters * list[3], num_filters, 3, padding=1)
+ self.up_end = Conv2dBlock(num_filters, num_out_put_channels, 7, padding=3,
+ padding_mode=padding_mode, nonlinearity=last_act)
+ if inner_cfg.name =='render' and gen_cfg.style_inject:
+ self.fc_z_cond = nn.Linear(style_dim, 4* list[-1] * num_filters)
+
+ def modulate(self, x, w, b):
+ w = w[..., None, None]
+ b = b[..., None, None]
+ return x * (w+1) + b
+
+ def forward(self, input,z=None):
+ r"""Coarse-to-fine generator forward.
+
+ Args:
+ input (4D tensor) : Input semantic representations.
+ Returns:
+ output (4D tensor) : Synthesized image by generator.
+ """
+ if z is not None:
+ z = self.fc_z_cond(z)
+ adapt = torch.chunk(z, 2 * 2, dim=-1)
+ input = self.model(input)
+ input = self.up0_a(input)
+ input = self.up0_b(input)
+ input = F.leaky_relu(input,negative_slope=0.2, inplace=True)
+ input = self.up1_a(input)
+ input = self.up1_b(input)
+ if z is not None:
+ input = self.modulate(input, adapt[0], adapt[1])
+ input = F.leaky_relu(input,negative_slope=0.2, inplace=True)
+
+ input = self.up2_a(input)
+ input = self.up2_b(input)
+ if z is not None:
+ input = self.modulate(input, adapt[2], adapt[3])
+ input = F.leaky_relu(input,negative_slope=0.2, inplace=True)
+
+ input = self.up3_a(input)
+ input = self.up3_b(input)
+ input = F.leaky_relu(input,negative_slope=0.2, inplace=True)
+
+ input = self.up_end(input)
+
+ return input
+
+if __name__=='__main__':
+ from easydict import EasyDict as edict
+ style_enc_cfg = edict()
+ style_enc_cfg.histo.mode = 'RGB'
+ style_enc_cfg.histo.num_filters = 180
+ model = histo_process(style_enc_cfg)
\ No newline at end of file
diff --git a/imaginaire/generators/dummy.py b/imaginaire/generators/dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9a2f1edec286be8751d6a188ebc4f47875c437
--- /dev/null
+++ b/imaginaire/generators/dummy.py
@@ -0,0 +1,29 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch.nn as nn
+
+from imaginaire.layers import LinearBlock
+
+
+class Generator(nn.Module):
+ r"""Dummy generator.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super(Generator, self).__init__()
+ self.dummy_layer = LinearBlock(1, 1)
+ pass
+
+ def forward(self, data):
+ r"""Dummy Generator forward.
+
+ Args:
+ data (dict):
+ """
+ return
diff --git a/imaginaire/generators/fs_vid2vid.py b/imaginaire/generators/fs_vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..93c2c20048d0e371d9393302cc858fe738a4d53b
--- /dev/null
+++ b/imaginaire/generators/fs_vid2vid.py
@@ -0,0 +1,1176 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import copy
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.layers import (Conv2dBlock, HyperConv2dBlock, HyperRes2dBlock,
+ LinearBlock, Res2dBlock)
+from imaginaire.model_utils.fs_vid2vid import (extract_valid_pose_labels,
+ pick_image, resample)
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+ get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.init_weight import weights_init
+from imaginaire.utils.misc import get_and_setattr, get_nested_attr
+
+
+class Generator(nn.Module):
+ r"""Few-shot vid2vid generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super().__init__()
+ self.gen_cfg = gen_cfg
+ self.data_cfg = data_cfg
+ self.num_frames_G = data_cfg.num_frames_G
+ self.flow_cfg = flow_cfg = gen_cfg.flow
+
+ # For pose dataset.
+ self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset')
+ if self.is_pose_data:
+ pose_cfg = data_cfg.for_pose_dataset
+ self.pose_type = getattr(pose_cfg, 'pose_type', 'both')
+ self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels',
+ False)
+
+ num_img_channels = get_paired_input_image_channel_number(data_cfg)
+ self.num_downsamples = num_downsamples = \
+ get_and_setattr(gen_cfg, 'num_downsamples', 5)
+ conv_kernel_size = get_and_setattr(gen_cfg, 'kernel_size', 3)
+ num_filters = get_and_setattr(gen_cfg, 'num_filters', 32)
+
+ max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024)
+ self.max_num_filters = gen_cfg.max_num_filters = \
+ min(max_num_filters, num_filters * (2 ** num_downsamples))
+ # Get number of filters at each layer in the main branch.
+ num_filters_each_layer = [min(self.max_num_filters,
+ num_filters * (2 ** i))
+ for i in range(num_downsamples + 2)]
+
+ # Hyper normalization / convolution.
+ hyper_cfg = gen_cfg.hyper
+ # Use adaptive weight generation for SPADE.
+ self.use_hyper_spade = hyper_cfg.is_hyper_spade
+ # Use adaptive for convolutional layers in the main branch.
+ self.use_hyper_conv = hyper_cfg.is_hyper_conv
+ # Number of hyper layers.
+ self.num_hyper_layers = getattr(hyper_cfg, 'num_hyper_layers', 4)
+ if self.num_hyper_layers == -1:
+ self.num_hyper_layers = num_downsamples
+ gen_cfg.hyper.num_hyper_layers = self.num_hyper_layers
+ # Network weight generator.
+ self.weight_generator = WeightGenerator(gen_cfg, data_cfg)
+
+ # Number of layers to perform multi-spade combine.
+ self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine,
+ 'num_layers', 3)
+ # Whether to generate raw output for additional losses.
+ self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output',
+ False)
+
+ # Main branch image generation.
+ padding = conv_kernel_size // 2
+ activation_norm_type = get_and_setattr(gen_cfg, 'activation_norm_type',
+ 'sync_batch')
+ weight_norm_type = get_and_setattr(gen_cfg, 'weight_norm_type',
+ 'spectral')
+ activation_norm_params = get_and_setattr(gen_cfg,
+ 'activation_norm_params',
+ None)
+ spade_in_channels = [] # Input channel size in SPADE module.
+ for i in range(num_downsamples + 1):
+ spade_in_channels += [[num_filters_each_layer[i]]] \
+ if i >= self.num_multi_spade_layers \
+ else [[num_filters_each_layer[i]] * 3]
+
+ order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC')
+ for i in reversed(range(num_downsamples + 1)):
+ activation_norm_params.cond_dims = spade_in_channels[i]
+ is_hyper_conv = self.use_hyper_conv and i < self.num_hyper_layers
+ is_hyper_norm = self.use_hyper_spade and i < self.num_hyper_layers
+ setattr(self, 'up_%d' % i, HyperRes2dBlock(
+ num_filters_each_layer[i + 1], num_filters_each_layer[i],
+ conv_kernel_size, padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ order=order * 2,
+ is_hyper_conv=is_hyper_conv, is_hyper_norm=is_hyper_norm))
+
+ self.conv_img = Conv2dBlock(num_filters, num_img_channels,
+ conv_kernel_size, padding=padding,
+ nonlinearity='leakyrelu', order='AC')
+ self.upsample = partial(F.interpolate, scale_factor=2)
+
+ # Flow estimation module.
+ # Whether to warp reference image and combine with the synthesized.
+ self.warp_ref = getattr(flow_cfg, 'warp_ref', True)
+ if self.warp_ref:
+ self.flow_network_ref = FlowGenerator(flow_cfg, data_cfg, 2)
+ self.ref_image_embedding = \
+ LabelEmbedder(flow_cfg.multi_spade_combine.embed,
+ num_img_channels + 1)
+ # At beginning of training, only train an image generator.
+ self.temporal_initialized = False
+ if getattr(gen_cfg, 'init_temporal', True):
+ self.init_temporal_network()
+
+ def forward(self, data):
+ r"""few-shot vid2vid generator forward.
+
+ Args:
+ data (dict) : Dictionary of input data.
+ Returns:
+ output (dict) : Dictionary of output data.
+ """
+ label = data['label']
+ ref_labels, ref_images = data['ref_labels'], data['ref_images']
+ prev_labels, prev_images = data['prev_labels'], data['prev_images']
+ is_first_frame = prev_labels is None
+
+ if self.is_pose_data:
+ label, prev_labels = extract_valid_pose_labels(
+ [label, prev_labels], self.pose_type, self.remove_face_labels)
+ ref_labels = extract_valid_pose_labels(
+ ref_labels, self.pose_type, self.remove_face_labels,
+ do_remove=False)
+
+ # Weight generation.
+ x, encoded_label, conv_weights, norm_weights, atn, atn_vis, ref_idx = \
+ self.weight_generator(ref_images, ref_labels, label, is_first_frame)
+
+ # Flow estimation.
+ flow, flow_mask, img_warp, cond_inputs = \
+ self.flow_generation(label, ref_labels, ref_images,
+ prev_labels, prev_images, ref_idx)
+
+ for i in range(len(encoded_label)):
+ encoded_label[i] = [encoded_label[i]]
+ if self.generate_raw_output:
+ encoded_label_raw = [encoded_label[i] for i in
+ range(self.num_multi_spade_layers)]
+ x_raw = None
+ encoded_label = self.SPADE_combine(encoded_label, cond_inputs)
+
+ # Main branch image generation.
+ for i in range(self.num_downsamples, -1, -1):
+ conv_weight = norm_weight = [None] * 3
+ if self.use_hyper_conv and i < self.num_hyper_layers:
+ conv_weight = conv_weights[i]
+ if self.use_hyper_spade and i < self.num_hyper_layers:
+ norm_weight = norm_weights[i]
+
+ # Main branch residual blocks.
+ x = self.one_up_conv_layer(x, encoded_label,
+ conv_weight, norm_weight, i)
+
+ # For raw output generation.
+ if self.generate_raw_output and i < self.num_multi_spade_layers:
+ x_raw = self.one_up_conv_layer(x_raw, encoded_label_raw,
+ conv_weight, norm_weight, i)
+ else:
+ x_raw = x
+
+ # Final conv layer.
+ if self.generate_raw_output:
+ img_raw = torch.tanh(self.conv_img(x_raw))
+ else:
+ img_raw = None
+ img_final = torch.tanh(self.conv_img(x))
+
+ output = dict()
+ output['fake_images'] = img_final
+ output['fake_flow_maps'] = flow
+ output['fake_occlusion_masks'] = flow_mask
+ output['fake_raw_images'] = img_raw
+ output['warped_images'] = img_warp
+ output['attention_visualization'] = atn_vis
+ output['ref_idx'] = ref_idx
+ return output
+
+ def one_up_conv_layer(self, x, encoded_label, conv_weight, norm_weight, i):
+ r"""One residual block layer in the main branch.
+
+ Args:
+ x (4D tensor) : Current feature map.
+ encoded_label (list of tensors) : Encoded input label maps.
+ conv_weight (list of tensors) : Hyper conv weights.
+ norm_weight (list of tensors) : Hyper norm weights.
+ i (int) : Layer index.
+ Returns:
+ x (4D tensor) : Output feature map.
+ """
+ layer = getattr(self, 'up_' + str(i))
+ x = layer(x, *encoded_label[i], conv_weights=conv_weight,
+ norm_weights=norm_weight)
+ if i != 0:
+ x = self.upsample(x)
+ return x
+
+ def init_temporal_network(self, cfg_init=None):
+ r"""When starting training multiple frames, initialize the flow network.
+
+ Args:
+ cfg_init (dict) : Weight initialization config.
+ """
+ flow_cfg = self.flow_cfg
+ emb_cfg = flow_cfg.multi_spade_combine.embed
+ num_frames_G = self.num_frames_G
+ self.temporal_initialized = True
+
+ self.sep_prev_flownet = flow_cfg.sep_prev_flow or (num_frames_G != 2) \
+ or not flow_cfg.warp_ref
+ if self.sep_prev_flownet:
+ self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg,
+ num_frames_G)
+ if cfg_init is not None:
+ self.flow_network_temp.apply(weights_init(cfg_init.type,
+ cfg_init.gain))
+ else:
+ self.flow_network_temp = self.flow_network_ref
+
+ self.sep_prev_embedding = emb_cfg.sep_warp_embed or \
+ not flow_cfg.warp_ref
+ if self.sep_prev_embedding:
+ num_img_channels = get_paired_input_image_channel_number(
+ self.data_cfg)
+ self.prev_image_embedding = \
+ LabelEmbedder(emb_cfg, num_img_channels + 1)
+ if cfg_init is not None:
+ self.prev_image_embedding.apply(
+ weights_init(cfg_init.type, cfg_init.gain))
+ else:
+ self.prev_image_embedding = self.ref_image_embedding
+
+ if self.warp_ref:
+ if self.sep_prev_flownet:
+ self.init_network_weights(self.flow_network_ref,
+ self.flow_network_temp)
+ print('Initialized temporal flow network with the reference '
+ 'one.')
+ if self.sep_prev_embedding:
+ self.init_network_weights(self.ref_image_embedding,
+ self.prev_image_embedding)
+ print('Initialized temporal embedding network with the '
+ 'reference one.')
+ self.flow_temp_is_initalized = True
+
+ def init_network_weights(self, net_src, net_dst):
+ r"""Initialize weights in net_dst with those in net_src."""
+ source_weights = net_src.state_dict()
+ target_weights = net_dst.state_dict()
+
+ for k, v in source_weights.items():
+ if k in target_weights and target_weights[k].size() == v.size():
+ target_weights[k] = v
+ net_dst.load_state_dict(target_weights)
+
+ def load_pretrained_network(self, pretrained_dict, prefix='module.'):
+ r"""Load the pretrained network into self network.
+
+ Args:
+ pretrained_dict (dict): Pretrained network weights.
+ prefix (str): Prefix to the network weights name.
+ """
+ # print(pretrained_dict.keys())
+ model_dict = self.state_dict()
+ print('Pretrained network has fewer layers; The following are '
+ 'not initialized:')
+
+ not_initialized = set()
+ for k, v in model_dict.items():
+ kp = prefix + k
+ if kp in pretrained_dict and v.size() == pretrained_dict[kp].size():
+ model_dict[k] = pretrained_dict[kp]
+ else:
+ not_initialized.add('.'.join(k.split('.')[:2]))
+ print(sorted(not_initialized))
+ self.load_state_dict(model_dict)
+
+ def reset(self):
+ r"""Reset the network at the beginning of a sequence."""
+ self.weight_generator.reset()
+
+ def flow_generation(self, label, ref_labels, ref_images, prev_labels,
+ prev_images, ref_idx):
+ r"""Generates flows and masks for warping reference / previous images.
+
+ Args:
+ label (NxCxHxW tensor): Target label map.
+ ref_labels (NxKxCxHxW tensor): Reference label maps.
+ ref_images (NxKx3xHxW tensor): Reference images.
+ prev_labels (NxTxCxHxW tensor): Previous label maps.
+ prev_images (NxTx3xHxW tensor): Previous images.
+ ref_idx (Nx1 tensor): Index for which image to use from the
+ reference images.
+ Returns:
+ (tuple):
+ - flow (list of Nx2xHxW tensor): Optical flows.
+ - occ_mask (list of Nx1xHxW tensor): Occlusion masks.
+ - img_warp (list of Nx3xHxW tensor): Warped reference / previous
+ images.
+ - cond_inputs (list of Nx4xHxW tensor): Conditional inputs for
+ SPADE combination.
+ """
+ # Pick an image in the reference images using ref_idx.
+ ref_label, ref_image = pick_image([ref_labels, ref_images], ref_idx)
+ # Only start using prev frames when enough prev frames are generated.
+ has_prev = prev_labels is not None and \
+ prev_labels.shape[1] == (self.num_frames_G - 1)
+ flow, occ_mask, img_warp, cond_inputs = [None] * 2, [None] * 2, \
+ [None] * 2, [None] * 2
+ if self.warp_ref:
+ # Generate flows/masks for warping the reference image.
+ flow_ref, occ_mask_ref = \
+ self.flow_network_ref(label, ref_label, ref_image)
+ ref_image_warp = resample(ref_image, flow_ref)
+ flow[0], occ_mask[0], img_warp[0] = \
+ flow_ref, occ_mask_ref, ref_image_warp[:, :3]
+ # Concat warped image and occlusion mask to form the conditional
+ # input.
+ cond_inputs[0] = torch.cat([img_warp[0], occ_mask[0]], dim=1)
+
+ if self.temporal_initialized and has_prev:
+ # Generate flows/masks for warping the previous image.
+ b, t, c, h, w = prev_labels.shape
+ prev_labels_concat = prev_labels.view(b, -1, h, w)
+ prev_images_concat = prev_images.view(b, -1, h, w)
+ flow_prev, occ_mask_prev = \
+ self.flow_network_temp(label, prev_labels_concat,
+ prev_images_concat)
+ img_prev_warp = resample(prev_images[:, -1], flow_prev)
+ flow[1], occ_mask[1], img_warp[1] = \
+ flow_prev, occ_mask_prev, img_prev_warp
+ cond_inputs[1] = torch.cat([img_warp[1], occ_mask[1]], dim=1)
+
+ return flow, occ_mask, img_warp, cond_inputs
+
+ def SPADE_combine(self, encoded_label, cond_inputs):
+ r"""Using Multi-SPADE to combine raw synthesized image with warped
+ images.
+
+ Args:
+ encoded_label (list of tensors): Original label map embeddings.
+ cond_inputs (list of tensors): New SPADE conditional inputs from the
+ warped images.
+ Returns:
+ encoded_label (list of tensors): Combined conditional inputs.
+ """
+ # Generate the conditional embeddings from inputs.
+ embedded_img_feat = [None, None]
+ if cond_inputs[0] is not None:
+ embedded_img_feat[0] = self.ref_image_embedding(cond_inputs[0])
+ if cond_inputs[1] is not None:
+ embedded_img_feat[1] = self.prev_image_embedding(cond_inputs[1])
+
+ # Combine the original encoded label maps with new conditional
+ # embeddings.
+ for i in range(self.num_multi_spade_layers):
+ encoded_label[i] += [w[i] if w is not None else None
+ for w in embedded_img_feat]
+ return encoded_label
+
+ def custom_init(self):
+ r"""This function is for dealing with the numerical issue that might
+ occur when doing mixed precision training.
+ """
+ print('Use custom initialization for the generator.')
+ for k, m in self.named_modules():
+ if 'weight_generator.ref_label_' in k and 'norm' in k:
+ m.eps = 1e-1
+
+
+class WeightGenerator(nn.Module):
+ r"""Weight generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super().__init__()
+ self.data_cfg = data_cfg
+ self.embed_cfg = embed_cfg = gen_cfg.embed
+ self.embed_arch = embed_cfg.arch
+
+ num_filters = gen_cfg.num_filters
+ self.max_num_filters = gen_cfg.max_num_filters
+ self.num_downsamples = num_downsamples = gen_cfg.num_downsamples
+ self.num_filters_each_layer = num_filters_each_layer = \
+ [min(self.max_num_filters, num_filters * (2 ** i))
+ for i in range(num_downsamples + 2)]
+ if getattr(embed_cfg, 'num_filters', 32) != num_filters:
+ raise ValueError('Embedding network must have the same number of '
+ 'filters as generator.')
+
+ # Normalization params.
+ hyper_cfg = gen_cfg.hyper
+ kernel_size = getattr(hyper_cfg, 'kernel_size', 3)
+ activation_norm_type = getattr(hyper_cfg, 'activation_norm_type',
+ 'sync_batch')
+ weight_norm_type = getattr(hyper_cfg, 'weight_norm_type', 'spectral')
+ # Conv kernel size in main branch.
+ self.conv_kernel_size = conv_kernel_size = gen_cfg.kernel_size
+ # Conv kernel size in embedding network.
+ self.embed_kernel_size = embed_kernel_size = \
+ getattr(gen_cfg.embed, 'kernel_size', 3)
+ # Conv kernel size in SPADE.
+ self.kernel_size = kernel_size = \
+ getattr(gen_cfg.activation_norm_params, 'kernel_size', 1)
+ # Input channel size in SPADE module.
+ self.spade_in_channels = []
+ for i in range(num_downsamples + 1):
+ self.spade_in_channels += [num_filters_each_layer[i]]
+
+ # Hyper normalization / convolution.
+ # Use adaptive weight generation for SPADE.
+ self.use_hyper_spade = hyper_cfg.is_hyper_spade
+ # Use adaptive for the label embedding network.
+ self.use_hyper_embed = hyper_cfg.is_hyper_embed
+ # Use adaptive for convolutional layers in the main branch.
+ self.use_hyper_conv = hyper_cfg.is_hyper_conv
+ # Number of hyper layers.
+ self.num_hyper_layers = hyper_cfg.num_hyper_layers
+ # Order of operations in the conv block.
+ order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC')
+ self.conv_before_norm = order.find('C') < order.find('N')
+
+ # For reference image encoding.
+ # How to utilize the reference label map: concat | mul.
+ self.concat_ref_label = 'concat' in hyper_cfg.method_to_use_ref_labels
+ self.mul_ref_label = 'mul' in hyper_cfg.method_to_use_ref_labels
+ # Output spatial size for adaptive pooling layer.
+ self.sh_fix = self.sw_fix = 32
+ # Number of fc layers in weight generation.
+ self.num_fc_layers = getattr(hyper_cfg, 'num_fc_layers', 2)
+
+ # Reference image encoding network.
+ num_input_channels = get_paired_input_label_channel_number(data_cfg)
+ if num_input_channels == 0:
+ num_input_channels = getattr(data_cfg, 'label_channels', 1)
+ elif get_nested_attr(data_cfg, 'for_pose_dataset.pose_type',
+ 'both') == 'open':
+ num_input_channels -= 3
+ data_cfg.num_input_channels = num_input_channels
+ num_img_channels = get_paired_input_image_channel_number(data_cfg)
+ num_ref_channels = num_img_channels + (num_input_channels
+ if self.concat_ref_label else 0)
+ conv_2d_block = partial(
+ Conv2dBlock, kernel_size=kernel_size,
+ padding=(kernel_size // 2), weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity='leakyrelu')
+
+ self.ref_img_first = conv_2d_block(num_ref_channels, num_filters)
+ if self.mul_ref_label:
+ self.ref_label_first = conv_2d_block(num_input_channels,
+ num_filters)
+
+ for i in range(num_downsamples):
+ in_ch, out_ch = num_filters_each_layer[i], \
+ num_filters_each_layer[i + 1]
+ setattr(self, 'ref_img_down_%d' % i,
+ conv_2d_block(in_ch, out_ch, stride=2))
+ setattr(self, 'ref_img_up_%d' % i, conv_2d_block(out_ch, in_ch))
+ if self.mul_ref_label:
+ setattr(self, 'ref_label_down_%d' % i,
+ conv_2d_block(in_ch, out_ch, stride=2))
+ setattr(self, 'ref_label_up_%d' % i,
+ conv_2d_block(out_ch, in_ch))
+
+ # Normalization / main branch conv weight generation.
+ if self.use_hyper_spade or self.use_hyper_conv:
+ for i in range(self.num_hyper_layers):
+ ch_in, ch_out = num_filters_each_layer[i], \
+ num_filters_each_layer[i + 1]
+ conv_ks2 = conv_kernel_size ** 2
+ embed_ks2 = embed_kernel_size ** 2
+ spade_ks2 = kernel_size ** 2
+ spade_in_ch = self.spade_in_channels[i]
+
+ fc_names, fc_ins, fc_outs = [], [], []
+ if self.use_hyper_spade:
+ fc0_out = fcs_out = (spade_in_ch * spade_ks2 + 1) * (
+ 1 if self.conv_before_norm else 2)
+ fc1_out = (spade_in_ch * spade_ks2 + 1) * (
+ 1 if ch_in != ch_out else 2)
+ fc_names += ['fc_spade_0', 'fc_spade_1', 'fc_spade_s']
+ fc_ins += [ch_out] * 3
+ fc_outs += [fc0_out, fc1_out, fcs_out]
+ if self.use_hyper_embed:
+ fc_names += ['fc_spade_e']
+ fc_ins += [ch_out]
+ fc_outs += [ch_in * embed_ks2 + 1]
+ if self.use_hyper_conv:
+ fc0_out = ch_out * conv_ks2 + 1
+ fc1_out = ch_in * conv_ks2 + 1
+ fcs_out = ch_out + 1
+ fc_names += ['fc_conv_0', 'fc_conv_1', 'fc_conv_s']
+ fc_ins += [ch_in] * 3
+ fc_outs += [fc0_out, fc1_out, fcs_out]
+
+ linear_block = partial(LinearBlock,
+ weight_norm_type='spectral',
+ nonlinearity='leakyrelu')
+ for n, l in enumerate(fc_names):
+ fc_in = fc_ins[n] if self.mul_ref_label \
+ else self.sh_fix * self.sw_fix
+ fc_layer = [linear_block(fc_in, ch_out)]
+ for k in range(1, self.num_fc_layers):
+ fc_layer += [linear_block(ch_out, ch_out)]
+ fc_layer += [LinearBlock(ch_out, fc_outs[n],
+ weight_norm_type='spectral')]
+ setattr(self, '%s_%d' % (l, i), nn.Sequential(*fc_layer))
+
+ # Label embedding network.
+ num_hyper_layers = self.num_hyper_layers if self.use_hyper_embed else 0
+ self.label_embedding = LabelEmbedder(self.embed_cfg,
+ num_input_channels,
+ num_hyper_layers=num_hyper_layers)
+
+ # For multiple reference images.
+ if hasattr(hyper_cfg, 'attention'):
+ self.num_downsample_atn = get_and_setattr(hyper_cfg.attention,
+ 'num_downsamples', 2)
+ if data_cfg.initial_few_shot_K > 1:
+ self.attention_module = AttentionModule(hyper_cfg, data_cfg,
+ conv_2d_block,
+ num_filters_each_layer)
+ else:
+ self.num_downsample_atn = 0
+
+ def forward(self, ref_image, ref_label, label, is_first_frame):
+ r"""Generate network weights based on the reference images.
+
+ Args:
+ ref_image (NxKx3xHxW tensor): Reference images.
+ ref_label (NxKxCxHxW tensor): Reference labels.
+ label (NxCxHxW tensor): Target label.
+ is_first_frame (bool): Whether the current frame is the first frame.
+
+ Returns:
+ (tuple):
+ - x (NxC2xH2xW2 tensor): Encoded features from reference images
+ for the main branch (as input to the decoder).
+ - encoded_label (list of tensors): Encoded target label map for
+ SPADE.
+ - conv_weights (list of tensors): Network weights for conv
+ layers in the main network.
+ - norm_weights (list of tensors): Network weights for SPADE
+ layers in the main network.
+ - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps.
+ - atn_vis (1x1xH1xW1 tensor): Visualization for attention
+ scores.
+ - ref_idx (Nx1 tensor): Index for which image to use from the
+ reference images.
+ """
+ b, k, c, h, w = ref_image.size()
+ ref_image = ref_image.view(b * k, -1, h, w)
+ if ref_label is not None:
+ ref_label = ref_label.view(b * k, -1, h, w)
+
+ # Encode the reference images to get the features.
+ x, encoded_ref, atn, atn_vis, ref_idx = \
+ self.encode_reference(ref_image, ref_label, label, k)
+
+ # If the reference image has changed, recompute the network weights.
+ if self.training or is_first_frame or k > 1:
+ embedding_weights, norm_weights, conv_weights = [], [], []
+ for i in range(self.num_hyper_layers):
+ if self.use_hyper_spade:
+ feat = encoded_ref[min(len(encoded_ref) - 1, i + 1)]
+ embedding_weight, norm_weight = \
+ self.get_norm_weights(feat, i)
+ embedding_weights.append(embedding_weight)
+ norm_weights.append(norm_weight)
+ if self.use_hyper_conv:
+ feat = encoded_ref[min(len(encoded_ref) - 1, i)]
+ conv_weights.append(self.get_conv_weights(feat, i))
+
+ if not self.training:
+ self.embedding_weights, self.conv_weights, self.norm_weights \
+ = embedding_weights, conv_weights, norm_weights
+ else:
+ # print('Reusing network weights.')
+ embedding_weights, conv_weights, norm_weights \
+ = self.embedding_weights, self.conv_weights, self.norm_weights
+
+ # Encode the target label to get the encoded features.
+ encoded_label = self.label_embedding(label, weights=(
+ embedding_weights if self.use_hyper_embed else None))
+
+ return x, encoded_label, conv_weights, norm_weights, \
+ atn, atn_vis, ref_idx
+
+ def encode_reference(self, ref_image, ref_label, label, k):
+ r"""Encode the reference image to get features for weight generation.
+
+ Args:
+ ref_image ((NxK)x3xHxW tensor): Reference images.
+ ref_label ((NxK)xCxHxW tensor): Reference labels.
+ label (NxCxHxW tensor): Target label.
+ k (int): Number of reference images.
+ Returns:
+ (tuple):
+ - x (NxC2xH2xW2 tensor): Encoded features from reference images
+ for the main branch (as input to the decoder).
+ - encoded_ref (list of tensors): Encoded features from reference
+ images for the weight generation branch.
+ - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps.
+ - atn_vis (1x1xH1xW1 tensor): Visualization for attention scores.
+ - ref_idx (Nx1 tensor): Index for which image to use from the
+ reference images.
+ """
+ if self.concat_ref_label:
+ # Concat reference label map and image together for encoding.
+ concat_ref = torch.cat([ref_image, ref_label], dim=1)
+ x = self.ref_img_first(concat_ref)
+ elif self.mul_ref_label:
+ # Apply conv to both reference label and image, then multiply them
+ # together for encoding.
+ x = self.ref_img_first(ref_image)
+ x_label = self.ref_label_first(ref_label)
+ else:
+ x = self.ref_img_first(ref_image)
+
+ # Attention map and the index of the most similar reference image.
+ atn = atn_vis = ref_idx = None
+ for i in range(self.num_downsamples):
+ x = getattr(self, 'ref_img_down_' + str(i))(x)
+ if self.mul_ref_label:
+ x_label = getattr(self, 'ref_label_down_' + str(i))(x_label)
+
+ # Combine different reference images at a particular layer.
+ if k > 1 and i == self.num_downsample_atn - 1:
+ x, atn, atn_vis = self.attention_module(x, label, ref_label)
+ if self.mul_ref_label:
+ x_label, _, _ = self.attention_module(x_label, None, None,
+ atn)
+
+ atn_sum = atn.view(label.shape[0], k, -1).sum(2)
+ ref_idx = torch.argmax(atn_sum, dim=1)
+
+ # Get all corresponding layers in the encoder output for generating
+ # weights in corresponding layers.
+ encoded_image_ref = [x]
+ if self.mul_ref_label:
+ encoded_ref_label = [x_label]
+
+ for i in reversed(range(self.num_downsamples)):
+ conv = getattr(self, 'ref_img_up_' + str(i))(
+ encoded_image_ref[-1])
+ encoded_image_ref.append(conv)
+ if self.mul_ref_label:
+ conv_label = getattr(self, 'ref_label_up_' + str(i))(
+ encoded_ref_label[-1])
+ encoded_ref_label.append(conv_label)
+
+ if self.mul_ref_label:
+ encoded_ref = []
+ for i in range(len(encoded_image_ref)):
+ conv, conv_label \
+ = encoded_image_ref[i], encoded_ref_label[i]
+ b, c, h, w = conv.size()
+ conv_label = nn.Softmax(dim=1)(conv_label)
+ conv_prod = (conv.view(b, c, 1, h * w) *
+ conv_label.view(b, 1, c,
+ h * w)).sum(3, keepdim=True)
+ encoded_ref.append(conv_prod)
+ else:
+ encoded_ref = encoded_image_ref
+ encoded_ref = encoded_ref[::-1]
+
+ return x, encoded_ref, atn, atn_vis, ref_idx
+
+ def get_norm_weights(self, x, i):
+ r"""Adaptively generate weights for SPADE in layer i of generator.
+
+ Args:
+ x (NxCxHxW tensor): Input features.
+ i (int): Layer index.
+ Returns:
+ (tuple):
+ - embedding_weights (list of tensors): Weights for the label
+ embedding network.
+ - norm_weights (list of tensors): Weights for the SPADE layers.
+ """
+ if not self.mul_ref_label:
+ # Get fixed output size for fc layers.
+ x = nn.AdaptiveAvgPool2d((self.sh_fix, self.sw_fix))(x)
+
+ in_ch = self.num_filters_each_layer[i]
+ out_ch = self.num_filters_each_layer[i + 1]
+ spade_ch = self.spade_in_channels[i]
+ eks, sks = self.embed_kernel_size, self.kernel_size
+
+ b = x.size(0)
+ weight_reshaper = WeightReshaper()
+ x = weight_reshaper.reshape_embed_input(x)
+
+ # Weights for the label embedding network.
+ embedding_weights = None
+ if self.use_hyper_embed:
+ fc_e = getattr(self, 'fc_spade_e_' + str(i))(x).view(b, -1)
+ if 'decoder' in self.embed_arch:
+ weight_shape = [in_ch, out_ch, eks, eks]
+ fc_e = fc_e[:, :-in_ch]
+ else:
+ weight_shape = [out_ch, in_ch, eks, eks]
+ embedding_weights = weight_reshaper.reshape_weight(fc_e,
+ weight_shape)
+
+ # Weights for the 3 layers in SPADE module: conv_0, conv_1,
+ # and shortcut.
+ fc_0 = getattr(self, 'fc_spade_0_' + str(i))(x).view(b, -1)
+ fc_1 = getattr(self, 'fc_spade_1_' + str(i))(x).view(b, -1)
+ fc_s = getattr(self, 'fc_spade_s_' + str(i))(x).view(b, -1)
+ if self.conv_before_norm:
+ out_ch = in_ch
+ weight_0 = weight_reshaper.reshape_weight(fc_0, [out_ch * 2, spade_ch,
+ sks, sks])
+ weight_1 = weight_reshaper.reshape_weight(fc_1, [in_ch * 2, spade_ch,
+ sks, sks])
+ weight_s = weight_reshaper.reshape_weight(fc_s, [out_ch * 2, spade_ch,
+ sks, sks])
+ norm_weights = [weight_0, weight_1, weight_s]
+
+ return embedding_weights, norm_weights
+
+ def get_conv_weights(self, x, i):
+ r"""Adaptively generate weights for layer i in main branch convolutions.
+
+ Args:
+ x (NxCxHxW tensor): Input features.
+ i (int): Layer index.
+ Returns:
+ (tuple):
+ - conv_weights (list of tensors): Weights for the conv layers in
+ the main branch.
+ """
+ if not self.mul_ref_label:
+ x = nn.AdaptiveAvgPool2d((self.sh_fix, self.sw_fix))(x)
+ in_ch = self.num_filters_each_layer[i]
+ out_ch = self.num_filters_each_layer[i + 1]
+ cks = self.conv_kernel_size
+ b = x.size()[0]
+ weight_reshaper = WeightReshaper()
+ x = weight_reshaper.reshape_embed_input(x)
+
+ fc_0 = getattr(self, 'fc_conv_0_' + str(i))(x).view(b, -1)
+ fc_1 = getattr(self, 'fc_conv_1_' + str(i))(x).view(b, -1)
+ fc_s = getattr(self, 'fc_conv_s_' + str(i))(x).view(b, -1)
+ weight_0 = weight_reshaper.reshape_weight(fc_0, [in_ch, out_ch,
+ cks, cks])
+ weight_1 = weight_reshaper.reshape_weight(fc_1, [in_ch, in_ch,
+ cks, cks])
+ weight_s = weight_reshaper.reshape_weight(fc_s, [in_ch, out_ch, 1, 1])
+ return [weight_0, weight_1, weight_s]
+
+ def reset(self):
+ r"""Reset the network at the beginning of a sequence."""
+ self.embedding_weights = self.conv_weights = self.norm_weights = None
+
+
+class WeightReshaper():
+ r"""Handles all weight reshape related tasks."""
+ def reshape_weight(self, x, weight_shape):
+ r"""Reshape input x to the desired weight shape.
+
+ Args:
+ x (tensor or list of tensors): Input features.
+ weight_shape (list of int): Desired shape of the weight.
+ Returns:
+ (tuple):
+ - weight (tensor): Network weights
+ - bias (tensor): Network bias.
+ """
+ # If desired shape is a list, first divide x into the target list of
+ # features.
+ if type(weight_shape[0]) == list and type(x) != list:
+ x = self.split_weights(x, self.sum_mul(weight_shape))
+
+ if type(x) == list:
+ return [self.reshape_weight(xi, wi)
+ for xi, wi in zip(x, weight_shape)]
+
+ # Get output shape, and divide x into either weight + bias or
+ # just weight.
+ weight_shape = [x.size(0)] + weight_shape
+ bias_size = weight_shape[1]
+ try:
+ weight = x[:, :-bias_size].view(weight_shape)
+ bias = x[:, -bias_size:]
+ except Exception:
+ weight = x.view(weight_shape)
+ bias = None
+ return [weight, bias]
+
+ def split_weights(self, weight, sizes):
+ r"""When the desired shape is a list, first divide the input to each
+ corresponding weight shape in the list.
+
+ Args:
+ weight (tensor): Input weight.
+ sizes (int or list of int): Target sizes.
+ Returns:
+ weight (list of tensors): Divided weights.
+ """
+ if isinstance(sizes, list):
+ weights = []
+ cur_size = 0
+ for i in range(len(sizes)):
+ # For each target size in sizes, get the number of elements
+ # needed.
+ next_size = cur_size + self.sum(sizes[i])
+ # Recursively divide the weights.
+ weights.append(self.split_weights(
+ weight[:, cur_size:next_size], sizes[i]))
+ cur_size = next_size
+ assert (next_size == weight.size(1))
+ return weights
+ return weight
+
+ def reshape_embed_input(self, x):
+ r"""Reshape input to be (B x C) X H X W.
+
+ Args:
+ x (tensor or list of tensors): Input features.
+ Returns:
+ x (tensor or list of tensors): Reshaped features.
+ """
+ if isinstance(x, list):
+ return [self.reshape_embed_input(xi) for xi in zip(x)]
+ b, c, _, _ = x.size()
+ x = x.view(b * c, -1)
+ return x
+
+ def sum(self, x):
+ r"""Sum all elements recursively in a nested list.
+
+ Args:
+ x (nested list of int): Input list of elements.
+ Returns:
+ out (int): Sum of all elements.
+ """
+ if type(x) != list:
+ return x
+ return sum([self.sum(xi) for xi in x])
+
+ def sum_mul(self, x):
+ r"""Given a weight shape, compute the number of elements needed for
+ weight + bias. If input is a list of shapes, sum all the elements.
+
+ Args:
+ x (list of int): Input list of elements.
+ Returns:
+ out (int or list of int): Summed number of elements.
+ """
+ assert (type(x) == list)
+ if type(x[0]) != list:
+ return np.prod(x) + x[0] # x[0] accounts for bias.
+ return [self.sum_mul(xi) for xi in x]
+
+
+class AttentionModule(nn.Module):
+ r"""Attention module constructor.
+
+ Args:
+ atn_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file
+ conv_2d_block: Conv2DBlock constructor.
+ num_filters_each_layer (int): The number of filters in each layer.
+ """
+
+ def __init__(self, atn_cfg, data_cfg, conv_2d_block,
+ num_filters_each_layer):
+ super().__init__()
+ self.initial_few_shot_K = data_cfg.initial_few_shot_K
+ num_input_channels = data_cfg.num_input_channels
+ num_filters = getattr(atn_cfg, 'num_filters', 32)
+
+ self.num_downsample_atn = getattr(atn_cfg, 'num_downsamples', 2)
+ self.atn_query_first = conv_2d_block(num_input_channels, num_filters)
+ self.atn_key_first = conv_2d_block(num_input_channels, num_filters)
+ for i in range(self.num_downsamples_atn):
+ f_in, f_out = num_filters_each_layer[i], \
+ num_filters_each_layer[i + 1]
+ setattr(self, 'atn_key_%d' % i,
+ conv_2d_block(f_in, f_out, stride=2))
+ setattr(self, 'atn_query_%d' % i,
+ conv_2d_block(f_in, f_out, stride=2))
+
+ def forward(self, in_features, label, ref_label, attention=None):
+ r"""Get the attention map to combine multiple image features in the
+ case of multiple reference images.
+
+ Args:
+ in_features ((NxK)xC1xH1xW1 tensor): Input feaures.
+ label (NxC2xH2xW2 tensor): Target label.
+ ref_label (NxC2xH2xW2 tensor): Reference label.
+ attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps.
+ Returns:
+ (tuple):
+ - out_features (NxC1xH1xW1 tensor): Attention-combined features.
+ - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps.
+ - atn_vis (1x1xH1xW1 tensor): Visualization for attention scores.
+ """
+ b, c, h, w = in_features.size()
+ k = self.initial_few_shot_K
+ b = b // k
+
+ if attention is None:
+ # Compute the attention map by encoding ref_label and label as
+ # key and query. The map represents how much energy for the k-th
+ # map at location (h_i, w_j) can contribute to the final map at
+ # location (h_i2, w_j2).
+ atn_key = self.attention_encode(ref_label, 'atn_key')
+ atn_query = self.attention_encode(label, 'atn_query')
+
+ atn_key = atn_key.view(b, k, c, -1).permute(
+ 0, 1, 3, 2).contiguous().view(b, -1, c) # B X KHW X C
+ atn_query = atn_query.view(b, c, -1) # B X C X HW
+ energy = torch.bmm(atn_key, atn_query) # B X KHW X HW
+ attention = nn.Softmax(dim=1)(energy)
+
+ # Combine the K features from different ref images into one by using
+ # the attention map.
+ in_features = in_features.view(b, k, c, h * w).permute(
+ 0, 2, 1, 3).contiguous().view(b, c, -1) # B X C X KHW
+ out_features = torch.bmm(in_features, attention).view(b, c, h, w)
+
+ # Get a slice of the attention map for visualization.
+ atn_vis = attention.view(b, k, h * w, h * w).sum(2).view(b, k, h, w)
+ return out_features, attention, atn_vis[-1:, 0:1]
+
+ def attention_encode(self, img, net_name):
+ r"""Encode the input image to get the attention map.
+
+ Args:
+ img (NxCxHxW tensor): Input image.
+ net_name (str): Name for attention network.
+ Returns:
+ x (NxC2xH2xW2 tensor): Encoded feature.
+ """
+ x = getattr(self, net_name + '_first')(img)
+ for i in range(self.num_downsample_atn):
+ x = getattr(self, net_name + '_' + str(i))(x)
+ return x
+
+
+class FlowGenerator(nn.Module):
+ r"""flow generator constructor.
+
+ Args:
+ flow_cfg (obj): Flow definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ num_frames (int): Number of input frames.
+ """
+
+ def __init__(self, flow_cfg, data_cfg, num_frames):
+ super().__init__()
+ num_input_channels = data_cfg.num_input_channels
+ if num_input_channels == 0:
+ num_input_channels = 1
+ num_prev_img_channels = get_paired_input_image_channel_number(data_cfg)
+ num_downsamples = getattr(flow_cfg, 'num_downsamples', 3)
+ kernel_size = getattr(flow_cfg, 'kernel_size', 3)
+ padding = kernel_size // 2
+ num_blocks = getattr(flow_cfg, 'num_blocks', 6)
+ num_filters = getattr(flow_cfg, 'num_filters', 32)
+ max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024)
+ num_filters_each_layer = [min(max_num_filters, num_filters * (2 ** i))
+ for i in range(num_downsamples + 1)]
+
+ self.flow_output_multiplier = getattr(flow_cfg,
+ 'flow_output_multiplier', 20)
+ self.sep_up_mask = getattr(flow_cfg, 'sep_up_mask', False)
+ activation_norm_type = getattr(flow_cfg, 'activation_norm_type',
+ 'sync_batch')
+ weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral')
+
+ base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity='leakyrelu')
+
+ num_input_channels = num_input_channels * num_frames + \
+ num_prev_img_channels * (num_frames - 1)
+ # First layer.
+ down_flow = [base_conv_block(num_input_channels, num_filters)]
+
+ # Downsamples.
+ for i in range(num_downsamples):
+ down_flow += [base_conv_block(num_filters_each_layer[i],
+ num_filters_each_layer[i + 1],
+ stride=2)]
+
+ # Resnet blocks.
+ res_flow = []
+ ch = num_filters_each_layer[num_downsamples]
+ for i in range(num_blocks):
+ res_flow += [
+ Res2dBlock(ch, ch, kernel_size, padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ order='NACNAC')]
+
+ # Upsamples.
+ up_flow = []
+ for i in reversed(range(num_downsamples)):
+ up_flow += [nn.Upsample(scale_factor=2),
+ base_conv_block(num_filters_each_layer[i + 1],
+ num_filters_each_layer[i])]
+
+ conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)]
+ conv_mask = [Conv2dBlock(num_filters, 1, kernel_size, padding=padding,
+ nonlinearity='sigmoid')]
+
+ self.down_flow = nn.Sequential(*down_flow)
+ self.res_flow = nn.Sequential(*res_flow)
+ self.up_flow = nn.Sequential(*up_flow)
+ if self.sep_up_mask:
+ self.up_mask = nn.Sequential(*copy.deepcopy(up_flow))
+ self.conv_flow = nn.Sequential(*conv_flow)
+ self.conv_mask = nn.Sequential(*conv_mask)
+
+ def forward(self, label, ref_label, ref_image):
+ r"""Flow generator forward.
+
+ Args:
+ label (4D tensor) : Input label tensor.
+ ref_label (4D tensor) : Reference label tensors.
+ ref_image (4D tensor) : Reference image tensors.
+ Returns:
+ (tuple):
+ - flow (4D tensor) : Generated flow map.
+ - mask (4D tensor) : Generated occlusion mask.
+ """
+ label_concat = torch.cat([label, ref_label, ref_image], dim=1)
+ downsample = self.down_flow(label_concat)
+ res = self.res_flow(downsample)
+ flow_feat = self.up_flow(res)
+ flow = self.conv_flow(flow_feat) * self.flow_output_multiplier
+
+ mask_feat = self.up_mask(res) if self.sep_up_mask else flow_feat
+ mask = self.conv_mask(mask_feat)
+ return flow, mask
+
+
+class LabelEmbedder(nn.Module):
+ r"""Embed the input label map to get embedded features.
+
+ Args:
+ emb_cfg (obj): Embed network configuration.
+ num_input_channels (int): Number of input channels.
+ num_hyper_layers (int): Number of hyper layers.
+ """
+
+ def __init__(self, emb_cfg, num_input_channels, num_hyper_layers=0):
+ super().__init__()
+ num_filters = getattr(emb_cfg, 'num_filters', 32)
+ max_num_filters = getattr(emb_cfg, 'max_num_filters', 1024)
+ self.arch = getattr(emb_cfg, 'arch', 'encoderdecoder')
+ self.num_downsamples = num_downsamples = \
+ getattr(emb_cfg, 'num_downsamples', 5)
+ kernel_size = getattr(emb_cfg, 'kernel_size', 3)
+ weight_norm_type = getattr(emb_cfg, 'weight_norm_type', 'spectral')
+ activation_norm_type = getattr(emb_cfg, 'activation_norm_type', 'none')
+
+ self.unet = 'unet' in self.arch
+ self.has_decoder = 'decoder' in self.arch or self.unet
+ self.num_hyper_layers = num_hyper_layers \
+ if num_hyper_layers != -1 else num_downsamples
+
+ base_conv_block = partial(HyperConv2dBlock, kernel_size=kernel_size,
+ padding=(kernel_size // 2),
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity='leakyrelu')
+
+ ch = [min(max_num_filters, num_filters * (2 ** i))
+ for i in range(num_downsamples + 1)]
+
+ self.conv_first = base_conv_block(num_input_channels, num_filters,
+ activation_norm_type='none')
+
+ # Downsample.
+ for i in range(num_downsamples):
+ is_hyper_conv = (i < num_hyper_layers) and not self.has_decoder
+ setattr(self, 'down_%d' % i,
+ base_conv_block(ch[i], ch[i + 1], stride=2,
+ is_hyper_conv=is_hyper_conv))
+
+ # Upsample.
+ if self.has_decoder:
+ self.upsample = nn.Upsample(scale_factor=2)
+ for i in reversed(range(num_downsamples)):
+ ch_i = ch[i + 1] * (
+ 2 if self.unet and i != num_downsamples - 1 else 1)
+ setattr(self, 'up_%d' % i,
+ base_conv_block(ch_i, ch[i],
+ is_hyper_conv=(i < num_hyper_layers)))
+
+ def forward(self, input, weights=None):
+ r"""Embedding network forward.
+
+ Args:
+ input (NxCxHxW tensor): Network input.
+ weights (list of tensors): Conv weights if using hyper network.
+ Returns:
+ output (list of tensors): Network outputs at different layers.
+ """
+ if input is None:
+ return None
+ output = [self.conv_first(input)]
+
+ for i in range(self.num_downsamples):
+ layer = getattr(self, 'down_%d' % i)
+ # For hyper networks, the hyper layers are at the last few layers
+ # of decoder (if the network has a decoder). Otherwise, the hyper
+ # layers will be at the first few layers of the network.
+ if i >= self.num_hyper_layers or self.has_decoder:
+ conv = layer(output[-1])
+ else:
+ conv = layer(output[-1], conv_weights=weights[i])
+ # We will use outputs from different layers as input to different
+ # SPADE layers in the main branch.
+ output.append(conv)
+
+ if not self.has_decoder:
+ return output
+
+ # If the network has a decoder, will use outputs from the decoder
+ # layers instead of the encoding layers.
+ if not self.unet:
+ output = [output[-1]]
+
+ for i in reversed(range(self.num_downsamples)):
+ input_i = output[-1]
+ if self.unet and i != self.num_downsamples - 1:
+ input_i = torch.cat([input_i, output[i + 1]], dim=1)
+
+ input_i = self.upsample(input_i)
+ layer = getattr(self, 'up_%d' % i)
+ # The last few layers will be hyper layers if necessary.
+ if i >= self.num_hyper_layers:
+ conv = layer(input_i)
+ else:
+ conv = layer(input_i, conv_weights=weights[i])
+ output.append(conv)
+
+ if self.unet:
+ output = output[self.num_downsamples:]
+ return output[::-1]
diff --git a/imaginaire/generators/funit.py b/imaginaire/generators/funit.py
new file mode 100644
index 0000000000000000000000000000000000000000..6520166a4f906afe3c5cd2fda09a6fbc11502213
--- /dev/null
+++ b/imaginaire/generators/funit.py
@@ -0,0 +1,400 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from functools import partial
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+
+from imaginaire.layers import \
+ (Conv2dBlock, LinearBlock, Res2dBlock, UpRes2dBlock)
+
+
+class Generator(nn.Module):
+ r"""Generator of the improved FUNIT baseline in the COCO-FUNIT paper.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super().__init__()
+ self.generator = FUNITTranslator(**vars(gen_cfg))
+
+ def forward(self, data):
+ r"""In the FUNIT's forward pass, it generates a content embedding and
+ a style code from the content image, and a style code from the style
+ image. By mixing the content code and the style code from the content
+ image, we reconstruct the input image. By mixing the content code and
+ the style code from the style image, we have a translation output.
+
+ Args:
+ data (dict): Training data at the current iteration.
+ """
+ content_a = self.generator.content_encoder(data['images_content'])
+ style_a = self.generator.style_encoder(data['images_content'])
+ style_b = self.generator.style_encoder(data['images_style'])
+ images_trans = self.generator.decode(content_a, style_b)
+ images_recon = self.generator.decode(content_a, style_a)
+
+ net_G_output = dict(images_trans=images_trans,
+ images_recon=images_recon)
+ return net_G_output
+
+ def inference(self, data, keep_original_size=True):
+ r"""COCO-FUNIT inference.
+
+ Args:
+ data (dict): Training data at the current iteration.
+ - images_content (tensor): Content images.
+ - images_style (tensor): Style images.
+ a2b (bool): If ``True``, translates images from domain A to B,
+ otherwise from B to A.
+ keep_original_size (bool): If ``True``, output image is resized
+ to the input content image size.
+ """
+ content_a = self.generator.content_encoder(data['images_content'])
+ style_b = self.generator.style_encoder(data['images_style'])
+ output_images = self.generator.decode(content_a, style_b)
+ if keep_original_size:
+ height = data['original_h_w'][0][0]
+ width = data['original_h_w'][0][1]
+ # print('( H, W) = ( %d, %d)' % (height, width))
+ output_images = torch.nn.functional.interpolate(
+ output_images, size=[height, width])
+ file_names = data['key']['images_content'][0]
+ return output_images, file_names
+
+
+class FUNITTranslator(nn.Module):
+ r"""
+
+ Args:
+ num_filters (int): Base filter numbers.
+ num_filters_mlp (int): Base filter number in the MLP module.
+ style_dims (int): Dimension of the style code.
+ num_res_blocks (int): Number of residual blocks at the end of the
+ content encoder.
+ num_mlp_blocks (int): Number of layers in the MLP module.
+ num_downsamples_content (int): Number of times we reduce
+ resolution by 2x2 for the content image.
+ num_downsamples_style (int): Number of times we reduce
+ resolution by 2x2 for the style image.
+ num_image_channels (int): Number of input image channels.
+ weight_norm_type (str): Type of weight normalization.
+ ``'none'``, ``'spectral'``, or ``'weight'``.
+ """
+
+ def __init__(self,
+ num_filters=64,
+ num_filters_mlp=256,
+ style_dims=64,
+ num_res_blocks=2,
+ num_mlp_blocks=3,
+ num_downsamples_style=4,
+ num_downsamples_content=2,
+ num_image_channels=3,
+ weight_norm_type='',
+ **kwargs):
+ super().__init__()
+
+ self.style_encoder = StyleEncoder(num_downsamples_style,
+ num_image_channels,
+ num_filters,
+ style_dims,
+ 'reflect',
+ 'none',
+ weight_norm_type,
+ 'relu')
+
+ self.content_encoder = ContentEncoder(num_downsamples_content,
+ num_res_blocks,
+ num_image_channels,
+ num_filters,
+ 'reflect',
+ 'instance',
+ weight_norm_type,
+ 'relu')
+
+ self.decoder = Decoder(self.content_encoder.output_dim,
+ num_filters_mlp,
+ num_image_channels,
+ num_downsamples_content,
+ 'reflect',
+ weight_norm_type,
+ 'relu')
+
+ self.mlp = MLP(style_dims,
+ num_filters_mlp,
+ num_filters_mlp,
+ num_mlp_blocks,
+ 'none',
+ 'relu')
+
+ def forward(self, images):
+ r"""Reconstruct the input image by combining the computer content and
+ style code.
+
+ Args:
+ images (tensor): Input image tensor.
+ """
+ # reconstruct an image
+ content, style = self.encode(images)
+ images_recon = self.decode(content, style)
+ return images_recon
+
+ def encode(self, images):
+ r"""Encoder images to get their content and style codes.
+
+ Args:
+ images (tensor): Input image tensor.
+ """
+ style = self.style_encoder(images)
+ content = self.content_encoder(images)
+ return content, style
+
+ def decode(self, content, style):
+ r"""Generate images by combining their content and style codes.
+
+ Args:
+ content (tensor): Content code tensor.
+ style (tensor): Style code tensor.
+ """
+ style = self.mlp(style)
+ images = self.decoder(content, style)
+ return images
+
+
+class Decoder(nn.Module):
+ r"""Improved FUNIT decoder.
+
+ Args:
+ num_enc_output_channels (int): Number of content feature channels.
+ style_channels (int): Dimension of the style code.
+ num_image_channels (int): Number of image channels.
+ num_upsamples (int): How many times we are going to apply
+ upsample residual block.
+ """
+
+ def __init__(self,
+ num_enc_output_channels,
+ style_channels,
+ num_image_channels=3,
+ num_upsamples=4,
+ padding_type='reflect',
+ weight_norm_type='none',
+ nonlinearity='relu'):
+ super(Decoder, self).__init__()
+ adain_params = SimpleNamespace(
+ activation_norm_type='instance',
+ activation_norm_params=SimpleNamespace(affine=False),
+ cond_dims=style_channels)
+
+ base_res_block = partial(Res2dBlock,
+ kernel_size=3,
+ padding=1,
+ padding_mode=padding_type,
+ nonlinearity=nonlinearity,
+ activation_norm_type='adaptive',
+ activation_norm_params=adain_params,
+ weight_norm_type=weight_norm_type,
+ learn_shortcut=False)
+
+ base_up_res_block = partial(UpRes2dBlock,
+ kernel_size=5,
+ padding=2,
+ padding_mode=padding_type,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type='adaptive',
+ activation_norm_params=adain_params,
+ skip_activation_norm='instance',
+ skip_nonlinearity=nonlinearity,
+ nonlinearity=nonlinearity,
+ hidden_channels_equal_out_channels=True,
+ learn_shortcut=True)
+
+ dims = num_enc_output_channels
+
+ # Residual blocks with AdaIN.
+ self.decoder = nn.ModuleList()
+ self.decoder += [base_res_block(dims, dims)]
+ self.decoder += [base_res_block(dims, dims)]
+ for _ in range(num_upsamples):
+ self.decoder += [base_up_res_block(dims, dims // 2)]
+ dims = dims // 2
+ self.decoder += [Conv2dBlock(dims,
+ num_image_channels,
+ kernel_size=7,
+ stride=1,
+ padding=3,
+ padding_mode='reflect',
+ nonlinearity='tanh')]
+
+ def forward(self, x, style):
+ r"""
+
+ Args:
+ x (tensor): Content embedding of the content image.
+ style (tensor): Style embedding of the style image.
+ """
+ for block in self.decoder:
+ if getattr(block, 'conditional', False):
+ x = block(x, style)
+ else:
+ x = block(x)
+ return x
+
+
+class StyleEncoder(nn.Module):
+ r"""Improved FUNIT Style Encoder. This is basically the same as the
+ original FUNIT Style Encoder.
+
+ Args:
+ num_downsamples (int): Number of times we reduce resolution by
+ 2x2.
+ image_channels (int): Number of input image channels.
+ num_filters (int): Base filter number.
+ style_channels (int): Style code dimension.
+ padding_mode (str): Padding mode.
+ activation_norm_type (str): Type of activation normalization.
+ weight_norm_type (str): Type of weight normalization.
+ ``'none'``, ``'spectral'``, or ``'weight'``.
+ nonlinearity (str): Nonlinearity.
+ """
+
+ def __init__(self,
+ num_downsamples,
+ image_channels,
+ num_filters,
+ style_channels,
+ padding_mode,
+ activation_norm_type,
+ weight_norm_type,
+ nonlinearity):
+ super().__init__()
+ conv_params = dict(padding_mode=padding_mode,
+ activation_norm_type=activation_norm_type,
+ weight_norm_type=weight_norm_type,
+ nonlinearity=nonlinearity,
+ inplace_nonlinearity=True)
+ model = []
+ model += [Conv2dBlock(image_channels, num_filters, 7, 1, 3,
+ **conv_params)]
+ for i in range(2):
+ model += [Conv2dBlock(num_filters, 2 * num_filters, 4, 2, 1,
+ **conv_params)]
+ num_filters *= 2
+ for i in range(num_downsamples - 2):
+ model += [Conv2dBlock(num_filters, num_filters, 4, 2, 1,
+ **conv_params)]
+ model += [nn.AdaptiveAvgPool2d(1)]
+ model += [nn.Conv2d(num_filters, style_channels, 1, 1, 0)]
+ self.model = nn.Sequential(*model)
+ self.output_dim = num_filters
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input image.
+ """
+ return self.model(x)
+
+
+class ContentEncoder(nn.Module):
+ r"""Improved FUNIT Content Encoder. This is basically the same as the
+ original FUNIT content encoder.
+
+ Args:
+ num_downsamples (int): Number of times we reduce resolution by
+ 2x2.
+ num_res_blocks (int): Number of times we append residual block
+ after all the downsampling modules.
+ image_channels (int): Number of input image channels.
+ num_filters (int): Base filter number.
+ padding_mode (str): Padding mode
+ activation_norm_type (str): Type of activation normalization.
+ weight_norm_type (str): Type of weight normalization.
+ ``'none'``, ``'spectral'``, or ``'weight'``.
+ nonlinearity (str): Nonlinearity.
+ """
+
+ def __init__(self,
+ num_downsamples,
+ num_res_blocks,
+ image_channels,
+ num_filters,
+ padding_mode,
+ activation_norm_type,
+ weight_norm_type,
+ nonlinearity):
+ super().__init__()
+ conv_params = dict(padding_mode=padding_mode,
+ activation_norm_type=activation_norm_type,
+ weight_norm_type=weight_norm_type,
+ nonlinearity=nonlinearity,
+ inplace_nonlinearity=True,
+ order='CNACNA')
+ model = []
+ model += [Conv2dBlock(image_channels, num_filters, 7, 1, 3,
+ **conv_params)]
+ dims = num_filters
+ for i in range(num_downsamples):
+ model += [Conv2dBlock(dims, dims * 2, 4, 2, 1, **conv_params)]
+ dims *= 2
+
+ for _ in range(num_res_blocks):
+ model += [Res2dBlock(dims, dims, learn_shortcut=False, **conv_params)]
+ self.model = nn.Sequential(*model)
+ self.output_dim = dims
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input image.
+ """
+ return self.model(x)
+
+
+class MLP(nn.Module):
+ r"""Improved FUNIT style decoder.
+
+ Args:
+ input_dim (int): Input dimension (style code dimension).
+ output_dim (int): Output dimension (to be fed into the AdaIN
+ layer).
+ latent_dim (int): Latent dimension.
+ num_layers (int): Number of layers in the MLP.
+ activation_norm_type (str): Activation type.
+ nonlinearity (str): Nonlinearity type.
+ """
+
+ def __init__(self,
+ input_dim,
+ output_dim,
+ latent_dim,
+ num_layers,
+ activation_norm_type,
+ nonlinearity):
+ super().__init__()
+ model = []
+ model += [LinearBlock(input_dim, latent_dim,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity)]
+ # changed from num_layers - 2 to num_layers - 3.
+ for i in range(num_layers - 3):
+ model += [LinearBlock(latent_dim, latent_dim,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity)]
+ model += [LinearBlock(latent_dim, output_dim,
+ activation_norm_type=activation_norm_type,
+ nonlinearity=nonlinearity)]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ """
+ return self.model(x.view(x.size(0), -1))
diff --git a/imaginaire/generators/gancraft.py b/imaginaire/generators/gancraft.py
new file mode 100644
index 0000000000000000000000000000000000000000..94fc34bee88e31fcdcf48f715f4d17f3de5bc37b
--- /dev/null
+++ b/imaginaire/generators/gancraft.py
@@ -0,0 +1,538 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+
+import cv2
+import imageio
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import imaginaire.model_utils.gancraft.camctl as camctl
+import imaginaire.model_utils.gancraft.mc_utils as mc_utils
+import imaginaire.model_utils.gancraft.voxlib as voxlib
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.generators.gancraft_base import Base3DGenerator, RenderMLP # noqa
+
+
+class Generator(Base3DGenerator):
+ r"""GANcraft generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super(Generator, self).__init__(gen_cfg, data_cfg)
+ print('GANcraft generator initialization.')
+
+ # Load voxels of the input world.
+ # The loaded voxel tensor has a shape of [X, Y, Z], dtype==torch.int32
+ # 0 means empty (air).
+ print('[Generator] Loading voxel world: ', gen_cfg.voxel_path)
+ if gen_cfg.voxel_path.endswith('.npy'):
+ voxel_t = np.load(gen_cfg.voxel_path)
+ voxel_t = torch.from_numpy(voxel_t.astype(np.int32))
+ else:
+ voxel_t = mc_utils.load_voxel_new(gen_cfg.voxel_path, shape=gen_cfg.voxel_shape)
+ print('[Generator] Loaded voxel world.')
+ self.voxel = mc_utils.McVoxel(voxel_t, preproc_ver=gen_cfg.voxel_preproc_ver)
+ blk_feats = torch.empty([self.voxel.nfilledvox, gen_cfg.blk_feat_dim], requires_grad=True)
+ self.blk_feats = nn.Parameter(blk_feats) # Feature per voxel corner.
+
+ # Minecraft -> SPADE label translator.
+ self.label_trans = mc_utils.MCLabelTranslator()
+ self.num_reduced_labels = self.label_trans.get_num_reduced_lbls()
+ self.reduced_label_set = getattr(gen_cfg, 'reduced_label_set', False)
+ self.use_label_smooth = getattr(gen_cfg, 'use_label_smooth', False)
+ self.use_label_smooth_real = getattr(gen_cfg, 'use_label_smooth_real', self.use_label_smooth)
+ self.use_label_smooth_pgt = getattr(gen_cfg, 'use_label_smooth_pgt', False)
+ self.label_smooth_dia = getattr(gen_cfg, 'label_smooth_dia', 11)
+
+ # Load MLP model.
+ self.render_net = globals()[gen_cfg.mlp_model](
+ self.input_dim, viewdir_dim=self.input_dim_viewdir, style_dim=self.interm_style_dims,
+ mask_dim=self.num_reduced_labels, out_channels_s=1, out_channels_c=self.final_feat_dim,
+ **self.mlp_model_kwargs)
+
+ # Camera sampler.
+ self.camera_sampler_type = getattr(gen_cfg, 'camera_sampler_type', "random")
+ assert self.camera_sampler_type in ['random', 'traditional']
+ self.camera_min_entropy = getattr(gen_cfg, 'camera_min_entropy', -1)
+ self.camera_rej_avg_depth = getattr(gen_cfg, 'camera_rej_avg_depth', -1)
+ self.cam_res = gen_cfg.cam_res
+ self.crop_size = gen_cfg.crop_size
+
+ print('Done with the GANcraft generator initialization.')
+
+ def custom_init(self):
+ r"""Weight initialization of GANcraft components."""
+
+ self.blk_feats.data.uniform_(-1, 1)
+
+ def init_func(m):
+ if hasattr(m, 'weight'):
+ nn.init.kaiming_normal_(m.weight.data, a=0.2, nonlinearity='leaky_relu')
+ m.weight.data *= 0.5
+ if hasattr(m, 'bias') and m.bias is not None:
+ m.bias.data.fill_(0.0)
+ self.apply(init_func)
+
+ def _get_batch(self, batch_size, device):
+ r"""Sample camera poses and perform ray-voxel intersection.
+
+ Args:
+ batch_size (int): Expected batch size of the current batch
+ device (torch.device): Device on which the tensors should be stored
+ """
+ with torch.no_grad():
+ voxel_id_batch = []
+ depth2_batch = []
+ raydirs_batch = []
+ cam_ori_t_batch = []
+ for b in range(batch_size):
+ while True: # Rejection sampling.
+ # Sample camera pose.
+ if self.camera_sampler_type == 'random':
+ cam_res = self.cam_res
+ cam_ori_t, cam_dir_t, cam_up_t = camctl.rand_camera_pose_thridperson2(self.voxel)
+ # ~24mm fov horizontal.
+ cam_f = 0.5/np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1)
+ cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2]
+ cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad]
+ cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop)
+ elif self.camera_sampler_type == 'traditional':
+ cam_res = self.cam_res
+ cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2]
+ dice = torch.rand(1).item()
+ if dice > 0.5:
+ cam_ori_t, cam_dir_t, cam_up_t, cam_f = \
+ camctl.rand_camera_pose_tour(self.voxel)
+ cam_f = cam_f * (cam_res[1]-1)
+ else:
+ cam_ori_t, cam_dir_t, cam_up_t = \
+ camctl.rand_camera_pose_thridperson2(self.voxel)
+ # ~24mm fov horizontal.
+ cam_f = 0.5 / np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1)
+
+ cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad]
+ cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop)
+ else:
+ raise NotImplementedError(
+ 'Unknown self.camera_sampler_type: {}'.format(self.camera_sampler_type))
+ # Run ray-voxel intersection test
+ r"""Ray-voxel intersection CUDA kernel.
+ Note: voxel_id = 0 and depth2 = NaN if there is no intersection along the ray
+
+ Args:
+ voxel_t (Y x 512 x 512 tensor, int32): Full 3D voxel of MC block IDs.
+ cam_ori_t (3 tensor): Camera origin.
+ cam_dir_t (3 tensor): Camera direction.
+ cam_up_t (3 tensor): Camera up vector.
+ cam_f (float): Camera focal length (in pixels).
+ cam_c (list of 2 floats [x, y]): Camera optical center.
+ img_dims (list of 2 ints [H, W]): Camera resolution.
+ max_samples (int): Maximum number of blocks intersected along the ray before stopping.
+ Returns:
+ voxel_id ( img_dims[0] x img_dims[1] x max_samples x 1 tensor): IDs of intersected tensors
+ along each ray
+ depth2 (2 x img_dims[0] x img_dims[1] x max_samples x 1 tensor): Depths of entrance and exit
+ points for each ray-voxel intersection.
+ raydirs ( img_dims[0] x img_dims[1] x 1 x 3 tensor): The direction of each ray.
+
+ """
+ voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective(
+ self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, cam_res_crop,
+ self.num_blocks_early_stop)
+
+ if self.camera_rej_avg_depth > 0:
+ depth_map = depth2[0, :, :, 0, :]
+ avg_depth = torch.mean(depth_map[~torch.isnan(depth_map)])
+ if avg_depth < self.camera_rej_avg_depth:
+ continue
+
+ # Reject low entropy.
+ if self.camera_min_entropy > 0:
+ # Check entropy.
+ maskcnt = torch.bincount(
+ torch.flatten(voxel_id[:, :, 0, 0]), weights=None, minlength=680).float() / \
+ (voxel_id.size(0)*voxel_id.size(1))
+ maskentropy = -torch.sum(maskcnt * torch.log(maskcnt+1e-10))
+ if maskentropy < self.camera_min_entropy:
+ continue
+ break
+
+ voxel_id_batch.append(voxel_id)
+ depth2_batch.append(depth2)
+ raydirs_batch.append(raydirs)
+ cam_ori_t_batch.append(cam_ori_t)
+ voxel_id = torch.stack(voxel_id_batch, dim=0)
+ depth2 = torch.stack(depth2_batch, dim=0)
+ raydirs = torch.stack(raydirs_batch, dim=0)
+ cam_ori_t = torch.stack(cam_ori_t_batch, dim=0).to(device)
+ cam_poses = None
+ return voxel_id, depth2, raydirs, cam_ori_t, cam_poses
+
+ def get_pseudo_gt(self, pseudo_gen, voxel_id, z=None, style_img=None, resize_512=True, deterministic=False):
+ r"""Evaluating img2img network to obtain pseudo-ground truth images.
+
+ Args:
+ pseudo_gen (callable): Function converting mask to image using img2img network.
+ voxel_id (N x img_dims[0] x img_dims[1] x max_samples x 1 tensor): IDs of intersected tensors along
+ each ray.
+ z (N x C tensor): Optional style code passed to pseudo_gen.
+ style_img (N x 3 x H x W tensor): Optional style image passed to pseudo_gen.
+ resize_512 (bool): If True, evaluate pseudo_gen at 512x512 regardless of input resolution.
+ deterministic (bool): If True, disable stochastic label mapping.
+ """
+ with torch.no_grad():
+ mc_mask = voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long()
+ coco_mask = self.label_trans.mc2coco(mc_mask) - 1
+ coco_mask[coco_mask < 0] = 183
+
+ if not deterministic:
+ # Stochastic mapping
+ dice = torch.rand(1).item()
+ if dice > 0.5 and dice < 0.9:
+ coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('clouds')
+ elif dice >= 0.9:
+ coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('fog')
+ dice = torch.rand(1).item()
+ if dice > 0.33 and dice < 0.66:
+ coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('sea')
+ elif dice >= 0.66:
+ coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('river')
+
+ fake_masks = torch.zeros([coco_mask.size(0), 185, coco_mask.size(2), coco_mask.size(3)],
+ dtype=torch.half, device=voxel_id.device)
+ fake_masks.scatter_(1, coco_mask, 1.0)
+
+ if self.use_label_smooth_pgt:
+ fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia)
+ if self.pad > 0:
+ fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
+
+ # Generate pseudo GT using GauGAN.
+ if resize_512:
+ fake_masks_512 = F.interpolate(fake_masks, size=[512, 512], mode='nearest')
+ else:
+ fake_masks_512 = fake_masks
+ pseudo_real_img = pseudo_gen(fake_masks_512, z=z, style_img=style_img)
+
+ # NaN Inf Guard. NaN can occure on Volta GPUs.
+ nan_mask = torch.isnan(pseudo_real_img)
+ inf_mask = torch.isinf(pseudo_real_img)
+ pseudo_real_img[nan_mask | inf_mask] = 0.0
+ if resize_512:
+ pseudo_real_img = F.interpolate(
+ pseudo_real_img, size=[fake_masks.size(2), fake_masks.size(3)], mode='area')
+ pseudo_real_img = torch.clamp(pseudo_real_img, -1, 1)
+
+ return pseudo_real_img, fake_masks
+
+ def sample_camera(self, data, pseudo_gen):
+ r"""Sample camera randomly and precompute everything used by both Gen and Dis.
+
+ Args:
+ data (dict):
+ images (N x 3 x H x W tensor) : Real images
+ label (N x C2 x H x W tensor) : Segmentation map
+ pseudo_gen (callable): Function converting mask to image using img2img network.
+ Returns:
+ ret (dict):
+ voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray.
+ depth2 (N x 2 x H x W x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel
+ intersection.
+ raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
+ cam_ori_t (N x 3 tensor): Camera origins.
+ pseudo_real_img (N x 3 x H x W tensor): Pseudo-ground truth image.
+ real_masks (N x C3 x H x W tensor): One-hot segmentation map for real images, with translated labels.
+ fake_masks (N x C3 x H x W tensor): One-hot segmentation map for sampled camera views.
+ """
+ device = torch.device('cuda')
+ batch_size = data['images'].size(0)
+ # ================ Assemble a batch ==================
+ # Requires: voxel_id, depth2, raydirs, cam_ori_t.
+ voxel_id, depth2, raydirs, cam_ori_t, _ = self._get_batch(batch_size, device)
+ ret = {'voxel_id': voxel_id, 'depth2': depth2, 'raydirs': raydirs, 'cam_ori_t': cam_ori_t}
+
+ if pseudo_gen is not None:
+ pseudo_real_img, _ = self.get_pseudo_gt(pseudo_gen, voxel_id)
+ ret['pseudo_real_img'] = pseudo_real_img.float()
+
+ # =============== Mask translation ================
+ real_masks = data['label']
+ if self.reduced_label_set:
+ # Translate fake mask (directly from mcid).
+ # convert unrecognized labels to 'dirt'.
+ # N C H W [1 1 80 80]
+ reduce_fake_mask = self.label_trans.mc2reduced(
+ voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long(), ign2dirt=True)
+ reduce_fake_mask_onehot = torch.zeros([
+ reduce_fake_mask.size(0), self.num_reduced_labels, reduce_fake_mask.size(2), reduce_fake_mask.size(3)],
+ dtype=torch.float, device=device)
+ reduce_fake_mask_onehot.scatter_(1, reduce_fake_mask, 1.0)
+ fake_masks = reduce_fake_mask_onehot
+ if self.pad != 0:
+ fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
+
+ # Translate real mask (data['label']), which is onehot.
+ real_masks_idx = torch.argmax(real_masks, dim=1, keepdim=True)
+ real_masks_idx[real_masks_idx > 182] = 182
+
+ reduced_real_mask = self.label_trans.coco2reduced(real_masks_idx)
+ reduced_real_mask_onehot = torch.zeros([
+ reduced_real_mask.size(0), self.num_reduced_labels, reduced_real_mask.size(2),
+ reduced_real_mask.size(3)], dtype=torch.float, device=device)
+ reduced_real_mask_onehot.scatter_(1, reduced_real_mask, 1.0)
+ real_masks = reduced_real_mask_onehot
+
+ # Mask smoothing.
+ if self.use_label_smooth:
+ fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia)
+ if self.use_label_smooth_real:
+ real_masks = mc_utils.segmask_smooth(real_masks, kernel_size=self.label_smooth_dia)
+
+ ret['real_masks'] = real_masks
+ ret['fake_masks'] = fake_masks
+
+ return ret
+
+ def forward(self, data, random_style=False):
+ r"""GANcraft Generator forward.
+
+ Args:
+ data (dict):
+ images (N x 3 x H x W tensor) : Real images
+ voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray.
+ depth2 (N x 2 x H x W x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel
+ intersection.
+ raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
+ cam_ori_t (N x 3 tensor): Camera origins.
+ random_style (bool): Whether to sample a random style vector.
+ Returns:
+ output (dict):
+ fake_images (N x 3 x H x W tensor): fake images
+ mu (N x C1 tensor): mean vectors
+ logvar (N x C1 tensor): log-variance vectors
+ """
+ device = torch.device('cuda')
+ batch_size = data['images'].size(0)
+
+ # ================ Assemble a batch ==================
+ # Requires: voxel_id, depth2, raydirs, cam_ori_t.
+ voxel_id, depth2, raydirs, cam_ori_t = data['voxel_id'], data['depth2'], data['raydirs'], data['cam_ori_t']
+ if 'pseudo_real_img' in data:
+ pseudo_real_img = data['pseudo_real_img']
+
+ z, mu, logvar = None, None, None
+ if random_style:
+ if self.style_dims > 0:
+ z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device)
+ else:
+ if self.style_encoder is None:
+ # ================ Get Style Code =================
+ if self.style_dims > 0:
+ z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device)
+ else:
+ mu, logvar, z = self.style_encoder(pseudo_real_img)
+
+ # ================ Network Forward ================
+ # Forward StyleNet
+ if self.style_net is not None:
+ z = self.style_net(z)
+
+ # Forward per-pixel net.
+ net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, nosky_mask, \
+ sky_mask, sky_only_mask, new_idx = self._forward_perpix(
+ self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z)
+
+ # Forward global net.
+ fake_images, fake_images_raw = self._forward_global(net_out, z)
+ if self.pad != 0:
+ fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
+
+ # =============== Arrange Return Values ================
+ output = {}
+ output['fake_images'] = fake_images
+ output['mu'] = mu
+ output['logvar'] = logvar
+ return output
+
+ def inference(self,
+ output_dir,
+ camera_mode,
+ style_img_path=None,
+ seed=1,
+ pad=30,
+ num_samples=40,
+ num_blocks_early_stop=6,
+ sample_depth=3,
+ tile_size=128,
+ resolution_hw=[540, 960],
+ cam_ang=72,
+ cam_maxstep=10):
+ r"""Compute result images according to the provided camera trajectory and save the results in the specified
+ folder. The full image is evaluated in multiple tiles to save memory.
+
+ Args:
+ output_dir (str): Where should the results be stored.
+ camera_mode (int): Which camera trajectory to use.
+ style_img_path (str): Path to the style-conditioning image.
+ seed (int): Random seed (controls style when style_image_path is not specified).
+ pad (int): Pixels to remove from the image tiles before stitching. Should be equal or larger than the
+ receptive field of the CNN to avoid border artifact.
+ num_samples (int): Number of samples per ray (different from training).
+ num_blocks_early_stop (int): Max number of intersected boxes per ray before stopping
+ (different from training).
+ sample_depth (float): Max distance traveled through boxes before stopping (different from training).
+ tile_size (int): Max size of a tile in pixels.
+ resolution_hw (list [H, W]): Resolution of the output image.
+ cam_ang (float): Horizontal FOV of the camera (may be adjusted by the camera controller).
+ cam_maxstep (int): Number of frames sampled from the camera trajectory.
+ """
+
+ def write_img(path, img, rgb_input=False):
+ img = ((img*0.5+0.5)*255).detach().cpu().numpy().astype(np.uint8)
+ img = img[0].transpose(1, 2, 0)
+ if rgb_input:
+ img = img[..., [2, 1, 0]]
+ cv2.imwrite(path, img, [cv2.IMWRITE_PNG_COMPRESSION, 4])
+ return img[..., ::-1]
+
+ def read_img(path):
+ img = cv2.imread(path).astype(np.float32)[..., [2, 1, 0]].transpose(2, 0, 1) / 255
+ img = img * 2 - 1
+ img = torch.from_numpy(img)
+
+ print('Saving to', output_dir)
+
+ # Use provided random seed.
+ device = torch.device('cuda')
+ rng_cuda = torch.Generator(device=device)
+ rng_cuda = rng_cuda.manual_seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+ self.pad = pad
+ self.num_samples = num_samples
+ self.num_blocks_early_stop = num_blocks_early_stop
+ self.sample_depth = sample_depth
+
+ self.coarse_deterministic_sampling = True
+ self.crop_size = resolution_hw
+ self.cam_res = [self.crop_size[0]+self.pad, self.crop_size[1]+self.pad]
+ self.use_label_smooth_pgt = False
+
+ # Make output dirs.
+ gancraft_outputs_dir = os.path.join(output_dir, 'gancraft_outputs')
+ os.makedirs(gancraft_outputs_dir, exist_ok=True)
+ vis_masks_dir = os.path.join(output_dir, 'vis_masks')
+ os.makedirs(vis_masks_dir, exist_ok=True)
+ fout = imageio.get_writer(gancraft_outputs_dir + '.mp4', fps=1)
+ fout_cat = imageio.get_writer(gancraft_outputs_dir + '-vis_masks.mp4', fps=1)
+
+ evalcamctl = camctl.EvalCameraController(
+ self.voxel, maxstep=cam_maxstep, pattern=camera_mode, cam_ang=cam_ang,
+ smooth_decay_multiplier=150/cam_maxstep)
+
+ # Get output style.
+ if style_img_path is None:
+ z = torch.empty(1, self.style_dims, dtype=torch.float32, device=device)
+ z.normal_(generator=rng_cuda)
+ else:
+ style_img = read_img(style_img_path)
+ style_img = style_img.to(device).unsqueeze(0)
+ mu, logvar, z = self.style_encoder(style_img)
+ z = self.style_net(z)
+
+ # Generate required output images.
+ for id, (cam_ori_t, cam_dir_t, cam_up_t, cam_f) in enumerate(evalcamctl):
+ print('Rendering frame', id)
+ cam_f = cam_f * (self.crop_size[1]-1) # So that the view is not depending on the padding
+ cam_c = [(self.cam_res[0]-1)/2, (self.cam_res[1]-1)/2]
+
+ voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective(
+ self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, self.cam_res,
+ self.num_blocks_early_stop)
+
+ voxel_id = voxel_id.unsqueeze(0)
+ depth2 = depth2.unsqueeze(0)
+ raydirs = raydirs.unsqueeze(0)
+ cam_ori_t = cam_ori_t.unsqueeze(0).to(device)
+
+ # Save 3D voxel rendering.
+ mc_rgb = self.label_trans.mc_color(voxel_id[0, :, :, 0, 0].cpu().numpy())
+ # Diffused shading, co-located light.
+ first_intersection_depth = depth2[:, 0, :, :, 0, None, :] # [1, 542, 542, 1, 1].
+ first_intersection_point = raydirs * first_intersection_depth + cam_ori_t[:, None, None, None, :]
+ fip_local_coords = torch.remainder(first_intersection_point, 1.0)
+ fip_wall_proximity = torch.minimum(fip_local_coords, 1.0-fip_local_coords)
+ fip_wall_orientation = torch.argmin(fip_wall_proximity, dim=-1, keepdim=False)
+ # 0: [1,0,0]; 1: [0,1,0]; 2: [0,0,1]
+ lut = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32,
+ device=fip_wall_orientation.device)
+ fip_normal = lut[fip_wall_orientation] # [1, 542, 542, 1, 3]
+ diffuse_shade = torch.abs(torch.sum(fip_normal * raydirs, dim=-1))
+
+ mc_rgb = (mc_rgb.astype(np.float) / 255) ** 2.2
+ mc_rgb = mc_rgb * diffuse_shade[0, :, :, :].cpu().numpy()
+ mc_rgb = (mc_rgb ** (1/2.2)) * 255
+ mc_rgb = mc_rgb.astype(np.uint8)
+ if self.pad > 0:
+ mc_rgb = mc_rgb[self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
+ cv2.imwrite(os.path.join(vis_masks_dir, '{:05d}.png'.format(id)), mc_rgb, [cv2.IMWRITE_PNG_COMPRESSION, 4])
+
+ # Tiled eval of GANcraft.
+ voxel_id_all = voxel_id
+ depth2_all = depth2
+ raydirs_all = raydirs
+
+ # Evaluate sky in advance to get a consistent sky in the semi-transparent region.
+ if self.sky_global_avgpool:
+ sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
+ sky_raydirs_in = voxlib.positional_encoding(
+ sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
+ skynet_out_c = self.sky_net(sky_raydirs_in, z)
+ sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
+ self.sky_avg = sky_avg
+
+ num_strips_h = (self.cam_res[0]-self.pad+tile_size-1)//tile_size
+ num_strips_w = (self.cam_res[1]-self.pad+tile_size-1)//tile_size
+
+ fake_images_chunks_v = []
+ # For each horizontal strip.
+ for strip_id_h in range(num_strips_h):
+ strip_begin_h = strip_id_h * tile_size
+ strip_end_h = np.minimum(strip_id_h * tile_size + tile_size + self.pad, self.cam_res[0])
+ # For each vertical strip.
+ fake_images_chunks_h = []
+ for strip_id_w in range(num_strips_w):
+ strip_begin_w = strip_id_w * tile_size
+ strip_end_w = np.minimum(strip_id_w * tile_size + tile_size + self.pad, self.cam_res[1])
+
+ voxel_id = voxel_id_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
+ depth2 = depth2_all[:, :, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
+ raydirs = raydirs_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :]
+
+ net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
+ nosky_mask, sky_mask, sky_only_mask, new_idx = self._forward_perpix(
+ self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z)
+ fake_images, _ = self._forward_global(net_out, z)
+
+ if self.pad != 0:
+ fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2]
+ fake_images_chunks_h.append(fake_images)
+ fake_images_h = torch.cat(fake_images_chunks_h, dim=-1)
+ fake_images_chunks_v.append(fake_images_h)
+ fake_images = torch.cat(fake_images_chunks_v, dim=-2)
+ rgb = write_img(os.path.join(gancraft_outputs_dir,
+ '{:05d}.png'.format(id)), fake_images, rgb_input=True)
+ fout.append_data(rgb)
+ fout_cat.append_data(np.concatenate((mc_rgb[..., ::-1], rgb), axis=1))
+ fout.close()
+ fout_cat.close()
diff --git a/imaginaire/generators/gancraft_base.py b/imaginaire/generators/gancraft_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef164b67053af5d228bda6c9aea75c261d7b114f
--- /dev/null
+++ b/imaginaire/generators/gancraft_base.py
@@ -0,0 +1,603 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+import re
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.layers import Conv2dBlock, LinearBlock
+from imaginaire.model_utils.gancraft.layers import AffineMod, ModLinear
+import imaginaire.model_utils.gancraft.mc_utils as mc_utils
+import imaginaire.model_utils.gancraft.voxlib as voxlib
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class RenderMLP(nn.Module):
+ r""" MLP with affine modulation."""
+
+ def __init__(self, in_channels, style_dim, viewdir_dim, mask_dim=680,
+ out_channels_s=1, out_channels_c=3, hidden_channels=256,
+ use_seg=True):
+ super(RenderMLP, self).__init__()
+
+ self.use_seg = use_seg
+ if self.use_seg:
+ self.fc_m_a = nn.Linear(mask_dim, hidden_channels, bias=False)
+
+ self.fc_viewdir = None
+ if viewdir_dim > 0:
+ self.fc_viewdir = nn.Linear(viewdir_dim, hidden_channels, bias=False)
+
+ self.fc_1 = nn.Linear(in_channels, hidden_channels)
+
+ self.fc_2 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
+ self.fc_3 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
+ self.fc_4 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
+
+ self.fc_sigma = nn.Linear(hidden_channels, out_channels_s)
+
+ if viewdir_dim > 0:
+ self.fc_5 = nn.Linear(hidden_channels, hidden_channels, bias=False)
+ self.mod_5 = AffineMod(hidden_channels, style_dim, mod_bias=True)
+ else:
+ self.fc_5 = ModLinear(hidden_channels, hidden_channels, style_dim,
+ bias=False, mod_bias=True, output_mode=True)
+ self.fc_6 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True)
+ self.fc_out_c = nn.Linear(hidden_channels, out_channels_c)
+
+ self.act = nn.LeakyReLU(negative_slope=0.2)
+
+ def forward(self, x, raydir, z, m):
+ r""" Forward network
+
+ Args:
+ x (N x H x W x M x in_channels tensor): Projected features.
+ raydir (N x H x W x 1 x viewdir_dim tensor): Ray directions.
+ z (N x style_dim tensor): Style codes.
+ m (N x H x W x M x mask_dim tensor): One-hot segmentation maps.
+ """
+ b, h, w, n, _ = x.size()
+ z = z[:, None, None, None, :]
+
+ f = self.fc_1(x)
+ if self.use_seg:
+ f = f + self.fc_m_a(m)
+ # Common MLP
+ f = self.act(f)
+ f = self.act(self.fc_2(f, z))
+ f = self.act(self.fc_3(f, z))
+ f = self.act(self.fc_4(f, z))
+
+ # Sigma MLP
+ sigma = self.fc_sigma(f)
+
+ # Color MLP
+ if self.fc_viewdir is not None:
+ f = self.fc_5(f)
+ f = f + self.fc_viewdir(raydir)
+ f = self.act(self.mod_5(f, z))
+ else:
+ f = self.act(self.fc_5(f, z))
+ f = self.act(self.fc_6(f, z))
+ c = self.fc_out_c(f)
+ return sigma, c
+
+
+class StyleMLP(nn.Module):
+ r"""MLP converting style code to intermediate style representation."""
+
+ def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True,
+ output_act=True):
+ super(StyleMLP, self).__init__()
+
+ self.normalize_input = normalize_input
+ self.output_act = output_act
+ fc_layers = []
+ fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True))
+ for i in range(num_layers-1):
+ fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True))
+ self.fc_layers = nn.ModuleList(fc_layers)
+
+ self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True)
+
+ if leaky_relu:
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ self.act = functools.partial(F.relu, inplace=True)
+
+ def forward(self, z):
+ r""" Forward network
+
+ Args:
+ z (N x style_dim tensor): Style codes.
+ """
+ if self.normalize_input:
+ z = F.normalize(z, p=2, dim=-1)
+ for fc_layer in self.fc_layers:
+ z = self.act(fc_layer(z))
+ z = self.fc_out(z)
+ if self.output_act:
+ z = self.act(z)
+ return z
+
+
+class SKYMLP(nn.Module):
+ r"""MLP converting ray directions to sky features."""
+
+ def __init__(self, in_channels, style_dim, out_channels_c=3,
+ hidden_channels=256, leaky_relu=True):
+ super(SKYMLP, self).__init__()
+ self.fc_z_a = nn.Linear(style_dim, hidden_channels, bias=False)
+
+ self.fc1 = nn.Linear(in_channels, hidden_channels)
+ self.fc2 = nn.Linear(hidden_channels, hidden_channels)
+ self.fc3 = nn.Linear(hidden_channels, hidden_channels)
+ self.fc4 = nn.Linear(hidden_channels, hidden_channels)
+ self.fc5 = nn.Linear(hidden_channels, hidden_channels)
+
+ self.fc_out_c = nn.Linear(hidden_channels, out_channels_c)
+
+ if leaky_relu:
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ self.act = functools.partial(F.relu, inplace=True)
+
+ def forward(self, x, z):
+ r"""Forward network
+
+ Args:
+ x (... x in_channels tensor): Ray direction embeddings.
+ z (... x style_dim tensor): Style codes.
+ """
+
+ z = self.fc_z_a(z)
+ while z.dim() < x.dim():
+ z = z.unsqueeze(1)
+
+ y = self.act(self.fc1(x) + z)
+ y = self.act(self.fc2(y))
+ y = self.act(self.fc3(y))
+ y = self.act(self.fc4(y))
+ y = self.act(self.fc5(y))
+ c = self.fc_out_c(y)
+
+ return c
+
+
+class RenderCNN(nn.Module):
+ r"""CNN converting intermediate feature map to final image."""
+
+ def __init__(self, in_channels, style_dim, hidden_channels=256,
+ leaky_relu=True):
+ super(RenderCNN, self).__init__()
+ self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels)
+
+ self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0)
+ self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
+ self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
+
+ self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1)
+ self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False)
+
+ self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
+ self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0)
+
+ self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0)
+
+ if leaky_relu:
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ self.act = functools.partial(F.relu, inplace=True)
+
+ def modulate(self, x, w, b):
+ w = w[..., None, None]
+ b = b[..., None, None]
+ return x * (w+1) + b
+
+ def forward(self, x, z):
+ r"""Forward network.
+
+ Args:
+ x (N x in_channels x H x W tensor): Intermediate feature map
+ z (N x style_dim tensor): Style codes.
+ """
+ z = self.fc_z_cond(z)
+ adapt = torch.chunk(z, 2 * 2, dim=-1)
+
+ y = self.act(self.conv1(x))
+
+ y = y + self.conv2b(self.act(self.conv2a(y)))
+ y = self.act(self.modulate(y, adapt[0], adapt[1]))
+
+ y = y + self.conv3b(self.act(self.conv3a(y)))
+ y = self.act(self.modulate(y, adapt[2], adapt[3]))
+
+ y = y + self.conv4b(self.act(self.conv4a(y)))
+ y = self.act(y)
+
+ y = self.conv4(y)
+
+ return y
+
+
+class StyleEncoder(nn.Module):
+ r"""Style Encoder constructor.
+
+ Args:
+ style_enc_cfg (obj): Style encoder definition file.
+ """
+
+ def __init__(self, style_enc_cfg):
+ super(StyleEncoder, self).__init__()
+ input_image_channels = style_enc_cfg.input_image_channels
+ num_filters = style_enc_cfg.num_filters
+ kernel_size = style_enc_cfg.kernel_size
+ padding = int(np.ceil((kernel_size - 1.0) / 2))
+ style_dims = style_enc_cfg.style_dims
+ weight_norm_type = style_enc_cfg.weight_norm_type
+ self.no_vae = getattr(style_enc_cfg, 'no_vae', False)
+ activation_norm_type = 'none'
+ nonlinearity = 'leakyrelu'
+ base_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ # inplace_nonlinearity=True,
+ nonlinearity=nonlinearity)
+ self.layer1 = base_conv2d_block(input_image_channels, num_filters)
+ self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2)
+ self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4)
+ self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8)
+ self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8)
+ self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8)
+ self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
+ if not self.no_vae:
+ self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
+
+ def forward(self, input_x):
+ r"""SPADE Style Encoder forward.
+
+ Args:
+ input_x (N x 3 x H x W tensor): input images.
+ Returns:
+ mu (N x C tensor): Mean vectors.
+ logvar (N x C tensor): Log-variance vectors.
+ z (N x C tensor): Style code vectors.
+ """
+ if input_x.size(2) != 256 or input_x.size(3) != 256:
+ input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear')
+ x = self.layer1(input_x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.layer5(x)
+ x = self.layer6(x)
+ x = x.view(x.size(0), -1)
+ mu = self.fc_mu(x)
+ if not self.no_vae:
+ logvar = self.fc_var(x)
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ z = eps.mul(std) + mu
+ else:
+ z = mu
+ logvar = torch.zeros_like(mu)
+ return mu, logvar, z
+
+
+class Base3DGenerator(nn.Module):
+ r"""Minecraft 3D generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super(Base3DGenerator, self).__init__()
+ print('Base3DGenerator initialization.')
+
+ # ---------------------- Main Network ------------------------
+ # Exclude some of the features from positional encoding
+ self.pe_no_pe_feat_dim = getattr(gen_cfg, 'pe_no_pe_feat_dim', 0)
+
+ # blk_feat passes through PE
+ input_dim = (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)*(gen_cfg.pe_lvl_feat*2) + self.pe_no_pe_feat_dim
+ if (gen_cfg.pe_incl_orig_feat):
+ input_dim += (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)
+ print('[Base3DGenerator] Expected input dimensions: ', input_dim)
+ self.input_dim = input_dim
+
+ self.mlp_model_kwargs = gen_cfg.mlp_model_kwargs
+ self.pe_lvl_localcoords = getattr(gen_cfg, 'pe_lvl_localcoords', 0)
+ if self.pe_lvl_localcoords > 0:
+ self.mlp_model_kwargs['poscode_dim'] = self.pe_lvl_localcoords * 2 * 3
+
+ # Set pe_lvl_raydir=0 and pe_incl_orig_raydir=False to disable view direction input
+ input_dim_viewdir = 3*(gen_cfg.pe_lvl_raydir*2)
+ if (gen_cfg.pe_incl_orig_raydir):
+ input_dim_viewdir += 3
+ print('[Base3DGenerator] Expected viewdir input dimensions: ', input_dim_viewdir)
+ self.input_dim_viewdir = input_dim_viewdir
+
+ self.pe_params = [gen_cfg.pe_lvl_feat, gen_cfg.pe_incl_orig_feat,
+ gen_cfg.pe_lvl_raydir, gen_cfg.pe_incl_orig_raydir]
+
+ # Style input dimension
+ style_dims = gen_cfg.style_dims
+ self.style_dims = style_dims
+ interm_style_dims = getattr(gen_cfg, 'interm_style_dims', style_dims)
+ self.interm_style_dims = interm_style_dims
+ # ---------------------- Style MLP --------------------------
+ self.style_net = globals()[gen_cfg.stylenet_model](
+ style_dims, interm_style_dims, **gen_cfg.stylenet_model_kwargs)
+
+ # number of output channels for MLP (before blending)
+ final_feat_dim = getattr(gen_cfg, 'final_feat_dim', 16)
+ self.final_feat_dim = final_feat_dim
+
+ # ----------------------- Sky Network -------------------------
+ sky_input_dim_base = 3
+ # Dedicated sky network input dimensions
+ sky_input_dim = sky_input_dim_base*(gen_cfg.pe_lvl_raydir_sky*2)
+ if (gen_cfg.pe_incl_orig_raydir_sky):
+ sky_input_dim += sky_input_dim_base
+ print('[Base3DGenerator] Expected sky input dimensions: ', sky_input_dim)
+ self.pe_params_sky = [gen_cfg.pe_lvl_raydir_sky, gen_cfg.pe_incl_orig_raydir_sky]
+ self.sky_net = SKYMLP(sky_input_dim, style_dim=interm_style_dims, out_channels_c=final_feat_dim)
+
+ # ----------------------- Style Encoder -------------------------
+ style_enc_cfg = getattr(gen_cfg, 'style_enc', None)
+ setattr(style_enc_cfg, 'input_image_channels', 3)
+ setattr(style_enc_cfg, 'style_dims', gen_cfg.style_dims)
+ self.style_encoder = StyleEncoder(style_enc_cfg)
+
+ # ---------------------- Ray Caster -------------------------
+ self.num_blocks_early_stop = gen_cfg.num_blocks_early_stop
+ self.num_samples = gen_cfg.num_samples
+ self.sample_depth = gen_cfg.sample_depth
+ self.coarse_deterministic_sampling = getattr(gen_cfg, 'coarse_deterministic_sampling', True)
+ self.sample_use_box_boundaries = getattr(gen_cfg, 'sample_use_box_boundaries', True)
+
+ # ---------------------- Blender -------------------------
+ self.raw_noise_std = getattr(gen_cfg, 'raw_noise_std', 0.0)
+ self.dists_scale = getattr(gen_cfg, 'dists_scale', 0.25)
+ self.clip_feat_map = getattr(gen_cfg, 'clip_feat_map', True)
+ self.keep_sky_out = getattr(gen_cfg, 'keep_sky_out', False)
+ self.keep_sky_out_avgpool = getattr(gen_cfg, 'keep_sky_out_avgpool', False)
+ keep_sky_out_learnbg = getattr(gen_cfg, 'keep_sky_out_learnbg', False)
+ self.sky_global_avgpool = getattr(gen_cfg, 'sky_global_avgpool', False)
+ if self.keep_sky_out:
+ self.sky_replace_color = None
+ if keep_sky_out_learnbg:
+ sky_replace_color = torch.zeros([final_feat_dim])
+ sky_replace_color.requires_grad = True
+ self.sky_replace_color = torch.nn.Parameter(sky_replace_color)
+ # ---------------------- render_cnn -------------------------
+ self.denoiser = RenderCNN(final_feat_dim, style_dim=interm_style_dims)
+ self.pad = gen_cfg.pad
+
+ def get_param_groups(self, cfg_opt):
+ print('[Generator] get_param_groups')
+
+ if hasattr(cfg_opt, 'ignore_parameters'):
+ print('[Generator::get_param_groups] [x]: ignored.')
+ optimize_parameters = []
+ for k, x in self.named_parameters():
+ match = False
+ for m in cfg_opt.ignore_parameters:
+ if re.match(m, k) is not None:
+ match = True
+ print(' [x]', k)
+ break
+ if match is False:
+ print(' [v]', k)
+ optimize_parameters.append(x)
+ else:
+ optimize_parameters = self.parameters()
+
+ param_groups = []
+ param_groups.append({'params': optimize_parameters})
+
+ if hasattr(cfg_opt, 'param_groups'):
+ optimized_param_names = []
+ all_param_names = [k for k, v in self.named_parameters()]
+ param_groups = []
+ for k, v in cfg_opt.param_groups.items():
+ print('[Generator::get_param_groups] Adding param group from config:', k, v)
+ params = getattr(self, k)
+ named_parameters = [k]
+ if issubclass(type(params), nn.Module):
+ named_parameters = [k+'.'+pname for pname, _ in params.named_parameters()]
+ params = params.parameters()
+ param_groups.append({'params': params, **v})
+ optimized_param_names.extend(named_parameters)
+
+ print('[Generator::get_param_groups] UNOPTIMIZED PARAMETERS:\n ',
+ set(all_param_names) - set(optimized_param_names))
+
+ return param_groups
+
+ def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None):
+ r"""Forwarding the MLP.
+
+ Args:
+ blk_feats (K x C1 tensor): Sparse block features.
+ worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points.
+ raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings.
+ z (N x C3 tensor): Intermediate style vectors.
+ mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps.
+ Returns:
+ net_out_s (N x H x W x L x 1 tensor): Opacities.
+ net_out_c (N x H x W x L x C5 tensor): Color embeddings.
+ """
+ proj_feature = voxlib.sparse_trilinear_interp_worldcoord(
+ blk_feats, self.voxel.corner_t, worldcoord2, ign_zero=True)
+
+ render_net_extra_kwargs = {}
+ if self.pe_lvl_localcoords > 0:
+ local_coords = torch.remainder(worldcoord2, 1.0) * 2.0
+ # Scale to [0, 2], as the positional encoding function doesn't have internal x2
+ local_coords[torch.isnan(local_coords)] = 0.0
+ local_coords = local_coords.contiguous()
+ poscode = voxlib.positional_encoding(local_coords, self.pe_lvl_localcoords, -1, False)
+ render_net_extra_kwargs['poscode'] = poscode
+
+ if self.pe_params[0] == 0 and self.pe_params[1] is True: # no PE shortcut, saves ~400MB
+ feature_in = proj_feature
+ else:
+ if self.pe_no_pe_feat_dim > 0:
+ feature_in = voxlib.positional_encoding(
+ proj_feature[..., :-self.pe_no_pe_feat_dim].contiguous(), self.pe_params[0], -1, self.pe_params[1])
+ feature_in = torch.cat([feature_in, proj_feature[..., -self.pe_no_pe_feat_dim:]], dim=-1)
+ else:
+ feature_in = voxlib.positional_encoding(
+ proj_feature.contiguous(), self.pe_params[0], -1, self.pe_params[1])
+
+ net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot, **render_net_extra_kwargs)
+
+ if self.raw_noise_std > 0.:
+ noise = torch.randn_like(net_out_s) * self.raw_noise_std
+ net_out_s = net_out_s + noise
+
+ return net_out_s, net_out_c
+
+ def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z):
+ r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features
+
+ Args:
+ blk_feats (K x C1 tensor): Sparse block features.
+ voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels
+ depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection.
+ raydirs (N x H x W x 1 x 3 tensor): The direction of each ray.
+ cam_ori_t (N x 3 tensor): Camera origins.
+ z (N x C3 tensor): Intermediate style vectors.
+ """
+ # Generate sky_mask; PE transform on ray direction.
+ with torch.no_grad():
+ raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
+ if self.pe_params[2] == 0 and self.pe_params[3] is True:
+ raydirs_in = raydirs_in
+ elif self.pe_params[2] == 0 and self.pe_params[3] is False: # Not using raydir at all
+ raydirs_in = None
+ else:
+ raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3])
+
+ # sky_mask: when True, ray finally hits sky
+ sky_mask = voxel_id[:, :, :, [-1], :] == 0
+ # sky_only_mask: when True, ray hits nothing but sky
+ sky_only_mask = voxel_id[:, :, :, [0], :] == 0
+
+ with torch.no_grad():
+ # Random sample points along the ray
+ num_samples = self.num_samples + 1
+ if self.sample_use_box_boundaries:
+ num_samples = self.num_samples - self.num_blocks_early_stop
+
+ # 10 samples per ray + 4 intersections - 2
+ rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched(
+ depth2, num_samples, deterministic=self.coarse_deterministic_sampling,
+ use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth)
+
+ worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :]
+
+ # Generate per-sample segmentation label
+ voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True)
+ mc_masks = torch.gather(voxel_id_reduced, -2, new_idx) # B 256 256 N 1
+ mc_masks = mc_masks.long()
+ mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size(
+ 2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device)
+ # mc_masks_onehot: [B H W Nlayer 680]
+ mc_masks_onehot.scatter_(-1, mc_masks, 1.0)
+
+ net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot)
+
+ # Handle sky
+ sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous()
+ sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1])
+ skynet_out_c = self.sky_net(sky_raydirs_in, z)
+
+ # Blending
+ weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2)
+
+ # If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero.
+ weights = weights * torch.logical_not(sky_only_mask).float()
+ total_weights_raw = torch.sum(weights, dim=-2, keepdim=True) # 256 256 1 1
+ total_weights = total_weights_raw
+
+ is_gnd = worldcoord2[..., [0]] <= 1.0 # Y X Z, [256, 256, 4, 3], nan < 1.0 == False
+ is_gnd = is_gnd.any(dim=-2, keepdim=True)
+ nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd)
+ nosky_mask = nosky_mask.float()
+
+ # Avoid sky leakage
+ sky_weight = 1.0-total_weights
+ if self.keep_sky_out:
+ # keep_sky_out_avgpool overrides sky_replace_color
+ if self.sky_replace_color is None or self.keep_sky_out_avgpool:
+ if self.keep_sky_out_avgpool:
+ if hasattr(self, 'sky_avg'):
+ sky_avg = self.sky_avg
+ else:
+ if self.sky_global_avgpool:
+ sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True)
+ else:
+ skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1)
+ sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False)
+ sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2)
+ # print(sky_avg.shape)
+ skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask)
+ else:
+ sky_weight = sky_weight * (1.0-nosky_mask)
+ else:
+ skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask)
+
+ if self.clip_feat_map is True: # intermediate feature before blending & CNN
+ rgbs = torch.clamp(net_out_c, -1, 1) + 1
+ rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1
+ net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
+ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
+ net_out = net_out.squeeze(-2)
+ net_out = net_out - 1
+ elif self.clip_feat_map is False:
+ rgbs = net_out_c
+ rgbs_sky = skynet_out_c
+ net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
+ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
+ net_out = net_out.squeeze(-2)
+ elif self.clip_feat_map == 'tanh':
+ rgbs = torch.tanh(net_out_c)
+ rgbs_sky = torch.tanh(skynet_out_c)
+ net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \
+ rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3
+ net_out = net_out.squeeze(-2)
+ else:
+ raise NotImplementedError
+
+ return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \
+ nosky_mask, sky_mask, sky_only_mask, new_idx
+
+ def _forward_global(self, net_out, z):
+ r"""Forward the CNN
+
+ Args:
+ net_out (N x C5 x H x W tensor): Intermediate feature maps.
+ z (N x C3 tensor): Intermediate style vectors.
+
+ Returns:
+ fake_images (N x 3 x H x W tensor): Output image.
+ fake_images_raw (N x 3 x H x W tensor): Output image before TanH.
+ """
+ fake_images = net_out.permute(0, 3, 1, 2)
+ fake_images_raw = self.denoiser(fake_images, z)
+ fake_images = torch.tanh(fake_images_raw)
+
+ return fake_images, fake_images_raw
diff --git a/imaginaire/generators/munit.py b/imaginaire/generators/munit.py
new file mode 100644
index 0000000000000000000000000000000000000000..55cb066bb51e6f8e96d623d2224af06af017c086
--- /dev/null
+++ b/imaginaire/generators/munit.py
@@ -0,0 +1,465 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import warnings
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+from torch.nn import Upsample as NearestUpsample
+
+from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
+from imaginaire.generators.unit import ContentEncoder
+
+
+class Generator(nn.Module):
+ r"""Improved MUNIT generator.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super().__init__()
+ self.autoencoder_a = AutoEncoder(**vars(gen_cfg))
+ self.autoencoder_b = AutoEncoder(**vars(gen_cfg))
+
+ def forward(self, data, random_style=True, image_recon=True,
+ latent_recon=True, cycle_recon=True, within_latent_recon=False):
+ r"""In MUNIT's forward pass, it generates a content code and a style
+ code from images in both domain. It then performs a within-domain
+ reconstruction step and a cross-domain translation step.
+ In within-domain reconstruction, it reconstructs an image using the
+ content and style from the same image and optionally encodes the image
+ back to the latent space.
+ In cross-domain translation, it generates an translated image by mixing
+ the content and style from images in different domains, and optionally
+ encodes the image back to the latent space.
+
+ Args:
+ data (dict): Training data at the current iteration.
+ - images_a (tensor): Images from domain A.
+ - images_b (tensor): Images from domain B.
+ random_style (bool): If ``True``, samples the style code from the
+ prior distribution, otherwise uses the style code encoded from
+ the input images in the other domain.
+ image_recon (bool): If ``True``, also returns reconstructed images.
+ latent_recon (bool): If ``True``, also returns reconstructed latent
+ code during cross-domain translation.
+ cycle_recon (bool): If ``True``, also returns cycle
+ reconstructed images.
+ within_latent_recon (bool): If ``True``, also returns reconstructed
+ latent code during within-domain reconstruction.
+ """
+
+ images_a = data['images_a']
+ images_b = data['images_b']
+ net_G_output = dict()
+
+ # encode input images into content and style code
+ content_a, style_a = self.autoencoder_a.encode(images_a)
+ content_b, style_b = self.autoencoder_b.encode(images_b)
+
+ # decode (within domain)
+ if image_recon:
+ images_aa = self.autoencoder_a.decode(content_a, style_a)
+ images_bb = self.autoencoder_b.decode(content_b, style_b)
+ net_G_output.update(dict(images_aa=images_aa, images_bb=images_bb))
+
+ # decode (cross domain)
+ if random_style: # use randomly sampled style code
+ style_a_rand = torch.randn_like(style_a)
+ style_b_rand = torch.randn_like(style_b)
+ else: # use style code encoded from the other domain
+ style_a_rand = style_a
+ style_b_rand = style_b
+ images_ba = self.autoencoder_a.decode(content_b, style_a_rand)
+ images_ab = self.autoencoder_b.decode(content_a, style_b_rand)
+
+ # encode translated images into content and style code
+ if latent_recon or cycle_recon:
+ content_ba, style_ba = self.autoencoder_a.encode(images_ba)
+ content_ab, style_ab = self.autoencoder_b.encode(images_ab)
+ net_G_output.update(dict(content_ba=content_ba, style_ba=style_ba,
+ content_ab=content_ab, style_ab=style_ab))
+
+ # encode reconstructed images into content and style code
+ if image_recon and within_latent_recon:
+ content_aa, style_aa = self.autoencoder_a.encode(images_aa)
+ content_bb, style_bb = self.autoencoder_b.encode(images_bb)
+ net_G_output.update(dict(content_aa=content_aa, style_aa=style_aa,
+ content_bb=content_bb, style_bb=style_bb))
+
+ # cycle reconstruction
+ if cycle_recon:
+ images_aba = self.autoencoder_a.decode(content_ab, style_a)
+ images_bab = self.autoencoder_b.decode(content_ba, style_b)
+ net_G_output.update(
+ dict(images_aba=images_aba, images_bab=images_bab))
+
+ # required outputs
+ net_G_output.update(dict(content_a=content_a, content_b=content_b,
+ style_a=style_a, style_b=style_b,
+ style_a_rand=style_a_rand,
+ style_b_rand=style_b_rand,
+ images_ba=images_ba, images_ab=images_ab))
+
+ return net_G_output
+
+ def inference(self, data, a2b=True, random_style=True):
+ r"""MUNIT inference.
+
+ Args:
+ data (dict): Training data at the current iteration.
+ - images_a (tensor): Images from domain A.
+ - images_b (tensor): Images from domain B.
+ a2b (bool): If ``True``, translates images from domain A to B,
+ otherwise from B to A.
+ random_style (bool): If ``True``, samples the style code from the
+ prior distribution, otherwise uses the style code encoded from
+ the input images in the other domain.
+ """
+ if a2b:
+ input_key = 'images_a'
+ content_encode = self.autoencoder_a.content_encoder
+ style_encode = self.autoencoder_b.style_encoder
+ decode = self.autoencoder_b.decode
+ else:
+ input_key = 'images_b'
+ content_encode = self.autoencoder_b.content_encoder
+ style_encode = self.autoencoder_a.style_encoder
+ decode = self.autoencoder_a.decode
+
+ content_images = data[input_key]
+ content = content_encode(content_images)
+ if random_style:
+ style_channels = self.autoencoder_a.style_channels
+ style = torch.randn(content.size(0), style_channels, 1, 1,
+ device=torch.device('cuda'))
+ file_names = data['key'][input_key]['filename']
+ else:
+ style_key = 'images_b' if a2b else 'images_a'
+ assert style_key in data.keys(), \
+ "{} must be provided when 'random_style' " \
+ "is set to False".format(style_key)
+ style_images = data[style_key]
+ style = style_encode(style_images)
+ file_names = \
+ [content_name + '_style_' + style_name
+ for content_name, style_name in
+ zip(data['key'][input_key]['filename'],
+ data['key'][style_key]['filename'])]
+
+ output_images = decode(content, style)
+ return output_images, file_names
+
+
+class AutoEncoder(nn.Module):
+ r"""Improved MUNIT autoencoder.
+
+ Args:
+ num_filters (int): Base filter numbers.
+ max_num_filters (int): Maximum number of filters in the encoder.
+ num_filters_mlp (int): Base filter number in the MLP module.
+ latent_dim (int): Dimension of the style code.
+ num_res_blocks (int): Number of residual blocks at the end of the
+ content encoder.
+ num_mlp_blocks (int): Number of layers in the MLP module.
+ num_downsamples_style (int): Number of times we reduce
+ resolution by 2x2 for the style image.
+ num_downsamples_content (int): Number of times we reduce
+ resolution by 2x2 for the content image.
+ num_image_channels (int): Number of input image channels.
+ content_norm_type (str): Type of activation normalization in the
+ content encoder.
+ style_norm_type (str): Type of activation normalization in the
+ style encoder.
+ decoder_norm_type (str): Type of activation normalization in the
+ decoder.
+ weight_norm_type (str): Type of weight normalization.
+ decoder_norm_params (obj): Parameters of activation normalization in the
+ decoder. If not ``None``, decoder_norm_params.__dict__ will be used
+ as keyword arguments when initializing activation normalization.
+ output_nonlinearity (str): Type of nonlinearity before final output,
+ ``'tanh'`` or ``'none'``.
+ pre_act (bool): If ``True``, uses pre-activation residual blocks.
+ apply_noise (bool): If ``True``, injects Gaussian noise in the decoder.
+ """
+
+ def __init__(self,
+ num_filters=64,
+ max_num_filters=256,
+ num_filters_mlp=256,
+ latent_dim=8,
+ num_res_blocks=4,
+ num_mlp_blocks=2,
+ num_downsamples_style=4,
+ num_downsamples_content=2,
+ num_image_channels=3,
+ content_norm_type='instance',
+ style_norm_type='',
+ decoder_norm_type='instance',
+ weight_norm_type='',
+ decoder_norm_params=SimpleNamespace(affine=False),
+ output_nonlinearity='',
+ pre_act=False,
+ apply_noise=False,
+ **kwargs):
+ super().__init__()
+ for key in kwargs:
+ if key != 'type':
+ warnings.warn(
+ "Generator argument '{}' is not used.".format(key))
+ self.style_encoder = StyleEncoder(num_downsamples_style,
+ num_image_channels,
+ num_filters,
+ latent_dim,
+ 'reflect',
+ style_norm_type,
+ weight_norm_type,
+ 'relu')
+ self.content_encoder = ContentEncoder(num_downsamples_content,
+ num_res_blocks,
+ num_image_channels,
+ num_filters,
+ max_num_filters,
+ 'reflect',
+ content_norm_type,
+ weight_norm_type,
+ 'relu',
+ pre_act)
+ self.decoder = Decoder(num_downsamples_content,
+ num_res_blocks,
+ self.content_encoder.output_dim,
+ num_image_channels,
+ num_filters_mlp,
+ 'reflect',
+ decoder_norm_type,
+ decoder_norm_params,
+ weight_norm_type,
+ 'relu',
+ output_nonlinearity,
+ pre_act,
+ apply_noise)
+ self.mlp = MLP(latent_dim,
+ num_filters_mlp,
+ num_filters_mlp,
+ num_mlp_blocks,
+ 'none',
+ 'relu')
+ self.style_channels = latent_dim
+
+ def forward(self, images):
+ r"""Reconstruct an image.
+
+ Args:
+ images (Tensor): Input images.
+ Returns:
+ images_recon (Tensor): Reconstructed images.
+ """
+ content, style = self.encode(images)
+ images_recon = self.decode(content, style)
+ return images_recon
+
+ def encode(self, images):
+ r"""Encode an image to content and style code.
+
+ Args:
+ images (Tensor): Input images.
+ Returns:
+ (tuple):
+ - content (Tensor): Content code.
+ - style (Tensor): Style code.
+ """
+ style = self.style_encoder(images)
+ content = self.content_encoder(images)
+ return content, style
+
+ def decode(self, content, style):
+ r"""Decode content and style code to an image.
+
+ Args:
+ content (Tensor): Content code.
+ style (Tensor): Style code.
+ Returns:
+ images (Tensor): Output images.
+ """
+ style = self.mlp(style)
+ images = self.decoder(content, style)
+ return images
+
+
+class StyleEncoder(nn.Module):
+ r"""MUNIT style encoder.
+
+ Args:
+ num_downsamples (int): Number of times we reduce
+ resolution by 2x2.
+ num_image_channels (int): Number of input image channels.
+ num_filters (int): Base filter numbers.
+ style_channels (int): Dimension of the style code.
+ padding_mode (string): Type of padding.
+ activation_norm_type (str): Type of activation normalization.
+ weight_norm_type (str): Type of weight normalization.
+ nonlinearity (str): Type of nonlinear activation function.
+ """
+
+ def __init__(self, num_downsamples, num_image_channels, num_filters,
+ style_channels, padding_mode, activation_norm_type,
+ weight_norm_type, nonlinearity):
+ super().__init__()
+ conv_params = dict(padding_mode=padding_mode,
+ activation_norm_type=activation_norm_type,
+ weight_norm_type=weight_norm_type,
+ nonlinearity=nonlinearity,
+ inplace_nonlinearity=True)
+ model = []
+ model += [Conv2dBlock(num_image_channels, num_filters, 7, 1, 3,
+ **conv_params)]
+ for i in range(2):
+ model += [Conv2dBlock(num_filters, 2 * num_filters, 4, 2, 1,
+ **conv_params)]
+ num_filters *= 2
+ for i in range(num_downsamples - 2):
+ model += [Conv2dBlock(num_filters, num_filters, 4, 2, 1,
+ **conv_params)]
+ model += [nn.AdaptiveAvgPool2d(1)]
+ model += [nn.Conv2d(num_filters, style_channels, 1, 1, 0)]
+ self.model = nn.Sequential(*model)
+ self.output_dim = num_filters
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input image.
+ """
+ return self.model(x)
+
+
+class Decoder(nn.Module):
+ r"""Improved MUNIT decoder. The network consists of
+
+ - $(num_res_blocks) residual blocks.
+ - $(num_upsamples) residual blocks or convolutional blocks
+ - output layer.
+
+ Args:
+ num_upsamples (int): Number of times we increase resolution by 2x2.
+ num_res_blocks (int): Number of residual blocks.
+ num_filters (int): Base filter numbers.
+ num_image_channels (int): Number of input image channels.
+ style_channels (int): Dimension of the style code.
+ padding_mode (string): Type of padding.
+ activation_norm_type (str): Type of activation normalization.
+ activation_norm_params (obj): Parameters of activation normalization.
+ If not ``None``, decoder_norm_params.__dict__ will be used
+ as keyword arguments when initializing activation normalization.
+ weight_norm_type (str): Type of weight normalization.
+ nonlinearity (str): Type of nonlinear activation function.
+ output_nonlinearity (str): Type of nonlinearity before final output,
+ ``'tanh'`` or ``'none'``.
+ pre_act (bool): If ``True``, uses pre-activation residual blocks.
+ apply_noise (bool): If ``True``, injects Gaussian noise.
+ """
+
+ def __init__(self,
+ num_upsamples,
+ num_res_blocks,
+ num_filters,
+ num_image_channels,
+ style_channels,
+ padding_mode,
+ activation_norm_type,
+ activation_norm_params,
+ weight_norm_type,
+ nonlinearity,
+ output_nonlinearity,
+ pre_act=False,
+ apply_noise=False):
+ super().__init__()
+ adain_params = SimpleNamespace(
+ activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ cond_dims=style_channels)
+ conv_params = dict(padding_mode=padding_mode,
+ nonlinearity=nonlinearity,
+ inplace_nonlinearity=True,
+ apply_noise=apply_noise,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type='adaptive',
+ activation_norm_params=adain_params)
+
+ # The order of operations in residual blocks.
+ order = 'pre_act' if pre_act else 'CNACNA'
+
+ # Residual blocks with AdaIN.
+ self.decoder = nn.ModuleList()
+ for _ in range(num_res_blocks):
+ self.decoder += [Res2dBlock(num_filters, num_filters,
+ **conv_params,
+ order=order)]
+
+ # Convolutional blocks with upsampling.
+ for i in range(num_upsamples):
+ self.decoder += [NearestUpsample(scale_factor=2)]
+ self.decoder += [Conv2dBlock(num_filters, num_filters // 2,
+ 5, 1, 2, **conv_params)]
+ num_filters //= 2
+ self.decoder += [Conv2dBlock(num_filters, num_image_channels, 7, 1, 3,
+ nonlinearity=output_nonlinearity,
+ padding_mode=padding_mode)]
+
+ def forward(self, x, style):
+ r"""
+
+ Args:
+ x (tensor): Content embedding of the content image.
+ style (tensor): Style embedding of the style image.
+ """
+ for block in self.decoder:
+ if getattr(block, 'conditional', False):
+ x = block(x, style)
+ else:
+ x = block(x)
+ return x
+
+
+class MLP(nn.Module):
+ r"""The multi-layer perceptron (MLP) that maps Gaussian style code to a
+ feature vector that is given as the conditional input to AdaIN.
+
+ Args:
+ input_dim (int): Number of channels in the input tensor.
+ output_dim (int): Number of channels in the output tensor.
+ latent_dim (int): Number of channels in the latent features.
+ num_layers (int): Number of layers in the MLP.
+ norm (str): Type of activation normalization.
+ nonlinearity (str): Type of nonlinear activation function.
+ """
+
+ def __init__(self, input_dim, output_dim, latent_dim, num_layers,
+ norm, nonlinearity):
+ super().__init__()
+ model = []
+ model += [LinearBlock(input_dim, latent_dim,
+ activation_norm_type=norm,
+ nonlinearity=nonlinearity)]
+ for i in range(num_layers - 2):
+ model += [LinearBlock(latent_dim, latent_dim,
+ activation_norm_type=norm,
+ nonlinearity=nonlinearity)]
+ model += [LinearBlock(latent_dim, output_dim,
+ activation_norm_type=norm,
+ nonlinearity=nonlinearity)]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input image.
+ """
+ return self.model(x.view(x.size(0), -1))
diff --git a/imaginaire/generators/pix2pixHD.py b/imaginaire/generators/pix2pixHD.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd2e36b31b2b045594d3dd1d7db7cb4ee336d6f8
--- /dev/null
+++ b/imaginaire/generators/pix2pixHD.py
@@ -0,0 +1,348 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import Upsample as NearestUpsample
+
+from imaginaire.layers import Conv2dBlock, Res2dBlock
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+ get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Generator(nn.Module):
+ r"""Pix2pixHD coarse-to-fine generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super().__init__()
+ # pix2pixHD has a global generator.
+ global_gen_cfg = gen_cfg.global_generator
+ num_filters_global = getattr(global_gen_cfg, 'num_filters', 64)
+ # Optionally, it can have several local enhancers. They are useful
+ # for generating high resolution images.
+ local_gen_cfg = gen_cfg.local_enhancer
+ self.num_local_enhancers = num_local_enhancers = \
+ getattr(local_gen_cfg, 'num_enhancers', 1)
+ # By default, pix2pixHD using instance normalization.
+ activation_norm_type = getattr(gen_cfg, 'activation_norm_type',
+ 'instance')
+ activation_norm_params = getattr(gen_cfg, 'activation_norm_params',
+ None)
+ weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '')
+ padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect')
+ base_conv_block = partial(Conv2dBlock,
+ padding_mode=padding_mode,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ nonlinearity='relu')
+ base_res_block = partial(Res2dBlock,
+ padding_mode=padding_mode,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ nonlinearity='relu', order='CNACN')
+ # Know what is the number of available segmentation labels.
+ num_input_channels = get_paired_input_label_channel_number(data_cfg)
+ self.concat_features = False
+ # Check whether label input contains specific type of data (e.g.
+ # instance_maps).
+ self.contain_instance_map = False
+ if data_cfg.input_labels[-1] == 'instance_maps':
+ self.contain_instance_map = True
+ # The feature encoder is only useful when the instance map is provided.
+ if hasattr(gen_cfg, 'enc') and self.contain_instance_map:
+ num_feat_channels = getattr(gen_cfg.enc, 'num_feat_channels', 0)
+ if num_feat_channels > 0:
+ num_input_channels += num_feat_channels
+ self.concat_features = True
+ self.encoder = Encoder(gen_cfg.enc, data_cfg)
+
+ # Global generator model.
+ global_model = GlobalGenerator(global_gen_cfg, data_cfg,
+ num_input_channels, padding_mode,
+ base_conv_block, base_res_block)
+ if num_local_enhancers == 0:
+ self.global_model = global_model
+ else:
+ # Get rid of the last layer.
+ global_model = global_model.model
+ global_model = [global_model[i]
+ for i in range(len(global_model) - 1)]
+ # global_model = [global_model[i]
+ # for i in range(len(global_model) - 2)]
+ self.global_model = nn.Sequential(*global_model)
+
+ # Local enhancer model.
+ for n in range(num_local_enhancers):
+ # num_filters = num_filters_global // (2 ** n)
+ num_filters = num_filters_global // (2 ** (n + 1))
+ output_img = (n == num_local_enhancers - 1)
+ setattr(self, 'enhancer_%d' % n,
+ LocalEnhancer(local_gen_cfg, data_cfg,
+ num_input_channels, num_filters,
+ padding_mode, base_conv_block,
+ base_res_block, output_img))
+
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1],
+ count_include_pad=False)
+
+ def forward(self, data, random_style=False):
+ r"""Coarse-to-fine generator forward.
+
+ Args:
+ data (dict) : Dictionary of input data.
+ random_style (bool): Always set to false for the pix2pixHD model.
+ Returns:
+ output (dict) : Dictionary of output data.
+ """
+ label = data['label']
+
+ output = dict()
+ if self.concat_features:
+ features = self.encoder(data['images'], data['instance_maps'])
+ label = torch.cat([label, features], dim=1)
+ output['feature_maps'] = features
+
+ # Create input pyramid.
+ input_downsampled = [label]
+ for i in range(self.num_local_enhancers):
+ input_downsampled.append(self.downsample(input_downsampled[-1]))
+
+ # Output at coarsest level.
+ x = self.global_model(input_downsampled[-1])
+
+ # Coarse-to-fine: build up one layer at a time.
+ for n in range(self.num_local_enhancers):
+ input_n = input_downsampled[self.num_local_enhancers - n - 1]
+ enhancer = getattr(self, 'enhancer_%d' % n)
+ x = enhancer(x, input_n)
+
+ output['fake_images'] = x
+ return output
+
+ def load_pretrained_network(self, pretrained_dict):
+ r"""Load a pretrained network."""
+ # print(pretrained_dict.keys())
+ model_dict = self.state_dict()
+ print('Pretrained network has fewer layers; The following are '
+ 'not initialized:')
+
+ not_initialized = set()
+ for k, v in model_dict.items():
+ kp = 'module.' + k.replace('global_model.', 'global_model.model.')
+ if kp in pretrained_dict and v.size() == pretrained_dict[kp].size():
+ model_dict[k] = pretrained_dict[kp]
+ else:
+ not_initialized.add('.'.join(k.split('.')[:2]))
+ print(sorted(not_initialized))
+ self.load_state_dict(model_dict)
+
+ def inference(self, data, **kwargs):
+ r"""Generator inference.
+
+ Args:
+ data (dict) : Dictionary of input data.
+ Returns:
+ fake_images (tensor): Output fake images.
+ file_names (str): Data file name.
+ """
+ output = self.forward(data, **kwargs)
+ return output['fake_images'], data['key']['seg_maps'][0]
+
+
+class LocalEnhancer(nn.Module):
+ r"""Local enhancer constructor. These are sub-networks that are useful
+ when aiming to produce high-resolution outputs.
+
+ Args:
+ gen_cfg (obj): local generator definition part of the yaml config
+ file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ num_input_channels (int): Number of segmentation labels.
+ num_filters (int): Number of filters for the first layer.
+ padding_mode (str): zero | reflect | ...
+ base_conv_block (obj): Conv block with preset attributes.
+ base_res_block (obj): Residual block with preset attributes.
+ output_img (bool): Output is image or feature map.
+ """
+
+ def __init__(self, gen_cfg, data_cfg, num_input_channels, num_filters,
+ padding_mode, base_conv_block, base_res_block,
+ output_img=False):
+ super(LocalEnhancer, self).__init__()
+ num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 3)
+ num_img_channels = get_paired_input_image_channel_number(data_cfg)
+ # Downsample.
+ model_downsample = \
+ [base_conv_block(num_input_channels, num_filters, 7, padding=3),
+ base_conv_block(num_filters, num_filters * 2, 3, stride=2,
+ padding=1)]
+ # Residual blocks.
+ model_upsample = []
+ for i in range(num_res_blocks):
+ model_upsample += [base_res_block(num_filters * 2, num_filters * 2,
+ 3, padding=1)]
+ # Upsample.
+ model_upsample += \
+ [NearestUpsample(scale_factor=2),
+ base_conv_block(num_filters * 2, num_filters, 3, padding=1)]
+
+ # Final convolution.
+ if output_img:
+ model_upsample += [Conv2dBlock(num_filters, num_img_channels, 7,
+ padding=3, padding_mode=padding_mode,
+ nonlinearity='tanh')]
+
+ self.model_downsample = nn.Sequential(*model_downsample)
+ self.model_upsample = nn.Sequential(*model_upsample)
+
+ def forward(self, output_coarse, input_fine):
+ r"""Local enhancer forward.
+
+ Args:
+ output_coarse (4D tensor) : Coarse output from previous layer.
+ input_fine (4D tensor) : Fine input from current layer.
+ Returns:
+ output (4D tensor) : Refined output.
+ """
+ output = self.model_upsample(self.model_downsample(input_fine) + output_coarse)
+ return output
+
+
+class GlobalGenerator(nn.Module):
+ r"""Coarse generator constructor. This is the main generator in the
+ pix2pixHD architecture.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ num_input_channels (int): Number of segmentation labels.
+ padding_mode (str): zero | reflect | ...
+ base_conv_block (obj): Conv block with preset attributes.
+ base_res_block (obj): Residual block with preset attributes.
+ """
+
+ def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode,
+ base_conv_block, base_res_block):
+ super(GlobalGenerator, self).__init__()
+ num_img_channels = get_paired_input_image_channel_number(data_cfg)
+ num_filters = getattr(gen_cfg, 'num_filters', 64)
+ num_downsamples = getattr(gen_cfg, 'num_downsamples', 4)
+ num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9)
+ # First layer.
+ model = [base_conv_block(num_input_channels, num_filters,
+ kernel_size=7, padding=3)]
+ # Downsample.
+ for i in range(num_downsamples):
+ ch = num_filters * (2 ** i)
+ model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)]
+ # ResNet blocks.
+ ch = num_filters * (2 ** num_downsamples)
+ for i in range(num_res_blocks):
+ model += [base_res_block(ch, ch, 3, padding=1)]
+ # Upsample.
+ num_upsamples = num_downsamples
+ for i in reversed(range(num_upsamples)):
+ ch = num_filters * (2 ** i)
+ model += \
+ [NearestUpsample(scale_factor=2),
+ base_conv_block(ch * 2, ch, 3, padding=1)]
+ model += [Conv2dBlock(num_filters, num_img_channels, 7, padding=3,
+ padding_mode=padding_mode, nonlinearity='tanh')]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ r"""Coarse-to-fine generator forward.
+
+ Args:
+ input (4D tensor) : Input semantic representations.
+ Returns:
+ output (4D tensor) : Synthesized image by generator.
+ """
+ return self.model(input)
+
+
+class Encoder(nn.Module):
+ r"""Encoder for getting region-wise features for style control.
+
+ Args:
+ enc_cfg (obj): Encoder definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file
+ """
+
+ def __init__(self, enc_cfg, data_cfg):
+ super(Encoder, self).__init__()
+ label_nc = get_paired_input_label_channel_number(data_cfg)
+ feat_nc = enc_cfg.num_feat_channels
+ n_clusters = getattr(enc_cfg, 'num_clusters', 10)
+ for i in range(label_nc):
+ dummy_arr = np.zeros((n_clusters, feat_nc), dtype=np.float32)
+ self.register_buffer('cluster_%d' % i,
+ torch.tensor(dummy_arr, dtype=torch.float32))
+ num_img_channels = get_paired_input_image_channel_number(data_cfg)
+ self.num_feat_channels = getattr(enc_cfg, 'num_feat_channels', 3)
+ num_filters = getattr(enc_cfg, 'num_filters', 64)
+ num_downsamples = getattr(enc_cfg, 'num_downsamples', 4)
+ weight_norm_type = getattr(enc_cfg, 'weight_norm_type', 'none')
+ activation_norm_type = getattr(enc_cfg, 'activation_norm_type',
+ 'instance')
+ padding_mode = getattr(enc_cfg, 'padding_mode', 'reflect')
+ base_conv_block = partial(Conv2dBlock,
+ padding_mode=padding_mode,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity='relu')
+ model = [base_conv_block(num_img_channels, num_filters, 7, padding=3)]
+ # Downsample.
+ for i in range(num_downsamples):
+ ch = num_filters * (2**i)
+ model += [base_conv_block(ch, ch * 2, 3, stride=2, padding=1)]
+ # Upsample.
+ for i in reversed(range(num_downsamples)):
+ ch = num_filters * (2 ** i)
+ model += [NearestUpsample(scale_factor=2),
+ base_conv_block(ch * 2, ch, 3, padding=1)]
+
+ model += [Conv2dBlock(num_filters, self.num_feat_channels, 7,
+ padding=3, padding_mode=padding_mode,
+ nonlinearity='tanh')]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input, instance_map):
+ r"""Extracting region-wise features
+
+ Args:
+ input (4D tensor): Real RGB images.
+ instance_map (4D tensor): Instance label mask.
+ Returns:
+ outputs_mean (4D tensor): Instance-wise average-pooled
+ feature maps.
+ """
+ outputs = self.model(input)
+ # Instance-wise average pooling.
+ outputs_mean = torch.zeros_like(outputs)
+ # Find all the unique labels in this batch.
+ inst_list = np.unique(instance_map.cpu().numpy().astype(int))
+ for i in inst_list:
+ for b in range(input.size(0)):
+ # Find the pixels in this instance map have this instance label.
+ indices = (instance_map[b:b+1] == int(i)).nonzero() # n x 4
+ # Scan through the feature channels.
+ for j in range(self.num_feat_channels):
+ output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j,
+ indices[:, 2], indices[:, 3]]
+ mean_feat = torch.mean(output_ins).expand_as(output_ins)
+ outputs_mean[indices[:, 0] + b, indices[:, 1] + j,
+ indices[:, 2], indices[:, 3]] = mean_feat
+ return outputs_mean
diff --git a/imaginaire/generators/spade.py b/imaginaire/generators/spade.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc69630304ccb2ce3fab707ca2e7de5f7aeec55a
--- /dev/null
+++ b/imaginaire/generators/spade.py
@@ -0,0 +1,571 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+import math
+import types
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Upsample as NearestUpsample
+
+from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
+from imaginaire.utils.data import (get_crop_h_w,
+ get_paired_input_image_channel_number,
+ get_paired_input_label_channel_number)
+from imaginaire.utils.distributed import master_only_print as print
+
+
+class Generator(nn.Module):
+ r"""SPADE generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super(Generator, self).__init__()
+ print('SPADE generator initialization.')
+ # We assume the first datum is the ground truth image.
+ image_channels = getattr(gen_cfg, 'image_channels', None)
+ if image_channels is None:
+ image_channels = get_paired_input_image_channel_number(data_cfg)
+ num_labels = getattr(gen_cfg, 'num_labels', None)
+ if num_labels is None:
+ # Calculate number of channels in the input label when not specified.
+ num_labels = get_paired_input_label_channel_number(data_cfg)
+ crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations)
+ # Build the generator
+ out_image_small_side_size = crop_w if crop_w < crop_h else crop_h
+ num_filters = getattr(gen_cfg, 'num_filters', 128)
+ kernel_size = getattr(gen_cfg, 'kernel_size', 3)
+ weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')
+
+ cond_dims = 0
+ # Check whether we use the style code.
+ style_dims = getattr(gen_cfg, 'style_dims', None)
+ self.style_dims = style_dims
+ if style_dims is not None:
+ print('\tStyle code dimensions: %d' % style_dims)
+ cond_dims += style_dims
+ self.use_style = True
+ else:
+ self.use_style = False
+ # Check whether we use the attribute code.
+ if hasattr(gen_cfg, 'attribute_dims'):
+ self.use_attribute = True
+ self.attribute_dims = gen_cfg.attribute_dims
+ cond_dims += gen_cfg.attribute_dims
+ else:
+ self.use_attribute = False
+
+ if not self.use_style and not self.use_attribute:
+ self.use_style_encoder = False
+ else:
+ self.use_style_encoder = True
+ print('\tBase filter number: %d' % num_filters)
+ print('\tConvolution kernel size: %d' % kernel_size)
+ print('\tWeight norm type: %s' % weight_norm_type)
+ skip_activation_norm = \
+ getattr(gen_cfg, 'skip_activation_norm', True)
+ activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None)
+ if activation_norm_params is None:
+ activation_norm_params = types.SimpleNamespace()
+ if not hasattr(activation_norm_params, 'num_filters'):
+ setattr(activation_norm_params, 'num_filters', 128)
+ if not hasattr(activation_norm_params, 'kernel_size'):
+ setattr(activation_norm_params, 'kernel_size', 3)
+ if not hasattr(activation_norm_params, 'activation_norm_type'):
+ setattr(activation_norm_params, 'activation_norm_type', 'sync_batch')
+ if not hasattr(activation_norm_params, 'separate_projection'):
+ setattr(activation_norm_params, 'separate_projection', False)
+ if not hasattr(activation_norm_params, 'activation_norm_params'):
+ activation_norm_params.activation_norm_params = types.SimpleNamespace()
+ activation_norm_params.activation_norm_params.affine = True
+ setattr(activation_norm_params, 'cond_dims', num_labels)
+ if not hasattr(activation_norm_params, 'weight_norm_type'):
+ setattr(activation_norm_params, 'weight_norm_type', weight_norm_type)
+ global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch')
+ use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True)
+ output_multiplier = getattr(gen_cfg, 'output_multiplier', 1.0)
+ print(activation_norm_params)
+ self.spade_generator = SPADEGenerator(num_labels,
+ out_image_small_side_size,
+ image_channels,
+ num_filters,
+ kernel_size,
+ cond_dims,
+ activation_norm_params,
+ weight_norm_type,
+ global_adaptive_norm_type,
+ skip_activation_norm,
+ use_posenc_in_input_layer,
+ self.use_style_encoder,
+ output_multiplier)
+ if self.use_style:
+ # Build the encoder.
+ style_enc_cfg = getattr(gen_cfg, 'style_enc', None)
+ if style_enc_cfg is None:
+ style_enc_cfg = types.SimpleNamespace()
+ if not hasattr(style_enc_cfg, 'num_filters'):
+ setattr(style_enc_cfg, 'num_filters', 128)
+ if not hasattr(style_enc_cfg, 'kernel_size'):
+ setattr(style_enc_cfg, 'kernel_size', 3)
+ if not hasattr(style_enc_cfg, 'weight_norm_type'):
+ setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type)
+ setattr(style_enc_cfg, 'input_image_channels', image_channels)
+ setattr(style_enc_cfg, 'style_dims', style_dims)
+ self.style_encoder = StyleEncoder(style_enc_cfg)
+
+ self.z = None
+ print('Done with the SPADE generator initialization.')
+
+ def forward(self, data, random_style=False):
+ r"""SPADE Generator forward.
+
+ Args:
+ data (dict):
+ - images (N x C1 x H x W tensor) : Ground truth images
+ - label (N x C2 x H x W tensor) : Semantic representations
+ - z (N x style_dims tensor): Gaussian random noise
+ - random_style (bool): Whether to sample a random style vector.
+ Returns:
+ (dict):
+ - fake_images (N x 3 x H x W tensor): fake images
+ - mu (N x C1 tensor): mean vectors
+ - logvar (N x C1 tensor): log-variance vectors
+ """
+ if self.use_style_encoder:
+ if random_style:
+ bs = data['label'].size(0)
+ z = torch.randn(
+ bs, self.style_dims, dtype=torch.float32).cuda()
+ if (data['label'].dtype ==
+ data['label'].dtype == torch.float16):
+ z = z.half()
+ mu = None
+ logvar = None
+ else:
+ mu, logvar, z = self.style_encoder(data['images'])
+ if self.use_attribute:
+ data['z'] = torch.cat((z, data['attributes'].squeeze(1)), dim=1)
+ else:
+ data['z'] = z
+ output = self.spade_generator(data)
+ if self.use_style_encoder:
+ output['mu'] = mu
+ output['logvar'] = logvar
+ return output
+
+ def inference(self,
+ data,
+ random_style=False,
+ use_fixed_random_style=False,
+ keep_original_size=False):
+ r"""Compute results images for a batch of input data and save the
+ results in the specified folder.
+
+ Args:
+ data (dict):
+ - images (N x C1 x H x W tensor) : Ground truth images
+ - label (N x C2 x H x W tensor) : Semantic representations
+ - z (N x style_dims tensor): Gaussian random noise
+ random_style (bool): Whether to sample a random style vector.
+ use_fixed_random_style (bool): Sample random style once and use it
+ for all the remaining inference.
+ keep_original_size (bool): Keep original size of the input.
+ Returns:
+ (dict):
+ - fake_images (N x 3 x H x W tensor): fake images
+ - mu (N x C1 tensor): mean vectors
+ - logvar (N x C1 tensor): log-variance vectors
+ """
+ self.eval()
+ self.spade_generator.eval()
+
+ if self.use_style_encoder:
+ if random_style and self.use_style_encoder:
+ if self.z is None or not use_fixed_random_style:
+ bs = data['label'].size(0)
+ z = torch.randn(
+ bs, self.style_dims, dtype=torch.float32).to('cuda')
+ if (data['label'].dtype ==
+ data['label'].dtype ==
+ torch.float16):
+ z = z.half()
+ self.z = z
+ else:
+ z = self.z
+ else:
+ mu, logvar, z = self.style_encoder(data['images'])
+ data['z'] = z
+
+ output = self.spade_generator(data)
+ output_images = output['fake_images']
+
+ if keep_original_size:
+ height = data['original_h_w'][0][0]
+ width = data['original_h_w'][0][1]
+ output_images = torch.nn.functional.interpolate(
+ output_images, size=[height, width])
+
+ for key in data['key'].keys():
+ if 'segmaps' in key or 'seg_maps' in key:
+ file_names = data['key'][key][0]
+ break
+ for key in data['key'].keys():
+ if 'edgemaps' in key or 'edge_maps' in key:
+ file_names = data['key'][key][0]
+ break
+
+ return output_images, file_names
+
+
+class SPADEGenerator(nn.Module):
+ r"""SPADE Image Generator constructor.
+
+ Args:
+ num_labels (int): Number of different labels.
+ out_image_small_side_size (int): min(width, height)
+ image_channels (int): Num. of channels of the output image.
+ num_filters (int): Base filter numbers.
+ kernel_size (int): Convolution kernel size.
+ style_dims (int): Dimensions of the style code.
+ activation_norm_params (obj): Spatially adaptive normalization param.
+ weight_norm_type (str): Type of weight normalization.
+ ``'none'``, ``'spectral'``, or ``'weight'``.
+ global_adaptive_norm_type (str): Type of normalization in SPADE.
+ skip_activation_norm (bool): If ``True``, applies activation norm to the
+ shortcut connection in residual blocks.
+ use_style_encoder (bool): Whether to use global adaptive norm
+ like conditional batch norm or adaptive instance norm.
+ output_multiplier (float): A positive number multiplied to the output
+ """
+
+ def __init__(self,
+ num_labels,
+ out_image_small_side_size,
+ image_channels,
+ num_filters,
+ kernel_size,
+ style_dims,
+ activation_norm_params,
+ weight_norm_type,
+ global_adaptive_norm_type,
+ skip_activation_norm,
+ use_posenc_in_input_layer,
+ use_style_encoder,
+ output_multiplier):
+ super(SPADEGenerator, self).__init__()
+ self.output_multiplier = output_multiplier
+ self.use_style_encoder = use_style_encoder
+ self.use_posenc_in_input_layer = use_posenc_in_input_layer
+ self.out_image_small_side_size = out_image_small_side_size
+ self.num_filters = num_filters
+ padding = int(np.ceil((kernel_size - 1.0) / 2))
+ nonlinearity = 'leakyrelu'
+ activation_norm_type = 'spatially_adaptive'
+ base_res2d_block = \
+ functools.partial(Res2dBlock,
+ kernel_size=kernel_size,
+ padding=padding,
+ bias=[True, True, False],
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ skip_activation_norm=skip_activation_norm,
+ nonlinearity=nonlinearity,
+ order='NACNAC')
+ if self.use_style_encoder:
+ self.fc_0 = LinearBlock(style_dims, 2 * style_dims,
+ weight_norm_type=weight_norm_type,
+ nonlinearity='relu',
+ order='CAN')
+ self.fc_1 = LinearBlock(2 * style_dims, 2 * style_dims,
+ weight_norm_type=weight_norm_type,
+ nonlinearity='relu',
+ order='CAN')
+
+ adaptive_norm_params = types.SimpleNamespace()
+ if not hasattr(adaptive_norm_params, 'cond_dims'):
+ setattr(adaptive_norm_params, 'cond_dims', 2 * style_dims)
+ if not hasattr(adaptive_norm_params, 'activation_norm_type'):
+ setattr(adaptive_norm_params, 'activation_norm_type', global_adaptive_norm_type)
+ if not hasattr(adaptive_norm_params, 'weight_norm_type'):
+ setattr(adaptive_norm_params, 'weight_norm_type', activation_norm_params.weight_norm_type)
+ if not hasattr(adaptive_norm_params, 'separate_projection'):
+ setattr(adaptive_norm_params, 'separate_projection', activation_norm_params.separate_projection)
+ adaptive_norm_params.activation_norm_params = types.SimpleNamespace()
+ setattr(adaptive_norm_params.activation_norm_params, 'affine',
+ activation_norm_params.activation_norm_params.affine)
+ base_cbn2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding,
+ bias=True,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type='adaptive',
+ activation_norm_params=adaptive_norm_params,
+ nonlinearity=nonlinearity,
+ order='NAC')
+ else:
+ base_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding,
+ bias=True,
+ weight_norm_type=weight_norm_type,
+ nonlinearity=nonlinearity,
+ order='NAC')
+ in_num_labels = num_labels
+ in_num_labels += 2 if self.use_posenc_in_input_layer else 0
+ self.head_0 = Conv2dBlock(in_num_labels, 8 * num_filters,
+ kernel_size=kernel_size, stride=1,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type='none',
+ nonlinearity=nonlinearity)
+ if self.use_style_encoder:
+ self.cbn_head_0 = base_cbn2d_block(
+ 8 * num_filters, 16 * num_filters)
+ else:
+ self.conv_head_0 = base_conv2d_block(
+ 8 * num_filters, 16 * num_filters)
+ self.head_1 = base_res2d_block(16 * num_filters, 16 * num_filters)
+ self.head_2 = base_res2d_block(16 * num_filters, 16 * num_filters)
+
+ self.up_0a = base_res2d_block(16 * num_filters, 8 * num_filters)
+ if self.use_style_encoder:
+ self.cbn_up_0a = base_cbn2d_block(
+ 8 * num_filters, 8 * num_filters)
+ else:
+ self.conv_up_0a = base_conv2d_block(
+ 8 * num_filters, 8 * num_filters)
+ self.up_0b = base_res2d_block(8 * num_filters, 8 * num_filters)
+
+ self.up_1a = base_res2d_block(8 * num_filters, 4 * num_filters)
+ if self.use_style_encoder:
+ self.cbn_up_1a = base_cbn2d_block(
+ 4 * num_filters, 4 * num_filters)
+ else:
+ self.conv_up_1a = base_conv2d_block(
+ 4 * num_filters, 4 * num_filters)
+ self.up_1b = base_res2d_block(4 * num_filters, 4 * num_filters)
+ self.up_2a = base_res2d_block(4 * num_filters, 4 * num_filters)
+ if self.use_style_encoder:
+ self.cbn_up_2a = base_cbn2d_block(
+ 4 * num_filters, 4 * num_filters)
+ else:
+ self.conv_up_2a = base_conv2d_block(
+ 4 * num_filters, 4 * num_filters)
+ self.up_2b = base_res2d_block(4 * num_filters, 2 * num_filters)
+ self.conv_img256 = Conv2dBlock(2 * num_filters, image_channels,
+ 5, stride=1, padding=2,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type='none',
+ nonlinearity=nonlinearity,
+ order='ANC')
+ self.base = 16
+ if self.out_image_small_side_size == 512:
+ self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters)
+ self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters)
+ self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels,
+ 5, stride=1, padding=2,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type='none',
+ nonlinearity=nonlinearity,
+ order='ANC')
+ self.base = 32
+ if self.out_image_small_side_size == 1024:
+ self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters)
+ self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters)
+ self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels,
+ 5, stride=1, padding=2,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type='none',
+ nonlinearity=nonlinearity,
+ order='ANC')
+ self.up_4a = base_res2d_block(num_filters, num_filters // 2)
+ self.up_4b = base_res2d_block(num_filters // 2, num_filters // 2)
+ self.conv_img1024 = Conv2dBlock(num_filters // 2, image_channels,
+ 5, stride=1, padding=2,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type='none',
+ nonlinearity=nonlinearity,
+ order='ANC')
+ self.nearest_upsample4x = NearestUpsample(scale_factor=4, mode='nearest')
+ self.base = 64
+ if self.out_image_small_side_size != 256 and self.out_image_small_side_size != 512 \
+ and self.out_image_small_side_size != 1024:
+ raise ValueError('Generation image size (%d, %d) not supported' %
+ (self.out_image_small_side_size,
+ self.out_image_small_side_size))
+ self.nearest_upsample2x = NearestUpsample(scale_factor=2, mode='nearest')
+
+ xv, yv = torch.meshgrid(
+ [torch.arange(-1, 1.1, 2. / 15), torch.arange(-1, 1.1, 2. / 15)])
+ self.xy = torch.cat((xv.unsqueeze(0), yv.unsqueeze(0)), 0).unsqueeze(0)
+ self.xy = self.xy.cuda()
+
+ def forward(self, data):
+ r"""SPADE Generator forward.
+
+ Args:
+ data (dict):
+ - data (N x C1 x H x W tensor) : Ground truth images.
+ - label (N x C2 x H x W tensor) : Semantic representations.
+ - z (N x style_dims tensor): Gaussian random noise.
+ Returns:
+ output (dict):
+ - fake_images (N x 3 x H x W tensor): Fake images.
+ """
+ seg = data['label']
+
+ if self.use_style_encoder:
+ z = data['z']
+ z = self.fc_0(z)
+ z = self.fc_1(z)
+
+ # The code piece below makes sure that the input size is always 16x16
+ sy = math.floor(seg.size()[2] * 1.0 / self.base)
+ sx = math.floor(seg.size()[3] * 1.0 / self.base)
+
+ in_seg = F.interpolate(seg, size=[sy, sx], mode='nearest')
+ if self.use_posenc_in_input_layer:
+ in_xy = F.interpolate(self.xy, size=[sy, sx], mode='bicubic')
+ in_seg_xy = torch.cat(
+ (in_seg, in_xy.expand(in_seg.size()[0], 2, sy, sx)), 1)
+ else:
+ in_seg_xy = in_seg
+ # 16x16
+ x = self.head_0(in_seg_xy)
+ if self.use_style_encoder:
+ x = self.cbn_head_0(x, z)
+ else:
+ x = self.conv_head_0(x)
+ x = self.head_1(x, seg)
+ x = self.head_2(x, seg)
+ x = self.nearest_upsample2x(x)
+ # 32x32
+ x = self.up_0a(x, seg)
+ if self.use_style_encoder:
+ x = self.cbn_up_0a(x, z)
+ else:
+ x = self.conv_up_0a(x)
+ x = self.up_0b(x, seg)
+ x = self.nearest_upsample2x(x)
+ # 64x64
+ x = self.up_1a(x, seg)
+ if self.use_style_encoder:
+ x = self.cbn_up_1a(x, z)
+ else:
+ x = self.conv_up_1a(x)
+ x = self.up_1b(x, seg)
+ x = self.nearest_upsample2x(x)
+ # 128x128
+ x = self.up_2a(x, seg)
+ if self.use_style_encoder:
+ x = self.cbn_up_2a(x, z)
+ else:
+ x = self.conv_up_2a(x)
+ x = self.up_2b(x, seg)
+ x = self.nearest_upsample2x(x)
+ # 256x256
+ if self.out_image_small_side_size == 256:
+ x256 = self.conv_img256(x)
+ x = torch.tanh(self.output_multiplier * x256)
+ # 512x512
+ elif self.out_image_small_side_size == 512:
+ x256 = self.conv_img256(x)
+ x256 = self.nearest_upsample2x(x256)
+ x = self.up_3a(x, seg)
+ x = self.up_3b(x, seg)
+ x = self.nearest_upsample2x(x)
+ x512 = self.conv_img512(x)
+ x = torch.tanh(self.output_multiplier * (x256 + x512))
+ # 1024x1024
+ elif self.out_image_small_side_size == 1024:
+ x256 = self.conv_img256(x)
+ x256 = self.nearest_upsample4x(x256)
+ x = self.up_3a(x, seg)
+ x = self.up_3b(x, seg)
+ x = self.nearest_upsample2x(x)
+ x512 = self.conv_img512(x)
+ x512 = self.nearest_upsample2x(x512)
+ x = self.up_4a(x, seg)
+ x = self.up_4b(x, seg)
+ x = self.nearest_upsample2x(x)
+ x1024 = self.conv_img1024(x)
+ x = torch.tanh(self.output_multiplier * (x256 + x512 + x1024))
+ output = dict()
+ output['fake_images'] = x
+ return output
+
+
+class StyleEncoder(nn.Module):
+ r"""Style Encode constructor.
+
+ Args:
+ style_enc_cfg (obj): Style encoder definition file.
+ """
+
+ def __init__(self, style_enc_cfg):
+ super(StyleEncoder, self).__init__()
+ input_image_channels = style_enc_cfg.input_image_channels
+ num_filters = style_enc_cfg.num_filters
+ kernel_size = style_enc_cfg.kernel_size
+ padding = int(np.ceil((kernel_size - 1.0) / 2))
+ style_dims = style_enc_cfg.style_dims
+ weight_norm_type = style_enc_cfg.weight_norm_type
+ activation_norm_type = 'none'
+ nonlinearity = 'leakyrelu'
+ base_conv2d_block = \
+ functools.partial(Conv2dBlock,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ # inplace_nonlinearity=True,
+ nonlinearity=nonlinearity)
+ self.layer1 = base_conv2d_block(input_image_channels, num_filters)
+ self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2)
+ self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4)
+ self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8)
+ self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8)
+ self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8)
+ self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
+ self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims)
+
+ def forward(self, input_x):
+ r"""SPADE Style Encoder forward.
+
+ Args:
+ input_x (N x 3 x H x W tensor): input images.
+ Returns:
+ (tuple):
+ - mu (N x C tensor): Mean vectors.
+ - logvar (N x C tensor): Log-variance vectors.
+ - z (N x C tensor): Style code vectors.
+ """
+ if input_x.size(2) != 256 or input_x.size(3) != 256:
+ input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear')
+ x = self.layer1(input_x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.layer5(x)
+ x = self.layer6(x)
+ x = x.view(x.size(0), -1)
+ mu = self.fc_mu(x)
+ logvar = self.fc_var(x)
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ z = eps.mul(std) + mu
+ return mu, logvar, z
diff --git a/imaginaire/generators/unit.py b/imaginaire/generators/unit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c09f1b050d59de4940d49c7336ddd2f5928c55c6
--- /dev/null
+++ b/imaginaire/generators/unit.py
@@ -0,0 +1,312 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import warnings
+
+from torch import nn
+from torch.nn import Upsample as NearestUpsample
+
+from imaginaire.layers import Conv2dBlock, Res2dBlock
+
+
+class Generator(nn.Module):
+ r"""Improved UNIT generator.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super().__init__()
+ self.autoencoder_a = AutoEncoder(**vars(gen_cfg))
+ self.autoencoder_b = AutoEncoder(**vars(gen_cfg))
+
+ def forward(self, data, image_recon=True, cycle_recon=True):
+ r"""UNIT forward function"""
+ images_a = data['images_a']
+ images_b = data['images_b']
+ net_G_output = dict()
+
+ # encode input images into latent code
+ content_a = self.autoencoder_a.content_encoder(images_a)
+ content_b = self.autoencoder_b.content_encoder(images_b)
+
+ # decode (within domain)
+ if image_recon:
+ images_aa = self.autoencoder_a.decoder(content_a)
+ images_bb = self.autoencoder_b.decoder(content_b)
+ net_G_output.update(dict(images_aa=images_aa, images_bb=images_bb))
+
+ # decode (cross domain)
+ images_ba = self.autoencoder_a.decoder(content_b)
+ images_ab = self.autoencoder_b.decoder(content_a)
+
+ # cycle reconstruction
+ if cycle_recon:
+ content_ba = self.autoencoder_a.content_encoder(images_ba)
+ content_ab = self.autoencoder_b.content_encoder(images_ab)
+ images_aba = self.autoencoder_a.decoder(content_ab)
+ images_bab = self.autoencoder_b.decoder(content_ba)
+ net_G_output.update(
+ dict(content_ba=content_ba, content_ab=content_ab,
+ images_aba=images_aba, images_bab=images_bab))
+
+ # required outputs
+ net_G_output.update(dict(content_a=content_a, content_b=content_b,
+ images_ba=images_ba, images_ab=images_ab))
+
+ return net_G_output
+
+ def inference(self, data, a2b=True):
+ r"""UNIT inference.
+
+ Args:
+ data (dict): Training data at the current iteration.
+ - images_a (tensor): Images from domain A.
+ - images_b (tensor): Images from domain B.
+ a2b (bool): If ``True``, translates images from domain A to B,
+ otherwise from B to A.
+ """
+ if a2b:
+ input_key = 'images_a'
+ content_encode = self.autoencoder_a.content_encoder
+ decode = self.autoencoder_b.decoder
+ else:
+ input_key = 'images_b'
+ content_encode = self.autoencoder_b.content_encoder
+ decode = self.autoencoder_a.decoder
+
+ content_images = data[input_key]
+ content = content_encode(content_images)
+ output_images = decode(content)
+ filename = '%s/%s' % (
+ data['key'][input_key]['sequence_name'][0],
+ data['key'][input_key]['filename'][0])
+ filenames = [filename]
+ return output_images, filenames
+
+
+class AutoEncoder(nn.Module):
+ r"""Improved UNIT autoencoder.
+
+ Args:
+ num_filters (int): Base filter numbers.
+ max_num_filters (int): Maximum number of filters in the encoder.
+ num_res_blocks (int): Number of residual blocks at the end of the
+ content encoder.
+ num_downsamples_content (int): Number of times we reduce
+ resolution by 2x2 for the content image.
+ num_image_channels (int): Number of input image channels.
+ content_norm_type (str): Type of activation normalization in the
+ content encoder.
+ decoder_norm_type (str): Type of activation normalization in the
+ decoder.
+ weight_norm_type (str): Type of weight normalization.
+ output_nonlinearity (str): Type of nonlinearity before final output,
+ ``'tanh'`` or ``'none'``.
+ pre_act (bool): If ``True``, uses pre-activation residual blocks.
+ apply_noise (bool): If ``True``, injects Gaussian noise in the decoder.
+ """
+
+ def __init__(self,
+ num_filters=64,
+ max_num_filters=256,
+ num_res_blocks=4,
+ num_downsamples_content=2,
+ num_image_channels=3,
+ content_norm_type='instance',
+ decoder_norm_type='instance',
+ weight_norm_type='',
+ output_nonlinearity='',
+ pre_act=False,
+ apply_noise=False,
+ **kwargs):
+ super().__init__()
+ for key in kwargs:
+ if key != 'type':
+ warnings.warn(
+ "Generator argument '{}' is not used.".format(key))
+ self.content_encoder = ContentEncoder(num_downsamples_content,
+ num_res_blocks,
+ num_image_channels,
+ num_filters,
+ max_num_filters,
+ 'reflect',
+ content_norm_type,
+ weight_norm_type,
+ 'relu',
+ pre_act)
+ self.decoder = Decoder(num_downsamples_content,
+ num_res_blocks,
+ self.content_encoder.output_dim,
+ num_image_channels,
+ 'reflect',
+ decoder_norm_type,
+ weight_norm_type,
+ 'relu',
+ output_nonlinearity,
+ pre_act,
+ apply_noise)
+
+ def forward(self, images):
+ r"""Reconstruct an image.
+
+ Args:
+ images (Tensor): Input images.
+ Returns:
+ images_recon (Tensor): Reconstructed images.
+ """
+ content = self.content_encoder(images)
+ images_recon = self.decoder(content)
+ return images_recon
+
+
+class ContentEncoder(nn.Module):
+ r"""Improved UNIT encoder. The network consists of:
+
+ - input layers
+ - $(num_downsamples) convolutional blocks
+ - $(num_res_blocks) residual blocks.
+ - output layer.
+
+ Args:
+ num_downsamples (int): Number of times we reduce
+ resolution by 2x2.
+ num_res_blocks (int): Number of residual blocks at the end of the
+ content encoder.
+ num_image_channels (int): Number of input image channels.
+ num_filters (int): Base filter numbers.
+ max_num_filters (int): Maximum number of filters in the encoder.
+ padding_mode (string): Type of padding.
+ activation_norm_type (str): Type of activation normalization.
+ weight_norm_type (str): Type of weight normalization.
+ nonlinearity (str): Type of nonlinear activation function.
+ pre_act (bool): If ``True``, uses pre-activation residual blocks.
+ """
+
+ def __init__(self,
+ num_downsamples,
+ num_res_blocks,
+ num_image_channels,
+ num_filters,
+ max_num_filters,
+ padding_mode,
+ activation_norm_type,
+ weight_norm_type,
+ nonlinearity,
+ pre_act=False):
+ super().__init__()
+ conv_params = dict(padding_mode=padding_mode,
+ activation_norm_type=activation_norm_type,
+ weight_norm_type=weight_norm_type,
+ nonlinearity=nonlinearity)
+ # Whether or not it is safe to use inplace nonlinear activation.
+ if not pre_act or (activation_norm_type != '' and
+ activation_norm_type != 'none'):
+ conv_params['inplace_nonlinearity'] = True
+
+ # The order of operations in residual blocks.
+ order = 'pre_act' if pre_act else 'CNACNA'
+
+ model = []
+ model += [Conv2dBlock(num_image_channels, num_filters, 7, 1, 3,
+ **conv_params)]
+
+ # Downsampling blocks.
+ for i in range(num_downsamples):
+ num_filters_prev = num_filters
+ num_filters = min(num_filters * 2, max_num_filters)
+ model += [Conv2dBlock(num_filters_prev, num_filters, 4, 2, 1,
+ **conv_params)]
+
+ # Residual blocks.
+ for _ in range(num_res_blocks):
+ model += [Res2dBlock(num_filters, num_filters,
+ **conv_params,
+ order=order)]
+ self.model = nn.Sequential(*model)
+ self.output_dim = num_filters
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input image.
+ """
+ return self.model(x)
+
+
+class Decoder(nn.Module):
+ r"""Improved UNIT decoder. The network consists of:
+
+ - $(num_res_blocks) residual blocks.
+ - $(num_upsamples) residual blocks or convolutional blocks
+ - output layer.
+
+ Args:
+ num_upsamples (int): Number of times we increase resolution by 2x2.
+ num_res_blocks (int): Number of residual blocks.
+ num_filters (int): Base filter numbers.
+ num_image_channels (int): Number of input image channels.
+ padding_mode (string): Type of padding.
+ activation_norm_type (str): Type of activation normalization.
+ weight_norm_type (str): Type of weight normalization.
+ nonlinearity (str): Type of nonlinear activation function.
+ output_nonlinearity (str): Type of nonlinearity before final output,
+ ``'tanh'`` or ``'none'``.
+ pre_act (bool): If ``True``, uses pre-activation residual blocks.
+ apply_noise (bool): If ``True``, injects Gaussian noise.
+ """
+
+ def __init__(self,
+ num_upsamples,
+ num_res_blocks,
+ num_filters,
+ num_image_channels,
+ padding_mode,
+ activation_norm_type,
+ weight_norm_type,
+ nonlinearity,
+ output_nonlinearity,
+ pre_act=False,
+ apply_noise=False):
+ super().__init__()
+
+ conv_params = dict(padding_mode=padding_mode,
+ nonlinearity=nonlinearity,
+ inplace_nonlinearity=True,
+ apply_noise=apply_noise,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type)
+
+ # The order of operations in residual blocks.
+ order = 'pre_act' if pre_act else 'CNACNA'
+
+ # Residual blocks.
+ self.decoder = nn.ModuleList()
+ for _ in range(num_res_blocks):
+ self.decoder += [Res2dBlock(num_filters, num_filters,
+ **conv_params,
+ order=order)]
+
+ # Convolutional blocks with upsampling.
+ for i in range(num_upsamples):
+ self.decoder += [NearestUpsample(scale_factor=2)]
+ self.decoder += [Conv2dBlock(num_filters, num_filters // 2,
+ 5, 1, 2, **conv_params)]
+ num_filters //= 2
+ self.decoder += [Conv2dBlock(num_filters, num_image_channels, 7, 1, 3,
+ nonlinearity=output_nonlinearity,
+ padding_mode=padding_mode)]
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Content embedding of the content image.
+ """
+ for block in self.decoder:
+ x = block(x)
+ return x
diff --git a/imaginaire/generators/vid2vid.py b/imaginaire/generators/vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..78262debc9d2ea6106a0816b53c4c73c5a5a6053
--- /dev/null
+++ b/imaginaire/generators/vid2vid.py
@@ -0,0 +1,481 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.generators.fs_vid2vid import LabelEmbedder
+from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock
+from imaginaire.model_utils.fs_vid2vid import (extract_valid_pose_labels,
+ resample)
+from imaginaire.utils.data import (get_paired_input_image_channel_number,
+ get_paired_input_label_channel_number)
+from imaginaire.utils.init_weight import weights_init
+
+
+class BaseNetwork(nn.Module):
+ r"""vid2vid generator."""
+
+ def __init__(self):
+ super(BaseNetwork, self).__init__()
+
+ def get_num_filters(self, num_downsamples):
+ r"""Get the number of filters at current layer.
+
+ Args:
+ num_downsamples (int) : How many downsamples at current layer.
+ Returns:
+ output (int) : Number of filters.
+ """
+ return min(self.max_num_filters,
+ self.num_filters * (2 ** num_downsamples))
+
+
+class Generator(BaseNetwork):
+ r"""vid2vid generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ super().__init__()
+ self.gen_cfg = gen_cfg
+ self.data_cfg = data_cfg
+ self.num_frames_G = data_cfg.num_frames_G
+ # Number of residual blocks in generator.
+ self.num_layers = num_layers = getattr(gen_cfg, 'num_layers', 7)
+ # Number of downsamplings for previous frame.
+ self.num_downsamples_img = getattr(gen_cfg, 'num_downsamples_img', 4)
+ # Number of filters in the first layer.
+ self.num_filters = num_filters = getattr(gen_cfg, 'num_filters', 32)
+ self.max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024)
+ self.kernel_size = kernel_size = getattr(gen_cfg, 'kernel_size', 3)
+ padding = kernel_size // 2
+
+ # For pose dataset.
+ self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset')
+ if self.is_pose_data:
+ pose_cfg = data_cfg.for_pose_dataset
+ self.pose_type = getattr(pose_cfg, 'pose_type', 'both')
+ self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels',
+ False)
+
+ # Input data params.
+ num_input_channels = get_paired_input_label_channel_number(data_cfg)
+ num_img_channels = get_paired_input_image_channel_number(data_cfg)
+ aug_cfg = data_cfg.val.augmentations
+ if hasattr(aug_cfg, 'center_crop_h_w'):
+ crop_h_w = aug_cfg.center_crop_h_w
+ elif hasattr(aug_cfg, 'resize_h_w'):
+ crop_h_w = aug_cfg.resize_h_w
+ else:
+ raise ValueError('Need to specify output size.')
+ crop_h, crop_w = crop_h_w.split(',')
+ crop_h, crop_w = int(crop_h), int(crop_w)
+ # Spatial size at the bottle neck of generator.
+ self.sh = crop_h // (2 ** num_layers)
+ self.sw = crop_w // (2 ** num_layers)
+
+ # Noise vector dimension.
+ self.z_dim = getattr(gen_cfg, 'style_dims', 256)
+ self.use_segmap_as_input = \
+ getattr(gen_cfg, 'use_segmap_as_input', False)
+
+ # Label / image embedding network.
+ self.emb_cfg = emb_cfg = getattr(gen_cfg, 'embed', None)
+ self.use_embed = getattr(emb_cfg, 'use_embed', 'True')
+ self.num_downsamples_embed = getattr(emb_cfg, 'num_downsamples', 5)
+ if self.use_embed:
+ self.label_embedding = LabelEmbedder(emb_cfg, num_input_channels)
+
+ # Flow network.
+ self.flow_cfg = flow_cfg = gen_cfg.flow
+ # Use SPADE to combine warped and hallucinated frames instead of
+ # linear combination.
+ self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
+ # Number of layers to perform multi-spade combine.
+ self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine,
+ 'num_layers', 3)
+ # At beginning of training, only train an image generator.
+ self.temporal_initialized = False
+ # Whether to output hallucinated frame (when training temporal network)
+ # for additional loss.
+ self.generate_raw_output = False
+
+ # Image generation network.
+ weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral')
+ activation_norm_type = gen_cfg.activation_norm_type
+ activation_norm_params = gen_cfg.activation_norm_params
+ if self.use_embed and \
+ not hasattr(activation_norm_params, 'num_filters'):
+ activation_norm_params.num_filters = 0
+ nonlinearity = 'leakyrelu'
+
+ self.base_res_block = base_res_block = partial(
+ Res2dBlock, kernel_size=kernel_size, padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ nonlinearity=nonlinearity, order='NACNAC')
+
+ # Upsampling residual blocks.
+ for i in range(num_layers, -1, -1):
+ activation_norm_params.cond_dims = self.get_cond_dims(i)
+ activation_norm_params.partial = self.get_partial(
+ i) if hasattr(self, 'get_partial') else False
+ layer = base_res_block(self.get_num_filters(i + 1),
+ self.get_num_filters(i))
+ setattr(self, 'up_%d' % i, layer)
+
+ # Final conv layer.
+ self.conv_img = Conv2dBlock(num_filters, num_img_channels,
+ kernel_size, padding=padding,
+ nonlinearity=nonlinearity, order='AC')
+
+ num_filters = min(self.max_num_filters,
+ num_filters * (2 ** (self.num_layers + 1)))
+ if self.use_segmap_as_input:
+ self.fc = Conv2dBlock(num_input_channels, num_filters,
+ kernel_size=3, padding=1)
+ else:
+ self.fc = LinearBlock(self.z_dim, num_filters * self.sh * self.sw)
+
+ # Misc.
+ self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+ self.upsample = partial(F.interpolate, scale_factor=2)
+ self.init_temporal_network()
+
+ def forward(self, data):
+ r"""vid2vid generator forward.
+
+ Args:
+ data (dict) : Dictionary of input data.
+ Returns:
+ output (dict) : Dictionary of output data.
+ """
+ label = data['label']
+ label_prev, img_prev = data['prev_labels'], data['prev_images']
+ is_first_frame = img_prev is None
+ z = getattr(data, 'z', None)
+ bs, _, h, w = label.size()
+
+ if self.is_pose_data:
+ label, label_prev = extract_valid_pose_labels(
+ [label, label_prev], self.pose_type, self.remove_face_labels)
+
+ # Get SPADE conditional maps by embedding current label input.
+ cond_maps_now = self.get_cond_maps(label, self.label_embedding)
+
+ # Input to the generator will either be noise/segmentation map (for
+ # first frame) or encoded previous frame (for subsequent frames).
+ if is_first_frame:
+ # First frame in the sequence, start from scratch.
+ if self.use_segmap_as_input:
+ x_img = F.interpolate(label, size=(self.sh, self.sw))
+ x_img = self.fc(x_img)
+ else:
+ if z is None:
+ z = torch.randn(bs, self.z_dim, dtype=label.dtype,
+ device=label.get_device()).fill_(0)
+ x_img = self.fc(z).view(bs, -1, self.sh, self.sw)
+
+ # Upsampling layers.
+ for i in range(self.num_layers, self.num_downsamples_img, -1):
+ j = min(self.num_downsamples_embed, i)
+ x_img = getattr(self, 'up_' + str(i))(x_img, *cond_maps_now[j])
+ x_img = self.upsample(x_img)
+ else:
+ # Not the first frame, will encode the previous frame and feed to
+ # the generator.
+ x_img = self.down_first(img_prev[:, -1])
+
+ # Get label embedding for the previous frame.
+ cond_maps_prev = self.get_cond_maps(label_prev[:, -1],
+ self.label_embedding)
+
+ # Downsampling layers.
+ for i in range(self.num_downsamples_img + 1):
+ j = min(self.num_downsamples_embed, i)
+ x_img = getattr(self, 'down_' + str(i))(x_img,
+ *cond_maps_prev[j])
+ if i != self.num_downsamples_img:
+ x_img = self.downsample(x_img)
+
+ # Resnet blocks.
+ j = min(self.num_downsamples_embed, self.num_downsamples_img + 1)
+ for i in range(self.num_res_blocks):
+ cond_maps = cond_maps_prev[j] if i < self.num_res_blocks // 2 \
+ else cond_maps_now[j]
+ x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps)
+
+ flow = mask = img_warp = None
+
+ num_frames_G = self.num_frames_G
+ # Whether to warp the previous frame or not.
+ warp_prev = self.temporal_initialized and not is_first_frame and \
+ label_prev.shape[1] == num_frames_G - 1
+ if warp_prev:
+ # Estimate flow & mask.
+ label_concat = torch.cat([label_prev.view(bs, -1, h, w),
+ label], dim=1)
+ img_prev_concat = img_prev.view(bs, -1, h, w)
+ flow, mask = self.flow_network_temp(label_concat, img_prev_concat)
+ img_warp = resample(img_prev[:, -1], flow)
+ if self.spade_combine:
+ # if using SPADE combine, integrate the warped image (and
+ # occlusion mask) into conditional inputs for SPADE.
+ img_embed = torch.cat([img_warp, mask], dim=1)
+ cond_maps_img = self.get_cond_maps(img_embed,
+ self.img_prev_embedding)
+ x_raw_img = None
+
+ # Main image generation branch.
+ for i in range(self.num_downsamples_img, -1, -1):
+ # Get SPADE conditional inputs.
+ j = min(i, self.num_downsamples_embed)
+ cond_maps = cond_maps_now[j]
+
+ # For raw output generation.
+ if self.generate_raw_output:
+ if i >= self.num_multi_spade_layers - 1:
+ x_raw_img = x_img
+ if i < self.num_multi_spade_layers:
+ x_raw_img = self.one_up_conv_layer(x_raw_img, cond_maps, i)
+
+ # For final output.
+ if warp_prev and i < self.num_multi_spade_layers:
+ cond_maps += cond_maps_img[j]
+ x_img = self.one_up_conv_layer(x_img, cond_maps, i)
+
+ # Final conv layer.
+ img_final = torch.tanh(self.conv_img(x_img))
+
+ img_raw = None
+ if self.spade_combine and self.generate_raw_output:
+ img_raw = torch.tanh(self.conv_img(x_raw_img))
+ if warp_prev and not self.spade_combine:
+ img_raw = img_final
+ img_final = img_final * mask + img_warp * (1 - mask)
+
+ output = dict()
+ output['fake_images'] = img_final
+ output['fake_flow_maps'] = flow
+ output['fake_occlusion_masks'] = mask
+ output['fake_raw_images'] = img_raw
+ output['warped_images'] = img_warp
+ return output
+
+ def one_up_conv_layer(self, x, encoded_label, i):
+ r"""One residual block layer in the main branch.
+
+ Args:
+ x (4D tensor) : Current feature map.
+ encoded_label (list of tensors) : Encoded input label maps.
+ i (int) : Layer index.
+ Returns:
+ x (4D tensor) : Output feature map.
+ """
+ layer = getattr(self, 'up_' + str(i))
+ x = layer(x, *encoded_label)
+ if i != 0:
+ x = self.upsample(x)
+ return x
+
+ def init_temporal_network(self, cfg_init=None):
+ r"""When starting training multiple frames, initialize the
+ downsampling network and flow network.
+
+ Args:
+ cfg_init (dict) : Weight initialization config.
+ """
+ # Number of image downsamplings for the previous frame.
+ num_downsamples_img = self.num_downsamples_img
+ # Number of residual blocks for the previous frame.
+ self.num_res_blocks = int(
+ np.ceil((self.num_layers - num_downsamples_img) / 2.0) * 2)
+
+ # First conv layer.
+ num_img_channels = get_paired_input_image_channel_number(self.data_cfg)
+ self.down_first = \
+ Conv2dBlock(num_img_channels,
+ self.num_filters, self.kernel_size,
+ padding=self.kernel_size // 2)
+ if cfg_init is not None:
+ self.down_first.apply(weights_init(cfg_init.type, cfg_init.gain))
+
+ # Downsampling residual blocks.
+ activation_norm_params = self.gen_cfg.activation_norm_params
+ for i in range(num_downsamples_img + 1):
+ activation_norm_params.cond_dims = self.get_cond_dims(i)
+ layer = self.base_res_block(self.get_num_filters(i),
+ self.get_num_filters(i + 1))
+ if cfg_init is not None:
+ layer.apply(weights_init(cfg_init.type, cfg_init.gain))
+ setattr(self, 'down_%d' % i, layer)
+
+ # Additional residual blocks.
+ res_ch = self.get_num_filters(num_downsamples_img + 1)
+ activation_norm_params.cond_dims = \
+ self.get_cond_dims(num_downsamples_img + 1)
+ for i in range(self.num_res_blocks):
+ layer = self.base_res_block(res_ch, res_ch)
+ if cfg_init is not None:
+ layer.apply(weights_init(cfg_init.type, cfg_init.gain))
+ setattr(self, 'res_%d' % i, layer)
+
+ # Flow network.
+ flow_cfg = self.flow_cfg
+ self.temporal_initialized = True
+ self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output',
+ False) and self.spade_combine
+ self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg)
+ if cfg_init is not None:
+ self.flow_network_temp.apply(weights_init(cfg_init.type,
+ cfg_init.gain))
+
+ self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True)
+ if self.spade_combine:
+ emb_cfg = flow_cfg.multi_spade_combine.embed
+ num_img_channels = get_paired_input_image_channel_number(
+ self.data_cfg)
+ self.img_prev_embedding = LabelEmbedder(emb_cfg,
+ num_img_channels + 1)
+ if cfg_init is not None:
+ self.img_prev_embedding.apply(weights_init(cfg_init.type,
+ cfg_init.gain))
+
+ def get_cond_dims(self, num_downs=0):
+ r"""Get the dimensions of conditional inputs.
+
+ Args:
+ num_downs (int) : How many downsamples at current layer.
+ Returns:
+ ch (list) : List of dimensions.
+ """
+ if not self.use_embed:
+ ch = [self.num_input_channels]
+ else:
+ num_filters = getattr(self.emb_cfg, 'num_filters', 32)
+ num_downs = min(num_downs, self.num_downsamples_embed)
+ ch = [min(self.max_num_filters, num_filters * (2 ** num_downs))]
+ if (num_downs < self.num_multi_spade_layers):
+ ch = ch * 2
+ return ch
+
+ def get_cond_maps(self, label, embedder):
+ r"""Get the conditional inputs.
+
+ Args:
+ label (4D tensor) : Input label tensor.
+ embedder (obj) : Embedding network.
+ Returns:
+ cond_maps (list) : List of conditional inputs.
+ """
+ if not self.use_embed:
+ return [label] * (self.num_layers + 1)
+ embedded_label = embedder(label)
+ cond_maps = [embedded_label]
+ cond_maps = [[m[i] for m in cond_maps] for i in
+ range(len(cond_maps[0]))]
+ return cond_maps
+
+
+class FlowGenerator(BaseNetwork):
+ r"""Flow generator constructor.
+
+ Args:
+ flow_cfg (obj): Flow definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file.
+ """
+
+ def __init__(self, flow_cfg, data_cfg):
+ super().__init__()
+ num_input_channels = get_paired_input_label_channel_number(data_cfg)
+ num_prev_img_channels = get_paired_input_image_channel_number(data_cfg)
+ num_frames = data_cfg.num_frames_G # Num. of input frames.
+
+ self.num_filters = num_filters = getattr(flow_cfg, 'num_filters', 32)
+ self.max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024)
+ num_downsamples = getattr(flow_cfg, 'num_downsamples', 5)
+ kernel_size = getattr(flow_cfg, 'kernel_size', 3)
+ padding = kernel_size // 2
+ self.num_res_blocks = getattr(flow_cfg, 'num_res_blocks', 6)
+ # Multiplier on the flow output.
+ self.flow_output_multiplier = getattr(flow_cfg,
+ 'flow_output_multiplier', 20)
+
+ activation_norm_type = getattr(flow_cfg, 'activation_norm_type',
+ 'sync_batch')
+ weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral')
+
+ base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ nonlinearity='leakyrelu')
+
+ # Will downsample the labels and prev frames separately, then combine.
+ down_lbl = [base_conv_block(num_input_channels * num_frames,
+ num_filters)]
+ down_img = [base_conv_block(num_prev_img_channels * (num_frames - 1),
+ num_filters)]
+ for i in range(num_downsamples):
+ down_lbl += [base_conv_block(self.get_num_filters(i),
+ self.get_num_filters(i + 1),
+ stride=2)]
+ down_img += [base_conv_block(self.get_num_filters(i),
+ self.get_num_filters(i + 1),
+ stride=2)]
+
+ # Resnet blocks.
+ res_flow = []
+ ch = self.get_num_filters(num_downsamples)
+ for i in range(self.num_res_blocks):
+ res_flow += [
+ Res2dBlock(ch, ch, kernel_size, padding=padding,
+ weight_norm_type=weight_norm_type,
+ activation_norm_type=activation_norm_type,
+ order='CNACN')]
+
+ # Upsample.
+ up_flow = []
+ for i in reversed(range(num_downsamples)):
+ up_flow += [nn.Upsample(scale_factor=2),
+ base_conv_block(self.get_num_filters(i + 1),
+ self.get_num_filters(i))]
+
+ conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)]
+ conv_mask = [Conv2dBlock(num_filters, 1, kernel_size, padding=padding,
+ nonlinearity='sigmoid')]
+
+ self.down_lbl = nn.Sequential(*down_lbl)
+ self.down_img = nn.Sequential(*down_img)
+ self.res_flow = nn.Sequential(*res_flow)
+ self.up_flow = nn.Sequential(*up_flow)
+ self.conv_flow = nn.Sequential(*conv_flow)
+ self.conv_mask = nn.Sequential(*conv_mask)
+
+ def forward(self, label, img_prev):
+ r"""Flow generator forward.
+
+ Args:
+ label (4D tensor) : Input label tensor.
+ img_prev (4D tensor) : Previously generated image tensors.
+ Returns:
+ (tuple):
+ - flow (4D tensor) : Generated flow map.
+ - mask (4D tensor) : Generated occlusion mask.
+ """
+ downsample = self.down_lbl(label) + self.down_img(img_prev)
+ res = self.res_flow(downsample)
+ flow_feat = self.up_flow(res)
+ flow = self.conv_flow(flow_feat) * self.flow_output_multiplier
+ mask = self.conv_mask(flow_feat)
+ return flow, mask
diff --git a/imaginaire/generators/wc_vid2vid.py b/imaginaire/generators/wc_vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d8139a3c05140a567ef202eccd1d9209ba681b4
--- /dev/null
+++ b/imaginaire/generators/wc_vid2vid.py
@@ -0,0 +1,354 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torchvision import transforms
+
+from imaginaire.config import Config
+from imaginaire.generators.vid2vid import Generator as Vid2VidGenerator
+from imaginaire.model_utils.fs_vid2vid import resample
+from imaginaire.model_utils.wc_vid2vid.render import SplatRenderer
+from imaginaire.utils.trainer import (get_model_optimizer_and_scheduler,
+ get_trainer)
+from imaginaire.utils.visualization import tensor2im
+
+
+class Generator(Vid2VidGenerator):
+ r"""world consistent vid2vid generator constructor.
+
+ Args:
+ gen_cfg (obj): Generator definition part of the yaml config file.
+ data_cfg (obj): Data definition part of the yaml config file
+ """
+
+ def __init__(self, gen_cfg, data_cfg):
+ # Guidance options.
+ self.guidance_cfg = gen_cfg.guidance
+ self.guidance_only_with_flow = getattr(
+ self.guidance_cfg, 'only_with_flow', False)
+ self.guidance_partial_conv = getattr(
+ self.guidance_cfg, 'partial_conv', False)
+
+ # Splatter for guidance.
+ self.renderer = SplatRenderer()
+ self.reset_renderer()
+
+ # Single image model.
+ self.single_image_model = None
+
+ # Initialize the rest same as vid2vid.
+ super().__init__(gen_cfg, data_cfg)
+
+ def _init_single_image_model(self, load_weights=True):
+ r"""Load single image model, if any."""
+ if self.single_image_model is None and \
+ hasattr(self.gen_cfg, 'single_image_model'):
+ print('Using single image model...')
+ single_image_cfg = Config(self.gen_cfg.single_image_model.config)
+
+ # Init model.
+ net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
+ get_model_optimizer_and_scheduler(single_image_cfg)
+
+ # Init trainer and load checkpoint.
+ trainer = get_trainer(single_image_cfg, net_G, net_D,
+ opt_G, opt_D,
+ sch_G, sch_D,
+ None, None)
+ if load_weights:
+ print('Loading single image model checkpoint')
+ single_image_ckpt = self.gen_cfg.single_image_model.checkpoint
+ trainer.load_checkpoint(single_image_cfg, single_image_ckpt)
+ print('Loaded single image model checkpoint')
+
+ self.single_image_model = net_G.module
+ self.single_image_model_z = None
+
+ def reset_renderer(self, is_flipped_input=False):
+ r"""Reset the renderer.
+ Args:
+ is_flipped_input (bool): Is the input sequence left-right flipped?
+ """
+ self.renderer.reset()
+ self.is_flipped_input = is_flipped_input
+ self.renderer_num_forwards = 0
+ self.single_image_model_z = None
+
+ def renderer_update_point_cloud(self, image, point_info):
+ r"""Update the renderer's color dictionary."""
+ if point_info is None or len(point_info) == 0:
+ return
+ # print('Updating the renderer.')
+ _, _, h, w = image.size()
+
+ # Renderer expects (h, w, c) [0-255] RGB image.
+ if isinstance(image, torch.Tensor):
+ image = tensor2im(image.detach())[0]
+
+ # Flip this image to correspond to SfM camera pose.
+ if self.is_flipped_input:
+ image = np.fliplr(image).copy()
+
+ self.renderer.update_point_cloud(image, point_info)
+ self.renderer_num_forwards += 1
+
+ def get_guidance_images_and_masks(self, unprojection):
+ r"""Do stuff."""
+
+ resolution = 'w1024xh512'
+ point_info = unprojection[resolution]
+
+ w, h = resolution.split('x')
+ w, h = int(w[1:]), int(h[1:])
+
+ # This returns guidance image in [0-255] RGB.
+ # We will convert it into Tensor repr. below.
+ guidance_image, guidance_mask = self.renderer.render_image(
+ point_info, w, h, return_mask=True)
+
+ # If mask is None, there is no guidance.
+ # print(np.sum(guidance_mask), guidance_mask.size)
+ # if np.sum(guidance_mask) == 0:
+ # return None, point_info
+
+ # Flip guidance image and guidance mask if needed.
+ if self.is_flipped_input:
+ guidance_image = np.fliplr(guidance_image).copy()
+ guidance_mask = np.fliplr(guidance_mask).copy()
+
+ # Go from (h, w, c) to (1, c, h, w).
+ # Convert guidance image to Tensor.
+ guidance_image = (transforms.ToTensor()(guidance_image) - 0.5) * 2
+ guidance_mask = transforms.ToTensor()(guidance_mask)
+ guidance = torch.cat((guidance_image, guidance_mask), dim=0)
+ guidance = guidance.unsqueeze(0).cuda()
+
+ # Save guidance at all resolutions.
+ guidance_images_and_masks = guidance
+
+ return guidance_images_and_masks, point_info
+
+ def forward(self, data):
+ r"""vid2vid generator forward.
+ Args:
+ data (dict) : Dictionary of input data.
+ Returns:
+ output (dict) : Dictionary of output data.
+ """
+ self._init_single_image_model()
+
+ label = data['label']
+ unprojection = data['unprojection']
+ label_prev, img_prev = data['prev_labels'], data['prev_images']
+ is_first_frame = img_prev is None
+ z = getattr(data, 'z', None)
+ bs, _, h, w = label.size()
+
+ # Whether to warp the previous frame or not.
+ flow = mask = img_warp = None
+ warp_prev = self.temporal_initialized and not is_first_frame and \
+ label_prev.shape[1] == self.num_frames_G - 1
+
+ # Get guidance images and masks.
+ guidance_images_and_masks, point_info = None, None
+ if unprojection is not None:
+ guidance_images_and_masks, point_info = \
+ self.get_guidance_images_and_masks(unprojection)
+
+ # Get SPADE conditional maps by embedding current label input.
+ cond_maps_now = self.get_cond_maps(label, self.label_embedding)
+
+ # Use single image model, if flow features are not available.
+ # Guidance features are used whenever flow features are available.
+ if self.single_image_model is not None and not warp_prev:
+ # Get z vector for single image model.
+ if self.single_image_model_z is None:
+ bs = data['label'].size(0)
+ z = torch.randn(bs, self.single_image_model.style_dims,
+ dtype=torch.float32).cuda()
+ if data['label'].dtype == torch.float16:
+ z = z.half()
+ self.single_image_model_z = z
+
+ # Get output image.
+ data['z'] = self.single_image_model_z
+ self.single_image_model.eval()
+ with torch.no_grad():
+ output = self.single_image_model.spade_generator(data)
+ img_final = output['fake_images'].detach()
+ fake_images_source = 'pretrained'
+ else:
+ # Input to the generator will either be noise/segmentation map (for
+ # first frame) or encoded previous frame (for subsequent frames).
+ if is_first_frame:
+ # First frame in the sequence, start from scratch.
+ if self.use_segmap_as_input:
+ x_img = F.interpolate(label, size=(self.sh, self.sw))
+ x_img = self.fc(x_img)
+ else:
+ if z is None:
+ z = torch.randn(bs, self.z_dim, dtype=label.dtype,
+ device=label.get_device()).fill_(0)
+ x_img = self.fc(z).view(bs, -1, self.sh, self.sw)
+
+ # Upsampling layers.
+ for i in range(self.num_layers, self.num_downsamples_img, -1):
+ j = min(self.num_downsamples_embed, i)
+ x_img = getattr(self, 'up_' + str(i)
+ )(x_img, *cond_maps_now[j])
+ x_img = self.upsample(x_img)
+ else:
+ # Not the first frame, will encode the previous frame and feed
+ # to the generator.
+ x_img = self.down_first(img_prev[:, -1])
+
+ # Get label embedding for the previous frame.
+ cond_maps_prev = self.get_cond_maps(label_prev[:, -1],
+ self.label_embedding)
+
+ # Downsampling layers.
+ for i in range(self.num_downsamples_img + 1):
+ j = min(self.num_downsamples_embed, i)
+ x_img = getattr(self, 'down_' + str(i))(x_img,
+ *cond_maps_prev[j])
+ if i != self.num_downsamples_img:
+ x_img = self.downsample(x_img)
+
+ # Resnet blocks.
+ j = min(self.num_downsamples_embed,
+ self.num_downsamples_img + 1)
+ for i in range(self.num_res_blocks):
+ cond_maps = cond_maps_prev[j] if \
+ i < self.num_res_blocks // 2 else cond_maps_now[j]
+ x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps)
+
+ # Optical flow warped image features.
+ if warp_prev:
+ # Estimate flow & mask.
+ label_concat = torch.cat([label_prev.view(bs, -1, h, w),
+ label], dim=1)
+ img_prev_concat = img_prev.view(bs, -1, h, w)
+ flow, mask = self.flow_network_temp(
+ label_concat, img_prev_concat)
+ img_warp = resample(img_prev[:, -1], flow)
+ if self.spade_combine:
+ # if using SPADE combine, integrate the warped image (and
+ # occlusion mask) into conditional inputs for SPADE.
+ img_embed = torch.cat([img_warp, mask], dim=1)
+ cond_maps_img = self.get_cond_maps(img_embed,
+ self.img_prev_embedding)
+ x_raw_img = None
+
+ # Main image generation branch.
+ for i in range(self.num_downsamples_img, -1, -1):
+ # Get SPADE conditional inputs.
+ j = min(i, self.num_downsamples_embed)
+ cond_maps = cond_maps_now[j]
+
+ # For raw output generation.
+ if self.generate_raw_output:
+ if i >= self.num_multi_spade_layers - 1:
+ x_raw_img = x_img
+ if i < self.num_multi_spade_layers:
+ x_raw_img = self.one_up_conv_layer(
+ x_raw_img, cond_maps, i)
+
+ # Add flow and guidance features.
+ if warp_prev:
+ if i < self.num_multi_spade_layers:
+ # Add flow.
+ cond_maps += cond_maps_img[j]
+ # Add guidance.
+ if guidance_images_and_masks is not None:
+ cond_maps += [guidance_images_and_masks]
+ elif not self.guidance_only_with_flow:
+ # Add guidance if it is to be applied to every layer.
+ if guidance_images_and_masks is not None:
+ cond_maps += [guidance_images_and_masks]
+
+ x_img = self.one_up_conv_layer(x_img, cond_maps, i)
+
+ # Final conv layer.
+ img_final = torch.tanh(self.conv_img(x_img))
+ fake_images_source = 'in_training'
+
+ # Update the point cloud color dict of renderer.
+ self.renderer_update_point_cloud(img_final, point_info)
+
+ output = dict()
+ output['fake_images'] = img_final
+ output['fake_flow_maps'] = flow
+ output['fake_occlusion_masks'] = mask
+ output['fake_raw_images'] = None
+ output['warped_images'] = img_warp
+ output['guidance_images_and_masks'] = guidance_images_and_masks
+ output['fake_images_source'] = fake_images_source
+ return output
+
+ def get_cond_dims(self, num_downs=0):
+ r"""Get the dimensions of conditional inputs.
+ Args:
+ num_downs (int) : How many downsamples at current layer.
+ Returns:
+ ch (list) : List of dimensions.
+ """
+ if not self.use_embed:
+ ch = [self.num_input_channels]
+ else:
+ num_filters = getattr(self.emb_cfg, 'num_filters', 32)
+ num_downs = min(num_downs, self.num_downsamples_embed)
+ ch = [min(self.max_num_filters, num_filters * (2 ** num_downs))]
+ if (num_downs < self.num_multi_spade_layers):
+ ch = ch * 2
+ # Also add guidance (RGB + mask = 4 channels, or 3 if partial).
+ if self.guidance_partial_conv:
+ ch.append(3)
+ else:
+ ch.append(4)
+ elif not self.guidance_only_with_flow:
+ if self.guidance_partial_conv:
+ ch.append(3)
+ else:
+ ch.append(4)
+ return ch
+
+ def get_partial(self, num_downs=0):
+ r"""Get if convs should be partial or not.
+ Args:
+ num_downs (int) : How many downsamples at current layer.
+ Returns:
+ partial (list) : List of boolean partial or not.
+ """
+ partial = [False]
+ if (num_downs < self.num_multi_spade_layers):
+ partial = partial * 2
+ # Also add guidance (RGB + mask = 4 channels, or 3 if partial).
+ if self.guidance_partial_conv:
+ partial.append(True)
+ else:
+ partial.append(False)
+ elif not self.guidance_only_with_flow:
+ if self.guidance_partial_conv:
+ partial.append(True)
+ else:
+ partial.append(False)
+ return partial
+
+ def get_cond_maps(self, label, embedder):
+ r"""Get the conditional inputs.
+ Args:
+ label (4D tensor) : Input label tensor.
+ embedder (obj) : Embedding network.
+ Returns:
+ cond_maps (list) : List of conditional inputs.
+ """
+ if not self.use_embed:
+ return [label] * (self.num_layers + 1)
+ embedded_label = embedder(label)
+ cond_maps = [embedded_label]
+ cond_maps = [[m[i] for m in cond_maps] for i in
+ range(len(cond_maps[0]))]
+ return cond_maps
diff --git a/imaginaire/layers/__init__.py b/imaginaire/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e3f93c154678b630de93ab3d6b199204c4fd8fb
--- /dev/null
+++ b/imaginaire/layers/__init__.py
@@ -0,0 +1,27 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .conv import LinearBlock, Conv1dBlock, Conv2dBlock, Conv3dBlock, \
+ HyperConv2dBlock, MultiOutConv2dBlock, \
+ PartialConv2dBlock, PartialConv3dBlock
+from .residual import ResLinearBlock, Res1dBlock, Res2dBlock, Res3dBlock, \
+ HyperRes2dBlock, MultiOutRes2dBlock, UpRes2dBlock, DownRes2dBlock, \
+ PartialRes2dBlock, PartialRes3dBlock
+from .non_local import NonLocal2dBlock
+
+__all__ = ['Conv1dBlock', 'Conv2dBlock', 'Conv3dBlock', 'LinearBlock',
+ 'HyperConv2dBlock', 'MultiOutConv2dBlock',
+ 'PartialConv2dBlock', 'PartialConv3dBlock',
+ 'Res1dBlock', 'Res2dBlock', 'Res3dBlock',
+ 'UpRes2dBlock', 'DownRes2dBlock',
+ 'ResLinearBlock', 'HyperRes2dBlock', 'MultiOutRes2dBlock',
+ 'PartialRes2dBlock', 'PartialRes3dBlock',
+ 'NonLocal2dBlock']
+
+try:
+ from .repvgg import RepVGG1dBlock, RepVGG2dBlock, RepVGG3dBlock
+ from .attn import MultiheadAttention
+ __all__.extend(['RepVGG1dBlock', 'RepVGG2dBlock', 'RepVGG3dBlock'])
+except: # noqa
+ pass
diff --git a/imaginaire/layers/__pycache__/__init__.cpython-38.pyc b/imaginaire/layers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..528061b6276196a6fdf6c254450dfc5f9f9f6a89
Binary files /dev/null and b/imaginaire/layers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/activation_norm.cpython-38.pyc b/imaginaire/layers/__pycache__/activation_norm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..754bbd93d5840bc721eaaede366f15c5903d286b
Binary files /dev/null and b/imaginaire/layers/__pycache__/activation_norm.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/conv.cpython-38.pyc b/imaginaire/layers/__pycache__/conv.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da61b20f0c5a954d150297691572cfcae2b75bee
Binary files /dev/null and b/imaginaire/layers/__pycache__/conv.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/misc.cpython-38.pyc b/imaginaire/layers/__pycache__/misc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0cf730aee80241b8f869fd876e4e43102a25790
Binary files /dev/null and b/imaginaire/layers/__pycache__/misc.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/non_local.cpython-38.pyc b/imaginaire/layers/__pycache__/non_local.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9333890d786503983cdaf4be1fb7458fbc215ea0
Binary files /dev/null and b/imaginaire/layers/__pycache__/non_local.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/nonlinearity.cpython-38.pyc b/imaginaire/layers/__pycache__/nonlinearity.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a3a57d74453c76df61eaed8c0a4294479611dc2
Binary files /dev/null and b/imaginaire/layers/__pycache__/nonlinearity.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/residual.cpython-38.pyc b/imaginaire/layers/__pycache__/residual.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a64ea7db7e342783ae70a6348e6278f3fc030fc5
Binary files /dev/null and b/imaginaire/layers/__pycache__/residual.cpython-38.pyc differ
diff --git a/imaginaire/layers/__pycache__/weight_norm.cpython-38.pyc b/imaginaire/layers/__pycache__/weight_norm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4fffd5e3452b5dbedacfeada43c7c969a7026b7
Binary files /dev/null and b/imaginaire/layers/__pycache__/weight_norm.cpython-38.pyc differ
diff --git a/imaginaire/layers/activation_norm.py b/imaginaire/layers/activation_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..709a928eed76f456c6284f575ba47f2dff581b39
--- /dev/null
+++ b/imaginaire/layers/activation_norm.py
@@ -0,0 +1,629 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# flake8: noqa E722
+from types import SimpleNamespace
+
+import torch
+
+try:
+ from torch.nn import SyncBatchNorm
+except ImportError:
+ from torch.nn import BatchNorm2d as SyncBatchNorm
+from torch import nn
+from torch.nn import functional as F
+from .conv import LinearBlock, Conv2dBlock, HyperConv2d, PartialConv2dBlock
+from .misc import PartialSequential, ApplyNoise
+
+
+class AdaptiveNorm(nn.Module):
+ r"""Adaptive normalization layer. The layer first normalizes the input, then
+ performs an affine transformation using parameters computed from the
+ conditional inputs.
+
+ Args:
+ num_features (int): Number of channels in the input tensor.
+ cond_dims (int): Number of channels in the conditional inputs.
+ weight_norm_type (str): Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``, or ``'weight_demod'``.
+ projection (bool): If ``True``, project the conditional input to gamma
+ and beta using a fully connected layer, otherwise directly use
+ the conditional input as gamma and beta.
+ projection_bias (bool) If ``True``, use bias in the fully connected
+ projection layer.
+ separate_projection (bool): If ``True``, we will use two different
+ layers for gamma and beta. Otherwise, we will use one layer. It
+ matters only if you apply any weight norms to this layer.
+ input_dim (int): Number of dimensions of the input tensor.
+ activation_norm_type (str):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ """
+
+ def __init__(self, num_features, cond_dims, weight_norm_type='',
+ projection=True,
+ projection_bias=True,
+ separate_projection=False,
+ input_dim=2,
+ activation_norm_type='instance',
+ activation_norm_params=None,
+ apply_noise=False,
+ add_bias=True,
+ input_scale=1.0,
+ init_gain=1.0):
+ super().__init__()
+ if activation_norm_params is None:
+ activation_norm_params = SimpleNamespace(affine=False)
+ self.norm = get_activation_norm_layer(num_features,
+ activation_norm_type,
+ input_dim,
+ **vars(activation_norm_params))
+ if apply_noise:
+ self.noise_layer = ApplyNoise()
+ else:
+ self.noise_layer = None
+
+ if projection:
+ if separate_projection:
+ self.fc_gamma = \
+ LinearBlock(cond_dims, num_features,
+ weight_norm_type=weight_norm_type,
+ bias=projection_bias)
+ self.fc_beta = \
+ LinearBlock(cond_dims, num_features,
+ weight_norm_type=weight_norm_type,
+ bias=projection_bias)
+ else:
+ self.fc = LinearBlock(cond_dims, num_features * 2,
+ weight_norm_type=weight_norm_type,
+ bias=projection_bias)
+
+ self.projection = projection
+ self.separate_projection = separate_projection
+ self.input_scale = input_scale
+ self.add_bias = add_bias
+ self.conditional = True
+ self.init_gain = init_gain
+
+ def forward(self, x, y, noise=None, **_kwargs):
+ r"""Adaptive Normalization forward.
+
+ Args:
+ x (N x C1 x * tensor): Input tensor.
+ y (N x C2 tensor): Conditional information.
+ Returns:
+ out (N x C1 x * tensor): Output tensor.
+ """
+ y = y * self.input_scale
+ if self.projection:
+ if self.separate_projection:
+ gamma = self.fc_gamma(y)
+ beta = self.fc_beta(y)
+ for _ in range(x.dim() - gamma.dim()):
+ gamma = gamma.unsqueeze(-1)
+ beta = beta.unsqueeze(-1)
+ else:
+ y = self.fc(y)
+ for _ in range(x.dim() - y.dim()):
+ y = y.unsqueeze(-1)
+ gamma, beta = y.chunk(2, 1)
+ else:
+ for _ in range(x.dim() - y.dim()):
+ y = y.unsqueeze(-1)
+ gamma, beta = y.chunk(2, 1)
+ if self.norm is not None:
+ x = self.norm(x)
+ if self.noise_layer is not None:
+ x = self.noise_layer(x, noise=noise)
+ if self.add_bias:
+ x = torch.addcmul(beta, x, 1 + gamma)
+ return x
+ else:
+ return x * (1 + gamma), beta.squeeze(3).squeeze(2)
+
+
+class SpatiallyAdaptiveNorm(nn.Module):
+ r"""Spatially Adaptive Normalization (SPADE) initialization.
+
+ Args:
+ num_features (int) : Number of channels in the input tensor.
+ cond_dims (int or list of int) : List of numbers of channels
+ in the input.
+ num_filters (int): Number of filters in SPADE.
+ kernel_size (int): Kernel size of the convolutional filters in
+ the SPADE layer.
+ weight_norm_type (str): Type of weight normalization.
+ ``'none'``, ``'spectral'``, or ``'weight'``.
+ separate_projection (bool): If ``True``, we will use two different
+ layers for gamma and beta. Otherwise, we will use one layer. It
+ matters only if you apply any weight norms to this layer.
+ activation_norm_type (str):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ """
+
+ def __init__(self,
+ num_features,
+ cond_dims,
+ num_filters=128,
+ kernel_size=3,
+ weight_norm_type='',
+ separate_projection=False,
+ activation_norm_type='sync_batch',
+ activation_norm_params=None,
+ bias_only=False,
+ partial=False,
+ interpolation='nearest'):
+ super().__init__()
+ if activation_norm_params is None:
+ activation_norm_params = SimpleNamespace(affine=False)
+ padding = kernel_size // 2
+ self.separate_projection = separate_projection
+ self.mlps = nn.ModuleList()
+ self.gammas = nn.ModuleList()
+ self.betas = nn.ModuleList()
+ self.bias_only = bias_only
+ self.interpolation = interpolation
+
+ # Make cond_dims a list.
+ if type(cond_dims) != list:
+ cond_dims = [cond_dims]
+
+ # Make num_filters a list.
+ if not isinstance(num_filters, list):
+ num_filters = [num_filters] * len(cond_dims)
+ else:
+ assert len(num_filters) >= len(cond_dims)
+
+ # Make partial a list.
+ if not isinstance(partial, list):
+ partial = [partial] * len(cond_dims)
+ else:
+ assert len(partial) >= len(cond_dims)
+
+ for i, cond_dim in enumerate(cond_dims):
+ mlp = []
+ conv_block = PartialConv2dBlock if partial[i] else Conv2dBlock
+ sequential = PartialSequential if partial[i] else nn.Sequential
+
+ if num_filters[i] > 0:
+ mlp += [conv_block(cond_dim,
+ num_filters[i],
+ kernel_size,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ nonlinearity='relu')]
+ mlp_ch = cond_dim if num_filters[i] == 0 else num_filters[i]
+
+ if self.separate_projection:
+ if partial[i]:
+ raise NotImplementedError(
+ 'Separate projection not yet implemented for ' +
+ 'partial conv')
+ self.mlps.append(nn.Sequential(*mlp))
+ self.gammas.append(
+ conv_block(mlp_ch, num_features,
+ kernel_size,
+ padding=padding,
+ weight_norm_type=weight_norm_type))
+ self.betas.append(
+ conv_block(mlp_ch, num_features,
+ kernel_size,
+ padding=padding,
+ weight_norm_type=weight_norm_type))
+ else:
+ mlp += [conv_block(mlp_ch, num_features * 2, kernel_size,
+ padding=padding,
+ weight_norm_type=weight_norm_type)]
+ self.mlps.append(sequential(*mlp))
+
+ self.norm = get_activation_norm_layer(num_features,
+ activation_norm_type,
+ 2,
+ **vars(activation_norm_params))
+ self.conditional = True
+
+ def forward(self, x, *cond_inputs, **_kwargs):
+ r"""Spatially Adaptive Normalization (SPADE) forward.
+
+ Args:
+ x (N x C1 x H x W tensor) : Input tensor.
+ cond_inputs (list of tensors) : Conditional maps for SPADE.
+ Returns:
+ output (4D tensor) : Output tensor.
+ """
+ output = self.norm(x) if self.norm is not None else x
+ for i in range(len(cond_inputs)):
+ if cond_inputs[i] is None:
+ continue
+ label_map = F.interpolate(cond_inputs[i], size=x.size()[2:], mode=self.interpolation)
+ if self.separate_projection:
+ hidden = self.mlps[i](label_map)
+ gamma = self.gammas[i](hidden)
+ beta = self.betas[i](hidden)
+ else:
+ affine_params = self.mlps[i](label_map)
+ gamma, beta = affine_params.chunk(2, dim=1)
+ if self.bias_only:
+ output = output + beta
+ else:
+ output = output * (1 + gamma) + beta
+ return output
+
+
+class DualAdaptiveNorm(nn.Module):
+ def __init__(self,
+ num_features,
+ cond_dims,
+ projection_bias=True,
+ weight_norm_type='',
+ activation_norm_type='instance',
+ activation_norm_params=None,
+ apply_noise=False,
+ bias_only=False,
+ init_gain=1.0,
+ fc_scale=None,
+ is_spatial=None):
+ super().__init__()
+ if activation_norm_params is None:
+ activation_norm_params = SimpleNamespace(affine=False)
+ self.mlps = nn.ModuleList()
+ self.gammas = nn.ModuleList()
+ self.betas = nn.ModuleList()
+ self.bias_only = bias_only
+
+ # Make cond_dims a list.
+ if type(cond_dims) != list:
+ cond_dims = [cond_dims]
+
+ if is_spatial is None:
+ is_spatial = [False for _ in range(len(cond_dims))]
+ self.is_spatial = is_spatial
+
+ for cond_dim, this_is_spatial in zip(cond_dims, is_spatial):
+ kwargs = dict(weight_norm_type=weight_norm_type,
+ bias=projection_bias,
+ init_gain=init_gain,
+ output_scale=fc_scale)
+ if this_is_spatial:
+ self.gammas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs))
+ self.betas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs))
+ else:
+ self.gammas.append(LinearBlock(cond_dim, num_features, **kwargs))
+ self.betas.append(LinearBlock(cond_dim, num_features, **kwargs))
+
+ self.norm = get_activation_norm_layer(num_features,
+ activation_norm_type,
+ 2,
+ **vars(activation_norm_params))
+ self.conditional = True
+
+ def forward(self, x, *cond_inputs, **_kwargs):
+ assert len(cond_inputs) == len(self.gammas)
+ output = self.norm(x) if self.norm is not None else x
+ for cond, gamma_layer, beta_layer in zip(cond_inputs, self.gammas, self.betas):
+ if cond is None:
+ continue
+ gamma = gamma_layer(cond)
+ beta = beta_layer(cond)
+ if cond.dim() == 4 and gamma.shape != x.shape:
+ gamma = F.interpolate(gamma, size=x.size()[2:], mode='bilinear')
+ beta = F.interpolate(beta, size=x.size()[2:], mode='bilinear')
+ elif cond.dim() == 2:
+ gamma = gamma[:, :, None, None]
+ beta = beta[:, :, None, None]
+ if self.bias_only:
+ output = output + beta
+ else:
+ output = output * (1 + gamma) + beta
+ return output
+
+
+class HyperSpatiallyAdaptiveNorm(nn.Module):
+ r"""Spatially Adaptive Normalization (SPADE) initialization.
+
+ Args:
+ num_features (int) : Number of channels in the input tensor.
+ cond_dims (int or list of int) : List of numbers of channels
+ in the conditional input.
+ num_filters (int): Number of filters in SPADE.
+ kernel_size (int): Kernel size of the convolutional filters in
+ the SPADE layer.
+ weight_norm_type (str): Type of weight normalization.
+ ``'none'``, ``'spectral'``, or ``'weight'``.
+ activation_norm_type (str):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``.
+ is_hyper (bool): Whether to use hyper SPADE.
+ """
+
+ def __init__(self, num_features, cond_dims,
+ num_filters=0, kernel_size=3,
+ weight_norm_type='',
+ activation_norm_type='sync_batch', is_hyper=True):
+ super().__init__()
+ padding = kernel_size // 2
+ self.mlps = nn.ModuleList()
+ if type(cond_dims) != list:
+ cond_dims = [cond_dims]
+
+ for i, cond_dim in enumerate(cond_dims):
+ mlp = []
+ if not is_hyper or (i != 0):
+ if num_filters > 0:
+ mlp += [Conv2dBlock(cond_dim, num_filters, kernel_size,
+ padding=padding,
+ weight_norm_type=weight_norm_type,
+ nonlinearity='relu')]
+ mlp_ch = cond_dim if num_filters == 0 else num_filters
+ mlp += [Conv2dBlock(mlp_ch, num_features * 2, kernel_size,
+ padding=padding,
+ weight_norm_type=weight_norm_type)]
+ mlp = nn.Sequential(*mlp)
+ else:
+ if num_filters > 0:
+ raise ValueError('Multi hyper layer not supported yet.')
+ mlp = HyperConv2d(padding=padding)
+ self.mlps.append(mlp)
+
+ self.norm = get_activation_norm_layer(num_features,
+ activation_norm_type,
+ 2,
+ affine=False)
+
+ self.conditional = True
+
+ def forward(self, x, *cond_inputs,
+ norm_weights=(None, None), **_kwargs):
+ r"""Spatially Adaptive Normalization (SPADE) forward.
+
+ Args:
+ x (4D tensor) : Input tensor.
+ cond_inputs (list of tensors) : Conditional maps for SPADE.
+ norm_weights (5D tensor or list of tensors): conv weights or
+ [weights, biases].
+ Returns:
+ output (4D tensor) : Output tensor.
+ """
+ output = self.norm(x)
+ for i in range(len(cond_inputs)):
+ if cond_inputs[i] is None:
+ continue
+ if type(cond_inputs[i]) == list:
+ cond_input, mask = cond_inputs[i]
+ mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=False)
+ else:
+ cond_input = cond_inputs[i]
+ mask = None
+ label_map = F.interpolate(cond_input, size=x.size()[2:])
+ if norm_weights is None or norm_weights[0] is None or i != 0:
+ affine_params = self.mlps[i](label_map)
+ else:
+ affine_params = self.mlps[i](label_map,
+ conv_weights=norm_weights)
+ gamma, beta = affine_params.chunk(2, dim=1)
+ if mask is not None:
+ gamma = gamma * (1 - mask)
+ beta = beta * (1 - mask)
+ output = output * (1 + gamma) + beta
+ return output
+
+
+class LayerNorm2d(nn.Module):
+ r"""Layer Normalization as introduced in
+ https://arxiv.org/abs/1607.06450.
+ This is the usual way to apply layer normalization in CNNs.
+ Note that unlike the pytorch implementation which applies per-element
+ scale and bias, here it applies per-channel scale and bias, similar to
+ batch/instance normalization.
+
+ Args:
+ num_features (int): Number of channels in the input tensor.
+ eps (float, optional, default=1e-5): a value added to the
+ denominator for numerical stability.
+ affine (bool, optional, default=False): If ``True``, performs
+ affine transformation after normalization.
+ """
+
+ def __init__(self, num_features, eps=1e-5, channel_only=False, affine=True):
+ super(LayerNorm2d, self).__init__()
+ self.num_features = num_features
+ self.affine = affine
+ self.eps = eps
+ self.channel_only = channel_only
+
+ if self.affine:
+ self.gamma = nn.Parameter(torch.Tensor(num_features).fill_(1.0))
+ self.beta = nn.Parameter(torch.zeros(num_features))
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ """
+ shape = [-1] + [1] * (x.dim() - 1)
+ if self.channel_only:
+ mean = x.mean(1, keepdim=True)
+ std = x.std(1, keepdim=True)
+ else:
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
+ std = x.view(x.size(0), -1).std(1).view(*shape)
+
+ x = (x - mean) / (std + self.eps)
+
+ if self.affine:
+ shape = [1, -1] + [1] * (x.dim() - 2)
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+ return x
+
+
+class ScaleNorm(nn.Module):
+ r"""Scale normalization:
+ "Transformers without Tears: Improving the Normalization of Self-Attention"
+ Modified from:
+ https://github.com/tnq177/transformers_without_tears
+ """
+
+ def __init__(self, dim=-1, learned_scale=True, eps=1e-5):
+ super().__init__()
+ # scale = num_features ** 0.5
+ if learned_scale:
+ self.scale = nn.Parameter(torch.tensor(1.))
+ else:
+ self.scale = 1.
+ # self.num_features = num_features
+ self.dim = dim
+ self.eps = eps
+ self.learned_scale = learned_scale
+
+ def forward(self, x):
+ # noinspection PyArgumentList
+ scale = self.scale * torch.rsqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
+ return x * scale
+
+ def extra_repr(self):
+ s = 'learned_scale={learned_scale}'
+ return s.format(**self.__dict__)
+
+
+class PixelNorm(ScaleNorm):
+ def __init__(self, learned_scale=False, eps=1e-5, **_kwargs):
+ super().__init__(1, learned_scale, eps)
+
+
+class SplitMeanStd(nn.Module):
+ def __init__(self, num_features, eps=1e-5, **kwargs):
+ super().__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.multiple_outputs = True
+
+ def forward(self, x):
+ b, c, h, w = x.size()
+ mean = x.view(b, c, -1).mean(-1)[:, :, None, None]
+ var = x.view(b, c, -1).var(-1)[:, :, None, None]
+ std = torch.sqrt(var + self.eps)
+
+ # x = (x - mean) / std
+ return x, torch.cat((mean, std), dim=1)
+
+
+class ScaleNorm(nn.Module):
+ r"""Scale normalization:
+ "Transformers without Tears: Improving the Normalization of Self-Attention"
+ Modified from:
+ https://github.com/tnq177/transformers_without_tears
+ """
+
+ def __init__(self, dim=-1, learned_scale=True, eps=1e-5):
+ super().__init__()
+ # scale = num_features ** 0.5
+ if learned_scale:
+ self.scale = nn.Parameter(torch.tensor(1.))
+ else:
+ self.scale = 1.
+ # self.num_features = num_features
+ self.dim = dim
+ self.eps = eps
+ self.learned_scale = learned_scale
+
+ def forward(self, x):
+ # noinspection PyArgumentList
+ scale = self.scale * torch.rsqrt(
+ torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
+ return x * scale
+
+ def extra_repr(self):
+ s = 'learned_scale={learned_scale}'
+ return s.format(**self.__dict__)
+
+
+class PixelLayerNorm(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.norm = nn.LayerNorm(*args, **kwargs)
+
+ def forward(self, x):
+ if x.dim() == 4:
+ b, c, h, w = x.shape
+ return self.norm(x.permute(0, 2, 3, 1).view(-1, c)).view(b, h, w, c).permute(0, 3, 1, 2)
+ else:
+ return self.norm(x)
+
+
+def get_activation_norm_layer(num_features, norm_type, input_dim, **norm_params):
+ r"""Return an activation normalization layer.
+
+ Args:
+ num_features (int): Number of feature channels.
+ norm_type (str):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ input_dim (int): Number of input dimensions.
+ norm_params: Arbitrary keyword arguments that will be used to
+ initialize the activation normalization.
+ """
+ input_dim = max(input_dim, 1) # Norm1d works with both 0d and 1d inputs
+
+ if norm_type == 'none' or norm_type == '':
+ norm_layer = None
+ elif norm_type == 'batch':
+ norm = getattr(nn, 'BatchNorm%dd' % input_dim)
+ norm_layer = norm(num_features, **norm_params)
+ elif norm_type == 'instance':
+ affine = norm_params.pop('affine', True) # Use affine=True by default
+ norm = getattr(nn, 'InstanceNorm%dd' % input_dim)
+ norm_layer = norm(num_features, affine=affine, **norm_params)
+ elif norm_type == 'sync_batch':
+ norm_layer = SyncBatchNorm(num_features, **norm_params)
+ elif norm_type == 'layer':
+ norm_layer = nn.LayerNorm(num_features, **norm_params)
+ elif norm_type == 'layer_2d':
+ norm_layer = LayerNorm2d(num_features, **norm_params)
+ elif norm_type == 'pixel_layer':
+ elementwise_affine = norm_params.pop('affine', True) # Use affine=True by default
+ norm_layer = PixelLayerNorm(num_features, elementwise_affine=elementwise_affine, **norm_params)
+ elif norm_type == 'scale':
+ norm_layer = ScaleNorm(**norm_params)
+ elif norm_type == 'pixel':
+ norm_layer = PixelNorm(**norm_params)
+ import imaginaire.config
+ if imaginaire.config.USE_JIT:
+ norm_layer = torch.jit.script(norm_layer)
+ elif norm_type == 'group':
+ num_groups = norm_params.pop('num_groups', 4)
+ norm_layer = nn.GroupNorm(num_channels=num_features, num_groups=num_groups, **norm_params)
+ elif norm_type == 'adaptive':
+ norm_layer = AdaptiveNorm(num_features, **norm_params)
+ elif norm_type == 'dual_adaptive':
+ norm_layer = DualAdaptiveNorm(num_features, **norm_params)
+ elif norm_type == 'spatially_adaptive':
+ if input_dim != 2:
+ raise ValueError('Spatially adaptive normalization layers '
+ 'only supports 2D input')
+ norm_layer = SpatiallyAdaptiveNorm(num_features, **norm_params)
+ elif norm_type == 'hyper_spatially_adaptive':
+ if input_dim != 2:
+ raise ValueError('Spatially adaptive normalization layers '
+ 'only supports 2D input')
+ norm_layer = HyperSpatiallyAdaptiveNorm(num_features, **norm_params)
+ else:
+ raise ValueError('Activation norm layer %s '
+ 'is not recognized' % norm_type)
+ return norm_layer
diff --git a/imaginaire/layers/conv.py b/imaginaire/layers/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..499fc0442b77e3183225c3529a4e3590dab0bc57
--- /dev/null
+++ b/imaginaire/layers/conv.py
@@ -0,0 +1,1377 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import warnings
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .misc import ApplyNoise
+from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur
+
+
+class _BaseConvBlock(nn.Module):
+ r"""An abstract wrapper class that wraps a torch convolution or linear layer
+ with normalization and nonlinearity.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity,
+ inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale,
+ init_gain):
+ super().__init__()
+ from .nonlinearity import get_nonlinearity_layer
+ from .weight_norm import get_weight_norm_layer
+ from .activation_norm import get_activation_norm_layer
+ self.weight_norm_type = weight_norm_type
+ self.stride = stride
+ self.clamp = clamp
+ self.init_gain = init_gain
+
+ # Nonlinearity layer.
+ if 'fused' in nonlinearity:
+ # Fusing nonlinearity with bias.
+ lr_mul = getattr(weight_norm_params, 'lr_mul', 1)
+ conv_before_nonlinearity = order.find('C') < order.find('A')
+ if conv_before_nonlinearity:
+ assert bias is True
+ bias = False
+ channel = out_channels if conv_before_nonlinearity else in_channels
+ nonlinearity_layer = get_nonlinearity_layer(
+ nonlinearity, inplace=inplace_nonlinearity,
+ num_channels=channel, lr_mul=lr_mul)
+ else:
+ nonlinearity_layer = get_nonlinearity_layer(
+ nonlinearity, inplace=inplace_nonlinearity)
+
+ # Noise injection layer.
+ if apply_noise:
+ order = order.replace('C', 'CG')
+ noise_layer = ApplyNoise()
+ else:
+ noise_layer = None
+
+ # Convolutional layer.
+ if blur:
+ assert blur_kernel is not None
+ if stride == 2:
+ # Blur - Conv - Noise - Activate
+ p = (len(blur_kernel) - 2) + (kernel_size - 1)
+ pad0, pad1 = (p + 1) // 2, p // 2
+ padding = 0
+ blur_layer = Blur(
+ blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
+ )
+ order = order.replace('C', 'BC')
+ elif stride == 0.5:
+ # Conv - Blur - Noise - Activate
+ padding = 0
+ p = (len(blur_kernel) - 2) - (kernel_size - 1)
+ pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1
+ blur_layer = Blur(
+ blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
+ )
+ order = order.replace('C', 'CB')
+ elif stride == 1:
+ # No blur for now
+ blur_layer = nn.Identity()
+ else:
+ raise NotImplementedError
+ else:
+ blur_layer = nn.Identity()
+
+ if weight_norm_params is None:
+ weight_norm_params = SimpleNamespace()
+ weight_norm = get_weight_norm_layer(
+ weight_norm_type, **vars(weight_norm_params))
+ conv_layer = weight_norm(self._get_conv_layer(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ groups, bias, padding_mode, input_dim))
+
+ # Normalization layer.
+ conv_before_norm = order.find('C') < order.find('N')
+ norm_channels = out_channels if conv_before_norm else in_channels
+ if activation_norm_params is None:
+ activation_norm_params = SimpleNamespace()
+ activation_norm_layer = get_activation_norm_layer(
+ norm_channels,
+ activation_norm_type,
+ input_dim,
+ **vars(activation_norm_params))
+
+ # Mapping from operation names to layers.
+ mappings = {'C': {'conv': conv_layer},
+ 'N': {'norm': activation_norm_layer},
+ 'A': {'nonlinearity': nonlinearity_layer}}
+ mappings.update({'B': {'blur': blur_layer}})
+ mappings.update({'G': {'noise': noise_layer}})
+
+ # All layers in order.
+ self.layers = nn.ModuleDict()
+ for op in order:
+ if list(mappings[op].values())[0] is not None:
+ self.layers.update(mappings[op])
+
+ # Whether this block expects conditional inputs.
+ self.conditional = \
+ getattr(conv_layer, 'conditional', False) or \
+ getattr(activation_norm_layer, 'conditional', False)
+
+ # Scale the output by a learnable scaler parameter.
+ if output_scale is not None:
+ self.output_scale = nn.Parameter(torch.tensor(output_scale))
+ else:
+ self.register_parameter("output_scale", None)
+
+ def forward(self, x, *cond_inputs, **kw_cond_inputs):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ cond_inputs (list of tensors) : Conditional input tensors.
+ kw_cond_inputs (dict) : Keyword conditional inputs.
+ """
+ for key, layer in self.layers.items():
+ if getattr(layer, 'conditional', False):
+ # Layers that require conditional inputs.
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
+ else:
+ x = layer(x)
+ if self.clamp is not None and isinstance(layer, nn.Conv2d):
+ x.clamp_(max=self.clamp)
+ if key == 'conv':
+ if self.output_scale is not None:
+ x = x * self.output_scale
+ return x
+
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ input_dim):
+ # Returns the convolutional layer.
+ if input_dim == 0:
+ layer = nn.Linear(in_channels, out_channels, bias)
+ else:
+ if stride < 1: # Fractionally-strided convolution.
+ padding_mode = 'zeros'
+ assert padding == 0
+ layer_type = getattr(nn, f'ConvTranspose{input_dim}d')
+ stride = round(1 / stride)
+ else:
+ layer_type = getattr(nn, f'Conv{input_dim}d')
+ layer = layer_type(
+ in_channels, out_channels, kernel_size, stride, padding,
+ dilation=dilation, groups=groups, bias=bias,
+ padding_mode=padding_mode
+ )
+
+ return layer
+
+ def __repr__(self):
+ main_str = self._get_name() + '('
+ child_lines = []
+ for name, layer in self.layers.items():
+ mod_str = repr(layer)
+ if name == 'conv' and self.weight_norm_type != 'none' and \
+ self.weight_norm_type != '':
+ mod_str = mod_str[:-1] + \
+ ', weight_norm={}'.format(self.weight_norm_type) + ')'
+ if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1:
+ mod_str = mod_str[:-1] + \
+ ', lr_mul={}'.format(layer.base_lr_mul) + ')'
+ mod_str = self._addindent(mod_str, 2)
+ child_lines.append(mod_str)
+ if len(child_lines) == 1:
+ main_str += child_lines[0]
+ else:
+ main_str += '\n ' + '\n '.join(child_lines) + '\n'
+
+ main_str += ')'
+ return main_str
+
+ @staticmethod
+ def _addindent(s_, numSpaces):
+ s = s_.split('\n')
+ # don't do anything for single-line stuff
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(numSpaces * ' ') + line for line in s]
+ s = '\n'.join(s)
+ s = first + '\n' + s
+ return s
+
+
+class ModulatedConv2dBlock(_BaseConvBlock):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ nonlinearity='none', inplace_nonlinearity=False,
+ apply_noise=True, blur=True, order='CNA', demodulate=True,
+ eps=True, style_dim=None, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0):
+ self.eps = eps
+ self.demodulate = demodulate
+ assert style_dim is not None
+
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity, apply_noise, blur,
+ order, 2, clamp, blur_kernel, output_scale, init_gain)
+ self.modulation = LinearBlock(style_dim, in_channels,
+ weight_norm_type=weight_norm_type,
+ weight_norm_params=weight_norm_params)
+
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ input_dim):
+ assert input_dim == 2
+ layer = ModulatedConv2d(
+ in_channels, out_channels, kernel_size, stride, padding,
+ dilation, groups, bias, padding_mode, self.demodulate, self.eps)
+ return layer
+
+ def forward(self, x, *cond_inputs, **kw_cond_inputs):
+ for layer in self.layers.values():
+ if getattr(layer, 'conditional', False):
+ # Layers that require conditional inputs.
+ assert len(cond_inputs) == 1
+ style = cond_inputs[0]
+ x = layer(
+ x, self.modulation(style), **kw_cond_inputs
+ )
+ else:
+ x = layer(x)
+ if self.clamp is not None and isinstance(layer, ModulatedConv2d):
+ x.clamp_(max=self.clamp)
+ return x
+
+ def __repr__(self):
+ main_str = self._get_name() + '('
+ child_lines = []
+ for name, layer in self.layers.items():
+ mod_str = repr(layer)
+ if name == 'conv' and self.weight_norm_type != 'none' and \
+ self.weight_norm_type != '':
+ mod_str = mod_str[:-1] + \
+ ', weight_norm={}'.format(self.weight_norm_type) + \
+ ', demodulate={}'.format(self.demodulate) + ')'
+ mod_str = self._addindent(mod_str, 2)
+ child_lines.append(mod_str)
+ child_lines.append(
+ self._addindent('Modulation(' + repr(self.modulation) + ')', 2)
+ )
+ if len(child_lines) == 1:
+ main_str += child_lines[0]
+ else:
+ main_str += '\n ' + '\n '.join(child_lines) + '\n'
+
+ main_str += ')'
+ return main_str
+
+
+class ModulatedConv2d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
+ dilation, groups, bias, padding_mode, demodulate=True,
+ eps=1e-8):
+ # in_channels, out_channels, kernel_size, stride, padding,
+ # dilation, groups, bias, padding_mode
+ assert dilation == 1 and groups == 1
+
+ super().__init__()
+
+ self.eps = eps
+ self.kernel_size = kernel_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.padding = padding
+ self.stride = stride
+ self.padding_mode = padding_mode
+ # kernel_size // 2
+ # assert self.padding == padding
+
+ self.weight = nn.Parameter(
+ torch.randn(out_channels, in_channels, kernel_size, kernel_size)
+ )
+
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ # noinspection PyTypeChecker
+ self.register_parameter('bias', None)
+
+ # self.modulation = LinearBlock(style_dim, in_channels,
+ # weight_norm_type=weight_norm_type)
+ self.demodulate = demodulate
+ self.conditional = True
+
+ def forward(self, x, style, **_kwargs):
+ batch, in_channel, height, width = x.shape
+
+ # style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
+ # We assume the modulation layer is outside this module.
+ style = style.view(batch, 1, in_channel, 1, 1)
+ weight = self.weight.unsqueeze(0) * style
+
+ if self.demodulate:
+ demod = torch.rsqrt(
+ weight.pow(2).sum([2, 3, 4]) + self.eps)
+ weight = weight * demod.view(batch, self.out_channels, 1, 1, 1)
+
+ weight = weight.view(
+ batch * self.out_channels,
+ in_channel, self.kernel_size, self.kernel_size
+ )
+ if self.bias is not None:
+ bias = self.bias.repeat(batch)
+ else:
+ bias = self.bias
+
+ x = x.view(1, batch * in_channel, height, width)
+
+ if self.padding_mode != 'zeros':
+ x = F.pad(x, self._reversed_padding_repeated_twice,
+ mode=self.padding_mode)
+ padding = (0, 0)
+ else:
+ padding = self.padding
+
+ if self.stride == 0.5:
+ weight = weight.view(
+ batch, self.out_channels, in_channel,
+ self.kernel_size, self.kernel_size
+ )
+ weight = weight.transpose(1, 2).reshape(
+ batch * in_channel, self.out_channels,
+ self.kernel_size, self.kernel_size
+ )
+ out = F.conv_transpose2d(
+ x, weight, bias, padding=padding, stride=2, groups=batch
+ )
+
+ elif self.stride == 2:
+ out = F.conv2d(
+ x, weight, bias, padding=padding, stride=2, groups=batch
+ )
+
+ else:
+ out = F.conv2d(x, weight, bias, padding=padding, groups=batch)
+
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channels, height, width)
+
+ return out
+
+ def extra_repr(self):
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
+ ', stride={stride}')
+ if self.bias is None:
+ s += ', bias=False'
+ if self.padding_mode != 'zeros':
+ s += ', padding_mode={padding_mode}'
+ return s.format(**self.__dict__)
+
+
+class LinearBlock(_BaseConvBlock):
+ r"""A Wrapper class that wraps ``torch.nn.Linear`` with normalization and
+ nonlinearity.
+
+ Args:
+ in_features (int): Number of channels in the input tensor.
+ out_features (int): Number of channels in the output tensor.
+ bias (bool, optional, default=True):
+ If ``True``, adds a learnable bias to the output.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layer.
+ apply_noise (bool, optional, default=False): If ``True``, add
+ Gaussian noise with learnable magnitude after the
+ fully-connected layer.
+ order (str, optional, default='CNA'): Order of operations.
+ ``'C'``: fully-connected,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ For example, a block initialized with ``order='CNA'`` will
+ do convolution first, then normalization, then nonlinearity.
+ """
+
+ def __init__(self, in_features, out_features, bias=True,
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ nonlinearity='none', inplace_nonlinearity=False,
+ apply_noise=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None,
+ init_gain=1.0, **_kwargs):
+ if bool(_kwargs):
+ warnings.warn(f"Unused keyword arguments {_kwargs}")
+ super().__init__(in_features, out_features, None, None,
+ None, None, None, bias,
+ None, weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity, apply_noise,
+ False, order, 0, clamp, blur_kernel, output_scale,
+ init_gain)
+
+
+class EmbeddingBlock(_BaseConvBlock):
+ def __init__(self, in_features, out_features, bias=True,
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ nonlinearity='none', inplace_nonlinearity=False,
+ apply_noise=False, order='CNA', clamp=None, output_scale=None,
+ init_gain=1.0, **_kwargs):
+ if bool(_kwargs):
+ warnings.warn(f"Unused keyword arguments {_kwargs}")
+ super().__init__(in_features, out_features, None, None,
+ None, None, None, bias,
+ None, weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity, apply_noise,
+ False, order, 0, clamp, None, output_scale,
+ init_gain)
+
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ input_dim):
+ assert input_dim == 0
+ return nn.Embedding(in_channels, out_channels)
+
+
+class Embedding2dBlock(_BaseConvBlock):
+ def __init__(self, in_features, out_features, bias=True,
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ nonlinearity='none', inplace_nonlinearity=False,
+ apply_noise=False, order='CNA', clamp=None, output_scale=None,
+ init_gain=1.0, **_kwargs):
+ if bool(_kwargs):
+ warnings.warn(f"Unused keyword arguments {_kwargs}")
+ super().__init__(in_features, out_features, None, None,
+ None, None, None, bias,
+ None, weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity, apply_noise,
+ False, order, 0, clamp, None, output_scale,
+ init_gain)
+
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ input_dim):
+ assert input_dim == 0
+ return Embedding2d(in_channels, out_channels)
+
+
+class Conv1dBlock(_BaseConvBlock):
+ r"""A Wrapper class that wraps ``torch.nn.Conv1d`` with normalization and
+ nonlinearity.
+
+ Args:
+ in_channels (int): Number of channels in the input tensor.
+ out_channels (int): Number of channels in the output tensor.
+ kernel_size (int or tuple): Size of the convolving kernel.
+ stride (int or float or tuple, optional, default=1):
+ Stride of the convolution.
+ padding (int or tuple, optional, default=0):
+ Zero-padding added to both sides of the input.
+ dilation (int or tuple, optional, default=1):
+ Spacing between kernel elements.
+ groups (int, optional, default=1): Number of blocked connections
+ from input channels to output channels.
+ bias (bool, optional, default=True):
+ If ``True``, adds a learnable bias to the output.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layer.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ order (str, optional, default='CNA'): Order of operations.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ For example, a block initialized with ``order='CNA'`` will
+ do convolution first, then normalization, then nonlinearity.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ nonlinearity='none', inplace_nonlinearity=False,
+ apply_noise=False, blur=False, order='CNA', clamp=None, output_scale=None, init_gain=1.0, **_kwargs):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity, apply_noise,
+ blur, order, 1, clamp, None, output_scale, init_gain)
+
+
+class Conv2dBlock(_BaseConvBlock):
+ r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and
+ nonlinearity.
+
+ Args:
+ in_channels (int): Number of channels in the input tensor.
+ out_channels (int): Number of channels in the output tensor.
+ kernel_size (int or tuple): Size of the convolving kernel.
+ stride (int or float or tuple, optional, default=1):
+ Stride of the convolution.
+ padding (int or tuple, optional, default=0):
+ Zero-padding added to both sides of the input.
+ dilation (int or tuple, optional, default=1):
+ Spacing between kernel elements.
+ groups (int, optional, default=1): Number of blocked connections
+ from input channels to output channels.
+ bias (bool, optional, default=True):
+ If ``True``, adds a learnable bias to the output.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layer.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ order (str, optional, default='CNA'): Order of operations.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ For example, a block initialized with ``order='CNA'`` will
+ do convolution first, then normalization, then nonlinearity.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ nonlinearity='none', inplace_nonlinearity=False,
+ apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1),
+ output_scale=None, init_gain=1.0):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity,
+ apply_noise, blur, order, 2, clamp, blur_kernel, output_scale, init_gain)
+
+
+class Conv3dBlock(_BaseConvBlock):
+ r"""A Wrapper class that wraps ``torch.nn.Conv3d`` with normalization and
+ nonlinearity.
+
+ Args:
+ in_channels (int): Number of channels in the input tensor.
+ out_channels (int): Number of channels in the output tensor.
+ kernel_size (int or tuple): Size of the convolving kernel.
+ stride (int or float or tuple, optional, default=1):
+ Stride of the convolution.
+ padding (int or tuple, optional, default=0):
+ Zero-padding added to both sides of the input.
+ dilation (int or tuple, optional, default=1):
+ Spacing between kernel elements.
+ groups (int, optional, default=1): Number of blocked connections
+ from input channels to output channels.
+ bias (bool, optional, default=True):
+ If ``True``, adds a learnable bias to the output.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layer.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ order (str, optional, default='CNA'): Order of operations.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ For example, a block initialized with ``order='CNA'`` will
+ do convolution first, then normalization, then nonlinearity.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ nonlinearity='none', inplace_nonlinearity=False,
+ apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None,
+ init_gain=1.0):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity,
+ apply_noise, blur, order, 3, clamp, blur_kernel, output_scale, init_gain)
+
+
+class _BaseHyperConvBlock(_BaseConvBlock):
+ r"""An abstract wrapper class that wraps a hyper convolutional layer
+ with normalization and nonlinearity.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias,
+ padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity, apply_noise, blur,
+ is_hyper_conv, is_hyper_norm, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1),
+ output_scale=None, init_gain=1.0):
+ self.is_hyper_conv = is_hyper_conv
+ if is_hyper_conv:
+ weight_norm_type = 'none'
+ if is_hyper_norm:
+ activation_norm_type = 'hyper_' + activation_norm_type
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity, apply_noise, blur,
+ order, input_dim, clamp, blur_kernel, output_scale, init_gain)
+
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ input_dim):
+ if input_dim == 0:
+ raise ValueError('HyperLinearBlock is not supported.')
+ else:
+ name = 'HyperConv' if self.is_hyper_conv else 'nn.Conv'
+ layer_type = eval(name + '%dd' % input_dim)
+ layer = layer_type(
+ in_channels, out_channels, kernel_size, stride, padding,
+ dilation, groups, bias, padding_mode)
+ return layer
+
+
+class HyperConv2dBlock(_BaseHyperConvBlock):
+ r"""A Wrapper class that wraps ``HyperConv2d`` with normalization and
+ nonlinearity.
+
+ Args:
+ in_channels (int): Number of channels in the input tensor.
+ out_channels (int): Number of channels in the output tensor.
+ kernel_size (int or tuple): Size of the convolving kernel.
+ stride (int or float or tuple, optional, default=1):
+ Stride of the convolution.
+ padding (int or tuple, optional, default=0):
+ Zero-padding added to both sides of the input.
+ dilation (int or tuple, optional, default=1):
+ Spacing between kernel elements.
+ groups (int, optional, default=1): Number of blocked connections
+ from input channels to output channels.
+ bias (bool, optional, default=True):
+ If ``True``, adds a learnable bias to the output.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ is_hyper_conv (bool, optional, default=False): If ``True``, use
+ ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``.
+ is_hyper_norm (bool, optional, default=False): If ``True``, use
+ hyper normalizations.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layer.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ order (str, optional, default='CNA'): Order of operations.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ For example, a block initialized with ``order='CNA'`` will
+ do convolution first, then normalization, then nonlinearity.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ is_hyper_conv=False, is_hyper_norm=False,
+ nonlinearity='none', inplace_nonlinearity=False,
+ apply_noise=False, blur=False, order='CNA', clamp=None):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity, apply_noise, blur,
+ is_hyper_conv, is_hyper_norm, order, 2, clamp)
+
+
+class HyperConv2d(nn.Module):
+ r"""Hyper Conv2d initialization.
+
+ Args:
+ in_channels (int): Dummy parameter.
+ out_channels (int): Dummy parameter.
+ kernel_size (int or tuple): Dummy parameter.
+ stride (int or float or tuple, optional, default=1):
+ Stride of the convolution. Default: 1
+ padding (int or tuple, optional, default=0):
+ Zero-padding added to both sides of the input.
+ padding_mode (string, optional, default='zeros'):
+ ``'zeros'``, ``'reflect'``, ``'replicate'``
+ or ``'circular'``.
+ dilation (int or tuple, optional, default=1):
+ Spacing between kernel elements.
+ groups (int, optional, default=1): Number of blocked connections
+ from input channels to output channels.
+ bias (bool, optional, default=True): If ``True``,
+ adds a learnable bias to the output.
+ """
+
+ def __init__(self, in_channels=0, out_channels=0, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros'):
+ super().__init__()
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.use_bias = bias
+ self.padding_mode = padding_mode
+ self.conditional = True
+
+ def forward(self, x, *args, conv_weights=(None, None), **kwargs):
+ r"""Hyper Conv2d forward. Convolve x using the provided weight and bias.
+
+ Args:
+ x (N x C x H x W tensor): Input tensor.
+ conv_weights (N x C2 x C1 x k x k tensor or list of tensors):
+ Convolution weights or [weight, bias].
+ Returns:
+ y (N x C2 x H x W tensor): Output tensor.
+ """
+ if conv_weights is None:
+ conv_weight, conv_bias = None, None
+ elif isinstance(conv_weights, torch.Tensor):
+ conv_weight, conv_bias = conv_weights, None
+ else:
+ conv_weight, conv_bias = conv_weights
+
+ if conv_weight is None:
+ return x
+ if conv_bias is None:
+ if self.use_bias:
+ raise ValueError('bias not provided but set to true during '
+ 'initialization')
+ conv_bias = [None] * x.size(0)
+ if self.padding_mode != 'zeros':
+ x = F.pad(x, [self.padding] * 4, mode=self.padding_mode)
+ padding = 0
+ else:
+ padding = self.padding
+
+ y = None
+ # noinspection PyArgumentList
+ for i in range(x.size(0)):
+ if self.stride >= 1:
+ yi = F.conv2d(x[i: i + 1],
+ weight=conv_weight[i], bias=conv_bias[i],
+ stride=self.stride, padding=padding,
+ dilation=self.dilation, groups=self.groups)
+ else:
+ yi = F.conv_transpose2d(x[i: i + 1], weight=conv_weight[i],
+ bias=conv_bias[i], padding=self.padding,
+ stride=int(1 / self.stride),
+ dilation=self.dilation,
+ output_padding=self.padding,
+ groups=self.groups)
+ y = torch.cat([y, yi]) if y is not None else yi
+ return y
+
+
+class _BasePartialConvBlock(_BaseConvBlock):
+ r"""An abstract wrapper class that wraps a partial convolutional layer
+ with normalization and nonlinearity.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity,
+ multi_channel, return_mask,
+ apply_noise, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0):
+ self.multi_channel = multi_channel
+ self.return_mask = return_mask
+ self.partial_conv = True
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity, apply_noise,
+ False, order, input_dim, clamp, blur_kernel, output_scale, init_gain)
+
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ input_dim):
+ if input_dim == 2:
+ layer_type = PartialConv2d
+ elif input_dim == 3:
+ layer_type = PartialConv3d
+ else:
+ raise ValueError('Partial conv only supports 2D and 3D conv now.')
+ layer = layer_type(
+ in_channels, out_channels, kernel_size, stride, padding,
+ dilation, groups, bias, padding_mode,
+ multi_channel=self.multi_channel, return_mask=self.return_mask)
+ return layer
+
+ def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ cond_inputs (list of tensors) : Conditional input tensors.
+ mask_in (tensor, optional, default=``None``) If not ``None``,
+ it masks the valid input region.
+ kw_cond_inputs (dict) : Keyword conditional inputs.
+ Returns:
+ (tuple):
+ - x (tensor): Output tensor.
+ - mask_out (tensor, optional): Masks the valid output region.
+ """
+ mask_out = None
+ for layer in self.layers.values():
+ if getattr(layer, 'conditional', False):
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
+ elif getattr(layer, 'partial_conv', False):
+ x = layer(x, mask_in=mask_in, **kw_cond_inputs)
+ if type(x) == tuple:
+ x, mask_out = x
+ else:
+ x = layer(x)
+
+ if mask_out is not None:
+ return x, mask_out
+ return x
+
+
+class PartialConv2dBlock(_BasePartialConvBlock):
+ r"""A Wrapper class that wraps ``PartialConv2d`` with normalization and
+ nonlinearity.
+
+ Args:
+ in_channels (int): Number of channels in the input tensor.
+ out_channels (int): Number of channels in the output tensor.
+ kernel_size (int or tuple): Size of the convolving kernel.
+ stride (int or float or tuple, optional, default=1):
+ Stride of the convolution.
+ padding (int or tuple, optional, default=0):
+ Zero-padding added to both sides of the input.
+ dilation (int or tuple, optional, default=1):
+ Spacing between kernel elements.
+ groups (int, optional, default=1): Number of blocked connections
+ from input channels to output channels.
+ bias (bool, optional, default=True):
+ If ``True``, adds a learnable bias to the output.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layer.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ order (str, optional, default='CNA'): Order of operations.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ For example, a block initialized with ``order='CNA'`` will
+ do convolution first, then normalization, then nonlinearity.
+ multi_channel (bool, optional, default=False): If ``True``, use
+ different masks for different channels.
+ return_mask (bool, optional, default=True): If ``True``, the
+ forward call also returns a new mask.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ nonlinearity='none', inplace_nonlinearity=False,
+ multi_channel=False, return_mask=True,
+ apply_noise=False, order='CNA', clamp=None):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity,
+ multi_channel, return_mask, apply_noise, order, 2,
+ clamp)
+
+
+class PartialConv3dBlock(_BasePartialConvBlock):
+ r"""A Wrapper class that wraps ``PartialConv3d`` with normalization and
+ nonlinearity.
+
+ Args:
+ in_channels (int): Number of channels in the input tensor.
+ out_channels (int): Number of channels in the output tensor.
+ kernel_size (int or tuple): Size of the convolving kernel.
+ stride (int or float or tuple, optional, default=1):
+ Stride of the convolution.
+ padding (int or tuple, optional, default=0):
+ Zero-padding added to both sides of the input.
+ dilation (int or tuple, optional, default=1):
+ Spacing between kernel elements.
+ groups (int, optional, default=1): Number of blocked connections
+ from input channels to output channels.
+ bias (bool, optional, default=True):
+ If ``True``, adds a learnable bias to the output.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layer.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ order (str, optional, default='CNA'): Order of operations.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ For example, a block initialized with ``order='CNA'`` will
+ do convolution first, then normalization, then nonlinearity.
+ multi_channel (bool, optional, default=False): If ``True``, use
+ different masks for different channels.
+ return_mask (bool, optional, default=True): If ``True``, the
+ forward call also returns a new mask.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ nonlinearity='none', inplace_nonlinearity=False,
+ multi_channel=False, return_mask=True,
+ apply_noise=False, order='CNA', clamp=None):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity,
+ multi_channel, return_mask, apply_noise, order, 3,
+ clamp)
+
+
+class _MultiOutBaseConvBlock(_BaseConvBlock):
+ r"""An abstract wrapper class that wraps a hyper convolutional layer with
+ normalization and nonlinearity. It can return multiple outputs, if some
+ layers in the block return more than one output.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity,
+ inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1),
+ output_scale=None, init_gain=1.0):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity,
+ apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale, init_gain)
+ self.multiple_outputs = True
+
+ def forward(self, x, *cond_inputs, **kw_cond_inputs):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ cond_inputs (list of tensors) : Conditional input tensors.
+ kw_cond_inputs (dict) : Keyword conditional inputs.
+ Returns:
+ (tuple):
+ - x (tensor): Main output tensor.
+ - other_outputs (list of tensors): Other output tensors.
+ """
+ other_outputs = []
+ for layer in self.layers.values():
+ if getattr(layer, 'conditional', False):
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
+ if getattr(layer, 'multiple_outputs', False):
+ x, other_output = layer(x)
+ other_outputs.append(other_output)
+ else:
+ x = layer(x)
+ return (x, *other_outputs)
+
+
+class MultiOutConv2dBlock(_MultiOutBaseConvBlock):
+ r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and
+ nonlinearity. It can return multiple outputs, if some layers in the block
+ return more than one output.
+
+ Args:
+ in_channels (int): Number of channels in the input tensor.
+ out_channels (int): Number of channels in the output tensor.
+ kernel_size (int or tuple): Size of the convolving kernel.
+ stride (int or float or tuple, optional, default=1):
+ Stride of the convolution.
+ padding (int or tuple, optional, default=0):
+ Zero-padding added to both sides of the input.
+ dilation (int or tuple, optional, default=1):
+ Spacing between kernel elements.
+ groups (int, optional, default=1): Number of blocked connections
+ from input channels to output channels.
+ bias (bool, optional, default=True):
+ If ``True``, adds a learnable bias to the output.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layer.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ order (str, optional, default='CNA'): Order of operations.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ For example, a block initialized with ``order='CNA'`` will
+ do convolution first, then normalization, then nonlinearity.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ nonlinearity='none', inplace_nonlinearity=False,
+ apply_noise=False, blur=False, order='CNA', clamp=None):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity,
+ apply_noise, blur, order, 2, clamp)
+
+
+###############################################################################
+# BSD 3-Clause License
+#
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Author & Contact: Guilin Liu (guilinl@nvidia.com)
+###############################################################################
+class PartialConv2d(nn.Conv2d):
+ r"""Partial 2D convolution in
+ "Image inpainting for irregular holes using partial convolutions."
+ Liu et al., ECCV 2018
+ """
+
+ def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs):
+ # whether the mask is multi-channel or not
+ self.multi_channel = multi_channel
+ self.return_mask = return_mask
+ super(PartialConv2d, self).__init__(*args, **kwargs)
+
+ if self.multi_channel:
+ self.weight_maskUpdater = torch.ones(self.out_channels,
+ self.in_channels,
+ self.kernel_size[0],
+ self.kernel_size[1])
+ else:
+ self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0],
+ self.kernel_size[1])
+
+ shape = self.weight_maskUpdater.shape
+ self.slide_winsize = shape[1] * shape[2] * shape[3]
+
+ self.last_size = (None, None, None, None)
+ self.update_mask = None
+ self.mask_ratio = None
+ self.partial_conv = True
+
+ def forward(self, x, mask_in=None):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ mask_in (tensor, optional, default=``None``) If not ``None``,
+ it masks the valid input region.
+ """
+ assert len(x.shape) == 4
+ if mask_in is not None or self.last_size != tuple(x.shape):
+ self.last_size = tuple(x.shape)
+
+ with torch.no_grad():
+ if self.weight_maskUpdater.type() != x.type():
+ self.weight_maskUpdater = self.weight_maskUpdater.to(x)
+
+ if mask_in is None:
+ # If mask is not provided, create a mask.
+ if self.multi_channel:
+ mask = torch.ones(x.data.shape[0],
+ x.data.shape[1],
+ x.data.shape[2],
+ x.data.shape[3]).to(x)
+ else:
+ mask = torch.ones(1, 1, x.data.shape[2],
+ x.data.shape[3]).to(x)
+ else:
+ mask = mask_in
+
+ self.update_mask = F.conv2d(mask, self.weight_maskUpdater,
+ bias=None, stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation, groups=1)
+
+ # For mixed precision training, eps from 1e-8 to 1e-6.
+ eps = 1e-6
+ self.mask_ratio = self.slide_winsize / (self.update_mask + eps)
+ self.update_mask = torch.clamp(self.update_mask, 0, 1)
+ self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
+
+ raw_out = super(PartialConv2d, self).forward(
+ torch.mul(x, mask) if mask_in is not None else x)
+
+ if self.bias is not None:
+ bias_view = self.bias.view(1, self.out_channels, 1, 1)
+ output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
+ output = torch.mul(output, self.update_mask)
+ else:
+ output = torch.mul(raw_out, self.mask_ratio)
+
+ if self.return_mask:
+ return output, self.update_mask
+ else:
+ return output
+
+
+class PartialConv3d(nn.Conv3d):
+ r"""Partial 3D convolution in
+ "Image inpainting for irregular holes using partial convolutions."
+ Liu et al., ECCV 2018
+ """
+
+ def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs):
+ # whether the mask is multi-channel or not
+ self.multi_channel = multi_channel
+ self.return_mask = return_mask
+ super(PartialConv3d, self).__init__(*args, **kwargs)
+
+ if self.multi_channel:
+ self.weight_maskUpdater = \
+ torch.ones(self.out_channels, self.in_channels,
+ self.kernel_size[0], self.kernel_size[1],
+ self.kernel_size[2])
+ else:
+ self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0],
+ self.kernel_size[1],
+ self.kernel_size[2])
+ self.weight_maskUpdater = self.weight_maskUpdater.to('cuda')
+
+ shape = self.weight_maskUpdater.shape
+ self.slide_winsize = shape[1] * shape[2] * shape[3] * shape[4]
+ self.partial_conv = True
+
+ def forward(self, x, mask_in=None):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ mask_in (tensor, optional, default=``None``) If not ``None``, it
+ masks the valid input region.
+ """
+ assert len(x.shape) == 5
+
+ with torch.no_grad():
+ mask = mask_in
+ update_mask = F.conv3d(mask, self.weight_maskUpdater, bias=None,
+ stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=1)
+
+ mask_ratio = self.slide_winsize / (update_mask + 1e-8)
+ update_mask = torch.clamp(update_mask, 0, 1)
+ mask_ratio = torch.mul(mask_ratio, update_mask)
+
+ raw_out = super(PartialConv3d, self).forward(torch.mul(x, mask_in))
+
+ if self.bias is not None:
+ bias_view = self.bias.view(1, self.out_channels, 1, 1, 1)
+ output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view
+ if mask_in is not None:
+ output = torch.mul(output, update_mask)
+ else:
+ output = torch.mul(raw_out, mask_ratio)
+
+ if self.return_mask:
+ return output, update_mask
+ else:
+ return output
+
+
+class Embedding2d(nn.Embedding):
+ def __init__(self, in_channels, out_channels):
+ super().__init__(in_channels, out_channels)
+
+ def forward(self, x):
+ return F.embedding(
+ x.squeeze(1).long(), self.weight, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse).permute(0, 3, 1, 2).contiguous()
diff --git a/imaginaire/layers/misc.py b/imaginaire/layers/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..7731bd2fa855939a2866b211c33eb3ccce00c480
--- /dev/null
+++ b/imaginaire/layers/misc.py
@@ -0,0 +1,61 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch import nn
+
+
+class ApplyNoise(nn.Module):
+ r"""Add Gaussian noise to the input tensor."""
+
+ def __init__(self):
+ super().__init__()
+ # scale of the noise
+ self.scale = nn.Parameter(torch.zeros(1))
+ self.conditional = True
+
+ def forward(self, x, *_args, noise=None, **_kwargs):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ noise (tensor, optional, default=``None``) : Noise tensor to be
+ added to the input.
+ """
+ if noise is None:
+ sz = x.size()
+ noise = x.new_empty(sz[0], 1, *sz[2:]).normal_()
+
+ return x + self.scale * noise
+
+
+class PartialSequential(nn.Sequential):
+ r"""Sequential block for partial convolutions."""
+ def __init__(self, *modules):
+ super(PartialSequential, self).__init__(*modules)
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ """
+ act = x[:, :-1]
+ mask = x[:, -1].unsqueeze(1)
+ for module in self:
+ act, mask = module(act, mask_in=mask)
+ return act
+
+
+class ConstantInput(nn.Module):
+ def __init__(self, channel, size=4):
+ super().__init__()
+ if isinstance(size, int):
+ h, w = size, size
+ else:
+ h, w = size
+ self.input = nn.Parameter(torch.randn(1, channel, h, w))
+
+ def forward(self):
+ return self.input
diff --git a/imaginaire/layers/non_local.py b/imaginaire/layers/non_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d1a8b36d668377ef4c9c4d897cd3b55fe0363e7
--- /dev/null
+++ b/imaginaire/layers/non_local.py
@@ -0,0 +1,88 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from imaginaire.layers import Conv2dBlock
+
+
+class NonLocal2dBlock(nn.Module):
+ r"""Self attention Layer
+
+ Args:
+ in_channels (int): Number of channels in the input tensor.
+ scale (bool, optional, default=True): If ``True``, scale the
+ output by a learnable parameter.
+ clamp (bool, optional, default=``False``): If ``True``, clamp the
+ scaling parameter to (-1, 1).
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, weight_norm_params.__dict__ will be used as
+ keyword arguments when initializing weight normalization.
+ bias (bool, optional, default=True): If ``True``, adds bias in the
+ convolutional blocks.
+ """
+
+ def __init__(self,
+ in_channels,
+ scale=True,
+ clamp=False,
+ weight_norm_type='none',
+ weight_norm_params=None,
+ bias=True):
+ super(NonLocal2dBlock, self).__init__()
+ self.clamp = clamp
+ self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0
+ self.in_channels = in_channels
+ base_conv2d_block = partial(Conv2dBlock,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_norm_type=weight_norm_type,
+ weight_norm_params=weight_norm_params,
+ bias=bias)
+ self.theta = base_conv2d_block(in_channels, in_channels // 8)
+ self.phi = base_conv2d_block(in_channels, in_channels // 8)
+ self.g = base_conv2d_block(in_channels, in_channels // 2)
+ self.out_conv = base_conv2d_block(in_channels // 2, in_channels)
+ self.softmax = nn.Softmax(dim=-1)
+ self.max_pool = nn.MaxPool2d(2)
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor) : input feature maps (B X C X W X H)
+ Returns:
+ (tuple):
+ - out (tensor) : self attention value + input feature
+ - attention (tensor): B x N x N (N is Width*Height)
+ """
+ n, c, h, w = x.size()
+ theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1)
+
+ phi = self.phi(x)
+ phi = self.max_pool(phi).view(n, -1, h * w // 4)
+
+ energy = torch.bmm(theta, phi)
+ attention = self.softmax(energy)
+
+ g = self.g(x)
+ g = self.max_pool(g).view(n, -1, h * w // 4)
+
+ out = torch.bmm(g, attention.permute(0, 2, 1))
+ out = out.view(n, c // 2, h, w)
+ out = self.out_conv(out)
+
+ if self.clamp:
+ out = self.gamma.clamp(-1, 1) * out + x
+ else:
+ out = self.gamma * out + x
+ return out
diff --git a/imaginaire/layers/nonlinearity.py b/imaginaire/layers/nonlinearity.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fc172c74323e707e5a19f94e466c1bf0dae4418
--- /dev/null
+++ b/imaginaire/layers/nonlinearity.py
@@ -0,0 +1,65 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from imaginaire.third_party.bias_act.bias_act import FusedNonlinearity
+
+
+class ScaledLeakyReLU(nn.Module):
+ def __init__(self, negative_slope=0.2, scale=2 ** 0.5, inplace=False):
+ super().__init__()
+
+ self.negative_slope = negative_slope
+ self.scale = scale
+ self.inplace = inplace
+
+ def forward(self, x):
+ return F.leaky_relu(x, self.negative_slope, inplace=self.inplace) * self.scale
+ # return _fused_scaled_leakyrelu(x, self.negative_slope, self.inplace, self.scale)
+
+
+# @torch.jit.script
+# def _fused_scaled_leakyrelu(x: torch.Tensor, negative_slope: float, inplace: bool, scale: float):
+# return F.leaky_relu(x, negative_slope, inplace=inplace) * scale
+
+
+def get_nonlinearity_layer(nonlinearity_type, inplace, **kwargs):
+ r"""Return a nonlinearity layer.
+
+ Args:
+ nonlinearity_type (str):
+ Type of nonlinear activation function.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace (bool): If ``True``, set ``inplace=True`` when initializing
+ the nonlinearity layer.
+ """
+ if nonlinearity_type.startswith('fused'):
+ nonlinearity = FusedNonlinearity(nonlinearity=nonlinearity_type[6:], **kwargs)
+ elif nonlinearity_type == 'relu':
+ nonlinearity = nn.ReLU(inplace=inplace)
+ elif nonlinearity_type == 'leakyrelu':
+ nonlinearity = nn.LeakyReLU(0.2, inplace=inplace)
+ elif nonlinearity_type == 'scaled_leakyrelu':
+ nonlinearity = ScaledLeakyReLU(0.2, inplace=inplace)
+ import imaginaire.config
+ if imaginaire.config.USE_JIT:
+ nonlinearity = torch.jit.script(nonlinearity)
+ elif nonlinearity_type == 'prelu':
+ nonlinearity = nn.PReLU()
+ elif nonlinearity_type == 'tanh':
+ nonlinearity = nn.Tanh()
+ elif nonlinearity_type == 'sigmoid':
+ nonlinearity = nn.Sigmoid()
+ elif nonlinearity_type.startswith('softmax'):
+ dim = nonlinearity_type.split(',')[1] if ',' in nonlinearity_type else 1
+ nonlinearity = nn.Softmax(dim=int(dim))
+ elif nonlinearity_type == 'none' or nonlinearity_type == '':
+ nonlinearity = None
+ else:
+ raise ValueError('Nonlinearity %s is not recognized' % nonlinearity_type)
+ return nonlinearity
diff --git a/imaginaire/layers/residual.py b/imaginaire/layers/residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e1bda4dd30922f694302803b7d606af7f3c0c21
--- /dev/null
+++ b/imaginaire/layers/residual.py
@@ -0,0 +1,1411 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import functools
+
+import torch
+from torch import nn
+from torch.nn import Upsample as NearestUpsample
+from torch.utils.checkpoint import checkpoint
+
+from .conv import (Conv1dBlock, Conv2dBlock, Conv3dBlock, HyperConv2dBlock,
+ LinearBlock, MultiOutConv2dBlock, PartialConv2dBlock,
+ PartialConv3dBlock, ModulatedConv2dBlock)
+from imaginaire.third_party.upfirdn2d.upfirdn2d import BlurUpsample
+
+
+class _BaseResBlock(nn.Module):
+ r"""An abstract class for residual blocks.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels,
+ order, block, learn_shortcut, clamp, output_scale,
+ skip_block=None, blur=False, upsample_first=True, skip_weight_norm=True):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.output_scale = output_scale
+ self.upsample_first = upsample_first
+ self.stride = stride
+ self.blur = blur
+ if skip_block is None:
+ skip_block = block
+
+ if order == 'pre_act':
+ order = 'NACNAC'
+ if isinstance(bias, bool):
+ # The bias for conv_block_0, conv_block_1, and conv_block_s.
+ biases = [bias, bias, bias]
+ elif isinstance(bias, list):
+ if len(bias) == 3:
+ biases = bias
+ else:
+ raise ValueError('Bias list must be 3.')
+ else:
+ raise ValueError('Bias must be either an integer or s list.')
+ if learn_shortcut is None:
+ self.learn_shortcut = (in_channels != out_channels)
+ else:
+ self.learn_shortcut = learn_shortcut
+ if len(order) > 6 or len(order) < 5:
+ raise ValueError('order must be either 5 or 6 characters')
+ if hidden_channels_equal_out_channels:
+ hidden_channels = out_channels
+ else:
+ hidden_channels = min(in_channels, out_channels)
+
+ # Parameters.
+ residual_params = {}
+ shortcut_params = {}
+ base_params = dict(dilation=dilation,
+ groups=groups,
+ padding_mode=padding_mode,
+ clamp=clamp)
+ residual_params.update(base_params)
+ residual_params.update(
+ dict(activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ weight_norm_type=weight_norm_type,
+ weight_norm_params=weight_norm_params,
+ padding=padding,
+ apply_noise=apply_noise))
+ shortcut_params.update(base_params)
+ shortcut_params.update(dict(kernel_size=1))
+ if skip_activation_norm:
+ shortcut_params.update(
+ dict(activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ apply_noise=False))
+ if skip_weight_norm:
+ shortcut_params.update(
+ dict(weight_norm_type=weight_norm_type,
+ weight_norm_params=weight_norm_params))
+
+ # Residual branch.
+ if order.find('A') < order.find('C') and \
+ (activation_norm_type == '' or activation_norm_type == 'none'):
+ # Nonlinearity is the first operation in the residual path.
+ # In-place nonlinearity will modify the input variable and cause
+ # backward error.
+ first_inplace = False
+ else:
+ first_inplace = inplace_nonlinearity
+
+ (first_stride, second_stride, shortcut_stride,
+ first_blur, second_blur, shortcut_blur) = self._get_stride_blur()
+ self.conv_block_0 = block(
+ in_channels, hidden_channels,
+ kernel_size=kernel_size,
+ bias=biases[0],
+ nonlinearity=nonlinearity,
+ order=order[0:3],
+ inplace_nonlinearity=first_inplace,
+ stride=first_stride,
+ blur=first_blur,
+ **residual_params
+ )
+ self.conv_block_1 = block(
+ hidden_channels, out_channels,
+ kernel_size=kernel_size,
+ bias=biases[1],
+ nonlinearity=nonlinearity,
+ order=order[3:],
+ inplace_nonlinearity=inplace_nonlinearity,
+ stride=second_stride,
+ blur=second_blur,
+ **residual_params
+ )
+
+ # Shortcut branch.
+ if self.learn_shortcut:
+ if skip_nonlinearity:
+ skip_nonlinearity_type = nonlinearity
+ else:
+ skip_nonlinearity_type = ''
+ self.conv_block_s = skip_block(in_channels, out_channels,
+ bias=biases[2],
+ nonlinearity=skip_nonlinearity_type,
+ order=order[0:3],
+ stride=shortcut_stride,
+ blur=shortcut_blur,
+ **shortcut_params)
+ elif in_channels < out_channels:
+ if skip_nonlinearity:
+ skip_nonlinearity_type = nonlinearity
+ else:
+ skip_nonlinearity_type = ''
+ self.conv_block_s = skip_block(in_channels,
+ out_channels - in_channels,
+ bias=biases[2],
+ nonlinearity=skip_nonlinearity_type,
+ order=order[0:3],
+ stride=shortcut_stride,
+ blur=shortcut_blur,
+ **shortcut_params)
+
+ # Whether this block expects conditional inputs.
+ self.conditional = \
+ getattr(self.conv_block_0, 'conditional', False) or \
+ getattr(self.conv_block_1, 'conditional', False)
+
+ def _get_stride_blur(self):
+ if self.stride > 1:
+ # Downsampling.
+ first_stride, second_stride = 1, self.stride
+ first_blur, second_blur = False, self.blur
+ shortcut_stride = self.stride
+ shortcut_blur = self.blur
+ self.upsample = None
+ elif self.stride < 1:
+ # Upsampling.
+ first_stride, second_stride = self.stride, 1
+ first_blur, second_blur = self.blur, False
+ shortcut_blur = False
+ shortcut_stride = 1
+ if self.blur:
+ # The shortcut branch uses blur_upsample + stride-1 conv
+ self.upsample = BlurUpsample()
+ else:
+ shortcut_stride = self.stride
+ self.upsample = nn.Upsample(scale_factor=2)
+ else:
+ first_stride = second_stride = 1
+ first_blur = second_blur = False
+ shortcut_stride = 1
+ shortcut_blur = False
+ self.upsample = None
+ return (first_stride, second_stride, shortcut_stride,
+ first_blur, second_blur, shortcut_blur)
+
+ def conv_blocks(
+ self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
+ ):
+ r"""Returns the output of the residual branch.
+
+ Args:
+ x (tensor): Input tensor.
+ cond_inputs (list of tensors) : Conditional input tensors.
+ kw_cond_inputs (dict) : Keyword conditional inputs.
+ Returns:
+ dx (tensor): Output tensor.
+ """
+ if separate_cond:
+ dx = self.conv_block_0(x, cond_inputs[0],
+ **kw_cond_inputs.get('kwargs_0', {}))
+ dx = self.conv_block_1(dx, cond_inputs[1],
+ **kw_cond_inputs.get('kwargs_1', {}))
+ else:
+ dx = self.conv_block_0(x, *cond_inputs, **kw_cond_inputs)
+ dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
+ return dx
+
+ def forward(self, x, *cond_inputs, do_checkpoint=False, separate_cond=False,
+ **kw_cond_inputs):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ cond_inputs (list of tensors) : Conditional input tensors.
+ do_checkpoint (bool, optional, default=``False``) If ``True``,
+ trade compute for memory by checkpointing the model.
+ kw_cond_inputs (dict) : Keyword conditional inputs.
+ Returns:
+ output (tensor): Output tensor.
+ """
+ if do_checkpoint:
+ dx = checkpoint(self.conv_blocks, x, *cond_inputs,
+ separate_cond=separate_cond, **kw_cond_inputs)
+ else:
+ dx = self.conv_blocks(x, *cond_inputs,
+ separate_cond=separate_cond, **kw_cond_inputs)
+
+ if self.upsample_first and self.upsample is not None:
+ x = self.upsample(x)
+ if self.learn_shortcut:
+ if separate_cond:
+ x_shortcut = self.conv_block_s(
+ x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {})
+ )
+ else:
+ x_shortcut = self.conv_block_s(
+ x, *cond_inputs, **kw_cond_inputs
+ )
+ elif self.in_channels < self.out_channels:
+ if separate_cond:
+ x_shortcut_pad = self.conv_block_s(
+ x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {})
+ )
+ else:
+ x_shortcut_pad = self.conv_block_s(
+ x, *cond_inputs, **kw_cond_inputs
+ )
+ x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
+ elif self.in_channels > self.out_channels:
+ x_shortcut = x[:, :self.out_channels, :, :]
+ else:
+ x_shortcut = x
+ if not self.upsample_first and self.upsample is not None:
+ x_shortcut = self.upsample(x_shortcut)
+
+ output = x_shortcut + dx
+ return self.output_scale * output
+
+ def extra_repr(self):
+ s = 'output_scale={output_scale}'
+ return s.format(**self.__dict__)
+
+
+class ModulatedRes2dBlock(_BaseResBlock):
+ def __init__(self, in_channels, out_channels, style_dim, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ apply_noise=True, hidden_channels_equal_out_channels=False,
+ order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1,
+ demodulate=True, eps=1e-8):
+ block = functools.partial(ModulatedConv2dBlock,
+ style_dim=style_dim,
+ demodulate=demodulate, eps=eps)
+ skip_block = Conv2dBlock
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, block,
+ learn_shortcut, clamp, output_scale, skip_block=skip_block)
+
+ def conv_blocks(self, x, *cond_inputs, **kw_cond_inputs):
+ assert len(list(cond_inputs)) == 2
+ dx = self.conv_block_0(x, cond_inputs[0], **kw_cond_inputs)
+ dx = self.conv_block_1(dx, cond_inputs[1], **kw_cond_inputs)
+ return dx
+
+
+class ResLinearBlock(_BaseResBlock):
+ r"""Residual block with full-connected layers.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, add
+ Gaussian noise with learnable magnitude after the
+ fully-connected layer.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: fully-connected,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, bias=True,
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ apply_noise=False, hidden_channels_equal_out_channels=False,
+ order='CNACNA', learn_shortcut=None, clamp=None,
+ output_scale=1):
+ super().__init__(in_channels, out_channels, None, 1, None, None,
+ None, bias, None, weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, LinearBlock,
+ learn_shortcut, clamp, output_scale)
+
+
+class Res1dBlock(_BaseResBlock):
+ r"""Residual block for 1D input.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ kernel_size (int, optional, default=3): Kernel size for the
+ convolutional filters in the residual link.
+ padding (int, optional, default=1): Padding size.
+ dilation (int, optional, default=1): Dilation factor.
+ groups (int, optional, default=1): Number of convolutional/linear
+ groups.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ apply_noise=False, hidden_channels_equal_out_channels=False,
+ order='CNACNA', learn_shortcut=None, clamp=None,
+ output_scale=1):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, Conv1dBlock,
+ learn_shortcut, clamp, output_scale)
+
+
+class Res2dBlock(_BaseResBlock):
+ r"""Residual block for 2D input.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ kernel_size (int, optional, default=3): Kernel size for the
+ convolutional filters in the residual link.
+ padding (int, optional, default=1): Padding size.
+ dilation (int, optional, default=1): Dilation factor.
+ groups (int, optional, default=1): Number of convolutional/linear
+ groups.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ skip_weight_norm=True,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ apply_noise=False, hidden_channels_equal_out_channels=False,
+ order='CNACNA', learn_shortcut=None, clamp=None,
+ output_scale=1, blur=False, upsample_first=True):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, Conv2dBlock,
+ learn_shortcut, clamp, output_scale, blur=blur,
+ upsample_first=upsample_first,
+ skip_weight_norm=skip_weight_norm)
+
+
+class Res3dBlock(_BaseResBlock):
+ r"""Residual block for 3D input.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ kernel_size (int, optional, default=3): Kernel size for the
+ convolutional filters in the residual link.
+ padding (int, optional, default=1): Padding size.
+ dilation (int, optional, default=1): Dilation factor.
+ groups (int, optional, default=1): Number of convolutional/linear
+ groups.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ apply_noise=False, hidden_channels_equal_out_channels=False,
+ order='CNACNA', learn_shortcut=None, clamp=None,
+ output_scale=1):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, Conv3dBlock,
+ learn_shortcut, clamp, output_scale)
+
+
+class _BaseHyperResBlock(_BaseResBlock):
+ r"""An abstract class for hyper residual blocks.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels,
+ order, is_hyper_conv, is_hyper_norm, block, learn_shortcut,
+ clamp=None, output_scale=1):
+ block = functools.partial(block,
+ is_hyper_conv=is_hyper_conv,
+ is_hyper_norm=is_hyper_norm)
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, block,
+ learn_shortcut, clamp, output_scale)
+
+ def forward(self, x, *cond_inputs, conv_weights=(None,) * 3,
+ norm_weights=(None,) * 3, **kw_cond_inputs):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ cond_inputs (list of tensors) : Conditional input tensors.
+ conv_weights (list of tensors): Convolution weights for
+ three convolutional layers respectively.
+ norm_weights (list of tensors): Normalization weights for
+ three convolutional layers respectively.
+ kw_cond_inputs (dict) : Keyword conditional inputs.
+ Returns:
+ output (tensor): Output tensor.
+ """
+ dx = self.conv_block_0(x, *cond_inputs, conv_weights=conv_weights[0],
+ norm_weights=norm_weights[0])
+ dx = self.conv_block_1(dx, *cond_inputs, conv_weights=conv_weights[1],
+ norm_weights=norm_weights[1])
+ if self.learn_shortcut:
+ x_shortcut = self.conv_block_s(x, *cond_inputs,
+ conv_weights=conv_weights[2],
+ norm_weights=norm_weights[2])
+ else:
+ x_shortcut = x
+ output = x_shortcut + dx
+ return self.output_scale * output
+
+
+class HyperRes2dBlock(_BaseHyperResBlock):
+ r"""Hyper residual block for 2D input.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ kernel_size (int, optional, default=3): Kernel size for the
+ convolutional filters in the residual link.
+ padding (int, optional, default=1): Padding size.
+ dilation (int, optional, default=1): Dilation factor.
+ groups (int, optional, default=1): Number of convolutional/linear
+ groups.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ is_hyper_conv (bool, optional, default=False): If ``True``, use
+ ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``.
+ is_hyper_norm (bool, optional, default=False): If ``True``, use
+ hyper normalizations.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='', weight_norm_params=None,
+ activation_norm_type='', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ apply_noise=False, hidden_channels_equal_out_channels=False,
+ order='CNACNA', is_hyper_conv=False, is_hyper_norm=False,
+ learn_shortcut=None, clamp=None, output_scale=1):
+ super().__init__(in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels,
+ order, is_hyper_conv, is_hyper_norm,
+ HyperConv2dBlock, learn_shortcut, clamp, output_scale)
+
+
+class _BaseDownResBlock(_BaseResBlock):
+ r"""An abstract class for residual blocks with downsampling.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity,
+ apply_noise, hidden_channels_equal_out_channels,
+ order, block, pooling, down_factor, learn_shortcut,
+ clamp=None, output_scale=1):
+ super().__init__(in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, block,
+ learn_shortcut, clamp, output_scale)
+ self.pooling = pooling(down_factor)
+
+ def forward(self, x, *cond_inputs):
+ r"""
+
+ Args:
+ x (tensor) : Input tensor.
+ cond_inputs (list of tensors) : conditional input.
+ Returns:
+ output (tensor) : Output tensor.
+ """
+ dx = self.conv_block_0(x, *cond_inputs)
+ dx = self.conv_block_1(dx, *cond_inputs)
+ dx = self.pooling(dx)
+ if self.learn_shortcut:
+ x_shortcut = self.conv_block_s(x, *cond_inputs)
+ else:
+ x_shortcut = x
+ x_shortcut = self.pooling(x_shortcut)
+ output = x_shortcut + dx
+ return self.output_scale * output
+
+
+class DownRes2dBlock(_BaseDownResBlock):
+ r"""Residual block for 2D input with downsampling.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ kernel_size (int, optional, default=3): Kernel size for the
+ convolutional filters in the residual link.
+ padding (int, optional, default=1): Padding size.
+ dilation (int, optional, default=1): Dilation factor.
+ groups (int, optional, default=1): Number of convolutional/linear
+ groups.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ pooling (class, optional, default=nn.AvgPool2d): Pytorch pooling
+ layer to be used.
+ down_factor (int, optional, default=2): Downsampling factor.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ apply_noise=False, hidden_channels_equal_out_channels=False,
+ order='CNACNA', pooling=nn.AvgPool2d, down_factor=2,
+ learn_shortcut=None, clamp=None, output_scale=1):
+ super().__init__(in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels,
+ order, Conv2dBlock, pooling,
+ down_factor, learn_shortcut, clamp, output_scale)
+
+
+class _BaseUpResBlock(_BaseResBlock):
+ r"""An abstract class for residual blocks with upsampling.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity,
+ apply_noise, hidden_channels_equal_out_channels,
+ order, block, upsample, up_factor, learn_shortcut, clamp=None,
+ output_scale=1):
+ super().__init__(in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, block,
+ learn_shortcut, clamp, output_scale)
+ self.order = order
+ self.upsample = upsample(scale_factor=up_factor)
+
+ def _get_stride_blur(self):
+ # Upsampling.
+ first_stride, second_stride = self.stride, 1
+ first_blur, second_blur = self.blur, False
+ shortcut_blur = False
+ shortcut_stride = 1
+ # if self.upsample == 'blur_deconv':
+
+ if self.blur:
+ # The shortcut branch uses blur_upsample + stride-1 conv
+ self.upsample = BlurUpsample()
+ else:
+ shortcut_stride = self.stride
+ self.upsample = nn.Upsample(scale_factor=2)
+
+ return (first_stride, second_stride, shortcut_stride,
+ first_blur, second_blur, shortcut_blur)
+
+ def forward(self, x, *cond_inputs):
+ r"""Implementation of the up residual block forward function.
+ If the order is 'NAC' for the first residual block, we will first
+ do the activation norm and nonlinearity, in the original resolution.
+ We will then upsample the activation map to a higher resolution. We
+ then do the convolution.
+ It is is other orders, then we first do the whole processing and
+ then upsample.
+
+ Args:
+ x (tensor) : Input tensor.
+ cond_inputs (list of tensors) : Conditional input.
+ Returns:
+ output (tensor) : Output tensor.
+ """
+ # In this particular upsample residual block operation, we first
+ # upsample the skip connection.
+ if self.learn_shortcut:
+ x_shortcut = self.upsample(x)
+ x_shortcut = self.conv_block_s(x_shortcut, *cond_inputs)
+ else:
+ x_shortcut = self.upsample(x)
+
+ if self.order[0:3] == 'NAC':
+ for ix, layer in enumerate(self.conv_block_0.layers.values()):
+ if getattr(layer, 'conditional', False):
+ x = layer(x, *cond_inputs)
+ else:
+ x = layer(x)
+ if ix == 1:
+ x = self.upsample(x)
+ else:
+ x = self.conv_block_0(x, *cond_inputs)
+ x = self.upsample(x)
+ x = self.conv_block_1(x, *cond_inputs)
+
+ output = x_shortcut + x
+ return self.output_scale * output
+
+
+class UpRes2dBlock(_BaseUpResBlock):
+ r"""Residual block for 2D input with downsampling.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ kernel_size (int, optional, default=3): Kernel size for the
+ convolutional filters in the residual link.
+ padding (int, optional, default=1): Padding size.
+ dilation (int, optional, default=1): Dilation factor.
+ groups (int, optional, default=1): Number of convolutional/linear
+ groups.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ upsample (class, optional, default=NearestUpsample): PPytorch
+ upsampling layer to be used.
+ up_factor (int, optional, default=2): Upsampling factor.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ apply_noise=False, hidden_channels_equal_out_channels=False,
+ order='CNACNA', upsample=NearestUpsample, up_factor=2,
+ learn_shortcut=None, clamp=None, output_scale=1):
+ super().__init__(in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity,
+ apply_noise, hidden_channels_equal_out_channels,
+ order, Conv2dBlock,
+ upsample, up_factor, learn_shortcut, clamp,
+ output_scale)
+
+
+class _BasePartialResBlock(_BaseResBlock):
+ r"""An abstract class for residual blocks with partial convolution.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity,
+ multi_channel, return_mask,
+ apply_noise, hidden_channels_equal_out_channels,
+ order, block, learn_shortcut, clamp=None, output_scale=1):
+ block = functools.partial(block,
+ multi_channel=multi_channel,
+ return_mask=return_mask)
+ self.partial_conv = True
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, block,
+ learn_shortcut, clamp, output_scale)
+
+ def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ cond_inputs (list of tensors) : Conditional input tensors.
+ mask_in (tensor, optional, default=``None``) If not ``None``,
+ it masks the valid input region.
+ kw_cond_inputs (dict) : Keyword conditional inputs.
+ Returns:
+ (tuple):
+ - output (tensor): Output tensor.
+ - mask_out (tensor, optional): Masks the valid output region.
+ """
+ if self.conv_block_0.layers.conv.return_mask:
+ dx, mask_out = self.conv_block_0(x, *cond_inputs,
+ mask_in=mask_in, **kw_cond_inputs)
+ dx, mask_out = self.conv_block_1(dx, *cond_inputs,
+ mask_in=mask_out, **kw_cond_inputs)
+ else:
+ dx = self.conv_block_0(x, *cond_inputs,
+ mask_in=mask_in, **kw_cond_inputs)
+ dx = self.conv_block_1(dx, *cond_inputs,
+ mask_in=mask_in, **kw_cond_inputs)
+ mask_out = None
+
+ if self.learn_shortcut:
+ x_shortcut = self.conv_block_s(x, mask_in=mask_in, *cond_inputs,
+ **kw_cond_inputs)
+ if type(x_shortcut) == tuple:
+ x_shortcut, _ = x_shortcut
+ else:
+ x_shortcut = x
+ output = x_shortcut + dx
+
+ if mask_out is not None:
+ return output, mask_out
+ return self.output_scale * output
+
+
+class PartialRes2dBlock(_BasePartialResBlock):
+ r"""Residual block for 2D input with partial convolution.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ kernel_size (int, optional, default=3): Kernel size for the
+ convolutional filters in the residual link.
+ padding (int, optional, default=1): Padding size.
+ dilation (int, optional, default=1): Dilation factor.
+ groups (int, optional, default=1): Number of convolutional/linear
+ groups.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ multi_channel=False, return_mask=True,
+ apply_noise=False,
+ hidden_channels_equal_out_channels=False,
+ order='CNACNA', learn_shortcut=None, clamp=None,
+ output_scale=1):
+ super().__init__(in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias,
+ padding_mode, weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, multi_channel, return_mask,
+ apply_noise, hidden_channels_equal_out_channels,
+ order, PartialConv2dBlock, learn_shortcut, clamp,
+ output_scale)
+
+
+class PartialRes3dBlock(_BasePartialResBlock):
+ r"""Residual block for 3D input with partial convolution.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ kernel_size (int, optional, default=3): Kernel size for the
+ convolutional filters in the residual link.
+ padding (int, optional, default=1): Padding size.
+ dilation (int, optional, default=1): Dilation factor.
+ groups (int, optional, default=1): Number of convolutional/linear
+ groups.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ multi_channel=False, return_mask=True,
+ apply_noise=False, hidden_channels_equal_out_channels=False,
+ order='CNACNA', learn_shortcut=None, clamp=None,
+ output_scale=1):
+ super().__init__(in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias,
+ padding_mode, weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity, multi_channel,
+ return_mask, apply_noise,
+ hidden_channels_equal_out_channels,
+ order, PartialConv3dBlock, learn_shortcut, clamp,
+ output_scale)
+
+
+class _BaseMultiOutResBlock(_BaseResBlock):
+ r"""An abstract class for residual blocks that can returns multiple outputs.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity,
+ apply_noise, hidden_channels_equal_out_channels,
+ order, block, learn_shortcut, clamp=None, output_scale=1,
+ blur=False, upsample_first=True):
+ self.multiple_outputs = True
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, block,
+ learn_shortcut, clamp, output_scale, blur=blur,
+ upsample_first=upsample_first)
+
+ def forward(self, x, *cond_inputs):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ cond_inputs (list of tensors) : Conditional input tensors.
+ Returns:
+ (tuple):
+ - output (tensor): Output tensor.
+ - aux_outputs_0 (tensor): Auxiliary output of the first block.
+ - aux_outputs_1 (tensor): Auxiliary output of the second block.
+ """
+ dx, aux_outputs_0 = self.conv_block_0(x, *cond_inputs)
+ dx, aux_outputs_1 = self.conv_block_1(dx, *cond_inputs)
+ if self.learn_shortcut:
+ # We are not using the auxiliary outputs of self.conv_block_s.
+ x_shortcut, _ = self.conv_block_s(x, *cond_inputs)
+ else:
+ x_shortcut = x
+ output = x_shortcut + dx
+ return self.output_scale * output, aux_outputs_0, aux_outputs_1
+
+
+class MultiOutRes2dBlock(_BaseMultiOutResBlock):
+ r"""Residual block for 2D input. It can return multiple outputs, if some
+ layers in the block return more than one output.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ kernel_size (int, optional, default=3): Kernel size for the
+ convolutional filters in the residual link.
+ padding (int, optional, default=1): Padding size.
+ dilation (int, optional, default=1): Dilation factor.
+ groups (int, optional, default=1): Number of convolutional/linear
+ groups.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ apply_noise=False, hidden_channels_equal_out_channels=False,
+ order='CNACNA', learn_shortcut=None, clamp=None,
+ output_scale=1, blur=False, upsample_first=True):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order,
+ MultiOutConv2dBlock, learn_shortcut, clamp,
+ output_scale, blur=blur, upsample_first=upsample_first)
diff --git a/imaginaire/layers/residual_deep.py b/imaginaire/layers/residual_deep.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0bbcd497f4689bed4faf20e8e47c0fc4e282812
--- /dev/null
+++ b/imaginaire/layers/residual_deep.py
@@ -0,0 +1,346 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+from imaginaire.third_party.upfirdn2d import BlurDownsample, BlurUpsample
+from .conv import Conv2dBlock
+
+
+class _BaseDeepResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity,
+ nonlinearity, inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels,
+ order, block, learn_shortcut, output_scale, skip_block=None,
+ blur=True, border_free=True, resample_first=True,
+ skip_weight_norm=True, hidden_channel_ratio=4):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.output_scale = output_scale
+ self.resample_first = resample_first
+ self.stride = stride
+ self.blur = blur
+ self.border_free = border_free
+ assert not border_free
+ if skip_block is None:
+ skip_block = block
+
+ if order == 'pre_act':
+ order = 'NACNAC'
+ if isinstance(bias, bool):
+ # The bias for conv_block_0, conv_block_1, and conv_block_s.
+ biases = [bias, bias, bias]
+ elif isinstance(bias, list):
+ if len(bias) == 3:
+ biases = bias
+ else:
+ raise ValueError('Bias list must be 3.')
+ else:
+ raise ValueError('Bias must be either an integer or s list.')
+ self.learn_shortcut = learn_shortcut
+ if len(order) > 6 or len(order) < 5:
+ raise ValueError('order must be either 5 or 6 characters')
+ hidden_channels = in_channels // hidden_channel_ratio
+
+ # Parameters.
+ residual_params = {}
+ shortcut_params = {}
+ base_params = dict(dilation=dilation,
+ groups=groups,
+ padding_mode=padding_mode)
+ residual_params.update(base_params)
+ residual_params.update(
+ dict(activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ weight_norm_type=weight_norm_type,
+ weight_norm_params=weight_norm_params,
+ apply_noise=apply_noise)
+ )
+ shortcut_params.update(base_params)
+ shortcut_params.update(dict(kernel_size=1))
+ if skip_activation_norm:
+ shortcut_params.update(
+ dict(activation_norm_type=activation_norm_type,
+ activation_norm_params=activation_norm_params,
+ apply_noise=False))
+ if skip_weight_norm:
+ shortcut_params.update(
+ dict(weight_norm_type=weight_norm_type,
+ weight_norm_params=weight_norm_params))
+
+ # Residual branch.
+ if order.find('A') < order.find('C') and \
+ (activation_norm_type == '' or activation_norm_type == 'none'):
+ # Nonlinearity is the first operation in the residual path.
+ # In-place nonlinearity will modify the input variable and cause
+ # backward error.
+ first_inplace = False
+ else:
+ first_inplace = inplace_nonlinearity
+
+ (first_stride, second_stride, shortcut_stride,
+ first_blur, second_blur, shortcut_blur) = self._get_stride_blur()
+
+ self.conv_block_1x1_in = block(
+ in_channels, hidden_channels,
+ 1, 1, 0,
+ bias=biases[0],
+ nonlinearity=nonlinearity,
+ order=order[0:3],
+ inplace_nonlinearity=first_inplace,
+ **residual_params
+ )
+
+ self.conv_block_0 = block(
+ hidden_channels, hidden_channels,
+ kernel_size=2 if self.border_free and first_stride < 1 else
+ kernel_size,
+ padding=padding,
+ bias=biases[0],
+ nonlinearity=nonlinearity,
+ order=order[0:3],
+ inplace_nonlinearity=inplace_nonlinearity,
+ stride=first_stride,
+ blur=first_blur,
+ **residual_params
+ )
+ self.conv_block_1 = block(
+ hidden_channels, hidden_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ bias=biases[1],
+ nonlinearity=nonlinearity,
+ order=order[3:],
+ inplace_nonlinearity=inplace_nonlinearity,
+ stride=second_stride,
+ blur=second_blur,
+ **residual_params
+ )
+
+ self.conv_block_1x1_out = block(
+ hidden_channels, out_channels,
+ 1, 1, 0,
+ bias=biases[1],
+ nonlinearity=nonlinearity,
+ order=order[0:3],
+ inplace_nonlinearity=inplace_nonlinearity,
+ **residual_params
+ )
+
+ # Shortcut branch.
+ if self.learn_shortcut:
+ if skip_nonlinearity:
+ skip_nonlinearity_type = nonlinearity
+ else:
+ skip_nonlinearity_type = ''
+ self.conv_block_s = skip_block(in_channels, out_channels,
+ bias=biases[2],
+ nonlinearity=skip_nonlinearity_type,
+ order=order[0:3],
+ stride=shortcut_stride,
+ blur=shortcut_blur,
+ **shortcut_params)
+ elif in_channels < out_channels:
+ if skip_nonlinearity:
+ skip_nonlinearity_type = nonlinearity
+ else:
+ skip_nonlinearity_type = ''
+ self.conv_block_s = skip_block(in_channels,
+ out_channels - in_channels,
+ bias=biases[2],
+ nonlinearity=skip_nonlinearity_type,
+ order=order[0:3],
+ stride=shortcut_stride,
+ blur=shortcut_blur,
+ **shortcut_params)
+
+ # Whether this block expects conditional inputs.
+ self.conditional = \
+ getattr(self.conv_block_0, 'conditional', False) or \
+ getattr(self.conv_block_1, 'conditional', False) or \
+ getattr(self.conv_block_1x1_in, 'conditional', False) or \
+ getattr(self.conv_block_1x1_out, 'conditional', False)
+
+ def _get_stride_blur(self):
+ if self.stride > 1:
+ # Downsampling.
+ first_stride, second_stride = 1, self.stride
+ first_blur, second_blur = False, self.blur
+ shortcut_blur = False
+ shortcut_stride = 1
+ if self.blur:
+ # The shortcut branch uses blur_downsample + stride-1 conv
+ if self.border_free:
+ self.resample = nn.AvgPool2d(2)
+ else:
+ self.resample = BlurDownsample()
+ else:
+ shortcut_stride = self.stride
+ self.resample = nn.AvgPool2d(2)
+ elif self.stride < 1:
+ # Upsampling.
+ first_stride, second_stride = self.stride, 1
+ first_blur, second_blur = self.blur, False
+ shortcut_blur = False
+ shortcut_stride = 1
+ if self.blur:
+ # The shortcut branch uses blur_upsample + stride-1 conv
+ if self.border_free:
+ self.resample = nn.Upsample(scale_factor=2,
+ mode='bilinear')
+ else:
+ self.resample = BlurUpsample()
+ else:
+ shortcut_stride = self.stride
+ self.resample = nn.Upsample(scale_factor=2)
+ else:
+ first_stride = second_stride = 1
+ first_blur = second_blur = False
+ shortcut_stride = 1
+ shortcut_blur = False
+ self.resample = None
+ return (first_stride, second_stride, shortcut_stride,
+ first_blur, second_blur, shortcut_blur)
+
+ def conv_blocks(
+ self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
+ ):
+ if separate_cond:
+ assert len(list(cond_inputs)) == 4
+ dx = self.conv_block_1x1_in(x, cond_inputs[0],
+ **kw_cond_inputs.get('kwargs_0', {}))
+ dx = self.conv_block_0(dx, cond_inputs[1],
+ **kw_cond_inputs.get('kwargs_1', {}))
+ dx = self.conv_block_1(dx, cond_inputs[2],
+ **kw_cond_inputs.get('kwargs_2', {}))
+ dx = self.conv_block_1x1_out(dx, cond_inputs[3],
+ **kw_cond_inputs.get('kwargs_3', {}))
+ else:
+ dx = self.conv_block_1x1_in(x, *cond_inputs, **kw_cond_inputs)
+ dx = self.conv_block_0(dx, *cond_inputs, **kw_cond_inputs)
+ dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
+ dx = self.conv_block_1x1_out(dx, *cond_inputs, **kw_cond_inputs)
+ return dx
+
+ def forward(self, x, *cond_inputs, do_checkpoint=False, **kw_cond_inputs):
+ if do_checkpoint:
+ dx = checkpoint(self.conv_blocks, x, *cond_inputs, **kw_cond_inputs)
+ else:
+ dx = self.conv_blocks(x, *cond_inputs, **kw_cond_inputs)
+
+ if self.resample_first and self.resample is not None:
+ x = self.resample(x)
+ if self.learn_shortcut:
+ x_shortcut = self.conv_block_s(
+ x, *cond_inputs, **kw_cond_inputs
+ )
+ elif self.in_channels < self.out_channels:
+ x_shortcut_pad = self.conv_block_s(
+ x, *cond_inputs, **kw_cond_inputs
+ )
+ x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
+ elif self.in_channels > self.out_channels:
+ x_shortcut = x[:, :self.out_channels, :, :]
+ else:
+ x_shortcut = x
+ if not self.resample_first and self.resample is not None:
+ x_shortcut = self.resample(x_shortcut)
+
+ output = x_shortcut + dx
+ return self.output_scale * output
+
+ def extra_repr(self):
+ s = 'output_scale={output_scale}'
+ return s.format(**self.__dict__)
+
+
+class DeepRes2dBlock(_BaseDeepResBlock):
+ r"""Residual block for 2D input.
+
+ Args:
+ in_channels (int) : Number of channels in the input tensor.
+ out_channels (int) : Number of channels in the output tensor.
+ kernel_size (int, optional, default=3): Kernel size for the
+ convolutional filters in the residual link.
+ padding (int, optional, default=1): Padding size.
+ dilation (int, optional, default=1): Dilation factor.
+ groups (int, optional, default=1): Number of convolutional/linear
+ groups.
+ padding_mode (string, optional, default='zeros'): Type of padding:
+ ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ weight_norm_type (str, optional, default='none'):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ weight_norm_params (obj, optional, default=None):
+ Parameters of weight normalization.
+ If not ``None``, ``weight_norm_params.__dict__`` will be used as
+ keyword arguments when initializing weight normalization.
+ activation_norm_type (str, optional, default='none'):
+ Type of activation normalization.
+ ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
+ ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
+ ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
+ activation_norm_params (obj, optional, default=None):
+ Parameters of activation normalization.
+ If not ``None``, ``activation_norm_params.__dict__`` will be used as
+ keyword arguments when initializing activation normalization.
+ skip_activation_norm (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies activation norm to the
+ learned shortcut connection.
+ skip_nonlinearity (bool, optional, default=True): If ``True`` and
+ ``learn_shortcut`` is also ``True``, applies nonlinearity to the
+ learned shortcut connection.
+ nonlinearity (str, optional, default='none'):
+ Type of nonlinear activation function in the residual link.
+ ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
+ ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
+ inplace_nonlinearity (bool, optional, default=False): If ``True``,
+ set ``inplace=True`` when initializing the nonlinearity layers.
+ apply_noise (bool, optional, default=False): If ``True``, adds
+ Gaussian noise with learnable magnitude to the convolution output.
+ hidden_channels_equal_out_channels (bool, optional, default=False):
+ If ``True``, set the hidden channel number to be equal to the
+ output channel number. If ``False``, the hidden channel number
+ equals to the smaller of the input channel number and the
+ output channel number.
+ order (str, optional, default='CNACNA'): Order of operations
+ in the residual link.
+ ``'C'``: convolution,
+ ``'N'``: normalization,
+ ``'A'``: nonlinear activation.
+ learn_shortcut (bool, optional, default=False): If ``True``, always use
+ a convolutional shortcut instead of an identity one, otherwise only
+ use a convolutional one if input and output have different number of
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding=1, dilation=1, groups=1, bias=True,
+ padding_mode='zeros',
+ weight_norm_type='none', weight_norm_params=None,
+ activation_norm_type='none', activation_norm_params=None,
+ skip_activation_norm=True, skip_nonlinearity=False,
+ skip_weight_norm=True,
+ nonlinearity='leakyrelu', inplace_nonlinearity=False,
+ apply_noise=False, hidden_channels_equal_out_channels=False,
+ order='CNACNA', learn_shortcut=False, output_scale=1,
+ blur=True, resample_first=True, border_free=False):
+ super().__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ skip_activation_norm, skip_nonlinearity, nonlinearity,
+ inplace_nonlinearity, apply_noise,
+ hidden_channels_equal_out_channels, order, Conv2dBlock,
+ learn_shortcut, output_scale, blur=blur,
+ resample_first=resample_first, border_free=border_free,
+ skip_weight_norm=skip_weight_norm)
diff --git a/imaginaire/layers/vit.py b/imaginaire/layers/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..abd0d039d715444efc3c5a1e9889330bfa5b4c4f
--- /dev/null
+++ b/imaginaire/layers/vit.py
@@ -0,0 +1,204 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+
+from .misc import ApplyNoise
+from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur
+
+
+class ViT2dBlock(nn.Module):
+ r"""An abstract wrapper class that wraps a torch convolution or linear layer
+ with normalization and nonlinearity.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ weight_norm_type, weight_norm_params,
+ activation_norm_type, activation_norm_params,
+ nonlinearity, inplace_nonlinearity,
+ apply_noise, blur, order, input_dim, clamp,
+ blur_kernel=(1, 3, 3, 1), output_scale=None,
+ init_gain=1.0):
+ super().__init__()
+ from .nonlinearity import get_nonlinearity_layer
+ from .weight_norm import get_weight_norm_layer
+ from .activation_norm import get_activation_norm_layer
+ self.weight_norm_type = weight_norm_type
+ self.stride = stride
+ self.clamp = clamp
+ self.init_gain = init_gain
+
+ # Nonlinearity layer.
+ if 'fused' in nonlinearity:
+ # Fusing nonlinearity with bias.
+ lr_mul = getattr(weight_norm_params, 'lr_mul', 1)
+ conv_before_nonlinearity = order.find('C') < order.find('A')
+ if conv_before_nonlinearity:
+ assert bias
+ bias = False
+ channel = out_channels if conv_before_nonlinearity else in_channels
+ nonlinearity_layer = get_nonlinearity_layer(
+ nonlinearity, inplace=inplace_nonlinearity,
+ num_channels=channel, lr_mul=lr_mul)
+ else:
+ nonlinearity_layer = get_nonlinearity_layer(
+ nonlinearity, inplace=inplace_nonlinearity)
+
+ # Noise injection layer.
+ if apply_noise:
+ order = order.replace('C', 'CG')
+ noise_layer = ApplyNoise()
+ else:
+ noise_layer = None
+
+ # Convolutional layer.
+ if blur:
+ if stride == 2:
+ # Blur - Conv - Noise - Activate
+ p = (len(blur_kernel) - 2) + (kernel_size - 1)
+ pad0, pad1 = (p + 1) // 2, p // 2
+ padding = 0
+ blur_layer = Blur(
+ blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
+ )
+ order = order.replace('C', 'BC')
+ elif stride == 0.5:
+ # Conv - Blur - Noise - Activate
+ padding = 0
+ p = (len(blur_kernel) - 2) - (kernel_size - 1)
+ pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1
+ blur_layer = Blur(
+ blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
+ )
+ order = order.replace('C', 'CB')
+ elif stride == 1:
+ # No blur for now
+ blur_layer = nn.Identity()
+ else:
+ raise NotImplementedError
+ else:
+ blur_layer = nn.Identity()
+
+ if weight_norm_params is None:
+ weight_norm_params = SimpleNamespace()
+ weight_norm = get_weight_norm_layer(
+ weight_norm_type, **vars(weight_norm_params))
+ conv_layer = weight_norm(self._get_conv_layer(
+ in_channels, out_channels, kernel_size, stride, padding, dilation,
+ groups, bias, padding_mode, input_dim))
+
+ # Normalization layer.
+ conv_before_norm = order.find('C') < order.find('N')
+ norm_channels = out_channels if conv_before_norm else in_channels
+ if activation_norm_params is None:
+ activation_norm_params = SimpleNamespace()
+ activation_norm_layer = get_activation_norm_layer(
+ norm_channels,
+ activation_norm_type,
+ input_dim,
+ **vars(activation_norm_params))
+
+ # Mapping from operation names to layers.
+ mappings = {'C': {'conv': conv_layer},
+ 'N': {'norm': activation_norm_layer},
+ 'A': {'nonlinearity': nonlinearity_layer}}
+ mappings.update({'B': {'blur': blur_layer}})
+ mappings.update({'G': {'noise': noise_layer}})
+
+ # All layers in order.
+ self.layers = nn.ModuleDict()
+ for op in order:
+ if list(mappings[op].values())[0] is not None:
+ self.layers.update(mappings[op])
+
+ # Whether this block expects conditional inputs.
+ self.conditional = \
+ getattr(conv_layer, 'conditional', False) or \
+ getattr(activation_norm_layer, 'conditional', False)
+
+ if output_scale is not None:
+ self.output_scale = nn.Parameter(torch.tensor(output_scale))
+ else:
+ self.register_parameter("output_scale", None)
+
+ def forward(self, x, *cond_inputs, **kw_cond_inputs):
+ r"""
+
+ Args:
+ x (tensor): Input tensor.
+ cond_inputs (list of tensors) : Conditional input tensors.
+ kw_cond_inputs (dict) : Keyword conditional inputs.
+ """
+ for key, layer in self.layers.items():
+ if getattr(layer, 'conditional', False):
+ # Layers that require conditional inputs.
+ x = layer(x, *cond_inputs, **kw_cond_inputs)
+ else:
+ x = layer(x)
+ if self.clamp is not None and isinstance(layer, nn.Conv2d):
+ x.clamp_(max=self.clamp)
+ if key == 'conv':
+ if self.output_scale is not None:
+ x = x * self.output_scale
+ return x
+
+ def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias, padding_mode,
+ input_dim):
+ # Returns the convolutional layer.
+ if input_dim == 0:
+ layer = nn.Linear(in_channels, out_channels, bias)
+ else:
+ if stride < 1: # Fractionally-strided convolution.
+ padding_mode = 'zeros'
+ assert padding == 0
+ layer_type = getattr(nn, f'ConvTranspose{input_dim}d')
+ stride = round(1 / stride)
+ else:
+ layer_type = getattr(nn, f'Conv{input_dim}d')
+ layer = layer_type(
+ in_channels, out_channels, kernel_size, stride, padding,
+ dilation=dilation, groups=groups, bias=bias,
+ padding_mode=padding_mode
+ )
+
+ return layer
+
+ def __repr__(self):
+ main_str = self._get_name() + '('
+ child_lines = []
+ for name, layer in self.layers.items():
+ mod_str = repr(layer)
+ if name == 'conv' and self.weight_norm_type != 'none' and \
+ self.weight_norm_type != '':
+ mod_str = mod_str[:-1] + \
+ ', weight_norm={}'.format(self.weight_norm_type) + ')'
+ if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1:
+ mod_str = mod_str[:-1] + \
+ ', lr_mul={}'.format(layer.base_lr_mul) + ')'
+ mod_str = self._addindent(mod_str, 2)
+ child_lines.append(mod_str)
+ if len(child_lines) == 1:
+ main_str += child_lines[0]
+ else:
+ main_str += '\n ' + '\n '.join(child_lines) + '\n'
+
+ main_str += ')'
+ return main_str
+
+ @staticmethod
+ def _addindent(s_, numSpaces):
+ s = s_.split('\n')
+ # don't do anything for single-line stuff
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(numSpaces * ' ') + line for line in s]
+ s = '\n'.join(s)
+ s = first + '\n' + s
+ return s
diff --git a/imaginaire/layers/weight_norm.py b/imaginaire/layers/weight_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..e15ca2d21cea70062fa24ffdcd5adab51c8dcb25
--- /dev/null
+++ b/imaginaire/layers/weight_norm.py
@@ -0,0 +1,267 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import collections
+import functools
+
+import torch
+from torch import nn
+from torch.nn.utils import spectral_norm, weight_norm
+from torch.nn.utils.spectral_norm import SpectralNorm, \
+ SpectralNormStateDictHook, SpectralNormLoadStateDictPreHook
+
+from .conv import LinearBlock
+
+
+class WeightDemodulation(nn.Module):
+ r"""Weight demodulation in
+ "Analyzing and Improving the Image Quality of StyleGAN", Karras et al.
+
+ Args:
+ conv (torch.nn.Modules): Convolutional layer.
+ cond_dims (int): The number of channels in the conditional input.
+ eps (float, optional, default=1e-8): a value added to the
+ denominator for numerical stability.
+ adaptive_bias (bool, optional, default=False): If ``True``, adaptively
+ predicts bias from the conditional input.
+ demod (bool, optional, default=False): If ``True``, performs
+ weight demodulation.
+ """
+
+ def __init__(self, conv, cond_dims, eps=1e-8,
+ adaptive_bias=False, demod=True):
+ super().__init__()
+ self.conv = conv
+ self.adaptive_bias = adaptive_bias
+ if adaptive_bias:
+ self.conv.register_parameter('bias', None)
+ self.fc_beta = LinearBlock(cond_dims, self.conv.out_channels)
+ self.fc_gamma = LinearBlock(cond_dims, self.conv.in_channels)
+ self.eps = eps
+ self.demod = demod
+ self.conditional = True
+
+ def forward(self, x, y, **_kwargs):
+ r"""Weight demodulation forward"""
+ b, c, h, w = x.size()
+ self.conv.groups = b
+ gamma = self.fc_gamma(y)
+ gamma = gamma[:, None, :, None, None]
+ weight = self.conv.weight[None, :, :, :, :] * gamma
+
+ if self.demod:
+ d = torch.rsqrt(
+ (weight ** 2).sum(
+ dim=(2, 3, 4), keepdim=True) + self.eps)
+ weight = weight * d
+
+ x = x.reshape(1, -1, h, w)
+ _, _, *ws = weight.shape
+ weight = weight.reshape(b * self.conv.out_channels, *ws)
+ x = self.conv._conv_forward(x, weight)
+
+ x = x.reshape(-1, self.conv.out_channels, h, w)
+ if self.adaptive_bias:
+ x += self.fc_beta(y)[:, :, None, None]
+ return x
+
+
+def weight_demod(
+ conv, cond_dims=256, eps=1e-8, adaptive_bias=False, demod=True):
+ r"""Weight demodulation."""
+ return WeightDemodulation(conv, cond_dims, eps, adaptive_bias, demod)
+
+
+class ScaledLR(object):
+ def __init__(self, weight_name, bias_name):
+ self.weight_name = weight_name
+ self.bias_name = bias_name
+
+ def compute_weight(self, module):
+ weight = getattr(module, self.weight_name + '_ori')
+ return weight * module.weight_scale
+
+ def compute_bias(self, module):
+ bias = getattr(module, self.bias_name + '_ori')
+ if bias is not None:
+ return bias * module.bias_scale
+ else:
+ return None
+
+ @staticmethod
+ def apply(module, weight_name, bias_name, lr_mul, equalized):
+ assert weight_name == 'weight'
+ assert bias_name == 'bias'
+ fn = ScaledLR(weight_name, bias_name)
+ module.register_forward_pre_hook(fn)
+
+ if hasattr(module, bias_name):
+ # module.bias is a parameter (can be None).
+ bias = getattr(module, bias_name)
+ delattr(module, bias_name)
+ module.register_parameter(bias_name + '_ori', bias)
+ else:
+ # module.bias does not exist.
+ bias = None
+ setattr(module, bias_name + '_ori', bias)
+ if bias is not None:
+ setattr(module, bias_name, bias.data)
+ else:
+ setattr(module, bias_name, None)
+ module.register_buffer('bias_scale', torch.tensor(lr_mul))
+
+ if hasattr(module, weight_name + '_orig'):
+ # The module has been wrapped with spectral normalization.
+ # We only want to keep a single weight parameter.
+ weight = getattr(module, weight_name + '_orig')
+ delattr(module, weight_name + '_orig')
+ module.register_parameter(weight_name + '_ori', weight)
+ setattr(module, weight_name + '_orig', weight.data)
+ # Put this hook before the spectral norm hook.
+ module._forward_pre_hooks = collections.OrderedDict(
+ reversed(list(module._forward_pre_hooks.items()))
+ )
+ module.use_sn = True
+ else:
+ weight = getattr(module, weight_name)
+ delattr(module, weight_name)
+ module.register_parameter(weight_name + '_ori', weight)
+ setattr(module, weight_name, weight.data)
+ module.use_sn = False
+
+ # assert weight.dim() == 4 or weight.dim() == 2
+ if equalized:
+ fan_in = weight.data.size(1) * weight.data[0][0].numel()
+ # Theoretically, the gain should be sqrt(2) instead of 1.
+ # The official StyleGAN2 uses 1 for some reason.
+ module.register_buffer(
+ 'weight_scale', torch.tensor(lr_mul * ((1 / fan_in) ** 0.5))
+ )
+ else:
+ module.register_buffer('weight_scale', torch.tensor(lr_mul))
+
+ module.lr_mul = module.weight_scale
+ module.base_lr_mul = lr_mul
+
+ return fn
+
+ def remove(self, module):
+ with torch.no_grad():
+ weight = self.compute_weight(module)
+ delattr(module, self.weight_name + '_ori')
+
+ if module.use_sn:
+ setattr(module, self.weight_name + '_orig', weight.detach())
+ else:
+ delattr(module, self.weight_name)
+ module.register_parameter(self.weight_name,
+ torch.nn.Parameter(weight.detach()))
+
+ with torch.no_grad():
+ bias = self.compute_bias(module)
+ delattr(module, self.bias_name)
+ delattr(module, self.bias_name + '_ori')
+ if bias is not None:
+ module.register_parameter(self.bias_name,
+ torch.nn.Parameter(bias.detach()))
+ else:
+ module.register_parameter(self.bias_name, None)
+
+ module.lr_mul = 1.0
+ module.base_lr_mul = 1.0
+
+ def __call__(self, module, input):
+ weight = self.compute_weight(module)
+ if module.use_sn:
+ # The following spectral norm hook will compute the SN of
+ # "module.weight_orig" and store the normalized weight in
+ # "module.weight".
+ setattr(module, self.weight_name + '_orig', weight)
+ else:
+ setattr(module, self.weight_name, weight)
+ bias = self.compute_bias(module)
+ setattr(module, self.bias_name, bias)
+
+
+def remove_weight_norms(module, weight_name='weight', bias_name='bias'):
+ if hasattr(module, 'weight_ori') or hasattr(module, 'weight_orig'):
+ for k in list(module._forward_pre_hooks.keys()):
+ hook = module._forward_pre_hooks[k]
+ if (isinstance(hook, ScaledLR) or isinstance(hook, SpectralNorm)):
+ hook.remove(module)
+ del module._forward_pre_hooks[k]
+
+ for k, hook in module._state_dict_hooks.items():
+ if isinstance(hook, SpectralNormStateDictHook) and \
+ hook.fn.name == weight_name:
+ del module._state_dict_hooks[k]
+ break
+
+ for k, hook in module._load_state_dict_pre_hooks.items():
+ if isinstance(hook, SpectralNormLoadStateDictPreHook) and \
+ hook.fn.name == weight_name:
+ del module._load_state_dict_pre_hooks[k]
+ break
+
+ return module
+
+
+def remove_equalized_lr(module, weight_name='weight', bias_name='bias'):
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, ScaledLR) and hook.weight_name == weight_name:
+ hook.remove(module)
+ del module._forward_pre_hooks[k]
+ break
+ else:
+ raise ValueError("Equalized learning rate not found")
+
+ return module
+
+
+def scaled_lr(
+ module, weight_name='weight', bias_name='bias', lr_mul=1.,
+ equalized=False,
+):
+ ScaledLR.apply(module, weight_name, bias_name, lr_mul, equalized)
+ return module
+
+
+def get_weight_norm_layer(norm_type, **norm_params):
+ r"""Return weight normalization.
+
+ Args:
+ norm_type (str):
+ Type of weight normalization.
+ ``'none'``, ``'spectral'``, ``'weight'``
+ or ``'weight_demod'``.
+ norm_params: Arbitrary keyword arguments that will be used to
+ initialize the weight normalization.
+ """
+ if norm_type == 'none' or norm_type == '': # no normalization
+ return lambda x: x
+ elif norm_type == 'spectral': # spectral normalization
+ return functools.partial(spectral_norm, **norm_params)
+ elif norm_type == 'weight': # weight normalization
+ return functools.partial(weight_norm, **norm_params)
+ elif norm_type == 'weight_demod': # weight demodulation
+ return functools.partial(weight_demod, **norm_params)
+ elif norm_type == 'equalized_lr': # equalized learning rate
+ return functools.partial(scaled_lr, equalized=True, **norm_params)
+ elif norm_type == 'scaled_lr': # equalized learning rate
+ return functools.partial(scaled_lr, **norm_params)
+ elif norm_type == 'equalized_lr_spectral':
+ lr_mul = norm_params.pop('lr_mul', 1.0)
+ return lambda x: functools.partial(
+ scaled_lr, equalized=True, lr_mul=lr_mul)(
+ functools.partial(spectral_norm, **norm_params)(x)
+ )
+ elif norm_type == 'scaled_lr_spectral':
+ lr_mul = norm_params.pop('lr_mul', 1.0)
+ return lambda x: functools.partial(
+ scaled_lr, lr_mul=lr_mul)(
+ functools.partial(spectral_norm, **norm_params)(x)
+ )
+ else:
+ raise ValueError(
+ 'Weight norm layer %s is not recognized' % norm_type)
diff --git a/imaginaire/losses/TVloss.py b/imaginaire/losses/TVloss.py
new file mode 100644
index 0000000000000000000000000000000000000000..158b12519e90c580df1c71658c86b47ce8e71a6e
--- /dev/null
+++ b/imaginaire/losses/TVloss.py
@@ -0,0 +1,15 @@
+import torch
+import torch.nn as nn
+
+
+class TV_loss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self,input):
+ B,D1,D2,D3 = input.size()
+ tv_d1 = torch.pow(input[:,1:,:,:]-input[:,:-1,:,:], 2).sum()
+ tv_d2 = torch.pow(input[:,:,1:,:]-input[:,:,:-1,:], 2).sum()
+ tv_d3 = torch.pow(input[:,:,:,1:]-input[:,:,:,:-1], 2).sum()
+ return (tv_d1+tv_d2+tv_d3)/(B*D1*D2*D3)
+
diff --git a/imaginaire/losses/__init__.py b/imaginaire/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3b7d8e381cd8fe0771dd2ce2478af0986a9be87
--- /dev/null
+++ b/imaginaire/losses/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .gan import GANLoss
+from .perceptual import PerceptualLoss
+from .feature_matching import FeatureMatchingLoss
+from .kl import GaussianKLLoss
+from .flow import MaskedL1Loss, FlowLoss
+from .dict import DictLoss
+from .weighted_mse import WeightedMSELoss
+from .TVloss import TV_loss
+
+__all__ = ['GANLoss', 'PerceptualLoss', 'FeatureMatchingLoss', 'GaussianKLLoss',
+ 'MaskedL1Loss', 'FlowLoss', 'DictLoss',
+ 'WeightedMSELoss','TV_loss']
+
+try:
+ from .gradient_penalty import GradientPenaltyLoss
+ __all__.extend(['GradientPenaltyLoss'])
+except: # noqa
+ pass
diff --git a/imaginaire/losses/__pycache__/TVloss.cpython-38.pyc b/imaginaire/losses/__pycache__/TVloss.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e9cc54635ab39fc68ce1bf7fa42f3880f15ca82
Binary files /dev/null and b/imaginaire/losses/__pycache__/TVloss.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/__init__.cpython-38.pyc b/imaginaire/losses/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a283a2113362dc366af1fa2376349246babbd54b
Binary files /dev/null and b/imaginaire/losses/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/dict.cpython-38.pyc b/imaginaire/losses/__pycache__/dict.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..899655408ca317238e65846264a0fcbadd04a4d1
Binary files /dev/null and b/imaginaire/losses/__pycache__/dict.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/feature_matching.cpython-38.pyc b/imaginaire/losses/__pycache__/feature_matching.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbeb6414e6aeb71b2ec64402751478cb786495c4
Binary files /dev/null and b/imaginaire/losses/__pycache__/feature_matching.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/flow.cpython-38.pyc b/imaginaire/losses/__pycache__/flow.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1043723e126ca15751d7651e0e848b69ea7df80e
Binary files /dev/null and b/imaginaire/losses/__pycache__/flow.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/gan.cpython-38.pyc b/imaginaire/losses/__pycache__/gan.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fa848211b872fdf205c9c30b31d19bb2f3aae0e
Binary files /dev/null and b/imaginaire/losses/__pycache__/gan.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/info_nce.cpython-38.pyc b/imaginaire/losses/__pycache__/info_nce.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e56e1ed11ac004f1aa741f38c930f033457feca
Binary files /dev/null and b/imaginaire/losses/__pycache__/info_nce.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/kl.cpython-38.pyc b/imaginaire/losses/__pycache__/kl.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..140c611c8fda1667885013dfd8f35225b00bd8fb
Binary files /dev/null and b/imaginaire/losses/__pycache__/kl.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/perceptual.cpython-38.pyc b/imaginaire/losses/__pycache__/perceptual.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3aa1b81d9bc18105b7f3c3940b4b49831f19a22b
Binary files /dev/null and b/imaginaire/losses/__pycache__/perceptual.cpython-38.pyc differ
diff --git a/imaginaire/losses/__pycache__/weighted_mse.cpython-38.pyc b/imaginaire/losses/__pycache__/weighted_mse.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e0a751c997154fbaec8d66e1527da42ff45bb73
Binary files /dev/null and b/imaginaire/losses/__pycache__/weighted_mse.cpython-38.pyc differ
diff --git a/imaginaire/losses/dict.py b/imaginaire/losses/dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d0e5ee9516013036b92c564fa349755fc9fc1c9
--- /dev/null
+++ b/imaginaire/losses/dict.py
@@ -0,0 +1,36 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch.nn as nn
+
+
+class DictLoss(nn.Module):
+ def __init__(self, criterion='l1'):
+ super(DictLoss, self).__init__()
+ if criterion == 'l1':
+ self.criterion = nn.L1Loss()
+ elif criterion == 'l2' or criterion == 'mse':
+ self.criterion = nn.MSELoss()
+ else:
+ raise ValueError('Criterion %s is not recognized' % criterion)
+
+ def forward(self, fake, real):
+ """Return the target vector for the l1/l2 loss computation.
+
+ Args:
+ fake (dict, list or tuple): Discriminator features of fake images.
+ real (dict, list or tuple): Discriminator features of real images.
+ Returns:
+ loss (tensor): Loss value.
+ """
+ loss = 0
+ if type(fake) == dict:
+ for key in fake.keys():
+ loss += self.criterion(fake[key], real[key].detach())
+ elif type(fake) == list or type(fake) == tuple:
+ for f, r in zip(fake, real):
+ loss += self.criterion(f, r.detach())
+ else:
+ loss += self.criterion(fake, real.detach())
+ return loss
diff --git a/imaginaire/losses/feature_matching.py b/imaginaire/losses/feature_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..f70034b3c3afba5a914261b55cf0abeab832391c
--- /dev/null
+++ b/imaginaire/losses/feature_matching.py
@@ -0,0 +1,38 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch.nn as nn
+
+
+class FeatureMatchingLoss(nn.Module):
+ r"""Compute feature matching loss"""
+ def __init__(self, criterion='l1'):
+ super(FeatureMatchingLoss, self).__init__()
+ if criterion == 'l1':
+ self.criterion = nn.L1Loss()
+ elif criterion == 'l2' or criterion == 'mse':
+ self.criterion = nn.MSELoss()
+ else:
+ raise ValueError('Criterion %s is not recognized' % criterion)
+
+ def forward(self, fake_features, real_features):
+ r"""Return the target vector for the binary cross entropy loss
+ computation.
+
+ Args:
+ fake_features (list of lists): Discriminator features of fake images.
+ real_features (list of lists): Discriminator features of real images.
+
+ Returns:
+ (tensor): Loss value.
+ """
+ num_d = len(fake_features)
+ dis_weight = 1.0 / num_d
+ loss = fake_features[0][0].new_tensor(0)
+ for i in range(num_d):
+ for j in range(len(fake_features[i])):
+ tmp_loss = self.criterion(fake_features[i][j],
+ real_features[i][j].detach())
+ loss += dis_weight * tmp_loss
+ return loss
diff --git a/imaginaire/losses/flow.py b/imaginaire/losses/flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..3464d949bba8cf148f7bc5c8fa092995f66c663a
--- /dev/null
+++ b/imaginaire/losses/flow.py
@@ -0,0 +1,313 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# flake8: noqa
+import importlib
+import warnings
+
+import torch
+import torch.nn as nn
+
+from imaginaire.model_utils.fs_vid2vid import (get_face_mask, get_fg_mask,
+ get_part_mask, pick_image,
+ resample)
+
+
+class MaskedL1Loss(nn.Module):
+ r"""Masked L1 loss constructor."""
+
+ def __init__(self, normalize_over_valid=False):
+ super(MaskedL1Loss, self).__init__()
+ self.criterion = nn.L1Loss()
+ self.normalize_over_valid = normalize_over_valid
+
+ def forward(self, input, target, mask):
+ r"""Masked L1 loss computation.
+
+ Args:
+ input (tensor): Input tensor.
+ target (tensor): Target tensor.
+ mask (tensor): Mask to be applied to the output loss.
+
+ Returns:
+ (tensor): Loss value.
+ """
+ mask = mask.expand_as(input)
+ loss = self.criterion(input * mask, target * mask)
+ if self.normalize_over_valid:
+ # The loss has been averaged over all pixels.
+ # Only average over regions which are valid.
+ loss = loss * torch.numel(mask) / (torch.sum(mask) + 1e-6)
+ return loss
+
+
+class FlowLoss(nn.Module):
+ r"""Flow loss constructor.
+
+ Args:
+ cfg (obj): Configuration.
+ """
+
+ def __init__(self, cfg):
+ super(FlowLoss, self).__init__()
+ self.cfg = cfg
+ self.data_cfg = cfg.data
+ self.criterion = nn.L1Loss()
+ self.criterionMasked = MaskedL1Loss()
+ flow_module = importlib.import_module(cfg.flow_network.type)
+ self.flowNet = flow_module.FlowNet(pretrained=True)
+ self.warp_ref = getattr(cfg.gen.flow, 'warp_ref', False)
+ self.pose_cfg = pose_cfg = getattr(cfg.data, 'for_pose_dataset', None)
+ self.for_pose_dataset = pose_cfg is not None
+ self.has_fg = getattr(cfg.data, 'has_foreground', False)
+
+ def forward(self, data, net_G_output, current_epoch):
+ r"""Compute losses on the output flow and occlusion mask.
+
+ Args:
+ data (dict): Input data.
+ net_G_output (dict): Generator output.
+ current_epoch (int): Current training epoch number.
+ Returns:
+ (dict):
+ - loss_flow_L1 (tensor): L1 loss compared to ground truth flow.
+ - loss_flow_warp (tensor): L1 loss between the warped image and the
+ target image when using the flow to warp.
+ - loss_mask (tensor): Loss for the occlusion mask.
+ """
+ tgt_label, tgt_image = data['label'], data['image']
+
+ fake_image = net_G_output['fake_images']
+ warped_images = net_G_output['warped_images']
+ flow = net_G_output['fake_flow_maps']
+ occ_mask = net_G_output['fake_occlusion_masks']
+
+ if self.warp_ref:
+ # Pick the most similar reference image to warp.
+ ref_labels, ref_images = data['ref_labels'], data['ref_images']
+ ref_idx = net_G_output['ref_idx']
+ ref_label, ref_image = pick_image([ref_labels, ref_images], ref_idx)
+ else:
+ ref_label = ref_image = None
+
+ # Compute the ground truth flows and confidence maps.
+ flow_gt_prev = flow_gt_ref = conf_gt_prev = conf_gt_ref = None
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ if self.warp_ref:
+ # Compute GT for warping reference -> target.
+ if self.for_pose_dataset:
+ # Use DensePose maps to compute flows for pose dataset.
+ flow_gt_ref, conf_gt_ref = self.flowNet(tgt_label[:, :3],
+ ref_label[:, :3])
+ else:
+ # Use RGB images for other datasets.
+ flow_gt_ref, conf_gt_ref = self.flowNet(tgt_image,
+ ref_image)
+
+ if current_epoch >= self.cfg.single_frame_epoch and \
+ data['real_prev_image'] is not None:
+ # Compute GT for warping previous -> target.
+ tgt_image_prev = data['real_prev_image']
+ flow_gt_prev, conf_gt_prev = self.flowNet(tgt_image,
+ tgt_image_prev)
+
+ flow_gt = [flow_gt_ref, flow_gt_prev]
+ flow_conf_gt = [conf_gt_ref, conf_gt_prev]
+ # Get the foreground masks.
+ fg_mask, ref_fg_mask = get_fg_mask([tgt_label, ref_label], self.has_fg)
+
+ # Compute losses for flow maps and masks.
+ loss_flow_L1, loss_flow_warp, body_mask_diff = \
+ self.compute_flow_losses(flow, warped_images, tgt_image, flow_gt,
+ flow_conf_gt, fg_mask, tgt_label,
+ ref_label)
+
+ loss_mask = self.compute_mask_losses(
+ occ_mask, fake_image, warped_images, tgt_label, tgt_image,
+ fg_mask, ref_fg_mask, body_mask_diff)
+
+ return loss_flow_L1, loss_flow_warp, loss_mask
+
+ def compute_flow_losses(self, flow, warped_images, tgt_image, flow_gt,
+ flow_conf_gt, fg_mask, tgt_label, ref_label):
+ r"""Compute losses on the generated flow maps.
+
+ Args:
+ flow (tensor or list of tensors): Generated flow maps.
+ warped_images (tensor or list of tensors): Warped images using the
+ flow maps.
+ tgt_image (tensor): Target image for the warped image.
+ flow_gt (tensor or list of tensors): Ground truth flow maps.
+ flow_conf_gt (tensor or list of tensors): Confidence for the ground
+ truth flow maps.
+ fg_mask (tensor): Foreground mask for the target image.
+ tgt_label (tensor): Target label map.
+ ref_label (tensor): Reference label map.
+ Returns:
+ (dict):
+ - loss_flow_L1 (tensor): L1 loss compared to ground truth flow.
+ - loss_flow_warp (tensor): L1 loss between the warped image and the
+ target image when using the flow to warp.
+ - body_mask_diff (tensor): Difference between warped body part map
+ and target body part map. Used for pose dataset only.
+ """
+ loss_flow_L1 = torch.tensor(0., device=torch.device('cuda'))
+ loss_flow_warp = torch.tensor(0., device=torch.device('cuda'))
+ if isinstance(flow, list):
+ # Compute flow losses for both warping reference -> target and
+ # previous -> target.
+ for i in range(len(flow)):
+ loss_flow_L1_i, loss_flow_warp_i = \
+ self.compute_flow_loss(flow[i], warped_images[i], tgt_image,
+ flow_gt[i], flow_conf_gt[i], fg_mask)
+ loss_flow_L1 += loss_flow_L1_i
+ loss_flow_warp += loss_flow_warp_i
+ else:
+ # Compute loss for warping either reference or previous images.
+ loss_flow_L1, loss_flow_warp = \
+ self.compute_flow_loss(flow, warped_images, tgt_image,
+ flow_gt[-1], flow_conf_gt[-1], fg_mask)
+
+ # For pose dataset only.
+ body_mask_diff = None
+ if self.warp_ref:
+ if self.for_pose_dataset:
+ # Warped reference body part map should be similar to target
+ # body part map.
+ body_mask = get_part_mask(tgt_label[:, 2])
+ ref_body_mask = get_part_mask(ref_label[:, 2])
+ warped_ref_body_mask = resample(ref_body_mask, flow[0])
+ loss_flow_warp += self.criterion(warped_ref_body_mask,
+ body_mask)
+ body_mask_diff = torch.sum(
+ abs(warped_ref_body_mask - body_mask), dim=1, keepdim=True)
+
+ if self.has_fg:
+ # Warped reference foreground map should be similar to target
+ # foreground map.
+ fg_mask, ref_fg_mask = \
+ get_fg_mask([tgt_label, ref_label], True)
+ warped_ref_fg_mask = resample(ref_fg_mask, flow[0])
+ loss_flow_warp += self.criterion(warped_ref_fg_mask, fg_mask)
+
+ return loss_flow_L1, loss_flow_warp, body_mask_diff
+
+ def compute_flow_loss(self, flow, warped_image, tgt_image, flow_gt,
+ flow_conf_gt, fg_mask):
+ r"""Compute losses on the generated flow map.
+
+ Args:
+ flow (tensor): Generated flow map.
+ warped_image (tensor): Warped image using the flow map.
+ tgt_image (tensor): Target image for the warped image.
+ flow_gt (tensor): Ground truth flow map.
+ flow_conf_gt (tensor): Confidence for the ground truth flow map.
+ fg_mask (tensor): Foreground mask for the target image.
+ Returns:
+ (dict):
+ - loss_flow_L1 (tensor): L1 loss compared to ground truth flow.
+ - loss_flow_warp (tensor): L1 loss between the warped image and
+ the target image when using the flow to warp.
+ """
+ loss_flow_L1 = torch.tensor(0., device=torch.device('cuda'))
+ loss_flow_warp = torch.tensor(0., device=torch.device('cuda'))
+ if flow is not None and flow_gt is not None:
+ # L1 loss compared to flow ground truth.
+ loss_flow_L1 = self.criterionMasked(flow, flow_gt,
+ flow_conf_gt * fg_mask)
+ if warped_image is not None:
+ # L1 loss between warped image and target image.
+ loss_flow_warp = self.criterion(warped_image, tgt_image)
+ return loss_flow_L1, loss_flow_warp
+
+ def compute_mask_losses(self, occ_mask, fake_image, warped_image,
+ tgt_label, tgt_image, fg_mask, ref_fg_mask,
+ body_mask_diff):
+ r"""Compute losses on the generated occlusion masks.
+
+ Args:
+ occ_mask (tensor or list of tensors): Generated occlusion masks.
+ fake_image (tensor): Generated image.
+ warped_image (tensor or list of tensors): Warped images using the
+ flow maps.
+ tgt_label (tensor): Target label map.
+ tgt_image (tensor): Target image for the warped image.
+ fg_mask (tensor): Foreground mask for the target image.
+ ref_fg_mask (tensor): Foreground mask for the reference image.
+ body_mask_diff (tensor): Difference between warped body part map
+ and target body part map. Used for pose dataset only.
+ Returns:
+ (tensor): Loss for the mask.
+ """
+ loss_mask = torch.tensor(0., device=torch.device('cuda'))
+
+ if isinstance(occ_mask, list):
+ # Compute occlusion mask losses for both warping reference -> target
+ # and previous -> target.
+ for i in range(len(occ_mask)):
+ loss_mask += self.compute_mask_loss(occ_mask[i],
+ warped_image[i],
+ tgt_image)
+ else:
+ # Compute loss for warping either reference or previous images.
+ loss_mask += self.compute_mask_loss(occ_mask, warped_image,
+ tgt_image)
+
+ if self.warp_ref:
+ ref_occ_mask = occ_mask[0]
+ dummy0 = torch.zeros_like(ref_occ_mask)
+ dummy1 = torch.ones_like(ref_occ_mask)
+ if self.for_pose_dataset:
+ # Enforce output to use more warped reference image for
+ # face region.
+ face_mask = get_face_mask(tgt_label[:, 2]).unsqueeze(1)
+ AvgPool = torch.nn.AvgPool2d(15, padding=7, stride=1)
+ face_mask = AvgPool(face_mask)
+ loss_mask += self.criterionMasked(ref_occ_mask, dummy0,
+ face_mask)
+ loss_mask += self.criterionMasked(fake_image, warped_image[0],
+ face_mask)
+ # Enforce output to use more hallucinated image for discrepancy
+ # regions of body part masks between warped reference and
+ # target image.
+ loss_mask += self.criterionMasked(ref_occ_mask, dummy1,
+ body_mask_diff)
+
+ if self.has_fg:
+ # Enforce output to use more hallucinated image for discrepancy
+ # regions of foreground masks between reference and target
+ # image.
+ fg_mask_diff = ((ref_fg_mask - fg_mask) > 0).float()
+ loss_mask += self.criterionMasked(ref_occ_mask, dummy1,
+ fg_mask_diff)
+ return loss_mask
+
+ def compute_mask_loss(self, occ_mask, warped_image, tgt_image):
+ r"""Compute losses on the generated occlusion mask.
+
+ Args:
+ occ_mask (tensor): Generated occlusion mask.
+ warped_image (tensor): Warped image using the flow map.
+ tgt_image (tensor): Target image for the warped image.
+ Returns:
+ (tensor): Loss for the mask.
+ """
+ loss_mask = torch.tensor(0., device=torch.device('cuda'))
+ if occ_mask is not None:
+ dummy0 = torch.zeros_like(occ_mask)
+ dummy1 = torch.ones_like(occ_mask)
+
+ # Compute the confidence map based on L1 distance between warped
+ # and GT image.
+ img_diff = torch.sum(abs(warped_image - tgt_image), dim=1,
+ keepdim=True)
+ conf = torch.clamp(1 - img_diff, 0, 1)
+
+ # Force mask value to be small if warped image is similar to GT,
+ # and vice versa.
+ loss_mask = self.criterionMasked(occ_mask, dummy0, conf)
+ loss_mask += self.criterionMasked(occ_mask, dummy1, 1 - conf)
+
+ return loss_mask
diff --git a/imaginaire/losses/gan.py b/imaginaire/losses/gan.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaa9c30dd51887b25b439b90fa728e94fe2b03a9
--- /dev/null
+++ b/imaginaire/losses/gan.py
@@ -0,0 +1,173 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from imaginaire.utils.distributed import master_only_print as print
+
+
+@torch.jit.script
+def fuse_math_min_mean_pos(x):
+ r"""Fuse operation min mean for hinge loss computation of positive
+ samples"""
+ minval = torch.min(x - 1, x * 0)
+ loss = -torch.mean(minval)
+ return loss
+
+
+@torch.jit.script
+def fuse_math_min_mean_neg(x):
+ r"""Fuse operation min mean for hinge loss computation of negative
+ samples"""
+ minval = torch.min(-x - 1, x * 0)
+ loss = -torch.mean(minval)
+ return loss
+
+
+class GANLoss(nn.Module):
+ r"""GAN loss constructor.
+
+ Args:
+ gan_mode (str): Type of GAN loss. ``'hinge'``, ``'least_square'``,
+ ``'non_saturated'``, ``'wasserstein'``.
+ target_real_label (float): The desired output label for real images.
+ target_fake_label (float): The desired output label for fake images.
+ decay_k (float): The decay factor per epoch for top-k training.
+ min_k (float): The minimum percentage of samples to select.
+ separate_topk (bool): If ``True``, selects top-k for each sample
+ separately, otherwise selects top-k among all samples.
+ """
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
+ decay_k=1., min_k=1., separate_topk=False):
+ super(GANLoss, self).__init__()
+ self.real_label = target_real_label
+ self.fake_label = target_fake_label
+ self.real_label_tensor = None
+ self.fake_label_tensor = None
+ self.gan_mode = gan_mode
+ self.decay_k = decay_k
+ self.min_k = min_k
+ self.separate_topk = separate_topk
+ self.register_buffer('k', torch.tensor(1.0))
+ print('GAN mode: %s' % gan_mode)
+
+ def forward(self, dis_output, t_real, dis_update=True, reduce=True):
+ r"""GAN loss computation.
+
+ Args:
+ dis_output (tensor or list of tensors): Discriminator outputs.
+ t_real (bool): If ``True``, uses the real label as target, otherwise uses the fake label as target.
+ dis_update (bool): If ``True``, the loss will be used to update the discriminator, otherwise the generator.
+ reduce (bool): If ``True``, when a list of discriminator outputs are provided, it will return the average
+ of all losses, otherwise it will return a list of losses.
+ Returns:
+ loss (tensor): Loss value.
+ """
+ if isinstance(dis_output, list):
+ # For multi-scale discriminators.
+ # In this implementation, the loss is first averaged for each scale
+ # (batch size and number of locations) then averaged across scales,
+ # so that the gradient is not dominated by the discriminator that
+ # has the most output values (highest resolution).
+ losses = []
+ for dis_output_i in dis_output:
+ assert isinstance(dis_output_i, torch.Tensor)
+ losses.append(self.loss(dis_output_i, t_real, dis_update))
+ if reduce:
+ return torch.mean(torch.stack(losses))
+ else:
+ return losses
+ else:
+ return self.loss(dis_output, t_real, dis_update)
+
+ def loss(self, dis_output, t_real, dis_update=True):
+ r"""GAN loss computation.
+
+ Args:
+ dis_output (tensor): Discriminator outputs.
+ t_real (bool): If ``True``, uses the real label as target, otherwise
+ uses the fake label as target.
+ dis_update (bool): Updating the discriminator or the generator.
+ Returns:
+ loss (tensor): Loss value.
+ """
+ if not dis_update:
+ assert t_real, \
+ "The target should be real when updating the generator."
+
+ if not dis_update and self.k < 1:
+ r"""
+ Use top-k training:
+ "Top-k Training of GANs: Improving GAN Performance by Throwing
+ Away Bad Samples"
+ Here, each sample may have multiple discriminator output values
+ (patch discriminator). We could either select top-k for each sample
+ separately (when ``self.separate_topk=True``), or collect values
+ from all samples and then select top-k (default, when
+ ``self.separate_topk=False``).
+ """
+ if self.separate_topk:
+ dis_output = dis_output.view(dis_output.size(0), -1)
+ else:
+ dis_output = dis_output.view(-1)
+ k = math.ceil(self.k * dis_output.size(-1))
+ dis_output, _ = torch.topk(dis_output, k)
+
+ if self.gan_mode == 'non_saturated':
+ target_tensor = self.get_target_tensor(dis_output, t_real)
+ loss = F.binary_cross_entropy_with_logits(dis_output,
+ target_tensor)
+ elif self.gan_mode == 'least_square':
+ target_tensor = self.get_target_tensor(dis_output, t_real)
+ loss = 0.5 * F.mse_loss(dis_output, target_tensor)
+ elif self.gan_mode == 'hinge':
+ if dis_update:
+ if t_real:
+ loss = fuse_math_min_mean_pos(dis_output)
+ else:
+ loss = fuse_math_min_mean_neg(dis_output)
+ else:
+ loss = -torch.mean(dis_output)
+ elif self.gan_mode == 'wasserstein':
+ if t_real:
+ loss = -torch.mean(dis_output)
+ else:
+ loss = torch.mean(dis_output)
+ elif self.gan_mode == 'softplus':
+ target_tensor = self.get_target_tensor(dis_output, t_real)
+ loss = F.binary_cross_entropy_with_logits(dis_output,
+ target_tensor)
+ else:
+ raise ValueError('Unexpected gan_mode {}'.format(self.gan_mode))
+ return loss
+
+ def get_target_tensor(self, dis_output, t_real):
+ r"""Return the target vector for the binary cross entropy loss
+ computation.
+
+ Args:
+ dis_output (tensor): Discriminator outputs.
+ t_real (bool): If ``True``, uses the real label as target, otherwise
+ uses the fake label as target.
+ Returns:
+ target (tensor): Target tensor vector.
+ """
+ if t_real:
+ if self.real_label_tensor is None:
+ self.real_label_tensor = dis_output.new_tensor(self.real_label)
+ return self.real_label_tensor.expand_as(dis_output)
+ else:
+ if self.fake_label_tensor is None:
+ self.fake_label_tensor = dis_output.new_tensor(self.fake_label)
+ return self.fake_label_tensor.expand_as(dis_output)
+
+ def topk_anneal(self):
+ r"""Anneal k after each epoch."""
+ if self.decay_k < 1:
+ # noinspection PyAttributeOutsideInit
+ self.k.fill_(max(self.decay_k * self.k, self.min_k))
+ print("Top-k training: update k to {}.".format(self.k))
diff --git a/imaginaire/losses/info_nce.py b/imaginaire/losses/info_nce.py
new file mode 100644
index 0000000000000000000000000000000000000000..8033e828f0b99d12d6e8f8b71811982d0ab568f6
--- /dev/null
+++ b/imaginaire/losses/info_nce.py
@@ -0,0 +1,87 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from imaginaire.utils.distributed import get_world_size, get_rank, \
+ dist_all_reduce_tensor
+
+
+class GatherLayer(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input):
+ ctx.save_for_backward(input)
+ output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
+ dist.all_gather(output, input)
+ return tuple(output)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ input, = ctx.saved_tensors
+ grad_out = torch.zeros_like(input)
+ all_grads = torch.stack(grads)
+ all_grads = dist_all_reduce_tensor(all_grads, reduce='sum')
+ grad_out[:] = all_grads[get_rank()]
+ return grad_out
+
+
+class InfoNCELoss(nn.Module):
+ def __init__(self,
+ temperature=0.07,
+ gather_distributed=True,
+ learn_temperature=True,
+ single_direction=False,
+ flatten=True):
+ super(InfoNCELoss, self).__init__()
+ self.logit_scale = nn.Parameter(torch.tensor([math.log(1/temperature)]))
+ self.logit_scale.requires_grad = learn_temperature
+ self.gather_distributed = gather_distributed
+ self.single_direction = single_direction
+ self.flatten = flatten
+
+ def forward(self, features_a, features_b, gather_distributed=None, eps=1e-8):
+ if gather_distributed is None:
+ gather_distributed = self.gather_distributed
+
+ if features_a is None or features_b is None:
+ return torch.tensor(0, device='cuda'), torch.tensor(0, device='cuda')
+
+ bs_a, bs_b = features_a.size(0), features_b.size(0)
+ if self.flatten:
+ features_a, features_b = features_a.reshape(bs_a, -1), features_b.reshape(bs_b, -1)
+ else:
+ features_a = features_a.reshape(bs_a, features_a.size(1), -1).mean(-1)
+ features_b = features_b.reshape(bs_b, features_b.size(1), -1).mean(-1)
+
+ # Temperature clipping.
+ self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 4.6052)
+
+ # normalized features
+ features_a = features_a / (features_a.norm(dim=1, keepdim=True) + eps)
+ features_b = features_b / (features_b.norm(dim=1, keepdim=True) + eps)
+
+ loss_a = self._forward_single_direction(features_a, features_b, gather_distributed)
+ if self.single_direction:
+ return loss_a
+ else:
+ loss_b = self._forward_single_direction(features_b, features_a, gather_distributed)
+ return loss_a + loss_b
+
+ def _forward_single_direction(
+ self, features_a, features_b, gather_distributed):
+ bs_a = features_a.shape[0]
+ logit_scale = self.logit_scale.exp()
+ if get_world_size() > 1 and gather_distributed:
+ gather_features_b = torch.cat(GatherLayer.apply(features_b))
+ gather_labels_a = torch.arange(bs_a, device='cuda') + get_rank() * bs_a
+ logits_a = logit_scale * features_a @ gather_features_b.t()
+ else:
+ gather_labels_a = torch.arange(bs_a, device='cuda')
+ logits_a = logit_scale * features_a @ features_b.t()
+ loss_a = F.cross_entropy(logits_a, gather_labels_a)
+ return loss_a
diff --git a/imaginaire/losses/kl.py b/imaginaire/losses/kl.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc9a2da14b06ccfef143312acd20ccd3784bdb34
--- /dev/null
+++ b/imaginaire/losses/kl.py
@@ -0,0 +1,22 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+
+class GaussianKLLoss(nn.Module):
+ r"""Compute KL loss in VAE for Gaussian distributions"""
+ def __init__(self):
+ super(GaussianKLLoss, self).__init__()
+
+ def forward(self, mu, logvar=None):
+ r"""Compute loss
+
+ Args:
+ mu (tensor): mean
+ logvar (tensor): logarithm of variance
+ """
+ if logvar is None:
+ logvar = torch.zeros_like(mu)
+ return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
diff --git a/imaginaire/losses/perceptual.py b/imaginaire/losses/perceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..424656fa09b65333e4fa28cae2de7114de69ebfa
--- /dev/null
+++ b/imaginaire/losses/perceptual.py
@@ -0,0 +1,395 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, distributed as dist
+
+from imaginaire.losses.info_nce import InfoNCELoss
+from imaginaire.utils.distributed import master_only_print as print, \
+ is_local_master
+from imaginaire.utils.misc import apply_imagenet_normalization, to_float
+
+
+class PerceptualLoss(nn.Module):
+ r"""Perceptual loss initialization.
+
+ Args:
+ network (str) : The name of the loss network: 'vgg16' | 'vgg19'.
+ layers (str or list of str) : The layers used to compute the loss.
+ weights (float or list of float : The loss weights of each layer.
+ criterion (str): The type of distance function: 'l1' | 'l2'.
+ resize (bool) : If ``True``, resize the input images to 224x224.
+ resize_mode (str): Algorithm used for resizing.
+ num_scales (int): The loss will be evaluated at original size and
+ this many times downsampled sizes.
+ per_sample_weight (bool): Output loss for individual samples in the
+ batch instead of mean loss.
+ """
+
+ def __init__(self, network='vgg19', layers='relu_4_1', weights=None,
+ criterion='l1', resize=False, resize_mode='bilinear',
+ num_scales=1, per_sample_weight=False,
+ info_nce_temperature=0.07,
+ info_nce_gather_distributed=True,
+ info_nce_learn_temperature=True,
+ info_nce_flatten=True):
+ super().__init__()
+ if isinstance(layers, str):
+ layers = [layers]
+ if weights is None:
+ weights = [1.] * len(layers)
+ elif isinstance(layers, float) or isinstance(layers, int):
+ weights = [weights]
+
+ if dist.is_initialized() and not is_local_master():
+ # Make sure only the first process in distributed training downloads
+ # the model, and the others will use the cache
+ # noinspection PyUnresolvedReferences
+ torch.distributed.barrier()
+
+ assert len(layers) == len(weights), \
+ 'The number of layers (%s) must be equal to ' \
+ 'the number of weights (%s).' % (len(layers), len(weights))
+ if network == 'vgg19':
+ self.model = _vgg19(layers)
+ elif network == 'vgg16':
+ self.model = _vgg16(layers)
+ elif network == 'alexnet':
+ self.model = _alexnet(layers)
+ elif network == 'inception_v3':
+ self.model = _inception_v3(layers)
+ elif network == 'resnet50':
+ self.model = _resnet50(layers)
+ elif network == 'robust_resnet50':
+ self.model = _robust_resnet50(layers)
+ elif network == 'vgg_face_dag':
+ self.model = _vgg_face_dag(layers)
+ else:
+ raise ValueError('Network %s is not recognized' % network)
+
+ if dist.is_initialized() and is_local_master():
+ # Make sure only the first process in distributed training downloads
+ # the model, and the others will use the cache
+ # noinspection PyUnresolvedReferences
+ torch.distributed.barrier()
+
+ self.num_scales = num_scales
+ self.layers = layers
+ self.weights = weights
+ reduction = 'mean' if not per_sample_weight else 'none'
+ if criterion == 'l1':
+ self.criterion = nn.L1Loss(reduction=reduction)
+ elif criterion == 'l2' or criterion == 'mse':
+ self.criterion = nn.MSELoss(reduction=reduction)
+ elif criterion == 'info_nce':
+ self.criterion = InfoNCELoss(
+ temperature=info_nce_temperature,
+ gather_distributed=info_nce_gather_distributed,
+ learn_temperature=info_nce_learn_temperature,
+ flatten=info_nce_flatten,
+ single_direction=True
+ )
+ else:
+ raise ValueError('Criterion %s is not recognized' % criterion)
+ self.resize = resize
+ self.resize_mode = resize_mode
+ print('Perceptual loss:')
+ print('\tMode: {}'.format(network))
+
+ def forward(self, inp, target, per_sample_weights=None):
+ r"""Perceptual loss forward.
+
+ Args:
+ inp (4D tensor) : Input tensor.
+ target (4D tensor) : Ground truth tensor, same shape as the input.
+ per_sample_weight (bool): Output loss for individual samples in the
+ batch instead of mean loss.
+ Returns:
+ (scalar tensor) : The perceptual loss.
+ """
+ if not torch.is_autocast_enabled():
+ inp, target = to_float([inp, target])
+
+ # Perceptual loss should operate in eval mode by default.
+ self.model.eval()
+ inp, target = apply_imagenet_normalization(inp), apply_imagenet_normalization(target)
+ if self.resize:
+ inp = F.interpolate(inp, mode=self.resize_mode, size=(224, 224), align_corners=False)
+ target = F.interpolate(target, mode=self.resize_mode, size=(224, 224), align_corners=False)
+
+ # Evaluate perceptual loss at each scale.
+ loss = 0
+ for scale in range(self.num_scales):
+ input_features, target_features = self.model(inp), self.model(target)
+
+ for layer, weight in zip(self.layers, self.weights):
+ # Example per-layer VGG19 loss values after applying
+ # [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting.
+ # relu_1_1, 0.014698
+ # relu_2_1, 0.085817
+ # relu_3_1, 0.349977
+ # relu_4_1, 0.544188
+ # relu_5_1, 0.906261
+ # print('%s, %f' % (
+ # layer,
+ # weight * self.criterion(
+ # input_features[layer],
+ # target_features[
+ # layer].detach()).item()))
+ l_tmp = self.criterion(input_features[layer], target_features[layer].detach())
+ if per_sample_weights is not None:
+ l_tmp = l_tmp.mean(1).mean(1).mean(1)
+ loss += weight * l_tmp
+ # Downsample the input and target.
+ if scale != self.num_scales - 1:
+ inp = F.interpolate(
+ inp, mode=self.resize_mode, scale_factor=0.5,
+ align_corners=False, recompute_scale_factor=True)
+ target = F.interpolate(
+ target, mode=self.resize_mode, scale_factor=0.5,
+ align_corners=False, recompute_scale_factor=True)
+
+ return loss.float()
+
+
+class _PerceptualNetwork(nn.Module):
+ r"""The network that extracts features to compute the perceptual loss.
+
+ Args:
+ network (nn.Sequential) : The network that extracts features.
+ layer_name_mapping (dict) : The dictionary that
+ maps a layer's index to its name.
+ layers (list of str): The list of layer names that we are using.
+ """
+
+ def __init__(self, network, layer_name_mapping, layers):
+ super().__init__()
+ assert isinstance(network, nn.Sequential), \
+ 'The network needs to be of type "nn.Sequential".'
+ self.network = network
+ self.layer_name_mapping = layer_name_mapping
+ self.layers = layers
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ r"""Extract perceptual features."""
+ output = {}
+ for i, layer in enumerate(self.network):
+ x = layer(x)
+ layer_name = self.layer_name_mapping.get(i, None)
+ if layer_name in self.layers:
+ # If the current layer is used by the perceptual loss.
+ output[layer_name] = x
+ return output
+
+
+def _vgg19(layers):
+ r"""Get vgg19 layers"""
+ vgg = torchvision.models.vgg19(pretrained=True)
+ # network = vgg.features
+ network = torch.nn.Sequential(*(list(vgg.features) + [vgg.avgpool] + [nn.Flatten()] + list(vgg.classifier)))
+ layer_name_mapping = {1: 'relu_1_1',
+ 3: 'relu_1_2',
+ 6: 'relu_2_1',
+ 8: 'relu_2_2',
+ 11: 'relu_3_1',
+ 13: 'relu_3_2',
+ 15: 'relu_3_3',
+ 17: 'relu_3_4',
+ 20: 'relu_4_1',
+ 22: 'relu_4_2',
+ 24: 'relu_4_3',
+ 26: 'relu_4_4',
+ 29: 'relu_5_1',
+ 31: 'relu_5_2',
+ 33: 'relu_5_3',
+ 35: 'relu_5_4',
+ 36: 'pool_5',
+ 42: 'fc_2'}
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _vgg16(layers):
+ r"""Get vgg16 layers"""
+ network = torchvision.models.vgg16(pretrained=True).features
+ layer_name_mapping = {1: 'relu_1_1',
+ 3: 'relu_1_2',
+ 6: 'relu_2_1',
+ 8: 'relu_2_2',
+ 11: 'relu_3_1',
+ 13: 'relu_3_2',
+ 15: 'relu_3_3',
+ 18: 'relu_4_1',
+ 20: 'relu_4_2',
+ 22: 'relu_4_3',
+ 25: 'relu_5_1'}
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _alexnet(layers):
+ r"""Get alexnet layers"""
+ network = torchvision.models.alexnet(pretrained=True).features
+ layer_name_mapping = {0: 'conv_1',
+ 1: 'relu_1',
+ 3: 'conv_2',
+ 4: 'relu_2',
+ 6: 'conv_3',
+ 7: 'relu_3',
+ 8: 'conv_4',
+ 9: 'relu_4',
+ 10: 'conv_5',
+ 11: 'relu_5'}
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _inception_v3(layers):
+ r"""Get inception v3 layers"""
+ inception = torchvision.models.inception_v3(pretrained=True)
+ network = nn.Sequential(inception.Conv2d_1a_3x3,
+ inception.Conv2d_2a_3x3,
+ inception.Conv2d_2b_3x3,
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ inception.Conv2d_3b_1x1,
+ inception.Conv2d_4a_3x3,
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ inception.Mixed_5b,
+ inception.Mixed_5c,
+ inception.Mixed_5d,
+ inception.Mixed_6a,
+ inception.Mixed_6b,
+ inception.Mixed_6c,
+ inception.Mixed_6d,
+ inception.Mixed_6e,
+ inception.Mixed_7a,
+ inception.Mixed_7b,
+ inception.Mixed_7c,
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)))
+ layer_name_mapping = {3: 'pool_1',
+ 6: 'pool_2',
+ 14: 'mixed_6e',
+ 18: 'pool_3'}
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _resnet50(layers):
+ r"""Get resnet50 layers"""
+ resnet50 = torchvision.models.resnet50(pretrained=True)
+ network = nn.Sequential(resnet50.conv1,
+ resnet50.bn1,
+ resnet50.relu,
+ resnet50.maxpool,
+ resnet50.layer1,
+ resnet50.layer2,
+ resnet50.layer3,
+ resnet50.layer4,
+ resnet50.avgpool)
+ layer_name_mapping = {4: 'layer_1',
+ 5: 'layer_2',
+ 6: 'layer_3',
+ 7: 'layer_4'}
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _robust_resnet50(layers):
+ r"""Get robust resnet50 layers"""
+ resnet50 = torchvision.models.resnet50(pretrained=False)
+ state_dict = torch.utils.model_zoo.load_url(
+ 'http://andrewilyas.com/ImageNet.pt')
+ new_state_dict = {}
+ for k, v in state_dict['model'].items():
+ if k.startswith('module.model.'):
+ new_state_dict[k[13:]] = v
+ resnet50.load_state_dict(new_state_dict)
+ network = nn.Sequential(resnet50.conv1,
+ resnet50.bn1,
+ resnet50.relu,
+ resnet50.maxpool,
+ resnet50.layer1,
+ resnet50.layer2,
+ resnet50.layer3,
+ resnet50.layer4,
+ resnet50.avgpool)
+ layer_name_mapping = {4: 'layer_1',
+ 5: 'layer_2',
+ 6: 'layer_3',
+ 7: 'layer_4'}
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
+
+
+def _vgg_face_dag(layers):
+ network = torchvision.models.vgg16(num_classes=2622)
+ state_dict = torch.utils.model_zoo.load_url(
+ 'http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/'
+ 'vgg_face_dag.pth')
+ feature_layer_name_mapping = {
+ 0: 'conv1_1',
+ 2: 'conv1_2',
+ 5: 'conv2_1',
+ 7: 'conv2_2',
+ 10: 'conv3_1',
+ 12: 'conv3_2',
+ 14: 'conv3_3',
+ 17: 'conv4_1',
+ 19: 'conv4_2',
+ 21: 'conv4_3',
+ 24: 'conv5_1',
+ 26: 'conv5_2',
+ 28: 'conv5_3'}
+ new_state_dict = {}
+ for k, v in feature_layer_name_mapping.items():
+ new_state_dict['features.' + str(k) + '.weight'] = \
+ state_dict[v + '.weight']
+ new_state_dict['features.' + str(k) + '.bias'] = \
+ state_dict[v + '.bias']
+
+ classifier_layer_name_mapping = {
+ 0: 'fc6',
+ 3: 'fc7',
+ 6: 'fc8'}
+ for k, v in classifier_layer_name_mapping.items():
+ new_state_dict['classifier.' + str(k) + '.weight'] = \
+ state_dict[v + '.weight']
+ new_state_dict['classifier.' + str(k) + '.bias'] = \
+ state_dict[v + '.bias']
+
+ network.load_state_dict(new_state_dict)
+
+ class Flatten(nn.Module):
+ def forward(self, x):
+ return x.view(x.shape[0], -1)
+
+ layer_name_mapping = {
+ 0: 'conv_1_1',
+ 1: 'relu_1_1',
+ 2: 'conv_1_2',
+ 5: 'conv_2_1', # 1/2
+ 6: 'relu_2_1',
+ 7: 'conv_2_2',
+ 10: 'conv_3_1', # 1/4
+ 11: 'relu_3_1',
+ 12: 'conv_3_2',
+ 14: 'conv_3_3',
+ 17: 'conv_4_1', # 1/8
+ 18: 'relu_4_1',
+ 19: 'conv_4_2',
+ 21: 'conv_4_3',
+ 24: 'conv_5_1', # 1/16
+ 25: 'relu_5_1',
+ 26: 'conv_5_2',
+ 28: 'conv_5_3',
+ 33: 'fc6',
+ 36: 'fc7',
+ 39: 'fc8'
+ }
+ seq_layers = []
+ for feature in network.features:
+ seq_layers += [feature]
+ seq_layers += [network.avgpool, Flatten()]
+ for classifier in network.classifier:
+ seq_layers += [classifier]
+ network = nn.Sequential(*seq_layers)
+ return _PerceptualNetwork(network, layer_name_mapping, layers)
diff --git a/imaginaire/losses/weighted_mse.py b/imaginaire/losses/weighted_mse.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4e49989a5c3ee8576dcf4dea8a98a16c1911cc9
--- /dev/null
+++ b/imaginaire/losses/weighted_mse.py
@@ -0,0 +1,28 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+
+
+class WeightedMSELoss(nn.Module):
+ r"""Compute Weighted MSE loss"""
+ def __init__(self, reduction='mean'):
+ super(WeightedMSELoss, self).__init__()
+ self.reduction = reduction
+
+ def forward(self, input, target, weight):
+ r"""Return weighted MSE Loss.
+ Args:
+ input (tensor):
+ target (tensor):
+ weight (tensor):
+ Returns:
+ (tensor): Loss value.
+ """
+ if self.reduction == 'mean':
+ loss = torch.mean(weight * (input - target) ** 2)
+ else:
+ loss = torch.sum(weight * (input - target) ** 2)
+ return loss
diff --git a/imaginaire/model_utils/__init__.py b/imaginaire/model_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/model_utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/model_utils/__pycache__/__init__.cpython-38.pyc b/imaginaire/model_utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c45aa8fec0f4180fe7efd93161a67e545e3fb71
Binary files /dev/null and b/imaginaire/model_utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/model_utils/__pycache__/fs_vid2vid.cpython-38.pyc b/imaginaire/model_utils/__pycache__/fs_vid2vid.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa52fd3e68a006d480e174ee47d5066eaf79b98f
Binary files /dev/null and b/imaginaire/model_utils/__pycache__/fs_vid2vid.cpython-38.pyc differ
diff --git a/imaginaire/model_utils/fs_vid2vid.py b/imaginaire/model_utils/fs_vid2vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..b52faf73d3f37221e9d0f089d1f4a85d5378877d
--- /dev/null
+++ b/imaginaire/model_utils/fs_vid2vid.py
@@ -0,0 +1,865 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""Utils for the few shot vid2vid model."""
+import random
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def resample(image, flow):
+ r"""Resamples an image using the provided flow.
+
+ Args:
+ image (NxCxHxW tensor) : Image to resample.
+ flow (Nx2xHxW tensor) : Optical flow to resample the image.
+ Returns:
+ output (NxCxHxW tensor) : Resampled image.
+ """
+ assert flow.shape[1] == 2
+ b, c, h, w = image.size()
+ grid = get_grid(b, (h, w))
+ flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0),
+ flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1)
+ final_grid = (grid + flow).permute(0, 2, 3, 1)
+ try:
+ output = F.grid_sample(image, final_grid, mode='bilinear',
+ padding_mode='border', align_corners=True)
+ except Exception:
+ output = F.grid_sample(image, final_grid, mode='bilinear',
+ padding_mode='border')
+ return output
+
+
+def get_grid(batchsize, size, minval=-1.0, maxval=1.0):
+ r"""Get a grid ranging [-1, 1] of 2D/3D coordinates.
+
+ Args:
+ batchsize (int) : Batch size.
+ size (tuple) : (height, width) or (depth, height, width).
+ minval (float) : minimum value in returned grid.
+ maxval (float) : maximum value in returned grid.
+ Returns:
+ t_grid (4D tensor) : Grid of coordinates.
+ """
+ if len(size) == 2:
+ rows, cols = size
+ elif len(size) == 3:
+ deps, rows, cols = size
+ else:
+ raise ValueError('Dimension can only be 2 or 3.')
+ x = torch.linspace(minval, maxval, cols)
+ x = x.view(1, 1, 1, cols)
+ x = x.expand(batchsize, 1, rows, cols)
+
+ y = torch.linspace(minval, maxval, rows)
+ y = y.view(1, 1, rows, 1)
+ y = y.expand(batchsize, 1, rows, cols)
+
+ t_grid = torch.cat([x, y], dim=1)
+
+ if len(size) == 3:
+ z = torch.linspace(minval, maxval, deps)
+ z = z.view(1, 1, deps, 1, 1)
+ z = z.expand(batchsize, 1, deps, rows, cols)
+
+ t_grid = t_grid.unsqueeze(2).expand(batchsize, 2, deps, rows, cols)
+ t_grid = torch.cat([t_grid, z], dim=1)
+
+ t_grid.requires_grad = False
+ return t_grid.to('cuda')
+
+
+def pick_image(images, idx):
+ r"""Pick the image among images according to idx.
+
+ Args:
+ images (B x N x C x H x W tensor or list of tensors) : N images.
+ idx (B tensor) : indices to select.
+ Returns:
+ image (B x C x H x W) : Selected images.
+ """
+ if type(images) == list:
+ return [pick_image(r, idx) for r in images]
+ if idx is None:
+ return images[:, 0]
+ elif type(idx) == int:
+ return images[:, idx]
+ idx = idx.long().view(-1, 1, 1, 1, 1)
+ image = images.gather(1, idx.expand_as(images)[:, 0:1])[:, 0]
+ return image
+
+
+def crop_face_from_data(cfg, is_inference, data):
+ r"""Crop the face regions in input data and resize to the target size.
+ This is for training face datasets.
+
+ Args:
+ cfg (obj): Data configuration.
+ is_inference (bool): Is doing inference or not.
+ data (dict): Input data.
+ Returns:
+ data (dict): Cropped data.
+ """
+ label = data['label'] if 'label' in data else None
+ image = data['images']
+ landmarks = data['landmarks-dlib68_xy']
+ ref_labels = data['few_shot_label'] if 'few_shot_label' in data else None
+ ref_images = data['few_shot_images']
+ ref_landmarks = data['few_shot_landmarks-dlib68_xy']
+ img_size = image.shape[-2:]
+ h, w = cfg.output_h_w.split(',')
+ h, w = int(h), int(w)
+
+ # When doing inference, need to sync common attributes like crop coodinates
+ # between different workers, so all workers crop the same region.
+ if 'common_attr' in data and 'crop_coords' in data['common_attr']:
+ # Has been computed before, reusing the previous one.
+ crop_coords, ref_crop_coords = data['common_attr']['crop_coords']
+ else:
+ # Is the first frame, need to compute the bbox.
+ ref_crop_coords, scale = get_face_bbox_for_data(
+ ref_landmarks[0], img_size, None, is_inference)
+ crop_coords, _ = get_face_bbox_for_data(
+ landmarks[0], img_size, scale, is_inference)
+
+ # Crop the images according to the bbox and resize them to target size.
+ label, image = crop_and_resize([label, image], crop_coords, (h, w))
+ ref_labels, ref_images = crop_and_resize([ref_labels, ref_images],
+ ref_crop_coords, (h, w))
+
+ data['images'], data['few_shot_images'] = image, ref_images
+ if label is not None:
+ data['label'], data['few_shot_label'] = label, ref_labels
+ if is_inference:
+ if 'common_attr' not in data:
+ data['common_attr'] = dict()
+ data['common_attr']['crop_coords'] = crop_coords, ref_crop_coords
+ return data
+
+
+def get_face_bbox_for_data(keypoints, orig_img_size, scale, is_inference):
+ r"""Get the bbox coordinates for face region.
+
+ Args:
+ keypoints (Nx2 tensor): Facial landmarks.
+ orig_img_size (int tuple): Height and width of the input image size.
+ scale (float): When training, randomly scale the crop size for
+ augmentation.
+ is_inference (bool): Is doing inference or not.
+ Returns:
+ crop_coords (list of int): bbox for face region.
+ scale (float): Also returns scale to ensure reference and target frames
+ are croppped using the same scale.
+ """
+ min_y, max_y = int(keypoints[:, 1].min()), int(keypoints[:, 1].max())
+ min_x, max_x = int(keypoints[:, 0].min()), int(keypoints[:, 0].max())
+ x_cen, y_cen = (min_x + max_x) // 2, (min_y + max_y) // 2
+ H, W = orig_img_size
+ w = h = (max_x - min_x)
+ if not is_inference:
+ # During training, randomly jitter the cropping position by offset
+ # amount for augmentation.
+ offset_max = 0.2
+ offset = [np.random.uniform(-offset_max, offset_max),
+ np.random.uniform(-offset_max, offset_max)]
+ # Also augment the crop size.
+ if scale is None:
+ scale_max = 0.2
+ scale = [np.random.uniform(1 - scale_max, 1 + scale_max),
+ np.random.uniform(1 - scale_max, 1 + scale_max)]
+ w *= scale[0]
+ h *= scale[1]
+ x_cen += int(offset[0] * w)
+ y_cen += int(offset[1] * h)
+
+ # Get the cropping coordinates.
+ x_cen = max(w, min(W - w, x_cen))
+ y_cen = max(h * 1.25, min(H - h * 0.75, y_cen))
+
+ min_x = x_cen - w
+ min_y = y_cen - h * 1.25
+ max_x = min_x + w * 2
+ max_y = min_y + h * 2
+
+ crop_coords = [min_y, max_y, min_x, max_x]
+ return [int(x) for x in crop_coords], scale
+
+
+def crop_person_from_data(cfg, is_inference, data):
+ r"""Crop the person regions in data and resize to the target size.
+ This is for training full body datasets.
+
+ Args:
+ cfg (obj): Data configuration.
+ is_inference (bool): Is doing inference or not.
+ data (dict): Input data.
+ Returns:
+ data (dict): Cropped data.
+ """
+ label = data['label']
+ image = data['images']
+ use_few_shot = 'few_shot_label' in data
+ if use_few_shot:
+ ref_labels = data['few_shot_label']
+ ref_images = data['few_shot_images']
+
+ img_size = image.shape[-2:]
+ output_h, output_w = cfg.output_h_w.split(',')
+ output_h, output_w = int(output_h), int(output_w)
+ output_aspect_ratio = output_w / output_h
+
+ if 'human_instance_maps' in data:
+ # Remove other people in the DensePose map except for the current
+ # target.
+ label = remove_other_ppl(label, data['human_instance_maps'])
+ if use_few_shot:
+ ref_labels = remove_other_ppl(ref_labels,
+ data['few_shot_human_instance_maps'])
+
+ # Randomly jitter the crop position by offset amount for augmentation.
+ offset = ref_offset = None
+ if not is_inference:
+ offset = np.random.randn(2) * 0.05
+ offset = np.minimum(1, np.maximum(-1, offset))
+ ref_offset = np.random.randn(2) * 0.02
+ ref_offset = np.minimum(1, np.maximum(-1, ref_offset))
+
+ # Randomly scale the crop size for augmentation.
+ # Final cropped size = person height * scale.
+ scale = ref_scale = 1.5
+ if not is_inference:
+ scale = min(2, max(1, scale + np.random.randn() * 0.05))
+ ref_scale = min(2, max(1, ref_scale + np.random.randn() * 0.02))
+
+ # When doing inference, need to sync common attributes like crop coodinates
+ # between different workers, so all workers crop the same region.
+ if 'common_attr' in data:
+ # Has been computed before, reusing the previous one.
+ crop_coords, ref_crop_coords = data['common_attr']['crop_coords']
+ else:
+ # Is the first frame, need to compute the bbox.
+ crop_coords = get_person_bbox_for_data(label, img_size, scale,
+ output_aspect_ratio, offset)
+ if use_few_shot:
+ ref_crop_coords = get_person_bbox_for_data(
+ ref_labels, img_size, ref_scale,
+ output_aspect_ratio, ref_offset)
+ else:
+ ref_crop_coords = None
+
+ # Crop the images according to the bbox and resize them to target size.
+ label = crop_and_resize(label, crop_coords, (output_h, output_w), 'nearest')
+ image = crop_and_resize(image, crop_coords, (output_h, output_w))
+ if use_few_shot:
+ ref_labels = crop_and_resize(ref_labels, ref_crop_coords,
+ (output_h, output_w), 'nearest')
+ ref_images = crop_and_resize(ref_images, ref_crop_coords,
+ (output_h, output_w))
+
+ data['label'], data['images'] = label, image
+ if use_few_shot:
+ data['few_shot_label'], data['few_shot_images'] = ref_labels, ref_images
+ if 'human_instance_maps' in data:
+ del data['human_instance_maps']
+ if 'few_shot_human_instance_maps' in data:
+ del data['few_shot_human_instance_maps']
+ if is_inference:
+ data['common_attr'] = dict()
+ data['common_attr']['crop_coords'] = crop_coords, ref_crop_coords
+
+ return data
+
+
+def get_person_bbox_for_data(pose_map, orig_img_size, scale=1.5,
+ crop_aspect_ratio=1, offset=None):
+ r"""Get the bbox (pixel coordinates) to crop for person body region.
+
+ Args:
+ pose_map (NxCxHxW tensor): Input pose map.
+ orig_img_size (int tuple): Height and width of the input image size.
+ scale (float): When training, randomly scale the crop size for
+ augmentation.
+ crop_aspect_ratio (float): Output aspect ratio,
+ offset (list of float): Offset for crop position.
+ Returns:
+ crop_coords (list of int): bbox for body region.
+ """
+ H, W = orig_img_size
+ assert pose_map.dim() == 4
+ nonzero_indices = (pose_map[:, :3] > 0).nonzero(as_tuple=False)
+ if nonzero_indices.size(0) == 0:
+ bw = int(H * crop_aspect_ratio // 2)
+ return [0, H, W // 2 - bw, W // 2 + bw]
+
+ y_indices, x_indices = nonzero_indices[:, 2], nonzero_indices[:, 3]
+ y_min, y_max = y_indices.min().item(), y_indices.max().item()
+ x_min, x_max = x_indices.min().item(), x_indices.max().item()
+ y_cen = int(y_min + y_max) // 2
+ x_cen = int(x_min + x_max) // 2
+ y_len = y_max - y_min
+ x_len = x_max - x_min
+
+ # bh, bw: half of height / width of final cropped size.
+ bh = int(min(H, max(H // 2, y_len * scale))) // 2
+ bh = max(bh, int(x_len * scale / crop_aspect_ratio) // 2)
+ bw = int(bh * crop_aspect_ratio)
+
+ # Randomly offset the cropped position for augmentation.
+ if offset is not None:
+ x_cen += int(offset[0] * bw)
+ y_cen += int(offset[1] * bh)
+ x_cen = max(bw, min(W - bw, x_cen))
+ y_cen = max(bh, min(H - bh, y_cen))
+
+ return [(y_cen - bh), (y_cen + bh), (x_cen - bw), (x_cen + bw)]
+
+
+def crop_and_resize(img, coords, size=None, method='bilinear'):
+ r"""Crop the image using the given coordinates and resize to target size.
+
+ Args:
+ img (tensor or list of tensors): Input image.
+ coords (list of int): Pixel coordinates to crop.
+ size (list of int): Output size.
+ method (str): Interpolation method.
+ Returns:
+ img (tensor or list of tensors): Output image.
+ """
+ if isinstance(img, list):
+ return [crop_and_resize(x, coords, size, method) for x in img]
+ if img is None:
+ return None
+ min_y, max_y, min_x, max_x = coords
+
+ img = img[:, :, min_y:max_y, min_x:max_x]
+ if size is not None:
+ if method == 'nearest':
+ img = F.interpolate(img, size=size, mode=method)
+ else:
+ img = F.interpolate(img, size=size, mode=method,
+ align_corners=False)
+ return img
+
+
+def remove_other_ppl(labels, densemasks):
+ r"""Remove other people in the label map except for the current target
+ by looking at the id in the densemask map.
+
+ Args:
+ labels (NxCxHxW tensor): Input labels.
+ densemasks (Nx1xHxW tensor): Densemask maps.
+ Returns:
+ labels (NxCxHxW tensor): Output labels.
+ """
+ densemasks = densemasks[:, 0:1] * 255
+ for idx in range(labels.shape[0]):
+ label, densemask = labels[idx], densemasks[idx]
+ # Get OpenPose and find the person id in Densemask that has the most
+ # overlap with the person in OpenPose result.
+ openpose = label[3:]
+ valid = (openpose[0] > 0) | (openpose[1] > 0) | (openpose[2] > 0)
+ dp_valid = densemask[valid.unsqueeze(0)]
+ if dp_valid.shape[0]:
+ ind = np.bincount(dp_valid).argmax()
+ # Remove all other people that have different indices.
+ label = label * (densemask == ind).float()
+ labels[idx] = label
+ return labels
+
+
+def select_object(data, obj_indices=None):
+ r"""Select the object/person in the dict according to the object index.
+ Currently it's used to select the target person in OpenPose dict.
+
+ Args:
+ data (dict): Input data.
+ obj_indices (list of int): Indices for the objects to select.
+ Returns:
+ data (dict): Output data.
+ """
+ op_keys = ['poses-openpose', 'captions-clip']
+ for op_key in op_keys:
+ if op_key in data:
+ for i in range(len(data[op_key])):
+ # data[op_key] is a list of dicts for different frames.
+ # people = data[op_key][i]['people']
+ people = data[op_key][i]
+ # "people" is a list of people dicts found by OpenPose. We will
+ # use the obj_index to get the target person from the list, and
+ # write it back to the dict.
+ # data[op_key][i]['people'] = [people[obj_indices[i]]]
+ if obj_indices is not None:
+ data[op_key][i] = people[obj_indices[i]]
+ else:
+ if op_key == 'poses-openpose':
+ data[op_key][i] = people[0]
+ else:
+ idx = random.randint(0, len(people) - 1)
+ data[op_key][i] = people[idx]
+ return data
+
+
+def concat_frames(prev, now, n_frames):
+ r"""Concat previous and current frames and only keep the latest $(n_frames).
+ If concatenated frames are longer than $(n_frames), drop the oldest one.
+
+ Args:
+ prev (NxTxCxHxW tensor): Tensor for previous frames.
+ now (NxCxHxW tensor): Tensor for current frame.
+ n_frames (int): Max number of frames to store.
+ Returns:
+ result (NxTxCxHxW tensor): Updated tensor.
+ """
+ now = now.unsqueeze(1)
+ if prev is None:
+ return now
+ if prev.shape[1] == n_frames:
+ prev = prev[:, 1:]
+ return torch.cat([prev, now], dim=1)
+
+
+def combine_fg_mask(fg_mask, ref_fg_mask, has_fg):
+ r"""Get the union of target and reference foreground masks.
+ Args:
+ fg_mask (tensor): Foreground mask for target image.
+ ref_fg_mask (tensor): Foreground mask for reference image.
+ has_fg (bool): Whether the image can be classified into fg/bg.
+ Returns:
+ output (tensor or int): Combined foreground mask.
+ """
+ return ((fg_mask > 0) | (ref_fg_mask > 0)).float() if has_fg else 1
+
+
+def get_fg_mask(densepose_map, has_fg):
+ r"""Obtain the foreground mask for pose sequences, which only includes
+ the human. This is done by looking at the body part map from DensePose.
+
+ Args:
+ densepose_map (NxCxHxW tensor): DensePose map.
+ has_fg (bool): Whether data has foreground or not.
+ Returns:
+ mask (Nx1xHxW tensor): fg mask.
+ """
+ if type(densepose_map) == list:
+ return [get_fg_mask(label, has_fg) for label in densepose_map]
+ if not has_fg or densepose_map is None:
+ return 1
+ if densepose_map.dim() == 5:
+ densepose_map = densepose_map[:, 0]
+ # Get the body part map from DensePose.
+ mask = densepose_map[:, 2:3]
+
+ # Make the mask slightly larger.
+ mask = torch.nn.MaxPool2d(15, padding=7, stride=1)(mask)
+ mask = (mask > -1).float()
+ return mask
+
+
+def get_part_mask(densepose_map):
+ r"""Obtain mask of different body parts of humans. This is done by
+ looking at the body part map from DensePose.
+
+ Args:
+ densepose_map (NxCxHxW tensor): DensePose map.
+ Returns:
+ mask (NxKxHxW tensor): Body part mask, where K is the number of parts.
+ """
+ # Groups of body parts. Each group contains IDs of body part labels in
+ # DensePose. The 9 groups here are: background, torso, hands, feet,
+ # upper legs, lower legs, upper arms, lower arms, head.
+ part_groups = [[0], [1, 2], [3, 4], [5, 6], [7, 9, 8, 10], [11, 13, 12, 14],
+ [15, 17, 16, 18], [19, 21, 20, 22], [23, 24]]
+ n_parts = len(part_groups)
+
+ need_reshape = densepose_map.dim() == 4
+ if need_reshape:
+ bo, t, h, w = densepose_map.size()
+ densepose_map = densepose_map.view(-1, h, w)
+ b, h, w = densepose_map.size()
+ part_map = (densepose_map / 2 + 0.5) * 24
+ assert (part_map >= 0).all() and (part_map < 25).all()
+
+ mask = torch.cuda.ByteTensor(b, n_parts, h, w).fill_(0)
+ for i in range(n_parts):
+ for j in part_groups[i]:
+ # Account for numerical errors.
+ mask[:, i] = mask[:, i] | (
+ (part_map > j - 0.1) & (part_map < j + 0.1)).byte()
+ if need_reshape:
+ mask = mask.view(bo, t, -1, h, w)
+ return mask.float()
+
+
+def get_face_mask(densepose_map):
+ r"""Obtain mask of faces.
+ Args:
+ densepose_map (3D or 4D tensor): DensePose map.
+ Returns:
+ mask (3D or 4D tensor): Face mask.
+ """
+ need_reshape = densepose_map.dim() == 4
+ if need_reshape:
+ bo, t, h, w = densepose_map.size()
+ densepose_map = densepose_map.view(-1, h, w)
+
+ b, h, w = densepose_map.size()
+ part_map = (densepose_map / 2 + 0.5) * 24
+ assert (part_map >= 0).all() and (part_map < 25).all()
+ if densepose_map.is_cuda:
+ mask = torch.cuda.ByteTensor(b, h, w).fill_(0)
+ else:
+ mask = torch.ByteTensor(b, h, w).fill_(0)
+ for j in [23, 24]:
+ mask = mask | ((part_map > j - 0.1) & (part_map < j + 0.1)).byte()
+ if need_reshape:
+ mask = mask.view(bo, t, h, w)
+ return mask.float()
+
+
+def extract_valid_pose_labels(pose_map, pose_type, remove_face_labels,
+ do_remove=True):
+ r"""Remove some labels (e.g. face regions) in the pose map if necessary.
+
+ Args:
+ pose_map (3D, 4D or 5D tensor): Input pose map.
+ pose_type (str): 'both' or 'open'.
+ remove_face_labels (bool): Whether to remove labels for the face region.
+ do_remove (bool): Do remove face labels.
+ Returns:
+ pose_map (3D, 4D or 5D tensor): Output pose map.
+ """
+ if pose_map is None:
+ return pose_map
+ if type(pose_map) == list:
+ return [extract_valid_pose_labels(p, pose_type, remove_face_labels,
+ do_remove) for p in pose_map]
+
+ orig_dim = pose_map.dim()
+ assert (orig_dim >= 3 and orig_dim <= 5)
+ if orig_dim == 3:
+ pose_map = pose_map.unsqueeze(0).unsqueeze(0)
+ elif orig_dim == 4:
+ pose_map = pose_map.unsqueeze(0)
+
+ if pose_type == 'open':
+ # If input is only openpose, remove densepose part.
+ pose_map = pose_map[:, :, 3:]
+
+ elif remove_face_labels and do_remove:
+ # Remove face part for densepose input.
+ densepose, openpose = pose_map[:, :, :3], pose_map[:, :, 3:]
+ face_mask = get_face_mask(pose_map[:, :, 2]).unsqueeze(2)
+ pose_map = torch.cat([densepose * (1 - face_mask) - face_mask,
+ openpose], dim=2)
+
+ if orig_dim == 3:
+ pose_map = pose_map[0, 0]
+ elif orig_dim == 4:
+ pose_map = pose_map[0]
+ return pose_map
+
+
+def normalize_faces(keypoints, ref_keypoints,
+ dist_scale_x=None, dist_scale_y=None):
+ r"""Normalize face keypoints w.r.t. the reference face keypoints.
+
+ Args:
+ keypoints (Kx2 numpy array): target facial keypoints.
+ ref_keypoints (Kx2 numpy array): reference facial keypoints.
+ Returns:
+ keypoints (Kx2 numpy array): normalized facial keypoints.
+ """
+ if keypoints.shape[0] == 68:
+ central_keypoints = [8]
+ add_upper_face = False
+ part_list = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12],
+ [5, 11], [6, 10], [7, 9, 8],
+ [17, 26], [18, 25], [19, 24], [20, 23], [21, 22],
+ [27], [28], [29], [30], [31, 35], [32, 34], [33],
+ [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46],
+ [48, 54], [49, 53], [50, 52], [51], [55, 59], [56, 58],
+ [57],
+ [60, 64], [61, 63], [62], [65, 67], [66]
+ ]
+ if add_upper_face:
+ part_list += [[68, 82], [69, 81], [70, 80], [71, 79], [72, 78],
+ [73, 77], [74, 76, 75]]
+ elif keypoints.shape[0] == 126:
+ central_keypoints = [16]
+ part_list = [[i] for i in range(126)]
+ else:
+ raise ValueError('Input keypoints type not supported.')
+
+ face_cen = np.mean(keypoints[central_keypoints, :], axis=0)
+ ref_face_cen = np.mean(ref_keypoints[central_keypoints, :], axis=0)
+
+ def get_mean_dists(pts, face_cen):
+ r"""Get the mean xy distances of keypoints wrt face center."""
+ mean_dists_x, mean_dists_y = [], []
+ pts_cen = np.mean(pts, axis=0)
+ for p, pt in enumerate(pts):
+ mean_dists_x.append(np.linalg.norm(pt - pts_cen))
+ mean_dists_y.append(np.linalg.norm(pts_cen - face_cen))
+ mean_dist_x = sum(mean_dists_x) / len(mean_dists_x) + 1e-3
+ mean_dist_y = sum(mean_dists_y) / len(mean_dists_y) + 1e-3
+ return mean_dist_x, mean_dist_y
+
+ if dist_scale_x is None:
+ dist_scale_x, dist_scale_y = [None] * len(part_list), \
+ [None] * len(part_list)
+
+ for i, pts_idx in enumerate(part_list):
+ pts = keypoints[pts_idx]
+ if dist_scale_x[i] is None:
+ ref_pts = ref_keypoints[pts_idx]
+ mean_dist_x, mean_dist_y = get_mean_dists(pts, face_cen)
+ ref_dist_x, ref_dist_y = get_mean_dists(ref_pts, ref_face_cen)
+
+ dist_scale_x[i] = ref_dist_x / mean_dist_x
+ dist_scale_y[i] = ref_dist_y / mean_dist_y
+
+ pts_cen = np.mean(pts, axis=0)
+ pts = (pts - pts_cen) * dist_scale_x[i] + \
+ (pts_cen - face_cen) * dist_scale_y[i] + face_cen
+ keypoints[pts_idx] = pts
+ return keypoints, [dist_scale_x, dist_scale_y]
+
+
+def crop_face_from_output(data_cfg, image, input_label, crop_smaller=0):
+ r"""Crop out the face region of the image (and resize if necessary to feed
+ into generator/discriminator).
+
+ Args:
+ data_cfg (obj): Data configuration.
+ image (NxC1xHxW tensor or list of tensors): Image to crop.
+ input_label (NxC2xHxW tensor): Input label map.
+ crop_smaller (int): Number of pixels to crop slightly smaller region.
+ Returns:
+ output (NxC1xHxW tensor or list of tensors): Cropped image.
+ """
+ if type(image) == list:
+ return [crop_face_from_output(data_cfg, im, input_label, crop_smaller)
+ for im in image]
+
+ output = None
+ face_size = image.shape[-2] // 32 * 8
+ for i in range(input_label.size(0)):
+ ys, ye, xs, xe = get_face_bbox_for_output(data_cfg,
+ input_label[i:i + 1],
+ crop_smaller=crop_smaller)
+ output_i = F.interpolate(image[i:i + 1, -3:, ys:ye, xs:xe],
+ size=(face_size, face_size), mode='bilinear',
+ align_corners=True)
+ # output_i = image[i:i + 1, -3:, ys:ye, xs:xe]
+ output = torch.cat([output, output_i]) if i != 0 else output_i
+ return output
+
+
+def get_face_bbox_for_output(data_cfg, pose, crop_smaller=0):
+ r"""Get pixel coordinates of the face bounding box.
+
+ Args:
+ data_cfg (obj): Data configuration.
+ pose (NxCxHxW tensor): Pose label map.
+ crop_smaller (int): Number of pixels to crop slightly smaller region.
+ Returns:
+ output (list of int): Face bbox.
+ """
+ if pose.dim() == 3:
+ pose = pose.unsqueeze(0)
+ elif pose.dim() == 5:
+ pose = pose[-1, -1:]
+ _, _, h, w = pose.size()
+
+ use_openpose = 'pose_maps-densepose' not in data_cfg.input_labels
+ if use_openpose: # Use openpose face keypoints to identify face region.
+ for input_type in data_cfg.input_types:
+ if 'poses-openpose' in input_type:
+ num_ch = input_type['poses-openpose'].num_channels
+ if num_ch > 3:
+ face = (pose[:, -1] > 0).nonzero(as_tuple=False)
+ else:
+ raise ValueError('Not implemented yet.')
+ else: # Use densepose labels.
+ face = (pose[:, 2] > 0.9).nonzero(as_tuple=False)
+
+ ylen = xlen = h // 32 * 8
+ if face.size(0):
+ y, x = face[:, 1], face[:, 2]
+ ys, ye = y.min().item(), y.max().item()
+ xs, xe = x.min().item(), x.max().item()
+ if use_openpose:
+ xc, yc = (xs + xe) // 2, (ys * 3 + ye * 2) // 5
+ ylen = int((xe - xs) * 2.5)
+ else:
+ xc, yc = (xs + xe) // 2, (ys + ye) // 2
+ ylen = int((ye - ys) * 1.25)
+ ylen = xlen = min(w, max(32, ylen))
+ yc = max(ylen // 2, min(h - 1 - ylen // 2, yc))
+ xc = max(xlen // 2, min(w - 1 - xlen // 2, xc))
+ else:
+ yc = h // 4
+ xc = w // 2
+
+ ys, ye = yc - ylen // 2, yc + ylen // 2
+ xs, xe = xc - xlen // 2, xc + xlen // 2
+ if crop_smaller != 0: # Crop slightly smaller region inside face.
+ ys += crop_smaller
+ xs += crop_smaller
+ ye -= crop_smaller
+ xe -= crop_smaller
+ return [ys, ye, xs, xe]
+
+
+def crop_hand_from_output(data_cfg, image, input_label):
+ r"""Crop out the hand region of the image.
+
+ Args:
+ data_cfg (obj): Data configuration.
+ image (NxC1xHxW tensor or list of tensors): Image to crop.
+ input_label (NxC2xHxW tensor): Input label map.
+ Returns:
+ output (NxC1xHxW tensor or list of tensors): Cropped image.
+ """
+ if type(image) == list:
+ return [crop_hand_from_output(data_cfg, im, input_label)
+ for im in image]
+
+ output = None
+ for i in range(input_label.size(0)):
+ coords = get_hand_bbox_for_output(data_cfg, input_label[i:i + 1])
+ if coords:
+ for coord in coords:
+ ys, ye, xs, xe = coord
+ output_i = image[i:i + 1, -3:, ys:ye, xs:xe]
+ output = torch.cat([output, output_i]) \
+ if output is not None else output_i
+ return output
+
+
+def get_hand_bbox_for_output(data_cfg, pose):
+ r"""Get coordinates of the hand bounding box.
+
+ Args:
+ data_cfg (obj): Data configuration.
+ pose (NxCxHxW tensor): Pose label map.
+ Returns:
+ output (list of int): Hand bbox.
+ """
+ if pose.dim() == 3:
+ pose = pose.unsqueeze(0)
+ elif pose.dim() == 5:
+ pose = pose[-1, -1:]
+ _, _, h, w = pose.size()
+ ylen = xlen = h // 64 * 8
+
+ coords = []
+ colors = [[0.95, 0.5, 0.95], [0.95, 0.95, 0.5]]
+ for i, color in enumerate(colors):
+ if pose.shape[1] > 6: # Using one-hot encoding for openpose.
+ idx = -3 if i == 0 else -2
+ hand = (pose[:, idx] == 1).nonzero(as_tuple=False)
+ else:
+ raise ValueError('Not implemented yet.')
+ if hand.size(0):
+ y, x = hand[:, 1], hand[:, 2]
+ ys, ye, xs, xe = y.min().item(), y.max().item(), \
+ x.min().item(), x.max().item()
+ xc, yc = (xs + xe) // 2, (ys + ye) // 2
+ yc = max(ylen // 2, min(h - 1 - ylen // 2, yc))
+ xc = max(xlen // 2, min(w - 1 - xlen // 2, xc))
+ ys, ye, xs, xe = yc - ylen // 2, yc + ylen // 2, \
+ xc - xlen // 2, xc + xlen // 2
+ coords.append([ys, ye, xs, xe])
+ return coords
+
+
+def pre_process_densepose(pose_cfg, pose_map, is_infer=False):
+ r"""Pre-process the DensePose part of input label map.
+
+ Args:
+ pose_cfg (obj): Pose data configuration.
+ pose_map (NxCxHxW tensor): Pose label map.
+ is_infer (bool): Is doing inference.
+ Returns:
+ pose_map (NxCxHxW tensor): Processed pose label map.
+ """
+ part_map = pose_map[:, :, 2] * 255 # should be within [0-24]
+ assert (part_map >= 0).all() and (part_map < 25).all()
+
+ # Randomly drop some body part during training.
+ if not is_infer:
+ random_drop_prob = getattr(pose_cfg, 'random_drop_prob', 0)
+ else:
+ random_drop_prob = 0
+ if random_drop_prob > 0:
+ densepose_map = pose_map[:, :, :3]
+ for part_id in range(1, 25):
+ if (random.random() < random_drop_prob):
+ part_mask = abs(part_map - part_id) < 0.1
+ densepose_map[part_mask.unsqueeze(2).expand_as(
+ densepose_map)] = 0
+ pose_map[:, :, :3] = densepose_map
+
+ # Renormalize the DensePose channel from [0, 24] to [0, 255].
+ pose_map[:, :, 2] = pose_map[:, :, 2] * (255 / 24)
+ # Normalize from [0, 1] to [-1, 1].
+ pose_map = pose_map * 2 - 1
+ return pose_map
+
+
+def random_roll(tensors):
+ r"""Randomly roll the input tensors along x and y dimensions. Also randomly
+ flip the tensors.
+
+ Args:
+ tensors (list of 4D tensors): Input tensors.
+ Returns:
+ output (list of 4D tensors): Rolled tensors.
+ """
+ h, w = tensors[0].shape[2:]
+ ny = np.random.choice([np.random.randint(h//16),
+ h-np.random.randint(h//16)])
+ nx = np.random.choice([np.random.randint(w//16),
+ w-np.random.randint(w//16)])
+ flip = np.random.rand() > 0.5
+ return [roll(t, ny, nx, flip) for t in tensors]
+
+
+def roll(t, ny, nx, flip=False):
+ r"""Roll and flip the tensor by specified amounts.
+
+ Args:
+ t (4D tensor): Input tensor.
+ ny (int): Amount to roll along y dimension.
+ nx (int): Amount to roll along x dimension.
+ flip (bool): Whether to flip input.
+ Returns:
+ t (4D tensor): Output tensor.
+ """
+ t = torch.cat([t[:, :, -ny:], t[:, :, :-ny]], dim=2)
+ t = torch.cat([t[:, :, :, -nx:], t[:, :, :, :-nx]], dim=3)
+ if flip:
+ t = torch.flip(t, dims=[3])
+ return t
+
+
+def detach(output):
+ r"""Detach tensors in the dict.
+
+ Args:
+ output (dict): Output dict.
+ Returns:
+ output (dict): Detached output dict.
+ """
+ if type(output) == dict:
+ new_dict = dict()
+ for k, v in output.items():
+ new_dict[k] = detach(v)
+ return new_dict
+ elif type(output) == torch.Tensor:
+ return output.detach()
+ return output
diff --git a/imaginaire/model_utils/gancraft/camctl.py b/imaginaire/model_utils/gancraft/camctl.py
new file mode 100644
index 0000000000000000000000000000000000000000..26e4ab674b7f6b44d484d06c661267ed7ce69d56
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/camctl.py
@@ -0,0 +1,640 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import torch
+
+
+class EvalCameraController:
+ def __init__(self, voxel, maxstep=128, pattern=0, cam_ang=73, smooth_decay_multiplier=1.0):
+ self.voxel = voxel
+ self.maxstep = maxstep
+ self.camera_poses = [] # ori, dir, up, f
+ circle = torch.linspace(0, 2*np.pi, steps=maxstep)
+ size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
+ # Shrink the circle a bit.
+ shift = size * 0.2
+ size = size * 0.8
+
+ if pattern == 0:
+ height_history = []
+ # Calculate smooth height.
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 70,
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+ height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
+
+ # Filtfilt
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 70,
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+
+ farpoint[0] = height_history[i]
+
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+ elif pattern == 1:
+ zoom = torch.linspace(1.0, 0.25, steps=maxstep)
+ height_history = []
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 90,
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+
+ height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
+
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 90,
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+
+ farpoint[0] = height_history[i]
+
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)*zoom[i]) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+ elif pattern == 2:
+ move = torch.linspace(1.0, 0.2, steps=maxstep)
+ height_history = []
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 90,
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+ height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
+
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 90,
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+ farpoint[0] = height_history[i]
+
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+ elif pattern == 3:
+ move = torch.linspace(0.75, 0.2, steps=maxstep)
+ height_history = []
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 70,
+ torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+ height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
+
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 70,
+ torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+ farpoint[0] = height_history[i]
+
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(2)/2 + shift])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+ elif pattern == 4:
+ move = torch.linspace(1.0, 0.5, steps=maxstep)
+ height_history = []
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 90,
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+ height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0]))
+
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 90,
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+ farpoint[0] = height_history[i]
+
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+ # look outward
+ elif pattern == 5:
+ move = torch.linspace(1.0, 0.5, steps=maxstep)
+ height_history = []
+ for i in range(maxstep):
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+ height_history.append(self._get_height(nearpoint[1], nearpoint[2], nearpoint[0]))
+
+ height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier)
+
+ for i in range(maxstep):
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+ nearpoint[0] = height_history[i]
+
+ farpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+ cam_ori = self.voxel.world2local(nearpoint)
+ cam_dir = self.voxel.world2local(farpoint - nearpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+ # Rise
+ elif pattern == 6:
+ shift = 0
+ lift = torch.linspace(0.0, 200.0, steps=maxstep)
+ zoom = torch.linspace(0.8, 1.6, steps=maxstep)
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 80+lift[i],
+ torch.sin(circle[i]/4)*size*0.2 + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]/4)*size*0.2 + voxel.voxel_t.size(2)/2 + shift])
+
+ farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
+
+ nearpoint = torch.tensor([
+ 65,
+ torch.sin(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+ # 45deg
+ elif pattern == 7:
+ rad = torch.tensor([np.deg2rad(45).astype(np.float32)])
+ size = 1536
+ for i in range(maxstep):
+ farpoint = torch.tensor([
+ 61+size,
+ torch.sin(rad)*size + voxel.voxel_t.size(1)/2,
+ torch.cos(rad)*size + voxel.voxel_t.size(2)/2])
+
+ nearpoint = torch.tensor([
+ 61,
+ voxel.voxel_t.size(1)/2,
+ voxel.voxel_t.size(2)/2])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(19.5/2)) # about 50mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+ def _get_height(self, loc0, loc1, minheight):
+ loc0 = int(loc0)
+ loc1 = int(loc1)
+ height = minheight
+ for dx in range(-3, 4):
+ for dy in range(-3, 4):
+ if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \
+ (loc1+dy) >= self.voxel.heightmap.shape[1]:
+ height = max(height, minheight)
+ else:
+ height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2)
+ return height
+
+ def filtfilt(self, height_history, decay=0.2):
+ # Filtfilt
+ height_history2 = []
+ maxstep = len(height_history)
+ prev_height = height_history[0]
+ for i in range(maxstep):
+ prev_height = prev_height - decay
+ if prev_height < height_history[i]:
+ prev_height = height_history[i]
+ height_history2.append(prev_height)
+ prev_height = height_history[-1]
+ for i in range(maxstep-1, -1, -1):
+ prev_height = prev_height - decay
+ if prev_height < height_history[i]:
+ prev_height = height_history[i]
+ height_history2[i] = max(prev_height, height_history2[i])
+ return height_history2
+
+ def __len__(self):
+ return len(self.camera_poses)
+
+ def __getitem__(self, idx):
+ return self.camera_poses[idx]
+
+
+class TourCameraController:
+ def __init__(self, voxel, maxstep=128):
+ self.voxel = voxel
+ self.maxstep = maxstep
+ self.camera_poses = [] # ori, dir, up, f
+ circle = torch.linspace(0, 2*np.pi, steps=maxstep//4)
+ size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
+ # Shrink the circle a bit
+ shift = size * 0.2
+ size = size * 0.8
+
+ for i in range(maxstep//4):
+ farpoint = torch.tensor([
+ 70,
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+
+ farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
+
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+ zoom = torch.linspace(1.0, 0.25, steps=maxstep//4)
+ for i in range(maxstep//4):
+ farpoint = torch.tensor([
+ 90,
+ torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift])
+
+ farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
+
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+ move = torch.linspace(1.0, 0.2, steps=maxstep//4)
+ for i in range(maxstep//4):
+ farpoint = torch.tensor([
+ 90,
+ torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift])
+
+ farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
+
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+ lift = torch.linspace(0.0, 200.0, steps=maxstep//4)
+ zoom = torch.linspace(0.6, 1.2, steps=maxstep//4)
+ for i in range(maxstep//4):
+ farpoint = torch.tensor([
+ 80+lift[i],
+ torch.sin(circle[i])*size*0.2 + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i])*size*0.2 + voxel.voxel_t.size(2)/2 + shift])
+
+ farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0])
+
+ nearpoint = torch.tensor([
+ 60,
+ torch.sin(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift,
+ torch.cos(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift])
+ cam_ori = self.voxel.world2local(farpoint)
+ cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov
+
+ self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f))
+
+ def _get_height(self, loc0, loc1, minheight):
+ loc0 = int(loc0)
+ loc1 = int(loc1)
+ height = minheight
+ for dx in range(-3, 4):
+ for dy in range(-3, 4):
+ if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \
+ (loc1+dy) >= self.voxel.heightmap.shape[1]:
+ height = max(height, minheight)
+ else:
+ height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2)
+ return height
+
+ def __len__(self):
+ return len(self.camera_poses)
+
+ def __getitem__(self, idx):
+ return self.camera_poses[idx]
+
+
+def rand_camera_pose_birdseye(voxel, border=128):
+ r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up
+ Assuming [Y X Z] coordinate. Y is negative gravity direction.
+ The camera pose is converted into the voxel coordinate system so that it can be used directly for rendering
+ 1. Uniformly sample a point on the upper hemisphere of a unit sphere, as cam_ori.
+ 2. Set cam_dir to be from cam_ori to the origin
+ 3. cam_up is always pointing towards sky
+ 4. move cam_ori to random place according to voxel size
+ """
+ cam_dir = torch.randn(3, dtype=torch.float32)
+ cam_dir = cam_dir / torch.sqrt(torch.sum(cam_dir*cam_dir))
+ cam_dir[0] = -torch.abs(cam_dir[0])
+ cam_up = torch.tensor([1, 0, 0], dtype=torch.float32)
+
+ # generate camera lookat target
+ r = np.random.rand(2)
+ r[0] *= voxel.voxel_t.size(1)-border-border
+ r[1] *= voxel.voxel_t.size(2)-border-border
+ r = r + border
+ y = voxel.heightmap[int(r[0]+0.5), int(r[1]+0.5)] + (np.random.rand(1)-0.5) * 5
+ cam_target = torch.tensor([y, r[0], r[1]], dtype=torch.float32)
+ cam_ori = cam_target - cam_dir * (np.random.rand(1).item() * 100)
+ cam_ori[0] = max(voxel.heightmap[int(cam_ori[1]+0.5), int(cam_ori[2]+0.5)]+2, cam_ori[0])
+ # Translate to voxel coordinate
+ cam_ori = voxel.world2local(cam_ori)
+ cam_dir = voxel.world2local(cam_dir, is_vec=True)
+ cam_up = voxel.world2local(cam_up, is_vec=True)
+
+ return cam_ori, cam_dir, cam_up
+
+
+def get_neighbor_height(heightmap, loc0, loc1, minheight, neighbor_size=7):
+ loc0 = int(loc0)
+ loc1 = int(loc1)
+ height = 0
+ for dx in range(-neighbor_size//2, neighbor_size//2+1):
+ for dy in range(-neighbor_size//2, neighbor_size//2+1):
+ if (loc0+dx) < 0 or (loc0+dx) >= heightmap.shape[0] or (loc1+dy) < 0 or (loc1+dy) >= heightmap.shape[1]:
+ height = max(height, minheight)
+ else:
+ height = max(minheight, heightmap[loc0+dx, loc1+dy] + 2)
+ return height
+
+
+def rand_camera_pose_firstperson(voxel, border=128):
+ r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up
+ """
+ r = np.random.rand(5)
+ r[0] *= voxel.voxel_t.size(1)-border-border
+ r[1] *= voxel.voxel_t.size(2)-border-border
+ r[0] = r[0] + border
+ r[1] = r[1] + border
+
+ y = get_neighbor_height(voxel.heightmap, r[0], r[1], 0) + np.random.rand(1) * 15
+
+ cam_ori = torch.tensor([y, r[0], r[1]], dtype=torch.float32)
+
+ rand_ang_h = r[2] * 2 * np.pi
+ cam_target = torch.tensor([0, cam_ori[1]+np.sin(rand_ang_h)*border*r[4], cam_ori[2] +
+ np.cos(rand_ang_h)*border*r[4]], dtype=torch.float32)
+ cam_target[0] = get_neighbor_height(voxel.heightmap, cam_target[1],
+ cam_target[2], 0, neighbor_size=1) - 2 + r[3] * 10
+
+ cam_dir = cam_target - cam_ori
+
+ cam_up = torch.tensor([1, 0, 0], dtype=torch.float32)
+
+ cam_ori = voxel.world2local(cam_ori)
+ cam_dir = voxel.world2local(cam_dir, is_vec=True)
+ cam_up = voxel.world2local(cam_up, is_vec=True)
+
+ return cam_ori, cam_dir, cam_up
+
+
+def rand_camera_pose_thridperson(voxel, border=96):
+ r = torch.rand(2)
+ r[0] *= voxel.voxel_t.size(1)
+ r[1] *= voxel.voxel_t.size(2)
+ rand_height = 60 + torch.rand(1) * 40
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5)
+ farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+ r = torch.rand(2)
+ r[0] *= voxel.voxel_t.size(1) - border - border
+ r[1] *= voxel.voxel_t.size(2) - border - border
+ r[0] = r[0] + border
+ r[1] = r[1] + border
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5
+ nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+ cam_ori = voxel.world2local(farpoint)
+ cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True)
+
+ return cam_ori, cam_dir, cam_up
+
+
+def rand_camera_pose_thridperson2(voxel, border=48):
+ r = torch.rand(2)
+ r[0] *= voxel.voxel_t.size(1) - border - border
+ r[1] *= voxel.voxel_t.size(2) - border - border
+ r[0] = r[0] + border
+ r[1] = r[1] + border
+ rand_height = 60 + torch.rand(1) * 40
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5)
+ farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+ r = torch.rand(2)
+ r[0] *= voxel.voxel_t.size(1) - border - border
+ r[1] *= voxel.voxel_t.size(2) - border - border
+ r[0] = r[0] + border
+ r[1] = r[1] + border
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5
+ nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+ # Random Up vector (tilt a little bit)
+ # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
+ up = torch.randn(3) * 0.02
+ up[0] = 1.0
+ up = up / up.norm(p=2)
+ cam_ori = voxel.world2local(farpoint)
+ cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = voxel.world2local(up, is_vec=True)
+
+ return cam_ori, cam_dir, cam_up
+
+
+def rand_camera_pose_thridperson3(voxel, border=64):
+ r"""Attempting to solve the camera too close to wall problem and the lack of aerial poses."""
+ r = torch.rand(2)
+ r[0] *= voxel.voxel_t.size(1) - border - border
+ r[1] *= voxel.voxel_t.size(2) - border - border
+ r[0] = r[0] + border
+ r[1] = r[1] + border
+ rand_height = 60 + torch.rand(1) * 40
+ if torch.rand(1) > 0.8:
+ rand_height = 60 + torch.rand(1) * 60
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=7)
+ farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+ r = torch.rand(2)
+ r[0] *= voxel.voxel_t.size(1) - border - border
+ r[1] *= voxel.voxel_t.size(2) - border - border
+ r[0] = r[0] + border
+ r[1] = r[1] + border
+ rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=3) - 5
+ nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32)
+
+ # Random Up vector (tilt a little bit)
+ # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
+ up = torch.randn(3) * 0.02
+ up[0] = 1.0
+ up = up / up.norm(p=2)
+ # print(up)
+ cam_ori = voxel.world2local(farpoint)
+ cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = voxel.world2local(up, is_vec=True)
+
+ return cam_ori, cam_dir, cam_up
+
+
+def rand_camera_pose_tour(voxel):
+ size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
+ center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2]
+
+ rnd = torch.rand(8)
+
+ rnd_deg = torch.rand(1) * 2 * np.pi
+ far_radius = rnd[0]*0.8+0.2
+ far_height = rnd[1]*30 + 60
+ farpoint = torch.tensor([
+ far_height,
+ torch.sin(rnd_deg)*size*far_radius + center[0],
+ torch.cos(rnd_deg)*size*far_radius + center[1]])
+
+ farpoint[0] = get_neighbor_height(voxel.heightmap, farpoint[1], farpoint[2], farpoint[0], neighbor_size=7)
+
+ near_radius = far_radius * rnd[2]
+ near_shift_rad = np.pi*(rnd[3]-0.5)
+ near_height = 60 + rnd[4] * 10
+ nearpoint = torch.tensor([
+ near_height,
+ torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0],
+ torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]])
+
+ # Random Up vector (tilt a little bit)
+ # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
+ up = torch.randn(3) * 0.02
+ up[0] = 1.0
+ up = up / up.norm(p=2)
+ cam_ori = voxel.world2local(farpoint)
+ cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True)
+ cam_up = voxel.world2local(up, is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25)) # about 24mm fov
+
+ return cam_ori, cam_dir, cam_up, cam_f
+
+# Look from center to outward
+
+
+def rand_camera_pose_insideout(voxel):
+ size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2
+ center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2]
+
+ rnd = torch.rand(8)
+
+ rnd_deg = torch.rand(1) * 2 * np.pi
+ far_radius = rnd[0]*0.8+0.2
+ far_height = rnd[1]*10 + 60
+ farpoint = torch.tensor([
+ far_height,
+ torch.sin(rnd_deg)*size*far_radius + center[0],
+ torch.cos(rnd_deg)*size*far_radius + center[1]])
+
+ near_radius = far_radius * rnd[2]
+ near_shift_rad = np.pi*(rnd[3]-0.5)
+ near_height = 60 + rnd[4] * 30
+ nearpoint = torch.tensor([
+ near_height,
+ torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0],
+ torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]])
+
+ nearpoint[0] = get_neighbor_height(voxel.heightmap, nearpoint[1], nearpoint[2], nearpoint[0], neighbor_size=7)
+
+ # Random Up vector (tilt a little bit)
+ # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176
+ up = torch.randn(3) * 0.02
+ up[0] = 1.0
+ up = up / up.norm(p=2)
+ cam_ori = voxel.world2local(nearpoint)
+ cam_dir = voxel.world2local(farpoint-nearpoint, is_vec=True)
+ cam_up = voxel.world2local(up, is_vec=True)
+ cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25)) # about 24mm fov
+
+ return cam_ori, cam_dir, cam_up, cam_f
diff --git a/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv b/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv
new file mode 100644
index 0000000000000000000000000000000000000000..ba061b7a5bda98899c8b1f653d5f204fccfa38e4
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv
@@ -0,0 +1,182 @@
+person,#00AC0D
+bicycle,#012F47
+car,#0275B8
+motorcycle,#03C098
+airplane,#04434F
+bus,#05FB29
+train,#06C312
+truck,#076728
+boat,#0809B6
+traffic-light,#09D3CF
+fire-hydrant,#0A150B
+street-sign,#0BF2A6
+stop-sign,#0C246F
+parking-meter,#0D575D
+bench,#0E46F9
+bird,#0FD881
+cat,#1058DF
+dog,#118C76
+horse,#123A2C
+sheep,#13C1D8
+cow,#14E67D
+elephant,#152718
+bear,#165743
+zebra,#17AED2
+giraffe,#1858EF
+hat,#195103
+backpack,#1AA5EA
+umbrella,#1B19CC
+shoe,#1C4DE6
+eye-glasses,#1D4823
+handbag,#1E09D6
+tie,#1F94FE
+suitcase,#2073BD
+frisbee,#21D0C5
+skis,#22F3D7
+snowboard,#23C52B
+sports-ball,#24FE20
+kite,#254F0B
+baseball-bat,#26AF68
+baseball-glove,#27C0D4
+skateboard,#28528A
+surfboard,#2963B6
+tennis-racket,#2AD8EB
+bottle,#2BB1A5
+plate,#2CF37D
+wine-glass,#2D1D9C
+cup,#2E936F
+fork,#2F93E8
+knife,#308E02
+spoon,#31A71B
+bowl,#3220D3
+banana,#33C1D9
+apple,#340997
+sandwich,#35B935
+orange,#367F33
+broccoli,#3720AE
+carrot,#381F94
+hot-dog,#39CAB5
+pizza,#3AF41D
+donut,#3B9743
+cake,#3CA323
+chair,#3DFE27
+couch,#3ECB89
+potted-plant,#3F7249
+bed,#40B729
+mirror,#411C97
+dining-table,#422283
+window,#43802E
+desk,#4480DA
+toilet,#45A4B2
+door,#46356C
+tv,#478503
+laptop,#48261F
+mouse,#49E809
+remote,#4AF48A
+keyboard,#4B111B
+cell-phone,#4C4FAD
+microwave,#4D84C7
+oven,#4E69A7
+toaster,#4F2A3D
+sink,#50BA55
+refrigerator,#511F61
+blender,#52782C
+book,#530122
+clock,#5441A2
+vase,#55E758
+scissors,#56A921
+teddy-bear,#573985
+hair-drier,#5823E8
+toothbrush,#5966FF
+hair-brush,#5A7724
+banner,#5B0B00
+blanket,#5CAECB
+branch,#5D5222
+bridge,#5E5BC5
+building-other,#5F807E
+bush,#606E32
+cabinet,#6163FE
+cage,#623550
+cardboard,#638CBE
+carpet,#647988
+ceiling-other,#65AABD
+ceiling-tile,#665481
+cloth,#67CBD1
+clothes,#684470
+clouds,#696969
+counter,#6AC478
+cupboard,#6B2F5B
+curtain,#6C7FA8
+desk-stuff,#6DF474
+dirt,#6E6E28
+door-stuff,#6FCCB0
+fence,#706419
+floor-marble,#71B443
+floor-other,#72E867
+floor-stone,#734EFC
+floor-tile,#748F23
+floor-wood,#759472
+flower,#760000
+fog,#77BA1D
+food-other,#7817F1
+fruit,#79CF21
+furniture-other,#7A8D92
+grass,#7BC800
+gravel,#7C32C8
+ground-other,#7D3054
+hill,#7EC864
+house,#7F4502
+leaves,#80A945
+light,#81A365
+mat,#82C08C
+metal,#835F2C
+mirror-stuff,#84C575
+moss,#855EFD
+mountain,#869664
+mud,#87716F
+napkin,#88B25B
+net,#892455
+paper,#8AA2A7
+pavement,#8B3027
+pillow,#8C5DCB
+plant,#8DE61E
+plastic,#8E629E
+platform,#8F2A91
+playingfield,#90CDC6
+railing,#9170C7
+railroad,#92E712
+river,#9364C8
+road,#946E28
+rock,#956432
+roof,#9600B1
+rug,#978A29
+salad,#98725D
+sand,#999900
+sea,#9AC6DA
+shelf,#9B7FC9
+sky,#9CEEDD
+skyscraper,#9DBBF2
+snow,#9E9EAA
+solid-other,#9F79DB
+stairs,#A06249
+stone,#A1A164
+straw,#A2A3EB
+structural,#A3DED1
+table,#A47B69
+tent,#A5C3BA
+textile-other,#A65280
+towel,#A7AED6
+tree,#A8C832
+vegetable,#A99410
+wall-brick,#AAD16A
+wall-concrete,#AB32A4
+wall-other,#AC9B5E
+wall-panel,#AD0E18
+wall-stone,#AE2974
+wall-tile,#AF3ABF
+wall-wood,#B0C1C3
+water,#B1C8FF
+waterdrops,#B20A88
+window-blind,#B356B8
+window-other,#B42B5B
+wood,#B57B00
diff --git a/imaginaire/model_utils/gancraft/gaugan_reduction.csv b/imaginaire/model_utils/gancraft/gaugan_reduction.csv
new file mode 100644
index 0000000000000000000000000000000000000000..a49ad38c2f8bddd04a12a08a09b1bbdab5944fc2
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/gaugan_reduction.csv
@@ -0,0 +1,182 @@
+person,ignore
+bicycle,ignore
+car,ignore
+motorcycle,ignore
+airplane,ignore
+bus,ignore
+train,ignore
+truck,ignore
+boat,ignore
+traffic-light,ignore
+fire-hydrant,ignore
+street-sign,ignore
+stop-sign,ignore
+parking-meter,ignore
+bench,ignore
+bird,ignore
+cat,ignore
+dog,ignore
+horse,ignore
+sheep,ignore
+cow,ignore
+elephant,ignore
+bear,ignore
+zebra,ignore
+giraffe,ignore
+hat,ignore
+backpack,ignore
+umbrella,ignore
+shoe,ignore
+eye-glasses,ignore
+handbag,ignore
+tie,ignore
+suitcase,ignore
+frisbee,ignore
+skis,ignore
+snowboard,ignore
+sports-ball,ignore
+kite,ignore
+baseball-bat,ignore
+baseball-glove,ignore
+skateboard,ignore
+surfboard,ignore
+tennis-racket,ignore
+bottle,ignore
+plate,ignore
+wine-glass,ignore
+cup,ignore
+fork,ignore
+knife,ignore
+spoon,ignore
+bowl,ignore
+banana,ignore
+apple,ignore
+sandwich,ignore
+orange,ignore
+broccoli,ignore
+carrot,ignore
+hot-dog,ignore
+pizza,ignore
+donut,ignore
+cake,ignore
+chair,ignore
+couch,ignore
+potted-plant,ignore
+bed,ignore
+mirror,ignore
+dining-table,ignore
+window,ignore
+desk,ignore
+toilet,ignore
+door,ignore
+tv,ignore
+laptop,ignore
+mouse,ignore
+remote,ignore
+keyboard,ignore
+cell-phone,ignore
+microwave,ignore
+oven,ignore
+toaster,ignore
+sink,ignore
+refrigerator,ignore
+blender,ignore
+book,ignore
+clock,ignore
+vase,ignore
+scissors,ignore
+teddy-bear,ignore
+hair-drier,ignore
+toothbrush,ignore
+hair-brush,ignore
+banner,ignore
+blanket,ignore
+branch,tree
+bridge,ignore
+building-other,ignore
+bush,tree
+cabinet,ignore
+cage,ignore
+cardboard,ignore
+carpet,ignore
+ceiling-other,ignore
+ceiling-tile,ignore
+cloth,ignore
+clothes,ignore
+clouds,sky
+counter,ignore
+cupboard,ignore
+curtain,ignore
+desk-stuff,ignore
+dirt,dirt
+door-stuff,ignore
+fence,ignore
+floor-marble,ignore
+floor-other,ignore
+floor-stone,ignore
+floor-tile,ignore
+floor-wood,ignore
+flower,flower
+fog,sky
+food-other,ignore
+fruit,ignore
+furniture-other,ignore
+grass,grass
+gravel,gravel
+ground-other,ignore
+hill,grass
+house,ignore
+leaves,tree
+light,ignore
+mat,ignore
+metal,ignore
+mirror-stuff,ignore
+moss,grass
+mountain,grass
+mud,dirt
+napkin,ignore
+net,ignore
+paper,ignore
+pavement,ignore
+pillow,ignore
+plant,flower
+plastic,ignore
+platform,ignore
+playingfield,ignore
+railing,ignore
+railroad,ignore
+river,water
+road,ignore
+rock,rock
+roof,ignore
+rug,ignore
+salad,ignore
+sand,sand
+sea,water
+shelf,ignore
+sky,sky
+skyscraper,ignore
+snow,snow
+solid-other,ignore
+stairs,ignore
+stone,stone
+straw,grass
+structural,ignore
+table,ignore
+tent,ignore
+textile-other,ignore
+towel,ignore
+tree,tree
+vegetable,ignore
+wall-brick,ignore
+wall-concrete,ignore
+wall-other,ignore
+wall-panel,ignore
+wall-stone,ignore
+wall-tile,ignore
+wall-wood,ignore
+water,water
+waterdrops,ignore
+window-blind,ignore
+window-other,ignore
+wood,ignore
diff --git a/imaginaire/model_utils/gancraft/id2name_gg.csv b/imaginaire/model_utils/gancraft/id2name_gg.csv
new file mode 100644
index 0000000000000000000000000000000000000000..bb52afe4132cdae36494c08dab6ac4982f572386
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/id2name_gg.csv
@@ -0,0 +1,680 @@
+0,air,0,sky
+1,stone,7368816,stone
+2,granite,7368816,rock
+3,polished_granite,7368816,rock
+4,diorite,7368816,rock
+5,polished_diorite,7368816,rock
+6,andesite,7368816,rock
+7,polished_andesite,7368816,rock
+8,grass_block,8368696,grass
+9,dirt,9923917,dirt
+10,coarse_dirt,9923917,dirt
+11,podzol,9923917,dirt
+12,cobblestone,7368816,stone
+13,oak_planks,9402184,wood
+14,spruce_planks,9402184,wood
+15,birch_planks,9402184,wood
+16,jungle_planks,9402184,wood
+17,acacia_planks,9402184,wood
+18,dark_oak_planks,9402184,wood
+19,oak_sapling,31744,plant
+20,spruce_sapling,31744,plant
+21,birch_sapling,31744,plant
+22,jungle_sapling,31744,plant
+23,acacia_sapling,31744,plant
+24,dark_oak_sapling,31744,plant
+25,bedrock,7368816,rock
+26,water,4210943,water
+27,lava,16711680,
+28,sand,16247203,sand
+29,red_sand,16247203,sand
+30,gravel,16247203,gravel
+31,gold_ore,7368816,rock
+32,iron_ore,7368816,rock
+33,coal_ore,7368816,rock
+34,oak_log,9402184,tree
+35,spruce_log,9402184,tree
+36,birch_log,9402184,tree
+37,jungle_log,9402184,tree
+38,acacia_log,9402184,tree
+39,dark_oak_log,9402184,tree
+40,stripped_spruce_log,9402184,wood
+41,stripped_birch_log,9402184,wood
+42,stripped_jungle_log,9402184,wood
+43,stripped_acacia_log,9402184,wood
+44,stripped_dark_oak_log,9402184,wood
+45,stripped_oak_log,9402184,wood
+46,oak_wood,9402184,wood
+47,spruce_wood,9402184,wood
+48,birch_wood,9402184,wood
+49,jungle_wood,9402184,wood
+50,acacia_wood,9402184,wood
+51,dark_oak_wood,9402184,wood
+52,stripped_oak_wood,9402184,wood
+53,stripped_spruce_wood,9402184,wood
+54,stripped_birch_wood,9402184,wood
+55,stripped_jungle_wood,9402184,wood
+56,stripped_acacia_wood,9402184,wood
+57,stripped_dark_oak_wood,9402184,wood
+58,oak_leaves,31744,tree
+59,spruce_leaves,31744,tree
+60,birch_leaves,31744,tree
+61,jungle_leaves,31744,tree
+62,acacia_leaves,31744,tree
+63,dark_oak_leaves,31744,tree
+64,sponge,15066419,
+65,wet_sponge,15066419,
+66,glass,0,
+67,lapis_ore,7368816,
+68,lapis_block,10987431,
+69,dispenser,7368816,
+70,sandstone,7368816,sand
+71,chiseled_sandstone,7368816,sand
+72,cut_sandstone,7368816,sand
+73,note_block,9402184,
+74,white_bed,13092807,
+75,orange_bed,13092807,
+76,magenta_bed,13092807,
+77,light_blue_bed,13092807,
+78,yellow_bed,13092807,
+79,lime_bed,13092807,
+80,pink_bed,13092807,
+81,gray_bed,13092807,
+82,light_gray_bed,13092807,
+83,cyan_bed,13092807,
+84,purple_bed,13092807,
+85,blue_bed,13092807,
+86,brown_bed,13092807,
+87,green_bed,13092807,
+88,red_bed,13092807,
+89,black_bed,13092807,
+90,powered_rail,0,
+91,detector_rail,0,
+92,sticky_piston,7368816,
+93,cobweb,13092807,
+94,grass,31744,grass
+95,fern,31744,grass
+96,dead_bush,31744,grass
+97,seagrass,4210943,water
+98,tall_seagrass,4210943,water
+99,piston,7368816,
+100,piston_head,7368816,
+101,white_wool,13092807,
+102,orange_wool,13092807,
+103,magenta_wool,13092807,
+104,light_blue_wool,13092807,
+105,yellow_wool,13092807,
+106,lime_wool,13092807,
+107,pink_wool,13092807,
+108,gray_wool,13092807,
+109,light_gray_wool,13092807,
+110,cyan_wool,13092807,
+111,purple_wool,13092807,
+112,blue_wool,13092807,
+113,brown_wool,13092807,
+114,green_wool,13092807,
+115,red_wool,13092807,
+116,black_wool,13092807,
+117,moving_piston,7368816,
+118,dandelion,31744,flower
+119,poppy,31744,flower
+120,blue_orchid,31744,flower
+121,allium,31744,flower
+122,azure_bluet,31744,flower
+123,red_tulip,31744,flower
+124,orange_tulip,31744,flower
+125,white_tulip,31744,flower
+126,pink_tulip,31744,flower
+127,oxeye_daisy,31744,flower
+128,cornflower,31744,flower
+129,wither_rose,31744,flower
+130,lily_of_the_valley,31744,flower
+131,brown_mushroom,31744,flower
+132,red_mushroom,31744,flower
+133,gold_block,10987431,
+134,iron_block,10987431,
+135,bricks,7368816,
+136,tnt,16711680,
+137,bookshelf,9402184,
+138,mossy_cobblestone,7368816,
+139,obsidian,7368816,
+140,torch,0,
+141,wall_torch,0,
+142,fire,0,
+143,spawner,7368816,
+144,oak_stairs,9402184,
+145,chest,9402184,
+146,redstone_wire,0,
+147,diamond_ore,7368816,
+148,diamond_block,10987431,
+149,crafting_table,9402184,
+150,wheat,31744,
+151,farmland,9923917,
+152,furnace,7368816,
+153,oak_sign,9402184,
+154,spruce_sign,9402184,
+155,birch_sign,9402184,
+156,acacia_sign,9402184,
+157,jungle_sign,9402184,
+158,dark_oak_sign,9402184,
+159,oak_door,9402184,
+160,ladder,0,
+161,rail,0,
+162,cobblestone_stairs,7368816,
+163,oak_wall_sign,9402184,
+164,spruce_wall_sign,9402184,
+165,birch_wall_sign,9402184,
+166,acacia_wall_sign,9402184,
+167,jungle_wall_sign,9402184,
+168,dark_oak_wall_sign,9402184,
+169,lever,0,
+170,stone_pressure_plate,7368816,
+171,iron_door,10987431,
+172,oak_pressure_plate,9402184,
+173,spruce_pressure_plate,9402184,
+174,birch_pressure_plate,9402184,
+175,jungle_pressure_plate,9402184,
+176,acacia_pressure_plate,9402184,
+177,dark_oak_pressure_plate,9402184,
+178,redstone_ore,7368816,
+179,redstone_torch,0,
+180,redstone_wall_torch,0,
+181,stone_button,0,
+182,snow,16777215,snow
+183,ice,10526975,snow
+184,snow_block,16777215,snow
+185,cactus,31744,plant
+186,clay,10791096,
+187,sugar_cane,31744,plant
+188,jukebox,9402184,
+189,oak_fence,9402184,
+190,pumpkin,31744,
+191,netherrack,7368816,
+192,soul_sand,16247203,
+193,glowstone,0,
+194,nether_portal,0,
+195,carved_pumpkin,31744,
+196,jack_o_lantern,31744,
+197,cake,0,
+198,repeater,0,
+199,white_stained_glass,0,
+200,orange_stained_glass,0,
+201,magenta_stained_glass,0,
+202,light_blue_stained_glass,0,
+203,yellow_stained_glass,0,
+204,lime_stained_glass,0,
+205,pink_stained_glass,0,
+206,gray_stained_glass,0,
+207,light_gray_stained_glass,0,
+208,cyan_stained_glass,0,
+209,purple_stained_glass,0,
+210,blue_stained_glass,0,
+211,brown_stained_glass,0,
+212,green_stained_glass,0,
+213,red_stained_glass,0,
+214,black_stained_glass,0,
+215,oak_trapdoor,9402184,
+216,spruce_trapdoor,9402184,
+217,birch_trapdoor,9402184,
+218,jungle_trapdoor,9402184,
+219,acacia_trapdoor,9402184,
+220,dark_oak_trapdoor,9402184,
+221,stone_bricks,7368816,
+222,mossy_stone_bricks,7368816,
+223,cracked_stone_bricks,7368816,
+224,chiseled_stone_bricks,7368816,
+225,infested_stone,10791096,
+226,infested_cobblestone,10791096,
+227,infested_stone_bricks,10791096,
+228,infested_mossy_stone_bricks,10791096,
+229,infested_cracked_stone_bricks,10791096,
+230,infested_chiseled_stone_bricks,10791096,
+231,brown_mushroom_block,9402184,tree
+232,red_mushroom_block,9402184,tree
+233,mushroom_stem,9402184,tree
+234,iron_bars,10987431,
+235,glass_pane,0,
+236,melon,31744,
+237,attached_pumpkin_stem,31744,
+238,attached_melon_stem,31744,
+239,pumpkin_stem,31744,
+240,melon_stem,31744,
+241,vine,31744,plant
+242,oak_fence_gate,9402184,
+243,brick_stairs,7368816,
+244,stone_brick_stairs,7368816,
+245,mycelium,8368696,
+246,lily_pad,31744,grass
+247,nether_bricks,7368816,
+248,nether_brick_fence,7368816,
+249,nether_brick_stairs,7368816,
+250,nether_wart,31744,
+251,enchanting_table,7368816,
+252,brewing_stand,10987431,
+253,cauldron,10987431,
+254,end_portal,0,
+255,end_portal_frame,7368816,
+256,end_stone,7368816,
+257,dragon_egg,31744,
+258,redstone_lamp,0,
+259,cocoa,31744,
+260,sandstone_stairs,7368816,
+261,emerald_ore,7368816,
+262,ender_chest,7368816,
+263,tripwire_hook,0,
+264,tripwire,0,
+265,emerald_block,10987431,
+266,spruce_stairs,9402184,
+267,birch_stairs,9402184,
+268,jungle_stairs,9402184,
+269,command_block,10987431,
+270,beacon,0,
+271,cobblestone_wall,7368816,
+272,mossy_cobblestone_wall,7368816,
+273,flower_pot,0,
+274,potted_oak_sapling,0,
+275,potted_spruce_sapling,0,
+276,potted_birch_sapling,0,
+277,potted_jungle_sapling,0,
+278,potted_acacia_sapling,0,
+279,potted_dark_oak_sapling,0,
+280,potted_fern,0,
+281,potted_dandelion,0,
+282,potted_poppy,0,
+283,potted_blue_orchid,0,
+284,potted_allium,0,
+285,potted_azure_bluet,0,
+286,potted_red_tulip,0,
+287,potted_orange_tulip,0,
+288,potted_white_tulip,0,
+289,potted_pink_tulip,0,
+290,potted_oxeye_daisy,0,
+291,potted_cornflower,0,
+292,potted_lily_of_the_valley,0,
+293,potted_wither_rose,0,
+294,potted_red_mushroom,0,
+295,potted_brown_mushroom,0,
+296,potted_dead_bush,0,
+297,potted_cactus,0,
+298,carrots,31744,
+299,potatoes,31744,
+300,oak_button,0,
+301,spruce_button,0,
+302,birch_button,0,
+303,jungle_button,0,
+304,acacia_button,0,
+305,dark_oak_button,0,
+306,skeleton_skull,0,
+307,skeleton_wall_skull,0,
+308,wither_skeleton_skull,0,
+309,wither_skeleton_wall_skull,0,
+310,zombie_head,0,
+311,zombie_wall_head,0,
+312,player_head,0,
+313,player_wall_head,0,
+314,creeper_head,0,
+315,creeper_wall_head,0,
+316,dragon_head,0,
+317,dragon_wall_head,0,
+318,anvil,10987431,
+319,chipped_anvil,10987431,
+320,damaged_anvil,10987431,
+321,trapped_chest,9402184,
+322,light_weighted_pressure_plate,10987431,
+323,heavy_weighted_pressure_plate,10987431,
+324,comparator,0,
+325,daylight_detector,9402184,
+326,redstone_block,10987431,
+327,nether_quartz_ore,7368816,
+328,hopper,10987431,
+329,quartz_block,7368816,
+330,chiseled_quartz_block,7368816,
+331,quartz_pillar,7368816,
+332,quartz_stairs,7368816,
+333,activator_rail,0,
+334,dropper,7368816,
+335,white_terracotta,7368816,
+336,orange_terracotta,7368816,
+337,magenta_terracotta,7368816,
+338,light_blue_terracotta,7368816,
+339,yellow_terracotta,7368816,
+340,lime_terracotta,7368816,
+341,pink_terracotta,7368816,
+342,gray_terracotta,7368816,
+343,light_gray_terracotta,7368816,
+344,cyan_terracotta,7368816,
+345,purple_terracotta,7368816,
+346,blue_terracotta,7368816,
+347,brown_terracotta,7368816,
+348,green_terracotta,7368816,
+349,red_terracotta,7368816,
+350,black_terracotta,7368816,
+351,white_stained_glass_pane,0,
+352,orange_stained_glass_pane,0,
+353,magenta_stained_glass_pane,0,
+354,light_blue_stained_glass_pane,0,
+355,yellow_stained_glass_pane,0,
+356,lime_stained_glass_pane,0,
+357,pink_stained_glass_pane,0,
+358,gray_stained_glass_pane,0,
+359,light_gray_stained_glass_pane,0,
+360,cyan_stained_glass_pane,0,
+361,purple_stained_glass_pane,0,
+362,blue_stained_glass_pane,0,
+363,brown_stained_glass_pane,0,
+364,green_stained_glass_pane,0,
+365,red_stained_glass_pane,0,
+366,black_stained_glass_pane,0,
+367,acacia_stairs,9402184,
+368,dark_oak_stairs,9402184,
+369,slime_block,10791096,
+370,barrier,0,
+371,iron_trapdoor,10987431,
+372,prismarine,7368816,
+373,prismarine_bricks,7368816,
+374,dark_prismarine,7368816,
+375,prismarine_stairs,7368816,
+376,prismarine_brick_stairs,7368816,
+377,dark_prismarine_stairs,7368816,
+378,prismarine_slab,7368816,
+379,prismarine_brick_slab,7368816,
+380,dark_prismarine_slab,7368816,
+381,sea_lantern,0,
+382,hay_block,8368696,
+383,white_carpet,13092807,
+384,orange_carpet,13092807,
+385,magenta_carpet,13092807,
+386,light_blue_carpet,13092807,
+387,yellow_carpet,13092807,
+388,lime_carpet,13092807,
+389,pink_carpet,13092807,
+390,gray_carpet,13092807,
+391,light_gray_carpet,13092807,
+392,cyan_carpet,13092807,
+393,purple_carpet,13092807,
+394,blue_carpet,13092807,
+395,brown_carpet,13092807,
+396,green_carpet,13092807,
+397,red_carpet,13092807,
+398,black_carpet,13092807,
+399,terracotta,7368816,
+400,coal_block,7368816,
+401,packed_ice,10526975,
+402,sunflower,31744,flower
+403,lilac,31744,flower
+404,rose_bush,31744,flower
+405,peony,31744,flower
+406,tall_grass,31744,plant
+407,large_fern,31744,plant
+408,white_banner,9402184,
+409,orange_banner,9402184,
+410,magenta_banner,9402184,
+411,light_blue_banner,9402184,
+412,yellow_banner,9402184,
+413,lime_banner,9402184,
+414,pink_banner,9402184,
+415,gray_banner,9402184,
+416,light_gray_banner,9402184,
+417,cyan_banner,9402184,
+418,purple_banner,9402184,
+419,blue_banner,9402184,
+420,brown_banner,9402184,
+421,green_banner,9402184,
+422,red_banner,9402184,
+423,black_banner,9402184,
+424,white_wall_banner,9402184,
+425,orange_wall_banner,9402184,
+426,magenta_wall_banner,9402184,
+427,light_blue_wall_banner,9402184,
+428,yellow_wall_banner,9402184,
+429,lime_wall_banner,9402184,
+430,pink_wall_banner,9402184,
+431,gray_wall_banner,9402184,
+432,light_gray_wall_banner,9402184,
+433,cyan_wall_banner,9402184,
+434,purple_wall_banner,9402184,
+435,blue_wall_banner,9402184,
+436,brown_wall_banner,9402184,
+437,green_wall_banner,9402184,
+438,red_wall_banner,9402184,
+439,black_wall_banner,9402184,
+440,red_sandstone,7368816,
+441,chiseled_red_sandstone,7368816,
+442,cut_red_sandstone,7368816,
+443,red_sandstone_stairs,7368816,
+444,oak_slab,9402184,
+445,spruce_slab,9402184,
+446,birch_slab,9402184,
+447,jungle_slab,9402184,
+448,acacia_slab,9402184,
+449,dark_oak_slab,9402184,
+450,stone_slab,7368816,
+451,smooth_stone_slab,7368816,
+452,sandstone_slab,7368816,
+453,cut_sandstone_slab,7368816,
+454,petrified_oak_slab,7368816,
+455,cobblestone_slab,7368816,
+456,brick_slab,7368816,
+457,stone_brick_slab,7368816,
+458,nether_brick_slab,7368816,
+459,quartz_slab,7368816,
+460,red_sandstone_slab,7368816,
+461,cut_red_sandstone_slab,7368816,
+462,purpur_slab,7368816,
+463,smooth_stone,7368816,
+464,smooth_sandstone,7368816,
+465,smooth_quartz,7368816,
+466,smooth_red_sandstone,7368816,
+467,spruce_fence_gate,9402184,
+468,birch_fence_gate,9402184,
+469,jungle_fence_gate,9402184,
+470,acacia_fence_gate,9402184,
+471,dark_oak_fence_gate,9402184,
+472,spruce_fence,9402184,
+473,birch_fence,9402184,
+474,jungle_fence,9402184,
+475,acacia_fence,9402184,
+476,dark_oak_fence,9402184,
+477,spruce_door,9402184,
+478,birch_door,9402184,
+479,jungle_door,9402184,
+480,acacia_door,9402184,
+481,dark_oak_door,9402184,
+482,end_rod,0,
+483,chorus_plant,31744,
+484,chorus_flower,31744,
+485,purpur_block,7368816,
+486,purpur_pillar,7368816,
+487,purpur_stairs,7368816,
+488,end_stone_bricks,7368816,
+489,beetroots,31744,
+490,grass_path,9923917,
+491,end_gateway,0,
+492,repeating_command_block,10987431,
+493,chain_command_block,10987431,
+494,frosted_ice,10526975,
+495,magma_block,7368816,
+496,nether_wart_block,8368696,
+497,red_nether_bricks,7368816,
+498,bone_block,7368816,
+499,structure_void,0,
+500,observer,7368816,
+501,shulker_box,8339378,
+502,white_shulker_box,8339378,
+503,orange_shulker_box,8339378,
+504,magenta_shulker_box,8339378,
+505,light_blue_shulker_box,8339378,
+506,yellow_shulker_box,8339378,
+507,lime_shulker_box,8339378,
+508,pink_shulker_box,8339378,
+509,gray_shulker_box,8339378,
+510,light_gray_shulker_box,8339378,
+511,cyan_shulker_box,8339378,
+512,purple_shulker_box,8339378,
+513,blue_shulker_box,8339378,
+514,brown_shulker_box,8339378,
+515,green_shulker_box,8339378,
+516,red_shulker_box,8339378,
+517,black_shulker_box,8339378,
+518,white_glazed_terracotta,7368816,
+519,orange_glazed_terracotta,7368816,
+520,magenta_glazed_terracotta,7368816,
+521,light_blue_glazed_terracotta,7368816,
+522,yellow_glazed_terracotta,7368816,
+523,lime_glazed_terracotta,7368816,
+524,pink_glazed_terracotta,7368816,
+525,gray_glazed_terracotta,7368816,
+526,light_gray_glazed_terracotta,7368816,
+527,cyan_glazed_terracotta,7368816,
+528,purple_glazed_terracotta,7368816,
+529,blue_glazed_terracotta,7368816,
+530,brown_glazed_terracotta,7368816,
+531,green_glazed_terracotta,7368816,
+532,red_glazed_terracotta,7368816,
+533,black_glazed_terracotta,7368816,
+534,white_concrete,7368816,
+535,orange_concrete,7368816,
+536,magenta_concrete,7368816,
+537,light_blue_concrete,7368816,
+538,yellow_concrete,7368816,
+539,lime_concrete,7368816,
+540,pink_concrete,7368816,
+541,gray_concrete,7368816,
+542,light_gray_concrete,7368816,
+543,cyan_concrete,7368816,
+544,purple_concrete,7368816,
+545,blue_concrete,7368816,
+546,brown_concrete,7368816,
+547,green_concrete,7368816,
+548,red_concrete,7368816,
+549,black_concrete,7368816,
+550,white_concrete_powder,16247203,
+551,orange_concrete_powder,16247203,
+552,magenta_concrete_powder,16247203,
+553,light_blue_concrete_powder,16247203,
+554,yellow_concrete_powder,16247203,
+555,lime_concrete_powder,16247203,
+556,pink_concrete_powder,16247203,
+557,gray_concrete_powder,16247203,
+558,light_gray_concrete_powder,16247203,
+559,cyan_concrete_powder,16247203,
+560,purple_concrete_powder,16247203,
+561,blue_concrete_powder,16247203,
+562,brown_concrete_powder,16247203,
+563,green_concrete_powder,16247203,
+564,red_concrete_powder,16247203,
+565,black_concrete_powder,16247203,
+566,kelp,4210943,
+567,kelp_plant,4210943,
+568,dried_kelp_block,8368696,
+569,turtle_egg,31744,
+570,dead_tube_coral_block,7368816,
+571,dead_brain_coral_block,7368816,
+572,dead_bubble_coral_block,7368816,
+573,dead_fire_coral_block,7368816,
+574,dead_horn_coral_block,7368816,
+575,tube_coral_block,7368816,
+576,brain_coral_block,7368816,
+577,bubble_coral_block,7368816,
+578,fire_coral_block,7368816,
+579,horn_coral_block,7368816,
+580,dead_tube_coral,7368816,
+581,dead_brain_coral,7368816,
+582,dead_bubble_coral,7368816,
+583,dead_fire_coral,7368816,
+584,dead_horn_coral,7368816,
+585,tube_coral,4210943,
+586,brain_coral,4210943,
+587,bubble_coral,4210943,
+588,fire_coral,4210943,
+589,horn_coral,4210943,
+590,dead_tube_coral_fan,7368816,
+591,dead_brain_coral_fan,7368816,
+592,dead_bubble_coral_fan,7368816,
+593,dead_fire_coral_fan,7368816,
+594,dead_horn_coral_fan,7368816,
+595,tube_coral_fan,4210943,
+596,brain_coral_fan,4210943,
+597,bubble_coral_fan,4210943,
+598,fire_coral_fan,4210943,
+599,horn_coral_fan,4210943,
+600,dead_tube_coral_wall_fan,7368816,
+601,dead_brain_coral_wall_fan,7368816,
+602,dead_bubble_coral_wall_fan,7368816,
+603,dead_fire_coral_wall_fan,7368816,
+604,dead_horn_coral_wall_fan,7368816,
+605,tube_coral_wall_fan,4210943,
+606,brain_coral_wall_fan,4210943,
+607,bubble_coral_wall_fan,4210943,
+608,fire_coral_wall_fan,4210943,
+609,horn_coral_wall_fan,4210943,
+610,sea_pickle,4210943,
+611,blue_ice,10526975,
+612,conduit,0,
+613,bamboo_sapling,9402184,plant
+614,bamboo,9402184,plant
+615,potted_bamboo,0,
+616,void_air,0,dirt
+617,cave_air,0,dirt
+618,bubble_column,4210943,
+619,polished_granite_stairs,7368816,
+620,smooth_red_sandstone_stairs,7368816,
+621,mossy_stone_brick_stairs,7368816,
+622,polished_diorite_stairs,7368816,
+623,mossy_cobblestone_stairs,7368816,
+624,end_stone_brick_stairs,7368816,
+625,stone_stairs,7368816,
+626,smooth_sandstone_stairs,7368816,
+627,smooth_quartz_stairs,7368816,
+628,granite_stairs,7368816,
+629,andesite_stairs,7368816,
+630,red_nether_brick_stairs,7368816,
+631,polished_andesite_stairs,7368816,
+632,diorite_stairs,7368816,
+633,polished_granite_slab,7368816,
+634,smooth_red_sandstone_slab,7368816,
+635,mossy_stone_brick_slab,7368816,
+636,polished_diorite_slab,7368816,
+637,mossy_cobblestone_slab,7368816,
+638,end_stone_brick_slab,7368816,
+639,smooth_sandstone_slab,7368816,
+640,smooth_quartz_slab,7368816,
+641,granite_slab,7368816,
+642,andesite_slab,7368816,
+643,red_nether_brick_slab,7368816,
+644,polished_andesite_slab,7368816,
+645,diorite_slab,7368816,
+646,brick_wall,7368816,
+647,prismarine_wall,7368816,
+648,red_sandstone_wall,7368816,
+649,mossy_stone_brick_wall,7368816,
+650,granite_wall,7368816,
+651,stone_brick_wall,7368816,
+652,nether_brick_wall,7368816,
+653,andesite_wall,7368816,
+654,red_nether_brick_wall,7368816,
+655,sandstone_wall,7368816,
+656,end_stone_brick_wall,7368816,
+657,diorite_wall,7368816,
+658,scaffolding,0,
+659,loom,9402184,
+660,barrel,9402184,
+661,smoker,7368816,
+662,blast_furnace,7368816,
+663,cartography_table,9402184,
+664,fletching_table,9402184,
+665,grindstone,10987431,
+666,lectern,9402184,
+667,smithing_table,9402184,
+668,stonecutter,7368816,
+669,bell,10987431,
+670,lantern,10987431,
+671,campfire,9402184,
+672,sweet_berry_bush,31744,
+673,structure_block,10987431,
+674,jigsaw,10987431,
+675,composter,9402184,
+676,bee_nest,9402184,
+677,beehive,9402184,
+678,honey_block,10791096,
+679,honeycomb_block,10791096,
diff --git a/imaginaire/model_utils/gancraft/layers.py b/imaginaire/model_utils/gancraft/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d49900b0187c575d37c59a4fa6f62fb06413ef1d
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/layers.py
@@ -0,0 +1,153 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+class AffineMod(nn.Module):
+ r"""Learning affine modulation of activation.
+
+ Args:
+ in_features (int): Number of input features.
+ style_features (int): Number of style features.
+ mod_bias (bool): Whether to modulate bias.
+ """
+
+ def __init__(self,
+ in_features,
+ style_features,
+ mod_bias=True
+ ):
+ super().__init__()
+ self.weight_alpha = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features))
+ self.bias_alpha = nn.Parameter(torch.full([in_features], 1, dtype=torch.float)) # init to 1
+ self.weight_beta = None
+ self.bias_beta = None
+ self.mod_bias = mod_bias
+ if mod_bias:
+ self.weight_beta = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features))
+ self.bias_beta = nn.Parameter(torch.full([in_features], 0, dtype=torch.float))
+
+ @staticmethod
+ def _linear_f(x, w, b):
+ w = w.to(x.dtype)
+ x_shape = x.shape
+ x = x.reshape(-1, x_shape[-1])
+ if b is not None:
+ b = b.to(x.dtype)
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = x.reshape(*x_shape[:-1], -1)
+ return x
+
+ # x: B, ... , Cin
+ # z: B, 1, 1, , Cz
+ def forward(self, x, z):
+ x_shape = x.shape
+ z_shape = z.shape
+ x = x.reshape(x_shape[0], -1, x_shape[-1])
+ z = z.reshape(z_shape[0], 1, z_shape[-1])
+
+ alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I]
+ x = x * alpha
+
+ if self.mod_bias:
+ beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I]
+ x = x + beta
+
+ x = x.reshape(*x_shape[:-1], x.shape[-1])
+ return x
+
+
+class ModLinear(nn.Module):
+ r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod).
+ Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across
+ multiple inputs.
+ Args:
+ in_features (int): Number of input features.
+ out_features (int): Number of output features.
+ style_features (int): Number of style features.
+ bias (bool): Apply additive bias before the activation function?
+ mod_bias (bool): Whether to modulate bias.
+ output_mode (bool): If True, modulate output instead of input.
+ weight_gain (float): Initialization gain
+ """
+
+ def __init__(self,
+ in_features,
+ out_features,
+ style_features,
+ bias=True,
+ mod_bias=True,
+ output_mode=False,
+ weight_gain=1,
+ bias_init=0
+ ):
+ super().__init__()
+ weight_gain = weight_gain / np.sqrt(in_features)
+ self.weight = nn.Parameter(torch.randn([out_features, in_features]) * weight_gain)
+ self.bias = nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+ self.weight_alpha = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features))
+ self.bias_alpha = nn.Parameter(torch.full([in_features], 1, dtype=torch.float)) # init to 1
+ self.weight_beta = None
+ self.bias_beta = None
+ self.mod_bias = mod_bias
+ self.output_mode = output_mode
+ if mod_bias:
+ if output_mode:
+ mod_bias_dims = out_features
+ else:
+ mod_bias_dims = in_features
+ self.weight_beta = nn.Parameter(torch.randn([mod_bias_dims, style_features]) / np.sqrt(style_features))
+ self.bias_beta = nn.Parameter(torch.full([mod_bias_dims], 0, dtype=torch.float))
+
+ @staticmethod
+ def _linear_f(x, w, b):
+ w = w.to(x.dtype)
+ x_shape = x.shape
+ x = x.reshape(-1, x_shape[-1])
+ if b is not None:
+ b = b.to(x.dtype)
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = x.reshape(*x_shape[:-1], -1)
+ return x
+
+ # x: B, ... , Cin
+ # z: B, 1, 1, , Cz
+ def forward(self, x, z):
+ x_shape = x.shape
+ z_shape = z.shape
+ x = x.reshape(x_shape[0], -1, x_shape[-1])
+ z = z.reshape(z_shape[0], 1, z_shape[-1])
+
+ alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I]
+ w = self.weight.to(x.dtype) # [O I]
+ w = w.unsqueeze(0) * alpha # [1 O I] * [B 1 I] = [B O I]
+
+ if self.mod_bias:
+ beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I]
+ if not self.output_mode:
+ x = x + beta
+
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)[None, None, :]
+ if self.mod_bias and self.output_mode:
+ if b is None:
+ b = beta
+ else:
+ b = b + beta
+
+ # [B ? I] @ [B I O] = [B ? O]
+ if b is not None:
+ x = torch.baddbmm(b, x, w.transpose(1, 2))
+ else:
+ x = x.bmm(w.transpose(1, 2))
+ x = x.reshape(*x_shape[:-1], x.shape[-1])
+ return x
diff --git a/imaginaire/model_utils/gancraft/loss.py b/imaginaire/model_utils/gancraft/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1811de5307535167f645b4ea8a889a468b41780
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/loss.py
@@ -0,0 +1,96 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GANLoss(nn.Module):
+ def __init__(self, target_real_label=1.0, target_fake_label=0.0):
+ r"""GAN loss constructor.
+
+ Args:
+ target_real_label (float): Desired output label for the real images.
+ target_fake_label (float): Desired output label for the fake images.
+ """
+ super(GANLoss, self).__init__()
+ self.real_label = target_real_label
+ self.fake_label = target_fake_label
+ self.real_label_tensor = None
+ self.fake_label_tensor = None
+
+ def forward(self, input_x, t_real, weight=None,
+ reduce_dim=True, dis_update=True):
+ r"""GAN loss computation.
+
+ Args:
+ input_x (tensor or list of tensors): Output values.
+ t_real (boolean): Is this output value for real images.
+ reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use
+ multi-resolution discriminators.
+ weight (float): Weight to scale the loss value.
+ dis_update (boolean): Updating the discriminator or the generator.
+ Returns:
+ loss (tensor): Loss value.
+ """
+ if isinstance(input_x, list):
+ loss = 0
+ for pred_i in input_x:
+ if isinstance(pred_i, list):
+ pred_i = pred_i[-1]
+ loss_tensor = self.loss(pred_i, t_real, weight,
+ reduce_dim, dis_update)
+ bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
+ new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
+ loss += new_loss
+ return loss / len(input_x)
+ else:
+ return self.loss(input_x, t_real, weight, reduce_dim, dis_update)
+
+ def loss(self, input_x, t_real, weight=None,
+ reduce_dim=True, dis_update=True):
+ r"""N+1 label GAN loss computation.
+
+ Args:
+ input_x (tensor): Output values.
+ t_real (boolean): Is this output value for real images.
+ reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use
+ multi-resolution discriminators.
+ weight (float): Weight to scale the loss value.
+ dis_update (boolean): Updating the discriminator or the generator.
+ Returns:
+ loss (tensor): Loss value.
+ """
+ assert reduce_dim is True
+ pred = input_x['pred'].clone()
+ label = input_x['label'].clone()
+ batch_size = pred.size(0)
+
+ # ignore label 0
+ label[:, 0, ...] = 0
+ pred[:, 0, ...] = 0
+ pred = F.log_softmax(pred, dim=1)
+ assert pred.size(1) == (label.size(1) + 1)
+ if dis_update:
+ if t_real:
+ pred_real = pred[:, :-1, :, :]
+ loss = - label * pred_real
+ loss = torch.sum(loss, dim=1, keepdim=True)
+ else:
+ pred_fake = pred[:, -1, None, :, :] # N plus 1
+ loss = - pred_fake
+ else:
+ assert t_real, "GAN loss must be aiming for real."
+ pred_real = pred[:, :-1, :, :]
+ loss = - label * pred_real
+ loss = torch.sum(loss, dim=1, keepdim=True)
+
+ if weight is not None:
+ loss = loss * weight
+ if reduce_dim:
+ loss = torch.mean(loss)
+ else:
+ loss = loss.view(batch_size, -1).mean(dim=1)
+ return loss
diff --git a/imaginaire/model_utils/gancraft/mc_lbl_reduction.py b/imaginaire/model_utils/gancraft/mc_lbl_reduction.py
new file mode 100644
index 0000000000000000000000000000000000000000..03fec1d3b600cfd31358cf480924da5232e0104a
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/mc_lbl_reduction.py
@@ -0,0 +1,83 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+import csv
+
+
+class ReducedLabelMapper:
+ def __init__(self):
+ this_path = os.path.dirname(os.path.abspath(__file__))
+ print('[ReducedLabelMapper] Loading from {}'.format(this_path))
+
+ # Load Minecraft LUT
+ mcid2rdlbl_lut = {}
+ mcid2mclbl_lut = {}
+ with open(os.path.join(this_path, 'mc_reduction.csv'), newline='') as csvfile:
+ csvreader = csv.reader(csvfile, delimiter=',')
+ for row in csvreader:
+ mcid = int(row[0])
+ mcid2rdlbl_lut[mcid] = row[3]
+ mcid2mclbl_lut[mcid] = row[1]
+
+ # Load reduced label set
+ reduced_lbls = []
+ rdlbl2rdid = {}
+ with open(os.path.join(this_path, 'reduced_coco_lbls.csv'), newline='') as csvfile:
+ csvreader = csv.reader(csvfile, delimiter=',')
+ for idx, row in enumerate(csvreader):
+ rdlbl2rdid[row[0]] = idx
+ reduced_lbls.append(row[0])
+ print(['{}: {}'.format(rdid, rdlbl) for rdid, rdlbl in enumerate(reduced_lbls)])
+ # The first label should always be 'ignore'
+ assert reduced_lbls[0] == 'ignore'
+
+ # Generate Minecraft ID to Reduced ID LUT
+ mcid2rdid_lut = []
+ for mcid in range(len(mcid2rdlbl_lut)):
+ rdlbl = mcid2rdlbl_lut[mcid]
+ if rdlbl == '':
+ rdlbl = 'ignore'
+ rdid = rdlbl2rdid[rdlbl]
+ mcid2rdid_lut.append(rdid)
+
+ # ================= coco part ==================
+ gg_label_list = []
+ gglbl2ggid = {}
+ with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile:
+ csvreader = csv.reader(csvfile, delimiter=',')
+ for idx, row in enumerate(csvreader):
+ gg_label_list.append(row[0])
+ gglbl2ggid[row[0]] = idx
+
+ # Load coco -> reduced mapping table
+ gglbl2rdid = {}
+ with open(os.path.join(this_path, 'gaugan_reduction.csv'), newline='') as csvfile:
+ csvreader = csv.reader(csvfile, delimiter=',')
+ for idx, row in enumerate(csvreader):
+ gglbl = row[0]
+ target_rdlbl = row[1]
+ ggid = gglbl2ggid[gglbl]
+ target_rdid = rdlbl2rdid[target_rdlbl]
+ gglbl2rdid[ggid] = target_rdid
+ ggid2rdid = [gglbl2rdid[i] for i in range(len(gglbl2rdid))]
+
+ print('[ReducedLabelMapper] #Reduced Labels: {}'.format(len(reduced_lbls)))
+
+ self.mcid2rdid_lut = mcid2rdid_lut
+ self.ggid2rdid = ggid2rdid
+ self.reduced_lbls = reduced_lbls
+
+ self.ignore_id = rdlbl2rdid['ignore']
+ self.dirt_id = rdlbl2rdid['dirt']
+ self.water_id = rdlbl2rdid['water']
+
+ self.gglbl2ggid = gglbl2ggid
+
+ def gglbl2ggid(self, gglbl):
+ return self.gglbl2ggid[gglbl]
+
+
+if __name__ == '__main__':
+ mapper = ReducedLabelMapper()
diff --git a/imaginaire/model_utils/gancraft/mc_reduction.csv b/imaginaire/model_utils/gancraft/mc_reduction.csv
new file mode 100644
index 0000000000000000000000000000000000000000..254af7255d67be76b5d41cdbd162173e55ec0b9c
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/mc_reduction.csv
@@ -0,0 +1,680 @@
+0,air,0,sky
+1,stone,7368816,stone
+2,granite,7368816,rock
+3,polished_granite,7368816,rock
+4,diorite,7368816,rock
+5,polished_diorite,7368816,rock
+6,andesite,7368816,rock
+7,polished_andesite,7368816,rock
+8,grass_block,8368696,grass
+9,dirt,9923917,dirt
+10,coarse_dirt,9923917,dirt
+11,podzol,9923917,dirt
+12,cobblestone,7368816,stone
+13,oak_planks,9402184,
+14,spruce_planks,9402184,
+15,birch_planks,9402184,
+16,jungle_planks,9402184,
+17,acacia_planks,9402184,
+18,dark_oak_planks,9402184,
+19,oak_sapling,31744,grass
+20,spruce_sapling,31744,grass
+21,birch_sapling,31744,grass
+22,jungle_sapling,31744,grass
+23,acacia_sapling,31744,grass
+24,dark_oak_sapling,31744,grass
+25,bedrock,7368816,rock
+26,water,4210943,water
+27,lava,16711680,
+28,sand,16247203,sand
+29,red_sand,16247203,sand
+30,gravel,16247203,gravel
+31,gold_ore,7368816,rock
+32,iron_ore,7368816,rock
+33,coal_ore,7368816,rock
+34,oak_log,9402184,tree
+35,spruce_log,9402184,tree
+36,birch_log,9402184,tree
+37,jungle_log,9402184,tree
+38,acacia_log,9402184,tree
+39,dark_oak_log,9402184,tree
+40,stripped_spruce_log,9402184,
+41,stripped_birch_log,9402184,
+42,stripped_jungle_log,9402184,
+43,stripped_acacia_log,9402184,
+44,stripped_dark_oak_log,9402184,
+45,stripped_oak_log,9402184,
+46,oak_wood,9402184,
+47,spruce_wood,9402184,
+48,birch_wood,9402184,
+49,jungle_wood,9402184,
+50,acacia_wood,9402184,
+51,dark_oak_wood,9402184,
+52,stripped_oak_wood,9402184,
+53,stripped_spruce_wood,9402184,
+54,stripped_birch_wood,9402184,
+55,stripped_jungle_wood,9402184,
+56,stripped_acacia_wood,9402184,
+57,stripped_dark_oak_wood,9402184,
+58,oak_leaves,31744,tree
+59,spruce_leaves,31744,tree
+60,birch_leaves,31744,tree
+61,jungle_leaves,31744,tree
+62,acacia_leaves,31744,tree
+63,dark_oak_leaves,31744,tree
+64,sponge,15066419,
+65,wet_sponge,15066419,
+66,glass,0,
+67,lapis_ore,7368816,
+68,lapis_block,10987431,
+69,dispenser,7368816,
+70,sandstone,7368816,sand
+71,chiseled_sandstone,7368816,sand
+72,cut_sandstone,7368816,sand
+73,note_block,9402184,
+74,white_bed,13092807,
+75,orange_bed,13092807,
+76,magenta_bed,13092807,
+77,light_blue_bed,13092807,
+78,yellow_bed,13092807,
+79,lime_bed,13092807,
+80,pink_bed,13092807,
+81,gray_bed,13092807,
+82,light_gray_bed,13092807,
+83,cyan_bed,13092807,
+84,purple_bed,13092807,
+85,blue_bed,13092807,
+86,brown_bed,13092807,
+87,green_bed,13092807,
+88,red_bed,13092807,
+89,black_bed,13092807,
+90,powered_rail,0,
+91,detector_rail,0,
+92,sticky_piston,7368816,
+93,cobweb,13092807,
+94,grass,31744,grass
+95,fern,31744,grass
+96,dead_bush,31744,grass
+97,seagrass,4210943,water
+98,tall_seagrass,4210943,water
+99,piston,7368816,
+100,piston_head,7368816,
+101,white_wool,13092807,
+102,orange_wool,13092807,
+103,magenta_wool,13092807,
+104,light_blue_wool,13092807,
+105,yellow_wool,13092807,
+106,lime_wool,13092807,
+107,pink_wool,13092807,
+108,gray_wool,13092807,
+109,light_gray_wool,13092807,
+110,cyan_wool,13092807,
+111,purple_wool,13092807,
+112,blue_wool,13092807,
+113,brown_wool,13092807,
+114,green_wool,13092807,
+115,red_wool,13092807,
+116,black_wool,13092807,
+117,moving_piston,7368816,
+118,dandelion,31744,flower
+119,poppy,31744,flower
+120,blue_orchid,31744,flower
+121,allium,31744,flower
+122,azure_bluet,31744,flower
+123,red_tulip,31744,flower
+124,orange_tulip,31744,flower
+125,white_tulip,31744,flower
+126,pink_tulip,31744,flower
+127,oxeye_daisy,31744,flower
+128,cornflower,31744,flower
+129,wither_rose,31744,flower
+130,lily_of_the_valley,31744,flower
+131,brown_mushroom,31744,flower
+132,red_mushroom,31744,flower
+133,gold_block,10987431,
+134,iron_block,10987431,
+135,bricks,7368816,
+136,tnt,16711680,
+137,bookshelf,9402184,
+138,mossy_cobblestone,7368816,
+139,obsidian,7368816,
+140,torch,0,
+141,wall_torch,0,
+142,fire,0,
+143,spawner,7368816,
+144,oak_stairs,9402184,
+145,chest,9402184,
+146,redstone_wire,0,
+147,diamond_ore,7368816,
+148,diamond_block,10987431,
+149,crafting_table,9402184,
+150,wheat,31744,
+151,farmland,9923917,
+152,furnace,7368816,
+153,oak_sign,9402184,
+154,spruce_sign,9402184,
+155,birch_sign,9402184,
+156,acacia_sign,9402184,
+157,jungle_sign,9402184,
+158,dark_oak_sign,9402184,
+159,oak_door,9402184,
+160,ladder,0,
+161,rail,0,
+162,cobblestone_stairs,7368816,
+163,oak_wall_sign,9402184,
+164,spruce_wall_sign,9402184,
+165,birch_wall_sign,9402184,
+166,acacia_wall_sign,9402184,
+167,jungle_wall_sign,9402184,
+168,dark_oak_wall_sign,9402184,
+169,lever,0,
+170,stone_pressure_plate,7368816,
+171,iron_door,10987431,
+172,oak_pressure_plate,9402184,
+173,spruce_pressure_plate,9402184,
+174,birch_pressure_plate,9402184,
+175,jungle_pressure_plate,9402184,
+176,acacia_pressure_plate,9402184,
+177,dark_oak_pressure_plate,9402184,
+178,redstone_ore,7368816,
+179,redstone_torch,0,
+180,redstone_wall_torch,0,
+181,stone_button,0,
+182,snow,16777215,snow
+183,ice,10526975,snow
+184,snow_block,16777215,snow
+185,cactus,31744,flower
+186,clay,10791096,dirt
+187,sugar_cane,31744,flower
+188,jukebox,9402184,
+189,oak_fence,9402184,
+190,pumpkin,31744,
+191,netherrack,7368816,
+192,soul_sand,16247203,
+193,glowstone,0,
+194,nether_portal,0,
+195,carved_pumpkin,31744,
+196,jack_o_lantern,31744,
+197,cake,0,
+198,repeater,0,
+199,white_stained_glass,0,
+200,orange_stained_glass,0,
+201,magenta_stained_glass,0,
+202,light_blue_stained_glass,0,
+203,yellow_stained_glass,0,
+204,lime_stained_glass,0,
+205,pink_stained_glass,0,
+206,gray_stained_glass,0,
+207,light_gray_stained_glass,0,
+208,cyan_stained_glass,0,
+209,purple_stained_glass,0,
+210,blue_stained_glass,0,
+211,brown_stained_glass,0,
+212,green_stained_glass,0,
+213,red_stained_glass,0,
+214,black_stained_glass,0,
+215,oak_trapdoor,9402184,
+216,spruce_trapdoor,9402184,
+217,birch_trapdoor,9402184,
+218,jungle_trapdoor,9402184,
+219,acacia_trapdoor,9402184,
+220,dark_oak_trapdoor,9402184,
+221,stone_bricks,7368816,
+222,mossy_stone_bricks,7368816,
+223,cracked_stone_bricks,7368816,
+224,chiseled_stone_bricks,7368816,
+225,infested_stone,10791096,
+226,infested_cobblestone,10791096,
+227,infested_stone_bricks,10791096,
+228,infested_mossy_stone_bricks,10791096,
+229,infested_cracked_stone_bricks,10791096,
+230,infested_chiseled_stone_bricks,10791096,
+231,brown_mushroom_block,9402184,tree
+232,red_mushroom_block,9402184,tree
+233,mushroom_stem,9402184,tree
+234,iron_bars,10987431,
+235,glass_pane,0,
+236,melon,31744,
+237,attached_pumpkin_stem,31744,
+238,attached_melon_stem,31744,
+239,pumpkin_stem,31744,
+240,melon_stem,31744,
+241,vine,31744,tree
+242,oak_fence_gate,9402184,
+243,brick_stairs,7368816,
+244,stone_brick_stairs,7368816,
+245,mycelium,8368696,
+246,lily_pad,31744,grass
+247,nether_bricks,7368816,
+248,nether_brick_fence,7368816,
+249,nether_brick_stairs,7368816,
+250,nether_wart,31744,
+251,enchanting_table,7368816,
+252,brewing_stand,10987431,
+253,cauldron,10987431,
+254,end_portal,0,
+255,end_portal_frame,7368816,
+256,end_stone,7368816,
+257,dragon_egg,31744,
+258,redstone_lamp,0,
+259,cocoa,31744,
+260,sandstone_stairs,7368816,
+261,emerald_ore,7368816,
+262,ender_chest,7368816,
+263,tripwire_hook,0,
+264,tripwire,0,
+265,emerald_block,10987431,
+266,spruce_stairs,9402184,
+267,birch_stairs,9402184,
+268,jungle_stairs,9402184,
+269,command_block,10987431,
+270,beacon,0,
+271,cobblestone_wall,7368816,
+272,mossy_cobblestone_wall,7368816,
+273,flower_pot,0,
+274,potted_oak_sapling,0,
+275,potted_spruce_sapling,0,
+276,potted_birch_sapling,0,
+277,potted_jungle_sapling,0,
+278,potted_acacia_sapling,0,
+279,potted_dark_oak_sapling,0,
+280,potted_fern,0,
+281,potted_dandelion,0,
+282,potted_poppy,0,
+283,potted_blue_orchid,0,
+284,potted_allium,0,
+285,potted_azure_bluet,0,
+286,potted_red_tulip,0,
+287,potted_orange_tulip,0,
+288,potted_white_tulip,0,
+289,potted_pink_tulip,0,
+290,potted_oxeye_daisy,0,
+291,potted_cornflower,0,
+292,potted_lily_of_the_valley,0,
+293,potted_wither_rose,0,
+294,potted_red_mushroom,0,
+295,potted_brown_mushroom,0,
+296,potted_dead_bush,0,
+297,potted_cactus,0,
+298,carrots,31744,
+299,potatoes,31744,
+300,oak_button,0,
+301,spruce_button,0,
+302,birch_button,0,
+303,jungle_button,0,
+304,acacia_button,0,
+305,dark_oak_button,0,
+306,skeleton_skull,0,
+307,skeleton_wall_skull,0,
+308,wither_skeleton_skull,0,
+309,wither_skeleton_wall_skull,0,
+310,zombie_head,0,
+311,zombie_wall_head,0,
+312,player_head,0,
+313,player_wall_head,0,
+314,creeper_head,0,
+315,creeper_wall_head,0,
+316,dragon_head,0,
+317,dragon_wall_head,0,
+318,anvil,10987431,
+319,chipped_anvil,10987431,
+320,damaged_anvil,10987431,
+321,trapped_chest,9402184,
+322,light_weighted_pressure_plate,10987431,
+323,heavy_weighted_pressure_plate,10987431,
+324,comparator,0,
+325,daylight_detector,9402184,
+326,redstone_block,10987431,
+327,nether_quartz_ore,7368816,
+328,hopper,10987431,
+329,quartz_block,7368816,
+330,chiseled_quartz_block,7368816,
+331,quartz_pillar,7368816,
+332,quartz_stairs,7368816,
+333,activator_rail,0,
+334,dropper,7368816,
+335,white_terracotta,7368816,
+336,orange_terracotta,7368816,
+337,magenta_terracotta,7368816,
+338,light_blue_terracotta,7368816,
+339,yellow_terracotta,7368816,
+340,lime_terracotta,7368816,
+341,pink_terracotta,7368816,
+342,gray_terracotta,7368816,
+343,light_gray_terracotta,7368816,
+344,cyan_terracotta,7368816,
+345,purple_terracotta,7368816,
+346,blue_terracotta,7368816,
+347,brown_terracotta,7368816,
+348,green_terracotta,7368816,
+349,red_terracotta,7368816,
+350,black_terracotta,7368816,
+351,white_stained_glass_pane,0,
+352,orange_stained_glass_pane,0,
+353,magenta_stained_glass_pane,0,
+354,light_blue_stained_glass_pane,0,
+355,yellow_stained_glass_pane,0,
+356,lime_stained_glass_pane,0,
+357,pink_stained_glass_pane,0,
+358,gray_stained_glass_pane,0,
+359,light_gray_stained_glass_pane,0,
+360,cyan_stained_glass_pane,0,
+361,purple_stained_glass_pane,0,
+362,blue_stained_glass_pane,0,
+363,brown_stained_glass_pane,0,
+364,green_stained_glass_pane,0,
+365,red_stained_glass_pane,0,
+366,black_stained_glass_pane,0,
+367,acacia_stairs,9402184,
+368,dark_oak_stairs,9402184,
+369,slime_block,10791096,
+370,barrier,0,
+371,iron_trapdoor,10987431,
+372,prismarine,7368816,
+373,prismarine_bricks,7368816,
+374,dark_prismarine,7368816,
+375,prismarine_stairs,7368816,
+376,prismarine_brick_stairs,7368816,
+377,dark_prismarine_stairs,7368816,
+378,prismarine_slab,7368816,
+379,prismarine_brick_slab,7368816,
+380,dark_prismarine_slab,7368816,
+381,sea_lantern,0,
+382,hay_block,8368696,
+383,white_carpet,13092807,
+384,orange_carpet,13092807,
+385,magenta_carpet,13092807,
+386,light_blue_carpet,13092807,
+387,yellow_carpet,13092807,
+388,lime_carpet,13092807,
+389,pink_carpet,13092807,
+390,gray_carpet,13092807,
+391,light_gray_carpet,13092807,
+392,cyan_carpet,13092807,
+393,purple_carpet,13092807,
+394,blue_carpet,13092807,
+395,brown_carpet,13092807,
+396,green_carpet,13092807,
+397,red_carpet,13092807,
+398,black_carpet,13092807,
+399,terracotta,7368816,
+400,coal_block,7368816,
+401,packed_ice,10526975,snow
+402,sunflower,31744,flower
+403,lilac,31744,flower
+404,rose_bush,31744,flower
+405,peony,31744,flower
+406,tall_grass,31744,flower
+407,large_fern,31744,flower
+408,white_banner,9402184,
+409,orange_banner,9402184,
+410,magenta_banner,9402184,
+411,light_blue_banner,9402184,
+412,yellow_banner,9402184,
+413,lime_banner,9402184,
+414,pink_banner,9402184,
+415,gray_banner,9402184,
+416,light_gray_banner,9402184,
+417,cyan_banner,9402184,
+418,purple_banner,9402184,
+419,blue_banner,9402184,
+420,brown_banner,9402184,
+421,green_banner,9402184,
+422,red_banner,9402184,
+423,black_banner,9402184,
+424,white_wall_banner,9402184,
+425,orange_wall_banner,9402184,
+426,magenta_wall_banner,9402184,
+427,light_blue_wall_banner,9402184,
+428,yellow_wall_banner,9402184,
+429,lime_wall_banner,9402184,
+430,pink_wall_banner,9402184,
+431,gray_wall_banner,9402184,
+432,light_gray_wall_banner,9402184,
+433,cyan_wall_banner,9402184,
+434,purple_wall_banner,9402184,
+435,blue_wall_banner,9402184,
+436,brown_wall_banner,9402184,
+437,green_wall_banner,9402184,
+438,red_wall_banner,9402184,
+439,black_wall_banner,9402184,
+440,red_sandstone,7368816,
+441,chiseled_red_sandstone,7368816,
+442,cut_red_sandstone,7368816,
+443,red_sandstone_stairs,7368816,
+444,oak_slab,9402184,
+445,spruce_slab,9402184,
+446,birch_slab,9402184,
+447,jungle_slab,9402184,
+448,acacia_slab,9402184,
+449,dark_oak_slab,9402184,
+450,stone_slab,7368816,
+451,smooth_stone_slab,7368816,
+452,sandstone_slab,7368816,
+453,cut_sandstone_slab,7368816,
+454,petrified_oak_slab,7368816,
+455,cobblestone_slab,7368816,
+456,brick_slab,7368816,
+457,stone_brick_slab,7368816,
+458,nether_brick_slab,7368816,
+459,quartz_slab,7368816,
+460,red_sandstone_slab,7368816,
+461,cut_red_sandstone_slab,7368816,
+462,purpur_slab,7368816,
+463,smooth_stone,7368816,
+464,smooth_sandstone,7368816,
+465,smooth_quartz,7368816,
+466,smooth_red_sandstone,7368816,
+467,spruce_fence_gate,9402184,
+468,birch_fence_gate,9402184,
+469,jungle_fence_gate,9402184,
+470,acacia_fence_gate,9402184,
+471,dark_oak_fence_gate,9402184,
+472,spruce_fence,9402184,
+473,birch_fence,9402184,
+474,jungle_fence,9402184,
+475,acacia_fence,9402184,
+476,dark_oak_fence,9402184,
+477,spruce_door,9402184,
+478,birch_door,9402184,
+479,jungle_door,9402184,
+480,acacia_door,9402184,
+481,dark_oak_door,9402184,
+482,end_rod,0,
+483,chorus_plant,31744,
+484,chorus_flower,31744,
+485,purpur_block,7368816,
+486,purpur_pillar,7368816,
+487,purpur_stairs,7368816,
+488,end_stone_bricks,7368816,
+489,beetroots,31744,
+490,grass_path,9923917,
+491,end_gateway,0,
+492,repeating_command_block,10987431,
+493,chain_command_block,10987431,
+494,frosted_ice,10526975,snow
+495,magma_block,7368816,
+496,nether_wart_block,8368696,
+497,red_nether_bricks,7368816,
+498,bone_block,7368816,
+499,structure_void,0,
+500,observer,7368816,
+501,shulker_box,8339378,
+502,white_shulker_box,8339378,
+503,orange_shulker_box,8339378,
+504,magenta_shulker_box,8339378,
+505,light_blue_shulker_box,8339378,
+506,yellow_shulker_box,8339378,
+507,lime_shulker_box,8339378,
+508,pink_shulker_box,8339378,
+509,gray_shulker_box,8339378,
+510,light_gray_shulker_box,8339378,
+511,cyan_shulker_box,8339378,
+512,purple_shulker_box,8339378,
+513,blue_shulker_box,8339378,
+514,brown_shulker_box,8339378,
+515,green_shulker_box,8339378,
+516,red_shulker_box,8339378,
+517,black_shulker_box,8339378,
+518,white_glazed_terracotta,7368816,
+519,orange_glazed_terracotta,7368816,
+520,magenta_glazed_terracotta,7368816,
+521,light_blue_glazed_terracotta,7368816,
+522,yellow_glazed_terracotta,7368816,
+523,lime_glazed_terracotta,7368816,
+524,pink_glazed_terracotta,7368816,
+525,gray_glazed_terracotta,7368816,
+526,light_gray_glazed_terracotta,7368816,
+527,cyan_glazed_terracotta,7368816,
+528,purple_glazed_terracotta,7368816,
+529,blue_glazed_terracotta,7368816,
+530,brown_glazed_terracotta,7368816,
+531,green_glazed_terracotta,7368816,
+532,red_glazed_terracotta,7368816,
+533,black_glazed_terracotta,7368816,
+534,white_concrete,7368816,
+535,orange_concrete,7368816,
+536,magenta_concrete,7368816,
+537,light_blue_concrete,7368816,
+538,yellow_concrete,7368816,
+539,lime_concrete,7368816,
+540,pink_concrete,7368816,
+541,gray_concrete,7368816,
+542,light_gray_concrete,7368816,
+543,cyan_concrete,7368816,
+544,purple_concrete,7368816,
+545,blue_concrete,7368816,
+546,brown_concrete,7368816,
+547,green_concrete,7368816,
+548,red_concrete,7368816,
+549,black_concrete,7368816,
+550,white_concrete_powder,16247203,
+551,orange_concrete_powder,16247203,
+552,magenta_concrete_powder,16247203,
+553,light_blue_concrete_powder,16247203,
+554,yellow_concrete_powder,16247203,
+555,lime_concrete_powder,16247203,
+556,pink_concrete_powder,16247203,
+557,gray_concrete_powder,16247203,
+558,light_gray_concrete_powder,16247203,
+559,cyan_concrete_powder,16247203,
+560,purple_concrete_powder,16247203,
+561,blue_concrete_powder,16247203,
+562,brown_concrete_powder,16247203,
+563,green_concrete_powder,16247203,
+564,red_concrete_powder,16247203,
+565,black_concrete_powder,16247203,
+566,kelp,4210943,
+567,kelp_plant,4210943,
+568,dried_kelp_block,8368696,
+569,turtle_egg,31744,
+570,dead_tube_coral_block,7368816,
+571,dead_brain_coral_block,7368816,
+572,dead_bubble_coral_block,7368816,
+573,dead_fire_coral_block,7368816,
+574,dead_horn_coral_block,7368816,
+575,tube_coral_block,7368816,
+576,brain_coral_block,7368816,
+577,bubble_coral_block,7368816,
+578,fire_coral_block,7368816,
+579,horn_coral_block,7368816,
+580,dead_tube_coral,7368816,
+581,dead_brain_coral,7368816,
+582,dead_bubble_coral,7368816,
+583,dead_fire_coral,7368816,
+584,dead_horn_coral,7368816,
+585,tube_coral,4210943,
+586,brain_coral,4210943,
+587,bubble_coral,4210943,
+588,fire_coral,4210943,
+589,horn_coral,4210943,
+590,dead_tube_coral_fan,7368816,
+591,dead_brain_coral_fan,7368816,
+592,dead_bubble_coral_fan,7368816,
+593,dead_fire_coral_fan,7368816,
+594,dead_horn_coral_fan,7368816,
+595,tube_coral_fan,4210943,
+596,brain_coral_fan,4210943,
+597,bubble_coral_fan,4210943,
+598,fire_coral_fan,4210943,
+599,horn_coral_fan,4210943,
+600,dead_tube_coral_wall_fan,7368816,
+601,dead_brain_coral_wall_fan,7368816,
+602,dead_bubble_coral_wall_fan,7368816,
+603,dead_fire_coral_wall_fan,7368816,
+604,dead_horn_coral_wall_fan,7368816,
+605,tube_coral_wall_fan,4210943,
+606,brain_coral_wall_fan,4210943,
+607,bubble_coral_wall_fan,4210943,
+608,fire_coral_wall_fan,4210943,
+609,horn_coral_wall_fan,4210943,
+610,sea_pickle,4210943,
+611,blue_ice,10526975,snow
+612,conduit,0,
+613,bamboo_sapling,9402184,flower
+614,bamboo,9402184,tree
+615,potted_bamboo,0,
+616,void_air,0,dirt
+617,cave_air,0,dirt
+618,bubble_column,4210943,
+619,polished_granite_stairs,7368816,
+620,smooth_red_sandstone_stairs,7368816,
+621,mossy_stone_brick_stairs,7368816,
+622,polished_diorite_stairs,7368816,
+623,mossy_cobblestone_stairs,7368816,
+624,end_stone_brick_stairs,7368816,
+625,stone_stairs,7368816,
+626,smooth_sandstone_stairs,7368816,
+627,smooth_quartz_stairs,7368816,
+628,granite_stairs,7368816,
+629,andesite_stairs,7368816,
+630,red_nether_brick_stairs,7368816,
+631,polished_andesite_stairs,7368816,
+632,diorite_stairs,7368816,
+633,polished_granite_slab,7368816,
+634,smooth_red_sandstone_slab,7368816,
+635,mossy_stone_brick_slab,7368816,
+636,polished_diorite_slab,7368816,
+637,mossy_cobblestone_slab,7368816,
+638,end_stone_brick_slab,7368816,
+639,smooth_sandstone_slab,7368816,
+640,smooth_quartz_slab,7368816,
+641,granite_slab,7368816,
+642,andesite_slab,7368816,
+643,red_nether_brick_slab,7368816,
+644,polished_andesite_slab,7368816,
+645,diorite_slab,7368816,
+646,brick_wall,7368816,
+647,prismarine_wall,7368816,
+648,red_sandstone_wall,7368816,
+649,mossy_stone_brick_wall,7368816,
+650,granite_wall,7368816,
+651,stone_brick_wall,7368816,
+652,nether_brick_wall,7368816,
+653,andesite_wall,7368816,
+654,red_nether_brick_wall,7368816,
+655,sandstone_wall,7368816,
+656,end_stone_brick_wall,7368816,
+657,diorite_wall,7368816,
+658,scaffolding,0,
+659,loom,9402184,
+660,barrel,9402184,
+661,smoker,7368816,
+662,blast_furnace,7368816,
+663,cartography_table,9402184,
+664,fletching_table,9402184,
+665,grindstone,10987431,
+666,lectern,9402184,
+667,smithing_table,9402184,
+668,stonecutter,7368816,
+669,bell,10987431,
+670,lantern,10987431,
+671,campfire,9402184,
+672,sweet_berry_bush,31744,
+673,structure_block,10987431,
+674,jigsaw,10987431,
+675,composter,9402184,
+676,bee_nest,9402184,
+677,beehive,9402184,
+678,honey_block,10791096,
+679,honeycomb_block,10791096,
diff --git a/imaginaire/model_utils/gancraft/mc_utils.py b/imaginaire/model_utils/gancraft/mc_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..faa87e04060866541761586ffe8e41bcce9ca44c
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/mc_utils.py
@@ -0,0 +1,388 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import csv
+import time
+# For binary dilation
+from scipy import ndimage
+import os
+from imaginaire.model_utils.gancraft.mc_lbl_reduction import ReducedLabelMapper
+
+
+def load_voxel_new(voxel_path, shape=[256, 512, 512]):
+ voxel_world = np.fromfile(voxel_path, dtype='int32')
+ voxel_world = voxel_world.reshape(
+ shape[1]//16, shape[2]//16, 16, 16, shape[0])
+ voxel_world = voxel_world.transpose(4, 0, 2, 1, 3)
+ voxel_world = voxel_world.reshape(shape[0], shape[1], shape[2])
+ voxel_world = np.ascontiguousarray(voxel_world)
+ voxel_world = torch.from_numpy(voxel_world.astype(np.int32))
+ return voxel_world
+
+
+def gen_corner_voxel(voxel):
+ r"""Converting voxel center array to voxel corner array. The size of the
+ produced array grows by 1 on every dimension.
+
+ Args:
+ voxel (torch.IntTensor, CPU): Input voxel of three dimensions
+ """
+ structure = np.zeros([3, 3, 3], dtype=np.bool)
+ structure[1:, 1:, 1:] = True
+ voxel_p = F.pad(voxel, (0, 1, 0, 1, 0, 1))
+ corners = ndimage.binary_dilation(voxel_p.numpy(), structure)
+ corners = torch.tensor(corners, dtype=torch.int32)
+ return corners
+
+
+def calc_height_map(voxel_t):
+ r"""Calculate height map given a voxel grid [Y, X, Z] as input.
+ The height is defined as the Y index of the surface (non-air) block
+
+ Args:
+ voxel (Y x X x Z torch.IntTensor, CPU): Input voxel of three dimensions
+ Output:
+ heightmap (X x Z torch.IntTensor)
+ """
+ start_time = time.time()
+ m, h = torch.max((torch.flip(voxel_t, [0]) != 0).int(), dim=0, keepdim=False)
+ heightmap = voxel_t.shape[0] - 1 - h
+ heightmap[m == 0] = 0 # Special case when the whole vertical column is empty
+
+ elapsed_time = time.time() - start_time
+ print("[GANcraft-utils] Heightmap time: {}".format(elapsed_time))
+ return heightmap
+
+
+def trans_vec_homo(m, v, is_vec=False):
+ r"""3-dimensional Homogeneous matrix and regular vector multiplication
+ Convert v to homogeneous vector, perform M-V multiplication, and convert back
+ Note that this function does not support autograd.
+
+ Args:
+ m (4 x 4 tensor): a homogeneous matrix
+ v (3 tensor): a 3-d vector
+ vec (bool): if true, v is direction. Otherwise v is point
+ """
+ if is_vec:
+ v = torch.tensor([v[0], v[1], v[2], 0], dtype=v.dtype)
+ else:
+ v = torch.tensor([v[0], v[1], v[2], 1], dtype=v.dtype)
+ v = torch.mv(m, v)
+ if not is_vec:
+ v = v / v[3]
+ v = v[:3]
+ return v
+
+
+def cumsum_exclusive(tensor, dim):
+ cumsum = torch.cumsum(tensor, dim)
+ cumsum = torch.roll(cumsum, 1, dim)
+ cumsum.index_fill_(dim, torch.tensor([0], dtype=torch.long, device=tensor.device), 0)
+ return cumsum
+
+
+def sample_depth_batched(depth2, nsamples, deterministic=False, use_box_boundaries=True, sample_depth=4):
+ r""" Make best effort to sample points within the same distance for every ray.
+ Exception: When there is not enough voxel.
+
+ Args:
+ depth2 (N x 2 x 256 x 256 x 4 x 1 tensor):
+ - N: Batch.
+ - 2: Entrance / exit depth for each intersected box.
+ - 256, 256: Height, Width.
+ - 4: Number of intersected boxes along the ray.
+ - 1: One extra dim for consistent tensor dims.
+ depth2 can include NaNs.
+ deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling.
+ use_box_boundaries (bool): Whether to add the entrance / exit points into the sample.
+ sample_depth (float): Truncate the ray when it travels further than sample_depth inside voxels.
+ """
+
+ bs = depth2.size(0)
+ dim0 = depth2.size(2)
+ dim1 = depth2.size(3)
+ dists = depth2[:, 1] - depth2[:, 0]
+ dists[torch.isnan(dists)] = 0 # N, 256, 256, 4, 1
+ accu_depth = torch.cumsum(dists, dim=-2) # N, 256, 256, 4, 1
+ total_depth = accu_depth[..., [-1], :] # N, 256, 256, 1, 1
+
+ total_depth = torch.clamp(total_depth, None, sample_depth)
+
+ # Ignore out of range box boundaries. Fill with random samples.
+ if use_box_boundaries:
+ boundary_samples = accu_depth.clone().detach()
+ boundary_samples_filler = torch.rand_like(boundary_samples) * total_depth
+ bad_mask = (accu_depth > sample_depth) | (dists == 0)
+ boundary_samples[bad_mask] = boundary_samples_filler[bad_mask]
+
+ rand_shape = [bs, dim0, dim1, nsamples, 1]
+ # 256, 256, N, 1
+ if deterministic:
+ rand_samples = torch.empty(rand_shape, dtype=total_depth.dtype, device=total_depth.device)
+ rand_samples[..., :, 0] = torch.linspace(0, 1, nsamples+2)[1:-1]
+ else:
+ rand_samples = torch.rand(rand_shape, dtype=total_depth.dtype, device=total_depth.device) # 256, 256, N, 1
+ # Stratified sampling as in NeRF
+ rand_samples = rand_samples / nsamples
+ rand_samples[..., :, 0] += torch.linspace(0, 1, nsamples+1, device=rand_samples.device)[:-1]
+ rand_samples = rand_samples * total_depth # 256, 256, N, 1
+
+ # Can also include boundaries
+ if use_box_boundaries:
+ rand_samples = torch.cat([rand_samples, boundary_samples, torch.zeros(
+ [bs, dim0, dim1, 1, 1], dtype=total_depth.dtype, device=total_depth.device)], dim=-2)
+ rand_samples, _ = torch.sort(rand_samples, dim=-2, descending=False)
+
+ midpoints = (rand_samples[..., 1:, :] + rand_samples[..., :-1, :]) / 2
+ new_dists = rand_samples[..., 1:, :] - rand_samples[..., :-1, :]
+
+ # Scatter the random samples back
+ # 256, 256, 1, M, 1 > 256, 256, N, 1, 1
+ idx = torch.sum(midpoints.unsqueeze(-3) > accu_depth.unsqueeze(-2), dim=-3) # 256, 256, M, 1
+ # print(idx.shape, idx.max(), idx.min()) # max 3, min 0
+
+ depth_deltas = depth2[:, 0, :, :, 1:, :] - depth2[:, 1, :, :, :-1, :] # There might be NaNs!
+ depth_deltas = torch.cumsum(depth_deltas, dim=-2)
+ depth_deltas = torch.cat([depth2[:, 0, :, :, [0], :], depth_deltas+depth2[:, 0, :, :, [0], :]], dim=-2)
+ heads = torch.gather(depth_deltas, -2, idx) # 256 256 M 1
+ # heads = torch.gather(depth2[0], -2, idx) # 256 256 M 1
+
+ # print(torch.any(torch.isnan(heads)))
+ rand_depth = heads + midpoints # 256 256 N 1
+
+ return rand_depth, new_dists, idx
+
+
+def volum_rendering_relu(sigma, dists, dim=2):
+ free_energy = F.relu(sigma) * dists
+
+ a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here
+ b = torch.exp(-cumsum_exclusive(free_energy, dim=dim)) # probability of everything is empty up to now
+ probs = a * b # probability of the ray hits something here
+
+ return probs
+
+
+class McVoxel(nn.Module):
+ r"""Voxel management."""
+
+ def __init__(self, voxel_t, preproc_ver):
+ super(McVoxel, self).__init__()
+ # Filter voxel
+ voxel_t[voxel_t == 246] = 0 # lily_pad
+ voxel_t[voxel_t == 241] = 0 # vine
+ voxel_t[voxel_t == 611] = 26 # Blue ice -> water
+ voxel_t[voxel_t == 183] = 26 # ice -> water
+ voxel_t[voxel_t == 401] = 25 # Packed ice -> bedrock
+
+ if preproc_ver >= 3 and preproc_ver < 6:
+ voxel_t[voxel_t == 27] = 25 # Lava -> bedrock
+ voxel_t[voxel_t == 616] = 9 # void_air -> dirt
+ voxel_t[voxel_t == 617] = 25 # cave_air -> bedrock
+
+ if preproc_ver >= 6:
+ voxel_t[voxel_t == 616] = 0 # void_air -> air
+ voxel_t[voxel_t == 617] = 0 # cave_air -> air
+
+ # Simplify voxel
+ structure = ndimage.generate_binary_structure(3, 3)
+ mask = voxel_t.numpy() > 0
+ if preproc_ver == 4: # Hollow bottom
+ mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1)
+ voxel_t[mask] = 0
+ if preproc_ver >= 5: # Close cell before hollow bottom
+ mask = ndimage.morphology.binary_dilation(mask, iterations=1, border_value=1)
+ mask = ndimage.morphology.binary_erosion(mask, iterations=1, border_value=1)
+ mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1)
+ voxel_t[mask] = 0
+
+ self.register_buffer('voxel_t', voxel_t, persistent=False)
+
+ self.trans_mat = torch.eye(4) # Transform voxel to world
+ # Generate heightmap for camera trajectory generation
+ self.heightmap = calc_height_map(self.voxel_t)
+ self._truncate_voxel()
+ # Convert voxel ([X, Y, Z], int32) to corner ([X+1, Y+1, Z+1], int32) (Requires CPU tensor)
+ corner_t = gen_corner_voxel(self.voxel_t)
+ self.register_buffer('corner_t', corner_t, persistent=False)
+
+ # Generate 3D position to 1D feature LUT table
+ nfilledvox = torch.sum(self.corner_t > 0)
+ print('[GANcraft-utils] Number of filled voxels: {} / {}'.format(nfilledvox.item(), torch.numel(self.corner_t)))
+ # Zero means non-existent voxel.
+ self.corner_t[self.corner_t > 0] = torch.arange(start=1, end=nfilledvox+1, step=1, dtype=torch.int32)
+ self.nfilledvox = nfilledvox
+
+ def world2local(self, v, is_vec=False):
+ mat_world2local = torch.inverse(self.trans_mat)
+ return trans_vec_homo(mat_world2local, v, is_vec)
+
+ def _truncate_voxel(self):
+ gnd_level = self.heightmap.min()
+ sky_level = self.heightmap.max() + 1
+ self.voxel_t = self.voxel_t[gnd_level:sky_level, :, :]
+ self.trans_mat[0, 3] += gnd_level
+ print('[GANcraft-utils] Voxel truncated. Gnd: {}; Sky: {}.'.format(gnd_level.item(), sky_level.item()))
+
+ def is_sea(self, loc):
+ r"""loc: [2]: x, z."""
+ x = int(loc[1])
+ z = int(loc[2])
+ if x < 0 or x > self.heightmap.size(0) or z < 0 or z > self.heightmap.size(1):
+ print('[McVoxel] is_sea(): Index out of bound.')
+ return True
+ y = self.heightmap[x, z] - self.trans_mat[0, 3]
+ y = int(y)
+ if self.voxel_t[y, x, z] == 26:
+ print('[McVoxel] is_sea(): Get a sea.')
+ print(self.voxel_t[y, x, z], self.voxel_t[y+1, x, z])
+ return True
+ else:
+ return False
+
+
+class MCLabelTranslator:
+ r"""Resolving mapping across Minecraft voxel, coco-stuff label and GANcraft reduced label set."""
+
+ def __init__(self):
+ this_path = os.path.dirname(os.path.abspath(__file__))
+ # Load voxel name lut
+ id2name_lut = {}
+ id2color_lut = {}
+ id2glbl_lut = {}
+ with open(os.path.join(this_path, 'id2name_gg.csv'), newline='') as csvfile:
+ csvreader = csv.reader(csvfile, delimiter=',')
+ for row in csvreader:
+ id2name_lut[int(row[0])] = row[1]
+ id2color_lut[int(row[0])] = int(row[2])
+ id2glbl_lut[int(row[0])] = row[3]
+
+ # Load GauGAN color lut
+ glbl2color_lut = {}
+ glbl2cocoidx_lut = {}
+ with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile:
+ csvreader = csv.reader(csvfile, delimiter=',')
+ cocoidx = 1 # 0 is "Others"
+ for row in csvreader:
+ color = int(row[1].lstrip('#'), 16)
+ glbl2color_lut[row[0]] = color
+ glbl2cocoidx_lut[row[0]] = cocoidx
+ cocoidx += 1
+
+ # Generate id2ggcolor lut
+ id2ggcolor_lut = {}
+ for k, v in id2glbl_lut.items():
+ if v:
+ id2ggcolor_lut[k] = glbl2color_lut[v]
+ else:
+ id2ggcolor_lut[k] = 0
+
+ # Generate id2cocoidx
+ id2cocoidx_lut = {}
+ for k, v in id2glbl_lut.items():
+ if v:
+ id2cocoidx_lut[k] = glbl2cocoidx_lut[v]
+ else:
+ id2cocoidx_lut[k] = 0
+
+ self.id2color_lut = id2color_lut
+ self.id2name_lut = id2name_lut
+ self.id2glbl_lut = id2glbl_lut
+ self.id2ggcolor_lut = id2ggcolor_lut
+ self.id2cocoidx_lut = id2cocoidx_lut
+
+ if True:
+ mapper = ReducedLabelMapper()
+ mcid2rdid_lut = mapper.mcid2rdid_lut
+ mcid2rdid_lut = torch.tensor(mcid2rdid_lut, dtype=torch.long)
+ self.mcid2rdid_lut = mcid2rdid_lut
+ self.num_reduced_lbls = len(mapper.reduced_lbls)
+ self.ignore_id = mapper.ignore_id
+ self.dirt_id = mapper.dirt_id
+ self.water_id = mapper.water_id
+
+ self.mapper = mapper
+
+ ggid2rdid_lut = mapper.ggid2rdid + [0] # Last index is ignore
+ ggid2rdid_lut = torch.tensor(ggid2rdid_lut, dtype=torch.long)
+ self.ggid2rdid_lut = ggid2rdid_lut
+ if True:
+ mc2coco_lut = list(zip(*sorted([(k, v) for k, v in self.id2cocoidx_lut.items()])))[1]
+ mc2coco_lut = torch.tensor(mc2coco_lut, dtype=torch.long)
+ self.mc2coco_lut = mc2coco_lut
+
+ def gglbl2ggid(self, gglbl):
+ return self.mapper.gglbl2ggid[gglbl]
+
+ def mc2coco(self, mc):
+ self.mc2coco_lut = self.mc2coco_lut.to(mc.device)
+ coco = self.mc2coco_lut[mc.long()]
+ return coco
+
+ def mc2reduced(self, mc, ign2dirt=False):
+ self.mcid2rdid_lut = self.mcid2rdid_lut.to(mc.device)
+ reduced = self.mcid2rdid_lut[mc.long()]
+ if ign2dirt:
+ reduced[reduced == self.ignore_id] = self.dirt_id
+ return reduced
+
+ def coco2reduced(self, coco):
+ self.ggid2rdid_lut = self.ggid2rdid_lut.to(coco.device)
+ reduced = self.ggid2rdid_lut[coco.long()]
+ return reduced
+
+ def get_num_reduced_lbls(self):
+ return self.num_reduced_lbls
+
+ @staticmethod
+ def uint32_to_4uint8(x):
+ dt1 = np.dtype(('i4', [('bytes', 'u1', 4)]))
+ color = x.view(dtype=dt1)['bytes']
+ return color
+
+ def mc_color(self, img):
+ r"""Obtaining Minecraft default color.
+
+ Args:
+ img (H x W x 1 int32 numpy tensor): Segmentation map.
+ """
+ lut = self.id2color_lut
+ lut = list(zip(*sorted([(k, v) for k, v in lut.items()])))[1]
+ lut = np.array(lut, dtype=np.uint32)
+ rgb = lut[img]
+ rgb = self.uint32_to_4uint8(rgb)[..., :3]
+
+ return rgb
+
+
+def rand_crop(cam_c, cam_res, target_res):
+ r"""Produces a new cam_c so that the effect of rendering with the new cam_c and target_res is the same as rendering
+ with the old parameters and then crop out target_res.
+ """
+ d0 = np.random.randint(cam_res[0] - target_res[0] + 1)
+ d1 = np.random.randint(cam_res[1] - target_res[1] + 1)
+ cam_c = [cam_c[0]-d0, cam_c[1]-d1]
+ return cam_c
+
+
+def segmask_smooth(seg_mask, kernel_size=7):
+ labels = F.avg_pool2d(seg_mask, kernel_size, 1, kernel_size//2)
+ onehot_idx = torch.argmax(labels, dim=1, keepdims=True)
+ labels.fill_(0.0)
+ labels.scatter_(1, onehot_idx, 1.0)
+ return labels
+
+
+def colormap(x, cmap='viridis'):
+ x = np.nan_to_num(x, np.nan, np.nan, np.nan)
+ x = x - np.nanmin(x)
+ x = x / np.nanmax(x)
+ rgb = plt.get_cmap(cmap)(x)[..., :3]
+ return rgb
diff --git a/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv b/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv
new file mode 100644
index 0000000000000000000000000000000000000000..c82cc05572bbace78643911e9789f4f2cfd15f0e
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv
@@ -0,0 +1,12 @@
+ignore
+sky
+tree
+dirt
+flower
+grass
+gravel
+water
+rock
+stone
+sand
+snow
\ No newline at end of file
diff --git a/imaginaire/model_utils/gancraft/voxlib/Makefile b/imaginaire/model_utils/gancraft/voxlib/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..de903af09d2feda89118edc743048d70ce963d24
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/Makefile
@@ -0,0 +1,11 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+all:
+ python setup.py build_ext --inplace
+ python setup.py install
+
+clean:
+ rm -rf *.o *.a *.so test build
diff --git a/imaginaire/model_utils/gancraft/voxlib/__init__.py b/imaginaire/model_utils/gancraft/voxlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fce15c92b99ef0ad9feaeb31d75664a0971b385
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .positional_encoding import positional_encoding
+from .sp_trilinear import sparse_trilinear_interp_worldcoord
+from voxlib import ray_voxel_intersection_perspective
diff --git a/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py b/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef95d0bd47103233e1c4f32c70dd5b463c9eac1d
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py
@@ -0,0 +1,63 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch.autograd import Function
+import voxlib
+
+# Cheatsheet:
+# mark_dirty() must be used to mark any input that is modified inplace by the forward function.
+# mark_non_differentiable()
+
+
+class PositionalEncodingFunction(Function):
+ @staticmethod
+ def forward(ctx, in_feature, pe_degrees, dim, incl_orig):
+ out_feature = voxlib.positional_encoding(in_feature, pe_degrees, dim, incl_orig)
+
+ ctx.save_for_backward(out_feature)
+ ctx.pe_degrees = pe_degrees
+ ctx.dim = dim
+ ctx.incl_orig = incl_orig
+
+ return out_feature
+
+ @staticmethod
+ def backward(ctx, out_feature_grad):
+ out_feature, = ctx.saved_tensors
+
+ # torch::Tensor positional_encoding_backward(const torch::Tensor& out_feature_grad,
+ # const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig) {
+ in_feature_grad = voxlib.positional_encoding_backward(
+ out_feature_grad, out_feature, ctx.pe_degrees, ctx.dim, ctx.incl_orig)
+
+ return in_feature_grad, None, None, None
+
+
+def positional_encoding(in_feature, pe_degrees, dim=-1, incl_orig=False):
+ return PositionalEncodingFunction.apply(in_feature, pe_degrees, dim, incl_orig)
+
+# input: N, C
+# output: N, pe_degrees*C
+
+
+def positional_encoding_pt(pts, pe_degrees, dim=-1, incl_orig=False):
+ import numpy as np
+ pe_stor = []
+ for i in range(pe_degrees):
+ pe_stor.append(torch.sin(pts * np.pi * 2 ** i))
+ pe_stor.append(torch.cos(pts * np.pi * 2 ** i))
+ if incl_orig:
+ pe_stor.append(pts)
+ pe = torch.cat(pe_stor, dim=dim)
+ return pe
+
+
+if __name__ == '__main__':
+ x = torch.rand(384, 512, 5, 48).cuda() * 1024
+ y = positional_encoding_pt(x, 4, incl_orig=True)
+ y2 = positional_encoding(x, 4, incl_orig=True)
+
+ print(torch.abs(y - y2))
+ print(torch.allclose(y, y2, rtol=1e-05, atol=1e-05))
diff --git a/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu b/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..278fc165991d70aae549dd376b979f53703a0647
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu
@@ -0,0 +1,285 @@
+// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, check out LICENSE.md
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+
+#include
+#include
+#include
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+
+struct PE_Params {
+ int ndegrees;
+ int pre_size;
+ int post_size;
+ bool incl_orig;
+};
+
+// const int TILE_DIM_X = 16; // channel dim
+// const int TILE_DIM_Y = 64; // entry dim
+// dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+TILE_DIM_Y-1)/TILE_DIM_Y, 1);
+// dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1);
+template
+__global__ void positional_encoding_kernel(
+ float* __restrict__ out_feature,
+ const float* __restrict__ in_feature, const PE_Params p) {
+
+ const int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x;
+ const int idx_entry_base = blockIdx.y * TILE_DIM_Y * DUP_Y + threadIdx.y * DUP_Y;
+ if (idx_feat >= p.post_size) {
+ return;
+ }
+
+ int stride = p.ndegrees*2;
+ if (p.incl_orig) {
+ stride += 1;
+ }
+
+ for (int j=0; j= p.pre_size) {
+ return;
+ }
+ float data = in_feature[idx_entry*p.post_size + idx_feat];
+
+ for (int i=0; i
+__global__ void positional_encoding_backward_kernel(
+ float* __restrict__ in_feature_grad,
+ const float* __restrict__ out_feature_grad, const float* __restrict__ out_feature, const PE_Params p) {
+
+ int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x;
+ const int idx_entry_base = blockIdx.y * TILE_DIM_Y * DUP_Y + threadIdx.y * DUP_Y;
+
+ if (idx_feat >= p.post_size) {
+ return;
+ }
+
+ int stride = p.ndegrees*2;
+ if (p.incl_orig) {
+ stride += 1;
+ }
+
+ for (int j=0; j= p.pre_size) {
+ return;
+ }
+
+ float grad = 0.0f;
+ for (int i=0; i
+torch::Tensor positional_encoding_cuda(const torch::Tensor& in_feature, int ndegrees, int dim, bool incl_orig) {
+ CHECK_CUDA(in_feature);
+
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+ torch::Device device = in_feature.device();
+
+ assert(in_feature.dtype() == torch::kFloat32);
+
+ // Handle negative index
+ if (dim < 0) {
+ dim = in_feature.dim() + dim;
+ }
+ assert(dim >= 0 && dim < in_feature.dim());
+
+ // No need to be contiguous. Input and output has the same memory layout.
+ CHECK_CONTIGUOUS(in_feature);
+
+ PE_Params p;
+ p.ndegrees = ndegrees;
+ p.incl_orig = incl_orig;
+
+ // This only works for contiguous tensors...
+ int pre_size = 1;
+ int post_size = 1;
+ for (int i=0; i out_feature_shape;
+ for (int i=0; i Each thread handle a single post_size
+ // Case 2: Concat at the middle (post_size > pre_size) --> Each thread handle
+ const int TILE_DIM_X = 16; // channel dim
+ const int TILE_DIM_Y = 64; // entry dim
+ //const int DUP_Y = 4; // Each thread handle multiple entries to save threads
+ const int DUP_Y = 8; // DGXA 64 samples per ray @ 256x256
+ dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+(TILE_DIM_Y*DUP_Y)-1)/(TILE_DIM_Y*DUP_Y), 1);
+ dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1);
+ positional_encoding_kernel<<>>(
+ out_feature.data_ptr(),
+ in_feature.data_ptr(), p
+ );
+
+ THCudaCheck(cudaGetLastError());
+ return out_feature;
+}
+
+//in_feature_grad = voxrender_op.positional_encoding_backward(out_feature_grad, out_feature, ctx.pe_degrees, ctx.dim, ctx.incl_orig);
+// Input:
+// out_feature_grad: float32 [..., N*ndegree*2+incl_orig, ...]
+// out_feature: float32 [..., N*ndegree*2+incl_orig, ...]
+// ndegrees: int32 Degrees of PE encoding
+// dim: int32 Dimension to concatenate
+// incl_orig: bool Whether to include original feature vector or not
+// Output:
+// in_feature_grad: float32 [..., N, ...]
+// std::vector
+torch::Tensor positional_encoding_backward_cuda(const torch::Tensor& out_feature_grad_, const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig) {
+ CHECK_CUDA(out_feature_grad_);
+ CHECK_CUDA(out_feature);
+
+ const torch::Tensor out_feature_grad = out_feature_grad_.contiguous();
+
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+ torch::Device device = out_feature_grad.device();
+
+ assert(out_feature_grad.dtype() == torch::kFloat32);
+ assert(out_feature.dtype() == torch::kFloat32);
+ assert(out_feature_grad.sizes() == out_feature.sizes());
+
+ // Handle negative index
+ if (dim < 0) {
+ dim = out_feature.dim() + dim;
+ }
+ assert(dim >= 0 && dim < out_feature.dim());
+
+ CHECK_CONTIGUOUS(out_feature_grad);
+ CHECK_CONTIGUOUS(out_feature);
+
+ PE_Params p;
+ p.ndegrees = ndegrees;
+ p.incl_orig = incl_orig;
+
+ int expansion_factor = ndegrees*2;
+ if (incl_orig) {
+ expansion_factor += 1;
+ }
+ // This only works for contiguous tensors...
+ int pre_size = 1;
+ int post_size = 1;
+ for (int i=0; i out_feature_shape;
+ for (int i=0; i Each thread handle a single post_size
+ // Case 2: Concat at the middle (post_size > pre_size) --> Each thread handle
+ const int TILE_DIM_X = 16; // channel dim
+ const int TILE_DIM_Y = 64; // entry dim
+ //const int DUP_Y = 4; // Nothing to amortize
+ const int DUP_Y = 8; // DGXA
+ dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+(TILE_DIM_Y*DUP_Y)-1)/(TILE_DIM_Y*DUP_Y), 1);
+ dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1);
+ positional_encoding_backward_kernel<<>>(
+ in_feature_grad.data_ptr(),
+ out_feature_grad.data_ptr(), out_feature.data_ptr(), p
+ );
+
+ THCudaCheck(cudaGetLastError());
+
+ return in_feature_grad;
+}
diff --git a/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu b/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu
new file mode 100644
index 0000000000000000000000000000000000000000..7ef22dc309e2eb6d944c50d917235f0c62219cb6
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu
@@ -0,0 +1,325 @@
+// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, check out LICENSE.md
+//
+// The ray marching algorithm used in this file is a variety of modified Bresenham method:
+// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.42.3443&rep=rep1&type=pdf
+// Search for "voxel traversal algorithm" for related information
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+//#include
+#include
+#include
+#include
+
+#include "voxlib_common.h"
+
+struct RVIP_Params {
+ int voxel_dims[3];
+ int voxel_strides[3];
+ int max_samples;
+ int img_dims[2];
+ // Camera parameters
+ float cam_ori[3];
+ float cam_fwd[3];
+ float cam_side[3];
+ float cam_up[3];
+ float cam_c[2];
+ float cam_f;
+ //unsigned long seed;
+};
+
+/*
+ out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1]
+ out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1]
+ out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3]
+ Image coordinates refer to the center of the pixel
+ [0, 0, 0] at voxel coordinate is at the corner of the corner block (instead of at the center)
+*/
+template
+static __global__ void ray_voxel_intersection_perspective_kernel(int32_t* __restrict__ out_voxel_id, float* __restrict__ out_depth, float* __restrict__ out_raydirs,
+const int32_t* __restrict__ in_voxel, const RVIP_Params p) {
+
+ int img_coords[2];
+ img_coords[1] = blockIdx.x*TILE_DIM+threadIdx.x;
+ img_coords[0] = blockIdx.y*TILE_DIM+threadIdx.y;
+ if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) {
+ return;
+ }
+ int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1];
+
+ // Calculate ray origin and direction
+ float rayori[3], raydir[3];
+ rayori[0] = p.cam_ori[0];
+ rayori[1] = p.cam_ori[1];
+ rayori[2] = p.cam_ori[2];
+
+ // Camera intrinsics
+ float ndc_imcoords[2];
+ ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height
+ ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1];
+
+ raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] + p.cam_fwd[0] * p.cam_f;
+ raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] + p.cam_fwd[1] * p.cam_f;
+ raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] + p.cam_fwd[2] * p.cam_f;
+ normalize(raydir);
+
+ // Save out_raydirs
+ out_raydirs[pix_index*3] = raydir[0];
+ out_raydirs[pix_index*3+1] = raydir[1];
+ out_raydirs[pix_index*3+2] = raydir[2];
+
+ float axis_t[3];
+ int axis_int[3];
+ //int axis_intbound[3];
+
+ // Current voxel
+ axis_int[0] = floorf(rayori[0]);
+ axis_int[1] = floorf(rayori[1]);
+ axis_int[2] = floorf(rayori[2]);
+
+ #pragma unroll
+ for (int i=0; i<3; i++) {
+ if (raydir[i] > 0) {
+ // Initial t value
+ // Handle boundary case where rayori[i] is a whole number. Always round Up for the next block
+ //axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) / raydir[i];
+ axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i];
+ } else if (raydir[i] < 0) {
+ axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
+ } else {
+ axis_t[i] = HUGE_VALF;
+ }
+ }
+
+ // Fused raymarching and sampling
+ bool quit = false;
+ for (int cur_plane=0; cur_plane < p.max_samples; cur_plane++) { // Last cycle is for calculating p2
+ float t = nanf("0");
+ float t2 = nanf("0");
+ int32_t blk_id = 0;
+ // Find the next intersection
+ while (!quit) {
+ // Find the next smallest t
+ float tnow;
+ /*
+ #pragma unroll
+ for (int i=0; i<3; i++) {
+ if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
+ // Update current t
+ tnow = axis_t[i];
+ // Update t candidates
+ if (raydir[i] > 0) {
+ axis_int[i] += 1;
+ if (axis_int[i] >= p.voxel_dims[i]) {
+ quit = true;
+ }
+ axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i];
+ } else {
+ axis_int[i] -= 1;
+ if (axis_int[i] < 0) {
+ quit = true;
+ }
+ axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
+ }
+ break; // Avoid advancing multiple steps as axis_t is updated
+ }
+ }
+ */
+ // Hand unroll
+ if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
+ // Update current t
+ tnow = axis_t[0];
+ // Update t candidates
+ if (raydir[0] > 0) {
+ axis_int[0] += 1;
+ if (axis_int[0] >= p.voxel_dims[0]) {
+ quit = true;
+ }
+ axis_t[0] = ((float)(axis_int[0]+1) - rayori[0]) / raydir[0];
+ } else {
+ axis_int[0] -= 1;
+ if (axis_int[0] < 0) {
+ quit = true;
+ }
+ axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0];
+ }
+ } else if (axis_t[1] <= axis_t[2]) {
+ tnow = axis_t[1];
+ if (raydir[1] > 0) {
+ axis_int[1] += 1;
+ if (axis_int[1] >= p.voxel_dims[1]) {
+ quit = true;
+ }
+ axis_t[1] = ((float)(axis_int[1]+1) - rayori[1]) / raydir[1];
+ } else {
+ axis_int[1] -= 1;
+ if (axis_int[1] < 0) {
+ quit = true;
+ }
+ axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1];
+ }
+ } else {
+ tnow = axis_t[2];
+ if (raydir[2] > 0) {
+ axis_int[2] += 1;
+ if (axis_int[2] >= p.voxel_dims[2]) {
+ quit = true;
+ }
+ axis_t[2] = ((float)(axis_int[2]+1) - rayori[2]) / raydir[2];
+ } else {
+ axis_int[2] -= 1;
+ if (axis_int[2] < 0) {
+ quit = true;
+ }
+ axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2];
+ }
+ }
+
+ if (quit) {
+ break;
+ }
+
+ // Skip empty space
+ // Could there be deadlock if the ray direction is away from the world?
+ if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] || axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] || axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) {
+ continue;
+ }
+
+ // Test intersection using voxel grid
+ blk_id = in_voxel[axis_int[0]*p.voxel_strides[0] + axis_int[1]*p.voxel_strides[1] + axis_int[2]*p.voxel_strides[2]];
+ if (blk_id == 0) {
+ continue;
+ }
+
+ // Now that there is an intersection
+ t = tnow;
+ // Calculate t2
+ /*
+ #pragma unroll
+ for (int i=0; i<3; i++) {
+ if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
+ t2 = axis_t[i];
+ break;
+ }
+ }
+ */
+ // Hand unroll
+ if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
+ t2 = axis_t[0];
+ } else if (axis_t[1] <= axis_t[2]) {
+ t2 = axis_t[1];
+ } else {
+ t2 = axis_t[2];
+ }
+ break;
+ } // while !quit (ray marching loop)
+
+ out_depth[pix_index*p.max_samples+cur_plane] = t;
+ out_depth[p.img_dims[0]*p.img_dims[1]*p.max_samples + pix_index*p.max_samples+cur_plane] = t2;
+ out_voxel_id[pix_index*p.max_samples+cur_plane] = blk_id;
+ } // cur_plane
+}
+
+
+/*
+ out:
+ out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1]
+ out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1]
+ out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3]
+ in:
+ in_voxel: torch CUDA int32 [X, Y, Z] [40, 512, 512]
+ cam_ori: torch float [3]
+ cam_dir: torch float [3]
+ cam_up: torch float [3]
+ cam_f: float
+ cam_c: int [2]
+ img_dims: int [2]
+ max_samples: int
+*/
+std::vector ray_voxel_intersection_perspective_cuda(const torch::Tensor& in_voxel, const torch::Tensor& cam_ori, const torch::Tensor& cam_dir, const torch::Tensor& cam_up, float cam_f, const std::vector& cam_c, const std::vector& img_dims, int max_samples) {
+ CHECK_CUDA(in_voxel);
+
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+ torch::Device device = in_voxel.device();
+
+ //assert(in_voxel.dtype() == torch::kU8);
+ assert(in_voxel.dtype() == torch::kInt32); // Minecraft compatibility
+ assert(in_voxel.dim() == 3);
+ assert(cam_ori.dtype() == torch::kFloat32);
+ assert(cam_ori.numel() == 3);
+ assert(cam_dir.dtype() == torch::kFloat32);
+ assert(cam_dir.numel() == 3);
+ assert(cam_up.dtype() == torch::kFloat32);
+ assert(cam_up.numel() == 3);
+ assert(img_dims.size() == 2);
+
+ RVIP_Params p;
+
+ // Calculate camera rays
+ const torch::Tensor cam_ori_c = cam_ori.cpu();
+ const torch::Tensor cam_dir_c = cam_dir.cpu();
+ const torch::Tensor cam_up_c = cam_up.cpu();
+
+ // Get the coordinate frame of camera space in world space
+ normalize(p.cam_fwd, cam_dir_c.data_ptr());
+ cross(p.cam_side, p.cam_fwd, cam_up_c.data_ptr());
+ normalize(p.cam_side);
+ cross(p.cam_up, p.cam_side, p.cam_fwd);
+ normalize(p.cam_up); // Not absolutely necessary as both vectors are normalized. But just in case...
+
+ copyarr(p.cam_ori, cam_ori_c.data_ptr());
+
+ p.cam_f = cam_f;
+ p.cam_c[0] = cam_c[0];
+ p.cam_c[1] = cam_c[1];
+ p.max_samples = max_samples;
+ //printf("[Renderer] max_dist: %ld\n", max_dist);
+
+ p.voxel_dims[0] = in_voxel.size(0);
+ p.voxel_dims[1] = in_voxel.size(1);
+ p.voxel_dims[2] = in_voxel.size(2);
+ p.voxel_strides[0] = in_voxel.stride(0);
+ p.voxel_strides[1] = in_voxel.stride(1);
+ p.voxel_strides[2] = in_voxel.stride(2);
+
+ //printf("[Renderer] Voxel resolution: %ld, %ld, %ld\n", p.voxel_dims[0], p.voxel_dims[1], p.voxel_dims[2]);
+
+ p.img_dims[0] = img_dims[0];
+ p.img_dims[1] = img_dims[1];
+
+ // Create output tensors
+ // For Minecraft Seg Mask
+ torch::Tensor out_voxel_id = torch::empty({p.img_dims[0], p.img_dims[1], p.max_samples, 1}, torch::TensorOptions().dtype(torch::kInt32).device(device));
+
+ torch::Tensor out_depth;
+ // Produce two sets of localcoords, one for entry point, the other one for exit point. They share the same corner_ids.
+ out_depth = torch::empty({2, p.img_dims[0], p.img_dims[1], p.max_samples, 1}, torch::TensorOptions().dtype(torch::kFloat32).device(device));
+
+ torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3}, torch::TensorOptions().dtype(torch::kFloat32).device(device).requires_grad(false));
+
+ const int TILE_DIM = 8;
+ dim3 dimGrid((p.img_dims[1]+TILE_DIM-1)/TILE_DIM, (p.img_dims[0]+TILE_DIM-1)/TILE_DIM, 1);
+ dim3 dimBlock(TILE_DIM, TILE_DIM, 1);
+
+ ray_voxel_intersection_perspective_kernel<<>>(
+ out_voxel_id.data_ptr(), out_depth.data_ptr(), out_raydirs.data_ptr(), in_voxel.data_ptr(), p
+ );
+
+ return {out_voxel_id, out_depth, out_raydirs};
+}
diff --git a/imaginaire/model_utils/gancraft/voxlib/setup.py b/imaginaire/model_utils/gancraft/voxlib/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eca848211370c8ddf9dda55d5c67804a73061e9
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/setup.py
@@ -0,0 +1,25 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+cxx_args = ['-fopenmp']
+nvcc_args = []
+
+setup(
+ name='voxrender',
+ ext_modules=[
+ CUDAExtension('voxlib', [
+ 'voxlib.cpp',
+ 'ray_voxel_intersection.cu',
+ 'sp_trilinear_worldcoord_kernel.cu',
+ 'positional_encoding_kernel.cu'
+ ],
+ extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}
+ )
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ })
diff --git a/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bad56fb23f6b8e2a8e41573f8b6b85f9a5693f1
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py
@@ -0,0 +1,35 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from torch.autograd import Function
+import voxlib
+
+"""
+It takes world coordinate as input instead of block-local coordinate. Corner IDs are looked up on-the-fly to
+save memory.
+"""
+
+
+class SparseTrilinearWorldCoordFunction(Function):
+ @staticmethod
+ def forward(ctx, in_feature, corner_lut_t, in_worldcoord, ign_zero):
+
+ out_feature = voxlib.sp_trilinear_worldcoord(in_feature, corner_lut_t, in_worldcoord, ign_zero, -1)
+ ctx.ign_zero = ign_zero
+ ctx.save_for_backward(in_feature, corner_lut_t, in_worldcoord)
+
+ return out_feature
+
+ @staticmethod
+ def backward(ctx, out_feature_grad):
+ in_feature, corner_lut_t, in_worldcoord = ctx.saved_tensors
+
+ assert ctx.needs_input_grad[2] is False
+ in_feature_grad, = voxlib.sp_trilinear_worldcoord_backward(
+ out_feature_grad, in_feature, corner_lut_t, in_worldcoord, ctx.ign_zero, False)
+ return in_feature_grad, None, None, None, None
+
+
+def sparse_trilinear_interp_worldcoord(in_feature, corner_lut_t, in_worldcoord, ign_zero=False):
+ return SparseTrilinearWorldCoordFunction.apply(in_feature, corner_lut_t, in_worldcoord, ign_zero)
diff --git a/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..403d01fbda3f528211dd262d019a606f1f8f1640
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu
@@ -0,0 +1,527 @@
+// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, check out LICENSE.md
+//
+// Fast routine for sparse tri-linear interpolation of high dimensional features.
+// Ignore label is supported.
+
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+
+#include
+#include
+#include
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+
+struct SpTrilinear_wc_Params {
+ int in_feature_dim;
+ int in_feature_numentries;
+ int corner_lut_dims[3];
+ int corner_lut_strides[3];
+ int in_worldcoord_dims[8];
+ int in_worldcoord_strides[8];
+ int in_worldcoord_ndim;
+ int out_feature_dims[8];
+ int out_feature_strides[8];
+ bool ign_zero;
+};
+
+
+// out_feature.data_ptr(),
+// in_feature.data_ptr(), corner_lut_t.data_ptr(), in_worldcoord.data_ptr(), p
+template
+__global__ void sp_trilinear_worldcoord_kernel(
+ float* __restrict__ out_feature,
+ const float* __restrict__ in_feature, const int32_t* __restrict__ corner_lut_t, const float* __restrict__ in_worldcoord, SpTrilinear_wc_Params p) {
+
+ const int GRID_X = gridDim.y;
+ int idx_entry = blockIdx.x * TILE_DIM_Y + threadIdx.y;
+
+ // Index processing
+ //int index[7];
+ int t = idx_entry;
+ int idx_in_worldcoord = 0;
+ int idx_out_feature = 0;
+ for (int i=p.in_worldcoord_ndim-2; i>=0; i--) {
+ int idx_t = t % p.in_worldcoord_dims[i];
+ t = t / p.in_worldcoord_dims[i];
+ idx_in_worldcoord += p.in_worldcoord_strides[i] * idx_t;
+ idx_out_feature += p.out_feature_strides[i] * idx_t;
+ }
+ if (t > 0) {
+ return;
+ }
+ int stride_in_worldcoord = p.in_worldcoord_strides[p.in_worldcoord_ndim-1];
+ int stride_out_feature = p.out_feature_strides[p.in_worldcoord_ndim-1];
+
+
+ float world_coords[3];
+ world_coords[0] = in_worldcoord[idx_in_worldcoord];
+ world_coords[1] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord];
+ world_coords[2] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord*2];
+
+ float local_coords[3];
+ int vox_coords[3];
+ local_coords[0] = world_coords[0] - floorf(world_coords[0]);
+ vox_coords[0] = (int)floorf(world_coords[0]);
+ local_coords[1] = world_coords[1] - floorf(world_coords[1]);
+ vox_coords[1] = (int)floorf(world_coords[1]);
+ local_coords[2] = world_coords[2] - floorf(world_coords[2]);
+ vox_coords[2] = (int)floorf(world_coords[2]);
+
+ float interp_weight[8];
+ // 0,0,0
+ interp_weight[0] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]);
+ // 0,0,1
+ interp_weight[1] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]);
+ // 0,1,0
+ interp_weight[2] = (1.0f-local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]);
+ // 0,1,1
+ interp_weight[3] = (1.0f-local_coords[0])*(local_coords[1])*(local_coords[2]);
+ // 1,0,0
+ interp_weight[4] = (local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]);
+ // 1,0,1
+ interp_weight[5] = (local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]);
+ // 1,1,0
+ interp_weight[6] = (local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]);
+ // 1,1,1
+ interp_weight[7] = (local_coords[0])*(local_coords[1])*(local_coords[2]);
+
+ int indices[8];
+ // Hard boundary check (zero padding)
+ if (isnan(world_coords[0]) || isnan(world_coords[1]) || isnan(world_coords[2])) {
+ indices[0] = -1;indices[1] = -1;indices[2] = -1;indices[3] = -1;
+ indices[4] = -1;indices[5] = -1;indices[6] = -1;indices[7] = -1;
+ } else {
+ // Clamp to boundaries
+ int vox_coords_1[3];
+ vox_coords_1[0] = min(max(vox_coords[0]+1, 0), p.corner_lut_dims[0]-1);
+ vox_coords_1[1] = min(max(vox_coords[1]+1, 0), p.corner_lut_dims[1]-1);
+ vox_coords_1[2] = min(max(vox_coords[2]+1, 0), p.corner_lut_dims[2]-1);
+ vox_coords[0] = min(max(vox_coords[0], 0), p.corner_lut_dims[0]-1);
+ vox_coords[1] = min(max(vox_coords[1], 0), p.corner_lut_dims[1]-1);
+ vox_coords[2] = min(max(vox_coords[2], 0), p.corner_lut_dims[2]-1);
+ int idx_corner_lut;
+ // 000
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+ p.corner_lut_strides[1] * vox_coords[1] +
+ p.corner_lut_strides[2] * vox_coords[2];
+ indices[0] = corner_lut_t[idx_corner_lut];
+ // 001
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+ p.corner_lut_strides[1] * vox_coords[1] +
+ p.corner_lut_strides[2] * vox_coords_1[2];
+ indices[1] = corner_lut_t[idx_corner_lut];
+ // 010
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+ p.corner_lut_strides[1] * vox_coords_1[1] +
+ p.corner_lut_strides[2] * vox_coords[2];
+ indices[2] = corner_lut_t[idx_corner_lut];
+ // 011
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+ p.corner_lut_strides[1] * vox_coords_1[1] +
+ p.corner_lut_strides[2] * vox_coords_1[2];
+ indices[3] = corner_lut_t[idx_corner_lut];
+ // 100
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+ p.corner_lut_strides[1] * vox_coords[1] +
+ p.corner_lut_strides[2] * vox_coords[2];
+ indices[4] = corner_lut_t[idx_corner_lut];
+ // 101
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+ p.corner_lut_strides[1] * vox_coords[1] +
+ p.corner_lut_strides[2] * vox_coords_1[2];
+ indices[5] = corner_lut_t[idx_corner_lut];
+ // 110
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+ p.corner_lut_strides[1] * vox_coords_1[1] +
+ p.corner_lut_strides[2] * vox_coords[2];
+ indices[6] = corner_lut_t[idx_corner_lut];
+ // 111
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+ p.corner_lut_strides[1] * vox_coords_1[1] +
+ p.corner_lut_strides[2] * vox_coords_1[2];
+ indices[7] = corner_lut_t[idx_corner_lut];
+ }
+
+ if (p.ign_zero) {
+ // Zero indices are to be ignored
+#pragma unroll
+ for (int i=0; i<8; i++) {
+ indices[i] -= 1;
+ }
+ }
+
+ //int idx_feat = blockIdx.x * TILE_DIM_X * DUP_X + threadIdx.x;
+ int idx_feat = blockIdx.y * TILE_DIM_X + threadIdx.x;
+ for (int i=0; i= p.in_feature_dim) {
+ return;
+ }
+ float interp_feat = 0.0f;
+#pragma unroll
+ for (int j=0; j<8; j++) {
+ if (indices[j] >= 0) {
+ interp_feat = fmaf(in_feature[indices[j]*p.in_feature_dim+idx_feat], interp_weight[j], interp_feat);
+ }
+ }
+ //out_feature[idx_entry*p.in_feature_dim+idx_feat] = interp_feat;
+ out_feature[idx_out_feature+stride_out_feature*idx_feat] = interp_feat;
+ //idx_feat += TILE_DIM_X;
+ idx_feat += TILE_DIM_X * GRID_X;
+ }
+}
+
+
+//sp_trilinear_worldcoord_backward2feature_kernel<<>>(
+// in_feature_grad.data_ptr(),
+// out_feature_grad.data_ptr(), in_feature.data_ptr(), in_corner_lut.data_ptr(), in_worldcoord.data_ptr(), p
+// Backward to feature
+template
+__global__ void sp_trilinear_worldcoord_backward2feature_kernel(
+ float* __restrict__ in_feature_grad,
+ const float* __restrict__ out_feature_grad, const int32_t* __restrict__ corner_lut_t, const float* __restrict__ in_worldcoord, SpTrilinear_wc_Params p) {
+
+ const int GRID_X = gridDim.x;
+ int idx_entry = blockIdx.y * TILE_DIM_Y + threadIdx.y;
+
+ // Index processing
+ //int index[7];
+ int t = idx_entry;
+ int idx_in_worldcoord = 0;
+ int idx_out_feature = 0;
+ for (int i=p.in_worldcoord_ndim-2; i>=0; i--) {
+ int idx_t = t % p.in_worldcoord_dims[i];
+ t = t / p.in_worldcoord_dims[i];
+ //index[i] = idx_t;
+ idx_in_worldcoord += p.in_worldcoord_strides[i] * idx_t;
+ idx_out_feature += p.out_feature_strides[i] * idx_t;
+ }
+ if (t > 0) {
+ return;
+ }
+ int stride_in_worldcoord = p.in_worldcoord_strides[p.in_worldcoord_ndim-1];
+ int stride_out_feature = p.out_feature_strides[p.in_worldcoord_ndim-1];
+
+ float world_coords[3];
+ world_coords[0] = in_worldcoord[idx_in_worldcoord];
+ world_coords[1] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord];
+ world_coords[2] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord*2];
+
+ float local_coords[3];
+ int vox_coords[3];
+ local_coords[0] = world_coords[0] - floorf(world_coords[0]);
+ vox_coords[0] = (int)floorf(world_coords[0]);
+ local_coords[1] = world_coords[1] - floorf(world_coords[1]);
+ vox_coords[1] = (int)floorf(world_coords[1]);
+ local_coords[2] = world_coords[2] - floorf(world_coords[2]);
+ vox_coords[2] = (int)floorf(world_coords[2]);
+
+ float interp_weight[8];
+ // 0,0,0
+ interp_weight[0] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]);
+ // 0,0,1
+ interp_weight[1] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]);
+ // 0,1,0
+ interp_weight[2] = (1.0f-local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]);
+ // 0,1,1
+ interp_weight[3] = (1.0f-local_coords[0])*(local_coords[1])*(local_coords[2]);
+ // 1,0,0
+ interp_weight[4] = (local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]);
+ // 1,0,1
+ interp_weight[5] = (local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]);
+ // 1,1,0
+ interp_weight[6] = (local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]);
+ // 1,1,1
+ interp_weight[7] = (local_coords[0])*(local_coords[1])*(local_coords[2]);
+
+ int indices[8];
+ // Hard boundary check (zero padding)
+ if (isnan(world_coords[0]) || isnan(world_coords[1]) || isnan(world_coords[2])) {// ||
+ //vox_coords[0] < 0 || vox_coords[0] >= (p.corner_lut_dims[0]-1) ||
+ //vox_coords[1] < 0 || vox_coords[1] >= (p.corner_lut_dims[1]-1) ||
+ //vox_coords[2] < 0 || vox_coords[2] >= (p.corner_lut_dims[2]-1)) {
+ indices[0] = -1;indices[1] = -1;indices[2] = -1;indices[3] = -1;
+ indices[4] = -1;indices[5] = -1;indices[6] = -1;indices[7] = -1;
+ } else {
+ // Clamp to boundaries
+ int vox_coords_1[3];
+ vox_coords_1[0] = min(max(vox_coords[0]+1, 0), p.corner_lut_dims[0]-1);
+ vox_coords_1[1] = min(max(vox_coords[1]+1, 0), p.corner_lut_dims[1]-1);
+ vox_coords_1[2] = min(max(vox_coords[2]+1, 0), p.corner_lut_dims[2]-1);
+ vox_coords[0] = min(max(vox_coords[0], 0), p.corner_lut_dims[0]-1);
+ vox_coords[1] = min(max(vox_coords[1], 0), p.corner_lut_dims[1]-1);
+ vox_coords[2] = min(max(vox_coords[2], 0), p.corner_lut_dims[2]-1);
+ int idx_corner_lut;
+ // 000
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+ p.corner_lut_strides[1] * vox_coords[1] +
+ p.corner_lut_strides[2] * vox_coords[2];
+ indices[0] = corner_lut_t[idx_corner_lut];
+ // 001
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+ p.corner_lut_strides[1] * vox_coords[1] +
+ p.corner_lut_strides[2] * vox_coords_1[2];
+ indices[1] = corner_lut_t[idx_corner_lut];
+ // 010
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+ p.corner_lut_strides[1] * vox_coords_1[1] +
+ p.corner_lut_strides[2] * vox_coords[2];
+ indices[2] = corner_lut_t[idx_corner_lut];
+ // 011
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] +
+ p.corner_lut_strides[1] * vox_coords_1[1] +
+ p.corner_lut_strides[2] * vox_coords_1[2];
+ indices[3] = corner_lut_t[idx_corner_lut];
+ // 100
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+ p.corner_lut_strides[1] * vox_coords[1] +
+ p.corner_lut_strides[2] * vox_coords[2];
+ indices[4] = corner_lut_t[idx_corner_lut];
+ // 101
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+ p.corner_lut_strides[1] * vox_coords[1] +
+ p.corner_lut_strides[2] * vox_coords_1[2];
+ indices[5] = corner_lut_t[idx_corner_lut];
+ // 110
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+ p.corner_lut_strides[1] * vox_coords_1[1] +
+ p.corner_lut_strides[2] * vox_coords[2];
+ indices[6] = corner_lut_t[idx_corner_lut];
+ // 111
+ idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] +
+ p.corner_lut_strides[1] * vox_coords_1[1] +
+ p.corner_lut_strides[2] * vox_coords_1[2];
+ indices[7] = corner_lut_t[idx_corner_lut];
+ }
+
+ if (p.ign_zero) {
+#pragma unroll
+ for (int i=0; i<8; i++) {
+ indices[i] -= 1;
+ }
+ }
+
+ //int idx_feat = blockIdx.x * TILE_DIM_X * DUP_X + threadIdx.x;
+ int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x;
+ for (int i=0; i= p.in_feature_dim) {
+ return;
+ }
+ float grad = out_feature_grad[idx_out_feature+stride_out_feature*idx_feat];
+#pragma unroll
+ for (int j=0; j<8; j++) {
+ if (indices[j] >= 0) {
+ //indices[j]*p.in_feature_dim+idx_feat
+ atomicAdd(&in_feature_grad[indices[j]*p.in_feature_dim+idx_feat], grad * interp_weight[j]);
+ }
+ }
+ //idx_feat += TILE_DIM_X;
+ idx_feat += TILE_DIM_X * GRID_X;
+ }
+}
+
+// in_feature, corner_lut_t, in_world_coord, ign_zero=False
+// Input:
+// in_feature: float32 [M C]
+// in_corner_lut: int32 [X Y Z]
+// in_worldcoord: float32 [..., 3]
+// ---Index: int32 [..., 8], containing [0, M]. 0 is ignore label.
+// ---Coord: float32 [..., 3]
+// Output:
+// Interp. Feat: float32 [..., C]
+// std::vector
+torch::Tensor sp_trilinear_worldcoord_cuda(const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, int channel_pos) {
+ CHECK_CUDA(in_feature);
+ CHECK_CUDA(in_corner_lut);
+ CHECK_CUDA(in_worldcoord);
+
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+ torch::Device device = in_feature.device();
+
+ // assert(tensor.sizes() == std::vector{3, 4, 5});
+ assert(in_feature.dtype() == torch::kFloat32);
+ assert(in_feature.dim() == 2);
+ assert(in_corner_lut.dtype() == torch::kInt32);
+ assert(in_corner_lut.dim() == 3);
+ assert(in_worldcoord.dtype() == torch::kFloat32);
+ assert(in_worldcoord.size(-1) == 3);
+ assert(in_worldcoord.dim() <= 8);
+
+ CHECK_CONTIGUOUS(in_feature);
+ //CHECK_CONTIGUOUS(in_corner_lut); // Will still run correctly, but performance will suffer.
+ //CHECK_CONTIGUOUS(in_worldcoord);
+
+ //int channel_pos = -1; // -1 for HWC, -3 for CHW
+ if (channel_pos < 0) {
+ channel_pos += in_worldcoord.dim();
+ }
+ assert(channel_pos >= 0 && channel_pos < in_worldcoord.dim());
+
+ SpTrilinear_wc_Params p;
+ p.in_feature_dim = in_feature.size(1);
+ p.in_feature_numentries = in_feature.size(0);
+ p.in_worldcoord_ndim = in_worldcoord.dim();
+ for (int i=0; i out_feature_shape;
+ //if (channel_first) { // Channel First format, suitable for 2D convolution
+ // //assert(false);
+ for (int i=0; i<<>>(
+ out_feature.data_ptr(),
+ in_feature.data_ptr(), in_corner_lut.data_ptr(), in_worldcoord.data_ptr(), p
+ );
+ THCudaCheck(cudaGetLastError());
+ return out_feature;
+}
+
+
+// Backward function for sparse trilinear interpolation
+// Input:
+// out_feature_grad: float32 [..., C]
+// in_feature: float32 [M, C]
+// in_corner_lut: int32 [X Y Z]
+// ---in_index: int32 [..., 8], containing [0, M]. 0 is ignore label.
+// in_worldcoord: float32 [..., 3]
+// ign_zero: bool
+// need_coord_grad: bool
+// Output:
+// in_feature_grad: float32 [M, C]
+// in_coord_grad: float32 [..., 3]
+std::vector sp_trilinear_worldcoord_backward_cuda(const torch::Tensor& out_feature_grad , const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, bool need_coord_grad) {
+ assert(need_coord_grad == false);
+ CHECK_CUDA(out_feature_grad);
+ CHECK_CUDA(in_feature);
+ CHECK_CUDA(in_corner_lut);
+ CHECK_CUDA(in_worldcoord);
+
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+ torch::Device device = out_feature_grad.device();
+
+ //for (int i=0; i{3, 4, 5});
+ assert(out_feature_grad.dtype() == torch::kFloat32);
+ for (int i=0; i<<>>(
+ in_feature_grad.data_ptr(),
+ out_feature_grad.data_ptr(), in_corner_lut.data_ptr(), in_worldcoord.data_ptr(), p
+ );
+ }
+
+ THCudaCheck(cudaGetLastError());
+ return {in_feature_grad};
+}
diff --git a/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp b/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..70095052d71f53e5e519a5f57b3c0848998a1b22
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp
@@ -0,0 +1,31 @@
+// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, check out LICENSE.md
+#include
+#include
+#include
+#include
+
+// Fast voxel traversal along rays
+std::vector ray_voxel_intersection_perspective_cuda(const torch::Tensor& in_voxel, const torch::Tensor& cam_ori, const torch::Tensor& cam_dir, const torch::Tensor& cam_up, float cam_f, const std::vector& cam_c, const std::vector& img_dims, int max_samples);
+
+
+// World Coordinate Sparse Trilinear Interpolation
+torch::Tensor sp_trilinear_worldcoord_cuda(const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, int channel_pos);
+
+std::vector sp_trilinear_worldcoord_backward_cuda(const torch::Tensor& out_feature_grad , const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, bool need_coord_grad);
+
+// Fast & Memory Efficient Positional Encoding
+torch::Tensor positional_encoding_cuda(const torch::Tensor& in_feature, int ndegrees, int dim, bool incl_orig);
+
+torch::Tensor positional_encoding_backward_cuda(const torch::Tensor& out_feature_grad, const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig);
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ray_voxel_intersection_perspective", &ray_voxel_intersection_perspective_cuda, "Ray-voxel intersections given perspective camera parameters (CUDA)");
+ m.def("sp_trilinear_worldcoord", &sp_trilinear_worldcoord_cuda, "Sparse Trilinear interpolation, world coordinate [forward] (CUDA)");
+ m.def("sp_trilinear_worldcoord_backward", &sp_trilinear_worldcoord_backward_cuda, "Sparse Trilinear interpolation, world coordinate [backward] (CUDA)");
+ m.def("positional_encoding", &positional_encoding_cuda, "Fused Positional Encoding [forward] (CUDA)");
+ m.def("positional_encoding_backward", &positional_encoding_backward_cuda, "Fused Positional Encoding [backward] (CUDA)");
+}
\ No newline at end of file
diff --git a/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h b/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h
new file mode 100644
index 0000000000000000000000000000000000000000..46b47fc80ecf802347607395ff04565732a4ee87
--- /dev/null
+++ b/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h
@@ -0,0 +1,76 @@
+// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, check out LICENSE.md
+#ifndef VOXLIB_COMMON_H
+#define VOXLIB_COMMON_H
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+#define CHECK_CPU(x) TORCH_CHECK(x.device().is_cpu(), #x " must be a CPU tensor")
+
+#include
+#include
+// CUDA vector math functions
+__host__ __device__ __forceinline__ int floor_div(int a, int b) {
+ int c = a / b;
+
+ if (c * b > a) {
+ c--;
+ }
+
+ return c;
+}
+
+template
+__host__ __forceinline__ void cross(scalar_t* r, const scalar_t* a, const scalar_t* b) {
+ r[0] = a[1]*b[2] - a[2]*b[1];
+ r[1] = a[2]*b[0] - a[0]*b[2];
+ r[2] = a[0]*b[1] - a[1]*b[0];
+}
+
+__device__ __host__ __forceinline__ float dot(const float* a, const float* b) {
+ return a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
+}
+
+template
+__device__ __host__ __forceinline__ void copyarr(scalar_t* r, const scalar_t* a) {
+ #pragma unroll
+ for (int i=0; i
+__device__ __host__ __forceinline__ void normalize(scalar_t* a) {
+ scalar_t vec_len=0.0f;
+ #pragma unroll
+ for (int i=0; i
+__device__ __host__ __forceinline__ void normalize(scalar_t* r, const scalar_t* a) {
+ scalar_t vec_len=0.0f;
+ #pragma unroll
+ for (int i=0; i= data_type_num_classes] = data_type_num_classes
+ data[data_type] = label_map / 255.0
+ return data
+
+
+def _encode_onehot(label_map, num_classes, use_dont_care):
+ r"""Make input one-hot.
+
+ Args:
+ label_map (torch.Tensor): (C, H, W) tensor containing indices.
+ num_classes (int): Number of labels to expand tensor to.
+ use_dont_care (bool): Use the dont care label or not?
+ Returns:
+ output (torch.Tensor): (num_classes, H, W) one-hot tensor.
+ """
+ # All labels lie in [0. num_classes - 1].
+ # Encode dont care as num_classes.
+ label_map[label_map < 0] = num_classes
+ label_map[label_map >= num_classes] = num_classes
+
+ size = label_map.size()
+ output_size = (num_classes + 1, size[-2], size[-1])
+ output = torch.zeros(*output_size)
+ if label_map.dim() == 4:
+ output = output.unsqueeze(0).repeat(label_map.size(0), 1, 1, 1)
+ output = output.scatter_(1, label_map.data.long(), 1.0)
+ if not use_dont_care:
+ output = output[:, :num_classes, ...]
+ else:
+ output = output.scatter_(0, label_map.data.long(), 1.0)
+ if not use_dont_care:
+ output = output[:num_classes, ...]
+ return output
diff --git a/imaginaire/model_utils/pix2pixHD.py b/imaginaire/model_utils/pix2pixHD.py
new file mode 100644
index 0000000000000000000000000000000000000000..862eb591a665e3ea81b3f918696feac3640a6c94
--- /dev/null
+++ b/imaginaire/model_utils/pix2pixHD.py
@@ -0,0 +1,227 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+"""Utils for the pix2pixHD model."""
+import numpy as np
+import torch
+
+from imaginaire.utils.data import get_paired_input_label_channel_number
+from imaginaire.utils.distributed import dist_all_gather_tensor, is_master
+from imaginaire.utils.distributed import master_only_print as print
+from imaginaire.utils.trainer import (get_optimizer, get_optimizer_for_params,
+ wrap_model_and_optimizer)
+from sklearn.cluster import KMeans
+
+
+def cluster_features(cfg, train_data_loader, net_E,
+ preprocess=None, small_ratio=0.0625, is_cityscapes=True):
+ r"""Use clustering to compute the features.
+
+ Args:
+ cfg (obj): Global configuration file.
+ train_data_loader (obj): Dataloader for iterate through the training
+ set.
+ net_E (nn.Module): Pytorch network.
+ preprocess (function): Pre-processing function.
+ small_ratio (float): We only consider instance that at least occupy
+ $(small_ratio) amount of image space.
+ is_cityscapes (bool): Is this is the cityscape dataset? In the
+ Cityscapes dataset, the instance labels for car start with 26001,
+ 26002, ...
+
+ Returns:
+ ( num_labels x num_cluster_centers x feature_dims): cluster centers.
+ """
+ # Encode features.
+ label_nc = get_paired_input_label_channel_number(cfg.data)
+ feat_nc = cfg.gen.enc.num_feat_channels
+ n_clusters = getattr(cfg.gen.enc, 'num_clusters', 10)
+ # Compute features.
+ features = {}
+ for label in range(label_nc):
+ features[label] = np.zeros((0, feat_nc + 1))
+ for data in train_data_loader:
+ if preprocess is not None:
+ data = preprocess(data)
+ feat = encode_features(net_E, feat_nc, label_nc,
+ data['images'], data['instance_maps'],
+ is_cityscapes)
+ # We only collect the feature vectors for the master GPU.
+ if is_master():
+ for label in range(label_nc):
+ features[label] = np.append(
+ features[label], feat[label], axis=0)
+ # Clustering.
+ # We only perform clustering for the master GPU.
+ if is_master():
+ for label in range(label_nc):
+ feat = features[label]
+ # We only consider segments that are greater than a pre-set
+ # threshold.
+ feat = feat[feat[:, -1] > small_ratio, :-1]
+ if feat.shape[0]:
+ n_clusters = min(feat.shape[0], n_clusters)
+ kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat)
+ n, d = kmeans.cluster_centers_.shape
+ this_cluster = getattr(net_E, 'cluster_%d' % label)
+ this_cluster[0:n, :] = torch.Tensor(
+ kmeans.cluster_centers_).float()
+
+
+def encode_features(net_E, feat_nc, label_nc, image, inst,
+ is_cityscapes=True):
+ r"""Compute feature embeddings for an image image.
+ TODO(Ting-Chun): To make this funciton dataset independent.
+
+ Args:
+ net_E (nn.Module): The encoder network.
+ feat_nc (int): Feature dimensions
+ label_nc (int): Number of segmentation labels.
+ image (tensor): Input image tensor.
+ inst (tensor): Input instance map.
+ is_cityscapes (bool): Is this is the cityscape dataset? In the
+ Cityscapes dataset, the instance labels for car start with 26001,
+ 26002, ...
+ Returns:
+ (list of list of numpy vectors): We will have $(label_nc)
+ list. For each list, it will record a list of feature vectors of
+ dimension $(feat_nc+1) where the first $(feat_nc) dimensions is
+ the representative feature of an instance and the last dimension
+ is the proportion.
+ """
+ # h, w = inst.size()[2:]
+ feat_map = net_E(image, inst)
+ feature_map_gather = dist_all_gather_tensor(feat_map)
+ inst_gathered = dist_all_gather_tensor(inst)
+ # Initialize the cluster centers.
+ # For each feature vector,
+ # 0:feat_nc will be the feature vector.
+ # The feat_nc dimension record the percentage of the instance.
+ feature = {}
+ for i in range(label_nc):
+ feature[i] = np.zeros((0, feat_nc + 1))
+ if is_master():
+ all_feat_map = torch.cat(feature_map_gather, 0)
+ all_inst_map = torch.cat(inst_gathered, 0)
+ # Scan through the batches.
+ for n in range(all_feat_map.size()[0]):
+ feat_map = all_feat_map[n:(n + 1), :, :, :]
+ inst = all_inst_map[n:(n + 1), :, :, :]
+ fh, fw = feat_map.size()[2:]
+ inst_np = inst.cpu().numpy().astype(int)
+ for i in np.unique(inst_np):
+ if is_cityscapes:
+ label = i if i < 1000 else i // 1000
+ else:
+ label = i
+ idx = (inst == int(i)).nonzero()
+ num = idx.size()[0]
+ # We will just pick the middle pixel as its representative
+ # feature.
+ idx = idx[num // 2, :]
+ val = np.zeros((1, feat_nc + 1))
+ for k in range(feat_nc):
+ # We expect idx[0]=0 and idx[1]=0 as the number of sample
+ # per processing is 1 (idx[0]=0) and the channel number of
+ # the instance map is 1.
+ val[0, k] = feat_map[
+ idx[0], idx[1] + k, idx[2], idx[3]].item()
+ val[0, feat_nc] = float(num) / (fh * fw)
+ feature[label] = np.append(feature[label], val, axis=0)
+ return feature
+ else:
+ return feature
+
+
+def get_edges(t):
+ r""" Compute edge maps for a given input instance map.
+
+ Args:
+ t (4D tensor): Input instance map.
+ Returns:
+ (4D tensor): Output edge map.
+ """
+ edge = torch.cuda.ByteTensor(t.size()).zero_()
+ edge[:, :, :, 1:] = edge[:, :, :, 1:] | (
+ t[:, :, :, 1:] != t[:, :, :, :-1]).byte()
+ edge[:, :, :, :-1] = edge[:, :, :, :-1] | (
+ t[:, :, :, 1:] != t[:, :, :, :-1]).byte()
+ edge[:, :, 1:, :] = edge[:, :, 1:, :] | (
+ t[:, :, 1:, :] != t[:, :, :-1, :]).byte()
+ edge[:, :, :-1, :] = edge[:, :, :-1, :] | (
+ t[:, :, 1:, :] != t[:, :, :-1, :]).byte()
+ return edge.float()
+
+
+def get_train_params(net, param_names_start_with=[], param_names_include=[]):
+ r"""Get train parameters.
+
+ Args:
+ net (obj): Network object.
+ param_names_start_with (list of strings): Params whose names
+ start with any of the strings will be trained.
+ param_names_include (list of strings): Params whose names include
+ any of the strings will be trained.
+ """
+ params_to_train = []
+ params_dict = net.state_dict()
+ list_of_param_names_to_train = set()
+ # Iterate through all params in the network and check if we need to
+ # train it.
+ for key, value in params_dict.items():
+ do_train = False
+ # If the param name starts with the target string (excluding
+ # the 'module' part etc), we will train this param.
+ key_s = key.replace('module.', '').replace('averaged_model.', '')
+ for param_name in param_names_start_with:
+ if key_s.startswith(param_name):
+ do_train = True
+ list_of_param_names_to_train.add(param_name)
+
+ # Otherwise, if the param name includes the target string,
+ # we will also train it.
+ if not do_train:
+ for param_name in param_names_include:
+ if param_name in key_s:
+ do_train = True
+ full_param_name = \
+ key_s[:(key_s.find(param_name) + len(param_name))]
+ list_of_param_names_to_train.add(full_param_name)
+
+ # If we decide to train the param, add it to the list to train.
+ if do_train:
+ module = net
+ key_list = key.split('.')
+ for k in key_list:
+ module = getattr(module, k)
+ params_to_train += [module]
+
+ print('Training layers: ', sorted(list_of_param_names_to_train))
+ return params_to_train
+
+
+def get_optimizer_with_params(cfg, net_G, net_D, param_names_start_with=[],
+ param_names_include=[]):
+ r"""Return the optimizer object.
+
+ Args:
+ cfg (obj): Global config.
+ net_G (obj): Generator network.
+ net_D (obj): Discriminator network.
+ param_names_start_with (list of strings): Params whose names
+ start with any of the strings will be trained.
+ param_names_include (list of strings): Params whose names include
+ any of the strings will be trained.
+ """
+ # If any of the param name lists is not empty, will only train
+ # these params. Otherwise will train the entire network (all params).
+ if param_names_start_with or param_names_include:
+ params = get_train_params(net_G, param_names_start_with,
+ param_names_include)
+ else:
+ params = net_G.parameters()
+
+ opt_G = get_optimizer_for_params(cfg.gen_opt, params)
+ opt_D = get_optimizer(cfg.dis_opt, net_D)
+ return wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D)
diff --git a/imaginaire/model_utils/rename_inputs.py b/imaginaire/model_utils/rename_inputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..f40b3f98f6bf21f9efb21c9b7cd99226adbb2769
--- /dev/null
+++ b/imaginaire/model_utils/rename_inputs.py
@@ -0,0 +1,15 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+
+
+def rename_inputs(cfg, is_inference, data):
+ assert hasattr(cfg, 'rename_inputs')
+ attr = getattr(cfg, 'rename_inputs')
+ for key in attr.keys():
+ value = attr[key]
+ data[key] = data[value]
+ # Delete the old key.
+ del data[value]
+ return data
diff --git a/imaginaire/model_utils/wc_vid2vid/render.py b/imaginaire/model_utils/wc_vid2vid/render.py
new file mode 100644
index 0000000000000000000000000000000000000000..304b1cb3b58384ad8fb99bacf43ddda0bc0b2ff6
--- /dev/null
+++ b/imaginaire/model_utils/wc_vid2vid/render.py
@@ -0,0 +1,199 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import pickle
+import time
+
+import numpy as np
+
+
+class SplatRenderer(object):
+ """Splatting 3D point cloud into image using precomputed mapping."""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ """Reset the renderer."""
+ # 1 = point seen before, 0 = not seen.
+ # This is numpy uint8 array of size (N, 1)
+ self.seen_mask = None
+
+ # Time of first colorization of 3D point.
+ # This is numpy uint16 array of size (N, 1)
+ self.seen_time = None
+
+ # colors[kp_idx] is color of kp_idx'th keypoint.
+ # This is a numpy uint8 array of size (N, 3)
+ self.colors = None
+
+ self.time_taken = 0
+ self.call_idx = 0
+
+ def num_points(self):
+ r"""Number of points with assigned colors."""
+ return np.sum(self.seen_mask)
+
+ def _resize_arrays(self, max_point_idx):
+ r"""Makes arrays bigger, if needed.
+ Args:
+ max_point_idx (int): Highest 3D point index seen so far.
+ """
+ if self.colors is None:
+ old_max_point_idx = 0
+ else:
+ old_max_point_idx = self.colors.shape[0]
+
+ if max_point_idx > old_max_point_idx:
+ # Init new bigger arrays.
+ colors = np.zeros((max_point_idx, 3), dtype=np.uint8)
+ seen_mask = np.zeros((max_point_idx, 1), dtype=np.uint8)
+ seen_time = np.zeros((max_point_idx, 1), dtype=np.uint16)
+ # Copy old colors, if exist.
+ if old_max_point_idx > 0:
+ colors[:old_max_point_idx] = self.colors
+ seen_mask[:old_max_point_idx] = self.seen_mask
+ seen_time[:old_max_point_idx] = self.seen_time
+ # Reset pointers.
+ self.colors = colors
+ self.seen_mask = seen_mask
+ self.seen_time = seen_time
+
+ def update_point_cloud(self, image, point_info):
+ r"""Updates point cloud with new points and colors.
+ Args:
+ image (H x W x 3, uint8): Select colors from this image to assign to
+ 3D points which do not have previously assigned colors.
+ point_info (N x 3): (i, j, 3D point idx) per row containing
+ mapping of image pixel to 3D point in point cloud.
+ """
+ if point_info is None or len(point_info) == 0:
+ return
+
+ start = time.time()
+ self.call_idx += 1
+
+ i_idxs = point_info[:, 0]
+ j_idxs = point_info[:, 1]
+ point_idxs = point_info[:, 2]
+
+ # Allocate memory for new colors.
+ max_point_idx = np.max(np.array(point_idxs)) + 1
+ self._resize_arrays(max_point_idx)
+ # print('max point idx:', max_point_idx)
+
+ # Save only the new colors.
+ self.colors[point_idxs] = \
+ self.seen_mask[point_idxs] * self.colors[point_idxs] + \
+ (1 - self.seen_mask[point_idxs]) * image[i_idxs, j_idxs]
+
+ # Save point seen times.
+ self.seen_time[point_idxs] = \
+ self.seen_mask[point_idxs] * self.seen_time[point_idxs] + \
+ (1 - self.seen_mask[point_idxs]) * self.call_idx
+
+ # Update seen point mask.
+ self.seen_mask[point_idxs] = 1
+
+ end = time.time()
+ self.time_taken += (end - start)
+
+ def render_image(self, point_info, w, h, return_mask=False):
+ r"""Creates image of (h, w) and fills in colors.
+ Args:
+ point_info (N x 3): (i, j, 3D point idx) per row containing
+ mapping of image pixel to 3D point in point cloud.
+ w (int): Width of output image.
+ h (int): Height of output image.
+ return_mask (bool): Return binary mask of coloring.
+ Returns:
+ (tuple):
+ - output (H x W x 3, uint8): Image formed with mapping and colors.
+ - mask (H x W x 1, uint8): Binary (255 or 0) mask of colorization.
+ """
+ output = np.zeros((h, w, 3), dtype=np.uint8)
+ mask = np.zeros((h, w, 1), dtype=np.uint8)
+
+ if point_info is None or len(point_info) == 0:
+ if return_mask:
+ return output, mask
+ else:
+ return output
+
+ start = time.time()
+
+ i_idxs = point_info[:, 0]
+ j_idxs = point_info[:, 1]
+ point_idxs = point_info[:, 2]
+
+ # Allocate memory for new colors.
+ max_point_idx = np.max(np.array(point_idxs)) + 1
+ self._resize_arrays(max_point_idx)
+
+ # num_found = np.sum(self.seen_mask[point_idxs])
+ # print('Found %d points to color' % (num_found))
+
+ # Copy colors.
+ output[i_idxs, j_idxs] = self.colors[point_idxs]
+
+ end = time.time()
+ self.time_taken += (end - start)
+
+ if return_mask:
+ mask[i_idxs, j_idxs] = 255 * self.seen_mask[point_idxs]
+ return output, mask
+ else:
+ return output
+
+
+def decode_unprojections(data):
+ r"""Unpickle unprojections and make array.
+ Args:
+ data (array of pickled info): Each pickled string has keypoint mapping
+ info.
+ Returns:
+ output (dict): Keys are the different resolutions, and values are padded
+ mapping information.
+ """
+
+ # Unpickle unprojections and store them in a dict with resolutions as keys.
+ all_unprojections = {}
+ for item in data:
+ info = pickle.loads(item)
+
+ for resolution, value in info.items():
+ if resolution not in all_unprojections:
+ all_unprojections[resolution] = []
+
+ if not value or value is None:
+ point_info = []
+ else:
+ point_info = value
+ all_unprojections[resolution].append(point_info)
+
+ outputs = {}
+ for resolution, values in all_unprojections.items():
+ # Get max length of mapping.
+ max_len = 0
+ for value in values:
+ max_len = max(max_len, len(value))
+ # Entries are a 3-tuple of (i_idx, j_idx, point_idx).
+ assert len(value) % 3 == 0
+
+ # Pad each mapping to max_len.
+ values = [
+ value + # Original info.
+ [-1] * (max_len - len(value)) + # Padding.
+ [len(value) // 3] * 3 # End sentinel with length.
+ for value in values
+ ]
+
+ # Convert each mapping to numpy and reshape.
+ values = [np.array(value).reshape(-1, 3) for value in values]
+
+ # Stack and put in output.
+ # Shape is (T, N, 3). T is time steps, N is num mappings.
+ outputs[resolution] = np.stack(values, axis=0)
+
+ return outputs
diff --git a/imaginaire/optimizers/__init__.py b/imaginaire/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..69bedc71589e178b96730174c9860b8cc8430e55
--- /dev/null
+++ b/imaginaire/optimizers/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+from .fromage import Fromage
+from .madam import Madam
+
+__all__ = ['Fromage', 'Madam']
diff --git a/imaginaire/optimizers/__pycache__/__init__.cpython-38.pyc b/imaginaire/optimizers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f4a0a8977871d0626a62bfdd2561625a4fc3b5d7
Binary files /dev/null and b/imaginaire/optimizers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/optimizers/__pycache__/fromage.cpython-38.pyc b/imaginaire/optimizers/__pycache__/fromage.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f8740580fc34277745504381d463ec83bd9d493
Binary files /dev/null and b/imaginaire/optimizers/__pycache__/fromage.cpython-38.pyc differ
diff --git a/imaginaire/optimizers/__pycache__/madam.cpython-38.pyc b/imaginaire/optimizers/__pycache__/madam.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a384b1798545cd7b6d9b3515e1afc2b7fd892f68
Binary files /dev/null and b/imaginaire/optimizers/__pycache__/madam.cpython-38.pyc differ
diff --git a/imaginaire/optimizers/fromage.py b/imaginaire/optimizers/fromage.py
new file mode 100644
index 0000000000000000000000000000000000000000..d00203de89f55fd122f71b7de8718ed7ef681ec8
--- /dev/null
+++ b/imaginaire/optimizers/fromage.py
@@ -0,0 +1,44 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# import torch
+import math
+
+from torch.optim.optimizer import Optimizer, required
+
+
+class Fromage(Optimizer):
+ r"""Fromage optimizer implementation (https://arxiv.org/abs/2002.03432)"""
+
+ def __init__(self, params, lr=required, momentum=0):
+ if lr is not required and lr < 0.0:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ defaults = dict(lr=lr, momentum=momentum)
+ super(Fromage, self).__init__(params, defaults)
+
+ def step(self, closure=None):
+ r"""Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ d_p = p.grad.data
+ d_p_norm = p.grad.norm()
+ p_norm = p.norm()
+ if p_norm > 0.0 and d_p_norm > 0.0:
+ p.data.add_(-group['lr'], d_p * (p_norm / d_p_norm))
+ else:
+ p.data.add_(-group['lr'], d_p)
+ p.data /= math.sqrt(1 + group['lr'] ** 2)
+
+ return loss
diff --git a/imaginaire/optimizers/madam.py b/imaginaire/optimizers/madam.py
new file mode 100644
index 0000000000000000000000000000000000000000..11bf71d049d9323e9ba646713413578ae5eb4503
--- /dev/null
+++ b/imaginaire/optimizers/madam.py
@@ -0,0 +1,54 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch.optim.optimizer import Optimizer, required
+
+
+class Madam(Optimizer):
+ r"""MADAM optimizer implementation (https://arxiv.org/abs/2006.14560)"""
+ def __init__(self, params, lr=required, scale=3.0,
+ g_bound=None, momentum=0):
+ self.scale = scale
+ self.g_bound = g_bound
+ defaults = dict(lr=lr, momentum=momentum)
+ super(Madam, self).__init__(params, defaults)
+
+ def step(self, closure=None):
+ r"""Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ state = self.state[p]
+ if len(state) == 0:
+ state['max'] = self.scale * (p * p).mean().sqrt().item()
+ state['step'] = 0
+ state['exp_avg_sq'] = torch.zeros_like(p)
+
+ state['step'] += 1
+ bias_correction = 1 - 0.999 ** state['step']
+ state['exp_avg_sq'] = 0.999 * state[
+ 'exp_avg_sq'] + 0.001 * p.grad.data ** 2
+ g_normed = \
+ p.grad.data / (state['exp_avg_sq'] / bias_correction).sqrt()
+ g_normed[torch.isnan(g_normed)] = 0
+ if self.g_bound is not None:
+ g_normed.clamp_(-self.g_bound, self.g_bound)
+
+ p.data *= torch.exp(
+ -group['lr'] * g_normed * torch.sign(p.data))
+ p.data.clamp_(-state['max'], state['max'])
+
+ return loss
diff --git a/imaginaire/third_party/__init__.py b/imaginaire/third_party/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imaginaire/third_party/__pycache__/__init__.cpython-38.pyc b/imaginaire/third_party/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d5c112b422e5e97bfa8f09f096fbac9b1e6bae54
Binary files /dev/null and b/imaginaire/third_party/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/third_party/bias_act/__init__.py b/imaginaire/third_party/bias_act/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dfe0aec6e5fd4a1538ed959abff7c5106784c9b
--- /dev/null
+++ b/imaginaire/third_party/bias_act/__init__.py
@@ -0,0 +1,3 @@
+from .bias_act import FusedNonlinearity
+
+__all__ = ['FusedNonlinearity']
diff --git a/imaginaire/third_party/bias_act/__pycache__/__init__.cpython-38.pyc b/imaginaire/third_party/bias_act/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..11e3a34fdbfff20d3f67e073c1b6048c72f57216
Binary files /dev/null and b/imaginaire/third_party/bias_act/__pycache__/__init__.cpython-38.pyc differ
diff --git a/imaginaire/third_party/bias_act/__pycache__/bias_act.cpython-38.pyc b/imaginaire/third_party/bias_act/__pycache__/bias_act.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d99f8e6245f42b7450319d0f9bd9974ee53b5402
Binary files /dev/null and b/imaginaire/third_party/bias_act/__pycache__/bias_act.cpython-38.pyc differ
diff --git a/imaginaire/third_party/bias_act/bias_act.py b/imaginaire/third_party/bias_act/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..29b01dc97884036aec1c42feb184c510c5ad0870
--- /dev/null
+++ b/imaginaire/third_party/bias_act/bias_act.py
@@ -0,0 +1,219 @@
+# flake8: noqa
+import numpy as np
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+
+import bias_act_cuda
+
+# ----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': SimpleNamespace(func=lambda x, **_: x, def_alpha=0, def_gain=1,
+ cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.relu(x),
+ def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2,
+ ref='y', has_2nd_grad=False),
+ 'leakyrelu': SimpleNamespace(
+ func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
+ def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y',
+ has_2nd_grad=False),
+ 'tanh': SimpleNamespace(func=lambda x, **_: torch.tanh(x), def_alpha=0,
+ def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': SimpleNamespace(func=lambda x, **_: torch.sigmoid(x),
+ def_alpha=0, def_gain=1, cuda_idx=5, ref='y',
+ has_2nd_grad=True),
+ 'elu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.elu(x),
+ def_alpha=0, def_gain=1, cuda_idx=6, ref='y',
+ has_2nd_grad=True),
+ 'selu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.selu(x),
+ def_alpha=0, def_gain=1, cuda_idx=7, ref='y',
+ has_2nd_grad=True),
+ 'softplus': SimpleNamespace(
+ func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0,
+ def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': SimpleNamespace(func=lambda x, **_: torch.sigmoid(x) * x,
+ def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9,
+ ref='x', has_2nd_grad=True),
+}
+
+# ----------------------------------------------------------------------------
+
+_null_tensor = torch.empty([0])
+
+
+def _bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None,
+ impl='cuda'):
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda':
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain,
+ clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain,
+ clamp=clamp)
+
+
+# ----------------------------------------------------------------------------
+
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+
+# ----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ if x.ndim > 2 and x.stride()[1] == 1:
+ ctx.memory_format = torch.channels_last
+ else:
+ ctx.memory_format = torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not \
+ _null_tensor:
+ y = bias_act_cuda.bias_act_cuda(x, b, _null_tensor, _null_tensor,
+ _null_tensor, 0, dim, spec.cuda_idx, alpha,
+ gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ if x.ndim > 2 and x.stride()[1] == 1:
+ ctx.memory_format = torch.channels_last
+ else:
+ ctx.memory_format = torch.contiguous_format
+ dx = bias_act_cuda.bias_act_cuda(dy, b, x, y, _null_tensor, 1, dim,
+ spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (
+ ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = bias_act_cuda.bias_act_cuda(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx,
+ alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+
+class FusedNonlinearity(nn.Module):
+ def __init__(self, nonlinearity, num_channels=None, lr_mul=1.0, alpha=None, impl='cuda', gain=None):
+ super().__init__()
+ if num_channels is not None:
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.nonlinearity = nonlinearity
+ self.gain = gain
+ self.alpha = alpha
+ self.lr_mul = lr_mul
+ self.impl = impl
+
+ def forward(self, x):
+ bias = self.bias.type_as(x) * self.lr_mul if self.bias is not None else None
+ return _bias_act(
+ x, b=bias, dim=1, act=self.nonlinearity,
+ alpha=self.alpha, gain=self.gain, clamp=None, impl=self.impl
+ )
+
+ def __repr__(self):
+ mod_str = f'{self.__class__.__name__}(type={self.nonlinearity}'
+ if self.gain is not None:
+ mod_str += f', gain={self.gain}'
+ if self.alpha is not None:
+ mod_str += f', alpha={self.alpha}'
+ if self.lr_mul != 1:
+ mod_str += f', lr_mul={self.lr_mul}'
+ mod_str += ')'
+ return mod_str
diff --git a/imaginaire/third_party/bias_act/setup.py b/imaginaire/third_party/bias_act/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..13607b6083bd59eba24c5f4ac48a34048b55f642
--- /dev/null
+++ b/imaginaire/third_party/bias_act/setup.py
@@ -0,0 +1,43 @@
+# flake8: noqa
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+import os
+
+
+cuda_version = os.getenv('CUDA_VERSION')
+print('CUDA_VERSION: {}'.format(cuda_version))
+
+nvcc_args = list()
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_50,code=sm_50')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_52,code=sm_52')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_60,code=sm_60')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_61,code=sm_61')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_70,code=sm_70')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_75,code=sm_75')
+if cuda_version is not None:
+ if cuda_version >= '11.0':
+ nvcc_args.append('-gencode')
+ nvcc_args.append('arch=compute_80,code=sm_80')
+nvcc_args.append('-Xcompiler')
+nvcc_args.append('-Wall')
+nvcc_args.append('-std=c++14')
+
+setup(
+ name='bias_act_cuda',
+ py_modules=['bias_act'],
+ ext_modules=[
+ CUDAExtension('bias_act_cuda', [
+ './src/bias_act_cuda.cc',
+ './src/bias_act_cuda_kernel.cu'
+ ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'],
+ 'nvcc': nvcc_args})
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ })
diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda.cc b/imaginaire/third_party/bias_act/src/bias_act_cuda.cc
new file mode 100644
index 0000000000000000000000000000000000000000..cf975dbe6784e89cfa056574da8780d1e5f5b97d
--- /dev/null
+++ b/imaginaire/third_party/bias_act/src/bias_act_cuda.cc
@@ -0,0 +1,103 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "bias_act_cuda.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act_cuda(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda_kernel", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act_cuda", &bias_act_cuda);
+}
+
+//------------------------------------------------------------------------
diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda.h b/imaginaire/third_party/bias_act/src/bias_act_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4
--- /dev/null
+++ b/imaginaire/third_party/bias_act/src/bias_act_cuda.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu b/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9adbb942b5ce5740a5527449995e1887cda12816
--- /dev/null
+++ b/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu
@@ -0,0 +1,173 @@
+// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
+//
+// NVIDIA CORPORATION and its licensors retain all intellectual property
+// and proprietary rights in and to this software, related documentation
+// and any modifications thereto. Any use, reproduction, disclosure or
+// distribution of this software and related documentation without an express
+// license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+#include
+#include "bias_act_cuda.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/imaginaire/third_party/channelnorm/channelnorm.py b/imaginaire/third_party/channelnorm/channelnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdd46711ca0bf2b6bb650112fa364100f6d4c927
--- /dev/null
+++ b/imaginaire/third_party/channelnorm/channelnorm.py
@@ -0,0 +1,39 @@
+# flake8: noqa
+from torch.autograd import Function, Variable
+from torch.nn.modules.module import Module
+import channelnorm_cuda
+
+
+class ChannelNormFunction(Function):
+ @staticmethod
+ def forward(ctx, input1, norm_deg=2):
+ assert input1.is_contiguous()
+ b, _, h, w = input1.size()
+ output = input1.new(b, 1, h, w).zero_()
+
+ channelnorm_cuda.forward(input1, output, norm_deg)
+ ctx.save_for_backward(input1, output)
+ ctx.norm_deg = norm_deg
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input1, output = ctx.saved_tensors
+
+ grad_input1 = Variable(input1.new(input1.size()).zero_())
+
+ channelnorm_cuda.backward(input1, output, grad_output.data,
+ grad_input1.data, ctx.norm_deg)
+
+ return grad_input1, None
+
+
+class ChannelNorm(Module):
+
+ def __init__(self, norm_deg=2):
+ super(ChannelNorm, self).__init__()
+ self.norm_deg = norm_deg
+
+ def forward(self, input1):
+ return ChannelNormFunction.apply(input1, self.norm_deg)
diff --git a/imaginaire/third_party/channelnorm/setup.py b/imaginaire/third_party/channelnorm/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..8503ad2b254915bbbab391eb31baf0dcdc9a6bd1
--- /dev/null
+++ b/imaginaire/third_party/channelnorm/setup.py
@@ -0,0 +1,43 @@
+# flake8: noqa
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+import os
+
+
+cuda_version = os.getenv('CUDA_VERSION')
+print('CUDA_VERSION: {}'.format(cuda_version))
+
+nvcc_args = list()
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_50,code=sm_50')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_52,code=sm_52')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_60,code=sm_60')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_61,code=sm_61')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_70,code=sm_70')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_75,code=sm_75')
+if cuda_version is not None:
+ if cuda_version >= '11.0':
+ nvcc_args.append('-gencode')
+ nvcc_args.append('arch=compute_80,code=sm_80')
+nvcc_args.append('-Xcompiler')
+nvcc_args.append('-Wall')
+nvcc_args.append('-std=c++14')
+
+setup(
+ name='channelnorm_cuda',
+ py_modules=['channelnorm'],
+ ext_modules=[
+ CUDAExtension('channelnorm_cuda', [
+ './src/channelnorm_cuda.cc',
+ './src/channelnorm_kernel.cu'
+ ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'],
+ 'nvcc': nvcc_args})
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ })
diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc b/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc
new file mode 100644
index 0000000000000000000000000000000000000000..69d82eb184e97b2eefa9810ad156d1104cf84745
--- /dev/null
+++ b/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc
@@ -0,0 +1,31 @@
+#include
+#include
+
+#include "channelnorm_kernel.cuh"
+
+int channelnorm_cuda_forward(
+ at::Tensor& input1,
+ at::Tensor& output,
+ int norm_deg) {
+
+ channelnorm_kernel_forward(input1, output, norm_deg);
+ return 1;
+}
+
+
+int channelnorm_cuda_backward(
+ at::Tensor& input1,
+ at::Tensor& output,
+ at::Tensor& gradOutput,
+ at::Tensor& gradInput1,
+ int norm_deg) {
+
+ channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg);
+ return 1;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)");
+ m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)");
+}
+
diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..99ace6855a61373443a6ddff7c7858eb474e9e48
--- /dev/null
+++ b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu
@@ -0,0 +1,177 @@
+#include
+#include
+#include
+
+#include "channelnorm_kernel.cuh"
+
+#define CUDA_NUM_THREADS 512
+
+#define DIM0(TENSOR) ((TENSOR).x)
+#define DIM1(TENSOR) ((TENSOR).y)
+#define DIM2(TENSOR) ((TENSOR).z)
+#define DIM3(TENSOR) ((TENSOR).w)
+
+#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
+
+using at::Half;
+
+template
+__global__ void kernel_channelnorm_update_output(
+ const int n,
+ const scalar_t* __restrict__ input1,
+ const long4 input1_size,
+ const long4 input1_stride,
+ scalar_t* __restrict__ output,
+ const long4 output_size,
+ const long4 output_stride,
+ int norm_deg) {
+
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (index >= n) {
+ return;
+ }
+
+ int dim_b = DIM0(output_size);
+ int dim_c = DIM1(output_size);
+ int dim_h = DIM2(output_size);
+ int dim_w = DIM3(output_size);
+ int dim_chw = dim_c * dim_h * dim_w;
+
+ int b = ( index / dim_chw ) % dim_b;
+ int y = ( index / dim_w ) % dim_h;
+ int x = ( index ) % dim_w;
+
+ int i1dim_c = DIM1(input1_size);
+ int i1dim_h = DIM2(input1_size);
+ int i1dim_w = DIM3(input1_size);
+ int i1dim_chw = i1dim_c * i1dim_h * i1dim_w;
+ int i1dim_hw = i1dim_h * i1dim_w;
+
+ float result = 0.0;
+
+ for (int c = 0; c < i1dim_c; ++c) {
+ int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x;
+ scalar_t val = input1[i1Index];
+ result += static_cast(val * val);
+ }
+ result = sqrt(result);
+ output[index] = static_cast(result);
+}
+
+
+template
+__global__ void kernel_channelnorm_backward_input1(
+ const int n,
+ const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+ const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride,
+ const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
+ scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride,
+ int norm_deg) {
+
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (index >= n) {
+ return;
+ }
+
+ float val = 0.0;
+
+ int dim_b = DIM0(gradInput_size);
+ int dim_c = DIM1(gradInput_size);
+ int dim_h = DIM2(gradInput_size);
+ int dim_w = DIM3(gradInput_size);
+ int dim_chw = dim_c * dim_h * dim_w;
+ int dim_hw = dim_h * dim_w;
+
+ int b = ( index / dim_chw ) % dim_b;
+ int y = ( index / dim_w ) % dim_h;
+ int x = ( index ) % dim_w;
+
+
+ int outIndex = b * dim_hw + y * dim_w + x;
+ val = static_cast(gradOutput[outIndex]) * static_cast(input1[index]) / (static_cast(output[outIndex])+1e-9);
+ gradInput[index] = static_cast(val);
+
+}
+
+void channelnorm_kernel_forward(
+ at::Tensor& input1,
+ at::Tensor& output,
+ int norm_deg) {
+
+ const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+ const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+ const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
+ const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
+
+ int n = output.numel();
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] {
+
+ kernel_channelnorm_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+ n,
+ input1.data(),
+ input1_size,
+ input1_stride,
+ output.data(),
+ output_size,
+ output_stride,
+ norm_deg);
+
+ }));
+
+ // TODO: ATen-equivalent check
+
+ // THCudaCheck(cudaGetLastError());
+}
+
+void channelnorm_kernel_backward(
+ at::Tensor& input1,
+ at::Tensor& output,
+ at::Tensor& gradOutput,
+ at::Tensor& gradInput1,
+ int norm_deg) {
+
+ const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+ const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+ const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
+ const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
+
+ const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
+ const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));
+
+ const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
+ const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));
+
+ int n = gradInput1.numel();
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] {
+
+ kernel_channelnorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+ n,
+ input1.data(),
+ input1_size,
+ input1_stride,
+ output.data(),
+ output_size,
+ output_stride,
+ gradOutput.data(),
+ gradOutput_size,
+ gradOutput_stride,
+ gradInput1.data(),
+ gradInput1_size,
+ gradInput1_stride,
+ norm_deg
+ );
+
+ }));
+
+ // TODO: Add ATen-equivalent check
+
+// THCudaCheck(cudaGetLastError());
+}
diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..3e6223f7fe60feb4bf9e4f66c3d849b84c89dcda
--- /dev/null
+++ b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh
@@ -0,0 +1,16 @@
+#pragma once
+
+#include
+
+void channelnorm_kernel_forward(
+ at::Tensor& input1,
+ at::Tensor& output,
+ int norm_deg);
+
+
+void channelnorm_kernel_backward(
+ at::Tensor& input1,
+ at::Tensor& output,
+ at::Tensor& gradOutput,
+ at::Tensor& gradInput1,
+ int norm_deg);
diff --git a/imaginaire/third_party/correlation/correlation.py b/imaginaire/third_party/correlation/correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..e47739dff7475c0f29bff32bc2dc9f097161d144
--- /dev/null
+++ b/imaginaire/third_party/correlation/correlation.py
@@ -0,0 +1,105 @@
+# flake8: noqa
+import torch
+from torch.nn.modules.module import Module
+from torch.autograd import Function
+import correlation_cuda
+
+
+class CorrelationFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ pad_size,
+ kernel_size,
+ max_displacement,
+ stride1,
+ stride2,
+ corr_multiply,
+ input1,
+ input2):
+ ctx.save_for_backward(input1, input2)
+ ctx.pad_size = pad_size
+ ctx.kernel_size = kernel_size
+ ctx.max_displacement = max_displacement
+ ctx.stride1 = stride1
+ ctx.stride2 = stride2
+ ctx.corr_multiply = corr_multiply
+
+ with torch.cuda.device_of(input1):
+ rbot1 = input1.new()
+ rbot2 = input2.new()
+ output = input1.new()
+
+ correlation_cuda.forward(
+ input1,
+ input2,
+ rbot1,
+ rbot2,
+ output,
+ ctx.pad_size,
+ ctx.kernel_size,
+ ctx.max_displacement,
+ ctx.stride1,
+ ctx.stride2,
+ ctx.corr_multiply)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input1, input2 = ctx.saved_tensors
+
+ with torch.cuda.device_of(input1):
+ rbot1 = input1.new()
+ rbot2 = input2.new()
+
+ grad_input1 = input1.new()
+ grad_input2 = input2.new()
+
+ correlation_cuda.backward(
+ input1,
+ input2,
+ rbot1,
+ rbot2,
+ grad_output,
+ grad_input1,
+ grad_input2,
+ ctx.pad_size,
+ ctx.kernel_size,
+ ctx.max_displacement,
+ ctx.stride1,
+ ctx.stride2,
+ ctx.corr_multiply)
+
+ return grad_input1, grad_input2
+
+class Correlation(Module):
+ def __init__(
+ self,
+ pad_size=0,
+ kernel_size=0,
+ max_displacement=0,
+ stride1=1,
+ stride2=2,
+ corr_multiply=1):
+ super(Correlation, self).__init__()
+ self.pad_size = pad_size
+ self.kernel_size = kernel_size
+ self.max_displacement = max_displacement
+ self.stride1 = stride1
+ self.stride2 = stride2
+ self.corr_multiply = corr_multiply
+
+ def forward(self, input1, input2):
+
+ result = CorrelationFunction.apply(
+ self.pad_size,
+ self.kernel_size,
+ self.max_displacement,
+ self.stride1,
+ self.stride2,
+ self.corr_multiply,
+ input1,
+ input2)
+
+ return result
diff --git a/imaginaire/third_party/correlation/setup.py b/imaginaire/third_party/correlation/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c02aacc53102a5ef534db1a7cd69c546004f268
--- /dev/null
+++ b/imaginaire/third_party/correlation/setup.py
@@ -0,0 +1,43 @@
+# flake8: noqa
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+import os
+
+
+cuda_version = os.getenv('CUDA_VERSION')
+print('CUDA_VERSION: {}'.format(cuda_version))
+
+nvcc_args = list()
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_50,code=sm_50')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_52,code=sm_52')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_60,code=sm_60')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_61,code=sm_61')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_70,code=sm_70')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_75,code=sm_75')
+if cuda_version is not None:
+ if cuda_version >= '11.0':
+ nvcc_args.append('-gencode')
+ nvcc_args.append('arch=compute_80,code=sm_80')
+nvcc_args.append('-Xcompiler')
+nvcc_args.append('-Wall')
+nvcc_args.append('-std=c++14')
+
+setup(
+ name='correlation_cuda',
+ py_modules=['correlation'],
+ ext_modules=[
+ CUDAExtension('correlation_cuda', [
+ './src/correlation_cuda.cc',
+ './src/correlation_cuda_kernel.cu'
+ ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'],
+ 'nvcc': nvcc_args})
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ })
diff --git a/imaginaire/third_party/correlation/src/correlation_cuda.cc b/imaginaire/third_party/correlation/src/correlation_cuda.cc
new file mode 100644
index 0000000000000000000000000000000000000000..feccd65295fa90a22564b08fc80464a76361a1aa
--- /dev/null
+++ b/imaginaire/third_party/correlation/src/correlation_cuda.cc
@@ -0,0 +1,173 @@
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "correlation_cuda_kernel.cuh"
+
+int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
+ int pad_size,
+ int kernel_size,
+ int max_displacement,
+ int stride1,
+ int stride2,
+ int corr_type_multiply)
+{
+
+ int batchSize = input1.size(0);
+
+ int nInputChannels = input1.size(1);
+ int inputHeight = input1.size(2);
+ int inputWidth = input1.size(3);
+
+ int kernel_radius = (kernel_size - 1) / 2;
+ int border_radius = kernel_radius + max_displacement;
+
+ int paddedInputHeight = inputHeight + 2 * pad_size;
+ int paddedInputWidth = inputWidth + 2 * pad_size;
+
+ int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
+
+ int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1));
+ int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1));
+
+ rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+ rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+ output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
+
+ rInput1.fill_(0);
+ rInput2.fill_(0);
+ output.fill_(0);
+
+ int success = correlation_forward_cuda_kernel(
+ output,
+ output.size(0),
+ output.size(1),
+ output.size(2),
+ output.size(3),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ output.stride(3),
+ input1,
+ input1.size(1),
+ input1.size(2),
+ input1.size(3),
+ input1.stride(0),
+ input1.stride(1),
+ input1.stride(2),
+ input1.stride(3),
+ input2,
+ input2.size(1),
+ input2.stride(0),
+ input2.stride(1),
+ input2.stride(2),
+ input2.stride(3),
+ rInput1,
+ rInput2,
+ pad_size,
+ kernel_size,
+ max_displacement,
+ stride1,
+ stride2,
+ corr_type_multiply,
+ at::cuda::getCurrentCUDAStream()
+ //at::globalContext().getCurrentCUDAStream()
+ );
+
+ //check for errors
+ if (!success) {
+ AT_ERROR("CUDA call failed");
+ }
+
+ return 1;
+
+}
+
+int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
+ at::Tensor& gradInput1, at::Tensor& gradInput2,
+ int pad_size,
+ int kernel_size,
+ int max_displacement,
+ int stride1,
+ int stride2,
+ int corr_type_multiply)
+{
+
+ int batchSize = input1.size(0);
+ int nInputChannels = input1.size(1);
+ int paddedInputHeight = input1.size(2)+ 2 * pad_size;
+ int paddedInputWidth = input1.size(3)+ 2 * pad_size;
+
+ int height = input1.size(2);
+ int width = input1.size(3);
+
+ rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+ rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
+ gradInput1.resize_({batchSize, nInputChannels, height, width});
+ gradInput2.resize_({batchSize, nInputChannels, height, width});
+
+ rInput1.fill_(0);
+ rInput2.fill_(0);
+ gradInput1.fill_(0);
+ gradInput2.fill_(0);
+
+ int success = correlation_backward_cuda_kernel(gradOutput,
+ gradOutput.size(0),
+ gradOutput.size(1),
+ gradOutput.size(2),
+ gradOutput.size(3),
+ gradOutput.stride(0),
+ gradOutput.stride(1),
+ gradOutput.stride(2),
+ gradOutput.stride(3),
+ input1,
+ input1.size(1),
+ input1.size(2),
+ input1.size(3),
+ input1.stride(0),
+ input1.stride(1),
+ input1.stride(2),
+ input1.stride(3),
+ input2,
+ input2.stride(0),
+ input2.stride(1),
+ input2.stride(2),
+ input2.stride(3),
+ gradInput1,
+ gradInput1.stride(0),
+ gradInput1.stride(1),
+ gradInput1.stride(2),
+ gradInput1.stride(3),
+ gradInput2,
+ gradInput2.size(1),
+ gradInput2.stride(0),
+ gradInput2.stride(1),
+ gradInput2.stride(2),
+ gradInput2.stride(3),
+ rInput1,
+ rInput2,
+ pad_size,
+ kernel_size,
+ max_displacement,
+ stride1,
+ stride2,
+ corr_type_multiply,
+ at::cuda::getCurrentCUDAStream()
+ //at::globalContext().getCurrentCUDAStream()
+ );
+
+ if (!success) {
+ AT_ERROR("CUDA call failed");
+ }
+
+ return 1;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
+ m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
+}
+
diff --git a/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..eaf86fc129137d055de7400916567c6669b45c19
--- /dev/null
+++ b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu
@@ -0,0 +1,564 @@
+#include
+
+#include "correlation_cuda_kernel.cuh"
+
+#define CUDA_NUM_THREADS 1024
+#define THREADS_PER_BLOCK 32
+#define FULL_MASK 0xffffffff
+
+#include
+#include
+#include
+#include
+
+using at::Half;
+
+template
+__forceinline__ __device__ scalar_t warpReduceSum(scalar_t val) {
+ for (int offset = 16; offset > 0; offset /= 2)
+ val += __shfl_down_sync(FULL_MASK, val, offset);
+ return val;
+}
+
+template
+__forceinline__ __device__ scalar_t blockReduceSum(scalar_t val) {
+
+ static __shared__ scalar_t shared[32];
+ int lane = threadIdx.x % warpSize;
+ int wid = threadIdx.x / warpSize;
+
+ val = warpReduceSum(val);
+
+ if (lane == 0)
+ shared[wid] = val;
+
+ __syncthreads();
+
+ val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
+
+ if (wid == 0)
+ val = warpReduceSum(val);
+
+ return val;
+}
+
+
+template
+__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)
+{
+
+ // n (batch size), c (num of channels), y (height), x (width)
+ int n = blockIdx.x;
+ int y = blockIdx.y;
+ int x = blockIdx.z;
+
+ int ch_off = threadIdx.x;
+ scalar_t value;
+
+ int dimcyx = channels * height * width;
+ int dimyx = height * width;
+
+ int p_dimx = (width + 2 * pad_size);
+ int p_dimy = (height + 2 * pad_size);
+ int p_dimyxc = channels * p_dimy * p_dimx;
+ int p_dimxc = p_dimx * channels;
+
+ for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {
+ value = input[n * dimcyx + c * dimyx + y * width + x];
+ rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;
+ }
+}
+
+
+template
+__global__ void correlation_forward(scalar_t* __restrict__ output, const int nOutputChannels,
+ const int outputHeight, const int outputWidth, const scalar_t* __restrict__ rInput1,
+ const int nInputChannels, const int inputHeight, const int inputWidth,
+ const scalar_t* __restrict__ rInput2, const int pad_size, const int kernel_size,
+ const int max_displacement, const int stride1, const int stride2) {
+
+ int32_t pInputWidth = inputWidth + 2 * pad_size;
+ int32_t pInputHeight = inputHeight + 2 * pad_size;
+
+ int32_t kernel_rad = (kernel_size - 1) / 2;
+
+ int32_t displacement_rad = max_displacement / stride2;
+
+ int32_t displacement_size = 2 * displacement_rad + 1;
+
+ int32_t n = blockIdx.x;
+ int32_t y1 = blockIdx.y * stride1 + max_displacement;
+ int32_t x1 = blockIdx.z * stride1 + max_displacement;
+ int32_t c = threadIdx.x;
+
+ int32_t pdimyxc = pInputHeight * pInputWidth * nInputChannels;
+
+ int32_t pdimxc = pInputWidth * nInputChannels;
+
+ int32_t pdimc = nInputChannels;
+
+ int32_t tdimcyx = nOutputChannels * outputHeight * outputWidth;
+ int32_t tdimyx = outputHeight * outputWidth;
+ int32_t tdimx = outputWidth;
+
+ int32_t nelems = kernel_size * kernel_size * pdimc;
+
+ // element-wise product along channel axis
+ for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
+ for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
+ int x2 = x1 + ti * stride2;
+ int y2 = y1 + tj * stride2;
+
+ float acc0 = 0.0f;
+
+ for (int j = -kernel_rad; j <= kernel_rad; ++j) {
+ for (int i = -kernel_rad; i <= kernel_rad; ++i) {
+ // THREADS_PER_BLOCK
+ #pragma unroll
+ for (int ch = c; ch < pdimc; ch += blockDim.x) {
+
+ int indx1 = n * pdimyxc + (y1 + j) * pdimxc
+ + (x1 + i) * pdimc + ch;
+ int indx2 = n * pdimyxc + (y2 + j) * pdimxc
+ + (x2 + i) * pdimc + ch;
+ acc0 += static_cast(rInput1[indx1] * rInput2[indx2]);
+ }
+ }
+ }
+
+ if (blockDim.x == warpSize) {
+ __syncwarp();
+ acc0 = warpReduceSum(acc0);
+ } else {
+ __syncthreads();
+ acc0 = blockReduceSum(acc0);
+ }
+
+ if (threadIdx.x == 0) {
+
+ int tc = (tj + displacement_rad) * displacement_size
+ + (ti + displacement_rad);
+ const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx
+ + blockIdx.z;
+ output[tindx] = static_cast(acc0 / nelems);
+ }
+ }
+ }
+}
+
+
+template
+__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth,
+ const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
+ const scalar_t* __restrict__ rInput2,
+ int pad_size,
+ int kernel_size,
+ int max_displacement,
+ int stride1,
+ int stride2)
+ {
+ // n (batch size), c (num of channels), y (height), x (width)
+
+ int n = item;
+ int y = blockIdx.x * stride1 + pad_size;
+ int x = blockIdx.y * stride1 + pad_size;
+ int c = blockIdx.z;
+ int tch_off = threadIdx.x;
+
+ int kernel_rad = (kernel_size - 1) / 2;
+ int displacement_rad = max_displacement / stride2;
+ int displacement_size = 2 * displacement_rad + 1;
+
+ int xmin = (x - kernel_rad - max_displacement) / stride1;
+ int ymin = (y - kernel_rad - max_displacement) / stride1;
+
+ int xmax = (x + kernel_rad - max_displacement) / stride1;
+ int ymax = (y + kernel_rad - max_displacement) / stride1;
+
+ if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
+ // assumes gradInput1 is pre-allocated and zero filled
+ return;
+ }
+
+ if (xmin > xmax || ymin > ymax) {
+ // assumes gradInput1 is pre-allocated and zero filled
+ return;
+ }
+
+ xmin = max(0,xmin);
+ xmax = min(outputWidth-1,xmax);
+
+ ymin = max(0,ymin);
+ ymax = min(outputHeight-1,ymax);
+
+ int pInputWidth = inputWidth + 2 * pad_size;
+ int pInputHeight = inputHeight + 2 * pad_size;
+
+ int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
+ int pdimxc = pInputWidth * nInputChannels;
+ int pdimc = nInputChannels;
+
+ int tdimcyx = nOutputChannels * outputHeight * outputWidth;
+ int tdimyx = outputHeight * outputWidth;
+ int tdimx = outputWidth;
+
+ int odimcyx = nInputChannels * inputHeight* inputWidth;
+ int odimyx = inputHeight * inputWidth;
+ int odimx = inputWidth;
+
+ scalar_t nelems = kernel_size * kernel_size * nInputChannels;
+
+ __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
+ prod_sum[tch_off] = 0;
+
+ for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
+
+ int i2 = (tc % displacement_size - displacement_rad) * stride2;
+ int j2 = (tc / displacement_size - displacement_rad) * stride2;
+
+ int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;
+
+ scalar_t val2 = rInput2[indx2];
+
+ for (int j = ymin; j <= ymax; ++j) {
+ for (int i = xmin; i <= xmax; ++i) {
+ int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
+ prod_sum[tch_off] += gradOutput[tindx] * val2;
+ }
+ }
+ }
+ __syncthreads();
+
+ if(tch_off == 0) {
+ scalar_t reduce_sum = 0;
+ for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
+ reduce_sum += prod_sum[idx];
+ }
+ const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
+ gradInput1[indx1] = reduce_sum / nelems;
+ }
+
+}
+
+template
+__global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth,
+ const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
+ const scalar_t* __restrict__ rInput1,
+ int pad_size,
+ int kernel_size,
+ int max_displacement,
+ int stride1,
+ int stride2)
+{
+ // n (batch size), c (num of channels), y (height), x (width)
+
+ int n = item;
+ int y = blockIdx.x * stride1 + pad_size;
+ int x = blockIdx.y * stride1 + pad_size;
+ int c = blockIdx.z;
+
+ int tch_off = threadIdx.x;
+
+ int kernel_rad = (kernel_size - 1) / 2;
+ int displacement_rad = max_displacement / stride2;
+ int displacement_size = 2 * displacement_rad + 1;
+
+ int pInputWidth = inputWidth + 2 * pad_size;
+ int pInputHeight = inputHeight + 2 * pad_size;
+
+ int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
+ int pdimxc = pInputWidth * nInputChannels;
+ int pdimc = nInputChannels;
+
+ int tdimcyx = nOutputChannels * outputHeight * outputWidth;
+ int tdimyx = outputHeight * outputWidth;
+ int tdimx = outputWidth;
+
+ int odimcyx = nInputChannels * inputHeight* inputWidth;
+ int odimyx = inputHeight * inputWidth;
+ int odimx = inputWidth;
+
+ scalar_t nelems = kernel_size * kernel_size * nInputChannels;
+
+ __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
+ prod_sum[tch_off] = 0;
+
+ for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
+ int i2 = (tc % displacement_size - displacement_rad) * stride2;
+ int j2 = (tc / displacement_size - displacement_rad) * stride2;
+
+ int xmin = (x - kernel_rad - max_displacement - i2) / stride1;
+ int ymin = (y - kernel_rad - max_displacement - j2) / stride1;
+
+ int xmax = (x + kernel_rad - max_displacement - i2) / stride1;
+ int ymax = (y + kernel_rad - max_displacement - j2) / stride1;
+
+ if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
+ // assumes gradInput2 is pre-allocated and zero filled
+ continue;
+ }
+
+ if (xmin > xmax || ymin > ymax) {
+ // assumes gradInput2 is pre-allocated and zero filled
+ continue;
+ }
+
+ xmin = max(0,xmin);
+ xmax = min(outputWidth-1,xmax);
+
+ ymin = max(0,ymin);
+ ymax = min(outputHeight-1,ymax);
+
+ int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;
+ scalar_t val1 = rInput1[indx1];
+
+ for (int j = ymin; j <= ymax; ++j) {
+ for (int i = xmin; i <= xmax; ++i) {
+ int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
+ prod_sum[tch_off] += gradOutput[tindx] * val1;
+ }
+ }
+ }
+
+ __syncthreads();
+
+ if(tch_off == 0) {
+ scalar_t reduce_sum = 0;
+ for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
+ reduce_sum += prod_sum[idx];
+ }
+ const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
+ gradInput2[indx2] = reduce_sum / nelems;
+ }
+
+}
+
+int correlation_forward_cuda_kernel(at::Tensor& output,
+ int ob,
+ int oc,
+ int oh,
+ int ow,
+ int osb,
+ int osc,
+ int osh,
+ int osw,
+
+ at::Tensor& input1,
+ int ic,
+ int ih,
+ int iw,
+ int isb,
+ int isc,
+ int ish,
+ int isw,
+
+ at::Tensor& input2,
+ int gc,
+ int gsb,
+ int gsc,
+ int gsh,
+ int gsw,
+
+ at::Tensor& rInput1,
+ at::Tensor& rInput2,
+ int pad_size,
+ int kernel_size,
+ int max_displacement,
+ int stride1,
+ int stride2,
+ int corr_type_multiply,
+ cudaStream_t stream)
+{
+
+ int batchSize = ob;
+
+ int nInputChannels = ic;
+ int inputWidth = iw;
+ int inputHeight = ih;
+
+ int nOutputChannels = oc;
+ int outputWidth = ow;
+ int outputHeight = oh;
+
+ dim3 blocks_grid(batchSize, inputHeight, inputWidth);
+ dim3 threads_block(THREADS_PER_BLOCK);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] {
+
+ channels_first<<>>(
+ input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size);
+
+ }));
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] {
+
+ channels_first<<>> (
+ input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size);
+
+ }));
+
+ dim3 threadsPerBlock(THREADS_PER_BLOCK);
+ dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] {
+
+ correlation_forward<<>>
+ (output.data(), nOutputChannels, outputHeight, outputWidth,
+ rInput1.data(), nInputChannels, inputHeight, inputWidth,
+ rInput2.data(),
+ pad_size,
+ kernel_size,
+ max_displacement,
+ stride1,
+ stride2);
+
+ }));
+
+ cudaError_t err = cudaGetLastError();
+
+
+ // check for errors
+ if (err != cudaSuccess) {
+ printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err));
+ return 0;
+ }
+
+ return 1;
+}
+
+
+int correlation_backward_cuda_kernel(
+ at::Tensor& gradOutput,
+ int gob,
+ int goc,
+ int goh,
+ int gow,
+ int gosb,
+ int gosc,
+ int gosh,
+ int gosw,
+
+ at::Tensor& input1,
+ int ic,
+ int ih,
+ int iw,
+ int isb,
+ int isc,
+ int ish,
+ int isw,
+
+ at::Tensor& input2,
+ int gsb,
+ int gsc,
+ int gsh,
+ int gsw,
+
+ at::Tensor& gradInput1,
+ int gisb,
+ int gisc,
+ int gish,
+ int gisw,
+
+ at::Tensor& gradInput2,
+ int ggc,
+ int ggsb,
+ int ggsc,
+ int ggsh,
+ int ggsw,
+
+ at::Tensor& rInput1,
+ at::Tensor& rInput2,
+ int pad_size,
+ int kernel_size,
+ int max_displacement,
+ int stride1,
+ int stride2,
+ int corr_type_multiply,
+ cudaStream_t stream)
+{
+
+ int batchSize = gob;
+ int num = batchSize;
+
+ int nInputChannels = ic;
+ int inputWidth = iw;
+ int inputHeight = ih;
+
+ int nOutputChannels = goc;
+ int outputWidth = gow;
+ int outputHeight = goh;
+
+ dim3 blocks_grid(batchSize, inputHeight, inputWidth);
+ dim3 threads_block(THREADS_PER_BLOCK);
+
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] {
+
+ channels_first<<>>(
+ input1.data(),
+ rInput1.data(),
+ nInputChannels,
+ inputHeight,
+ inputWidth,
+ pad_size
+ );
+ }));
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
+
+ channels_first<<>>(
+ input2.data(),
+ rInput2.data(),
+ nInputChannels,
+ inputHeight,
+ inputWidth,
+ pad_size
+ );
+ }));
+
+ dim3 threadsPerBlock(THREADS_PER_BLOCK);
+ dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);
+
+ for (int n = 0; n < num; ++n) {
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
+
+
+ correlation_backward_input1<<>> (
+ n, gradInput1.data(), nInputChannels, inputHeight, inputWidth,
+ gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
+ rInput2.data(),
+ pad_size,
+ kernel_size,
+ max_displacement,
+ stride1,
+ stride2);
+ }));
+ }
+
+ for(int n = 0; n < batchSize; n++) {
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] {
+
+ correlation_backward_input2<<>>(
+ n, gradInput2.data(), nInputChannels, inputHeight, inputWidth,
+ gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
+ rInput1.data(),
+ pad_size,
+ kernel_size,
+ max_displacement,
+ stride1,
+ stride2);
+
+ }));
+ }
+
+ // check for errors
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess) {
+ printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err));
+ return 0;
+ }
+
+ return 1;
+}
diff --git a/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..1586d3af6bc184bfea8482a991a6625f865f02b3
--- /dev/null
+++ b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh
@@ -0,0 +1,91 @@
+#pragma once
+
+#include
+#include
+#include
+
+int correlation_forward_cuda_kernel(at::Tensor& output,
+ int ob,
+ int oc,
+ int oh,
+ int ow,
+ int osb,
+ int osc,
+ int osh,
+ int osw,
+
+ at::Tensor& input1,
+ int ic,
+ int ih,
+ int iw,
+ int isb,
+ int isc,
+ int ish,
+ int isw,
+
+ at::Tensor& input2,
+ int gc,
+ int gsb,
+ int gsc,
+ int gsh,
+ int gsw,
+
+ at::Tensor& rInput1,
+ at::Tensor& rInput2,
+ int pad_size,
+ int kernel_size,
+ int max_displacement,
+ int stride1,
+ int stride2,
+ int corr_type_multiply,
+ cudaStream_t stream);
+
+
+int correlation_backward_cuda_kernel(
+ at::Tensor& gradOutput,
+ int gob,
+ int goc,
+ int goh,
+ int gow,
+ int gosb,
+ int gosc,
+ int gosh,
+ int gosw,
+
+ at::Tensor& input1,
+ int ic,
+ int ih,
+ int iw,
+ int isb,
+ int isc,
+ int ish,
+ int isw,
+
+ at::Tensor& input2,
+ int gsb,
+ int gsc,
+ int gsh,
+ int gsw,
+
+ at::Tensor& gradInput1,
+ int gisb,
+ int gisc,
+ int gish,
+ int gisw,
+
+ at::Tensor& gradInput2,
+ int ggc,
+ int ggsb,
+ int ggsc,
+ int ggsh,
+ int ggsw,
+
+ at::Tensor& rInput1,
+ at::Tensor& rInput2,
+ int pad_size,
+ int kernel_size,
+ int max_displacement,
+ int stride1,
+ int stride2,
+ int corr_type_multiply,
+ cudaStream_t stream);
diff --git a/imaginaire/third_party/flow_net/__init__.py b/imaginaire/third_party/flow_net/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imaginaire/third_party/flow_net/flow_net.py b/imaginaire/third_party/flow_net/flow_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..41759c50fa1389b6fbe2e5db725a4b71ff3f2342
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flow_net.py
@@ -0,0 +1,89 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import types
+from imaginaire.third_party.flow_net.flownet2 import models as \
+ flownet2_models
+from imaginaire.third_party.flow_net.flownet2.utils import tools \
+ as flownet2_tools
+from imaginaire.model_utils.fs_vid2vid import resample
+from imaginaire.utils.io import get_checkpoint
+
+
+class FlowNet(nn.Module):
+ def __init__(self, pretrained=True, fp16=False):
+ super().__init__()
+ flownet2_args = types.SimpleNamespace()
+ setattr(flownet2_args, 'fp16', fp16)
+ setattr(flownet2_args, 'rgb_max', 1.0)
+ if fp16:
+ print('FlowNet2 is running in fp16 mode.')
+ self.flowNet = flownet2_tools.module_to_dict(flownet2_models)[
+ 'FlowNet2'](flownet2_args).to('cuda')
+ if pretrained:
+ flownet2_path = get_checkpoint('flownet2.pth.tar',
+ '1hF8vS6YeHkx3j2pfCeQqqZGwA_PJq_Da')
+ checkpoint = torch.load(flownet2_path,
+ map_location=torch.device('cpu'))
+ self.flowNet.load_state_dict(checkpoint['state_dict'])
+ self.flowNet.eval()
+
+ def forward(self, input_A, input_B):
+ size = input_A.size()
+ assert(len(size) == 4 or len(size) == 5 or len(size) == 6)
+ if len(size) >= 5:
+ if len(size) == 5:
+ b, n, c, h, w = size
+ else:
+ b, t, n, c, h, w = size
+ input_A = input_A.contiguous().view(-1, c, h, w)
+ input_B = input_B.contiguous().view(-1, c, h, w)
+ flow, conf = self.compute_flow_and_conf(input_A, input_B)
+ if len(size) == 5:
+ return flow.view(b, n, 2, h, w), conf.view(b, n, 1, h, w)
+ else:
+ return flow.view(b, t, n, 2, h, w), conf.view(b, t, n, 1, h, w)
+ else:
+ return self.compute_flow_and_conf(input_A, input_B)
+
+ def compute_flow_and_conf(self, im1, im2):
+ assert(im1.size()[1] == 3)
+ assert(im1.size() == im2.size())
+ old_h, old_w = im1.size()[2], im1.size()[3]
+ new_h, new_w = old_h // 64 * 64, old_w // 64 * 64
+ if old_h != new_h:
+ im1 = F.interpolate(im1, size=(new_h, new_w), mode='bilinear',
+ align_corners=False)
+ im2 = F.interpolate(im2, size=(new_h, new_w), mode='bilinear',
+ align_corners=False)
+ data1 = torch.cat([im1.unsqueeze(2), im2.unsqueeze(2)], dim=2)
+ with torch.no_grad():
+ flow1 = self.flowNet(data1)
+ # img_diff = torch.sum(abs(im1 - resample(im2, flow1)),
+ # dim=1, keepdim=True)
+ # conf = torch.clamp(1 - img_diff, 0, 1)
+
+ conf = (self.norm(im1 - resample(im2, flow1)) < 0.02).float()
+
+ # data2 = torch.cat([im2.unsqueeze(2), im1.unsqueeze(2)], dim=2)
+ # with torch.no_grad():
+ # flow2 = self.flowNet(data2)
+ # warped_flow2 = resample(flow2, flow1)
+ # flow_sum = self.norm(flow1 + warped_flow2)
+ # disocc = flow_sum > (0.05 * (self.norm(flow1) +
+ # self.norm(warped_flow2)) + 0.5)
+ # conf = 1 - disocc.float()
+
+ if old_h != new_h:
+ flow1 = F.interpolate(flow1, size=(old_h, old_w), mode='bilinear',
+ align_corners=False) * old_h / new_h
+ conf = F.interpolate(conf, size=(old_h, old_w), mode='bilinear',
+ align_corners=False)
+ return flow1, conf
+
+ def norm(self, t):
+ return torch.sum(t * t, dim=1, keepdim=True)
diff --git a/imaginaire/third_party/flow_net/flownet2/models.py b/imaginaire/third_party/flow_net/flownet2/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0066464a01942c85909a7e5ddbc97e39f244623
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/models.py
@@ -0,0 +1,474 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+from torch.nn import init
+import torch.nn as nn
+import resample2d
+import channelnorm
+import numpy as np
+from imaginaire.third_party.flow_net.flownet2.networks import flownet_c
+from imaginaire.third_party.flow_net.flownet2.networks import flownet_s
+from imaginaire.third_party.flow_net.flownet2.networks import flownet_sd
+from imaginaire.third_party.flow_net.flownet2.networks import flownet_fusion
+from imaginaire.third_party.flow_net.flownet2.networks.submodules import \
+ tofp16, tofp32
+'Parameter count = 162,518,834'
+
+
+class FlowNet2(nn.Module):
+ def __init__(self, args, use_batch_norm=False, div_flow=20.):
+ super(FlowNet2, self).__init__()
+ self.batch_norm = use_batch_norm
+ self.div_flow = div_flow
+ self.rgb_max = args.rgb_max
+ self.args = args
+ self.channelnorm = channelnorm.ChannelNorm()
+ # First Block (FlowNetC)
+ self.flownetc = flownet_c.FlowNetC(
+ args, use_batch_norm=self.batch_norm)
+ self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+ align_corners=False)
+ self.args = args
+ # if args.fp16:
+ # self.resample1 = nn.Sequential(
+ # tofp32(), resample2d.Resample2d(), tofp16())
+ # else:
+ self.resample1 = resample2d.Resample2d()
+ # Block (FlowNetS1)
+ self.flownets_1 = flownet_s.FlowNetS(
+ args, use_batch_norm=self.batch_norm)
+ self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear',
+ align_corners=False)
+ # if args.fp16:
+ # self.resample2 = nn.Sequential(
+ # tofp32(), resample2d.Resample2d(), tofp16())
+ # else:
+ self.resample2 = resample2d.Resample2d()
+ # Block (FlowNetS2)
+ self.flownets_2 = flownet_s.FlowNetS(
+ args, use_batch_norm=self.batch_norm)
+ # Block (FlowNetSD)
+ self.flownets_d = flownet_sd.FlowNetSD(
+ args, use_batch_norm=self.batch_norm)
+ self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest')
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest')
+ # if args.fp16:
+ # self.resample3 = nn.Sequential(
+ # tofp32(), resample2d.Resample2d(), tofp16())
+ # else:
+ self.resample3 = resample2d.Resample2d()
+ # if args.fp16:
+ # self.resample4 = nn.Sequential(
+ # tofp32(), resample2d.Resample2d(), tofp16())
+ # else:
+ self.resample4 = resample2d.Resample2d()
+ # Block (FLowNetFusion)
+ self.flownetfusion = flownet_fusion.FlowNetFusion(
+ args, use_batch_norm=self.batch_norm)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if m.bias is not None:
+ init.uniform_(m.bias)
+ init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.ConvTranspose2d):
+ if m.bias is not None:
+ init.uniform_(m.bias)
+ init.xavier_uniform_(m.weight)
+
+ def init_deconv_bilinear(self, weight):
+ f_shape = weight.size()
+ height, width = f_shape[-2], f_shape[-1]
+ f = np.ceil(width / 2.0)
+ c = (2 * f - 1 - f % 2) / (2.0 * f)
+ bilinear = np.zeros([height, width])
+ for x in range(width):
+ for y in range(height):
+ value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
+ bilinear[x, y] = value
+ min_dim = min(f_shape[0], f_shape[1])
+ weight.data.fill_(0.)
+ for i in range(min_dim):
+ weight.data[i, i, :, :] = torch.from_numpy(bilinear)
+ return
+
+ def forward(self, inputs):
+ rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+ dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+ x = (inputs - rgb_mean) / self.rgb_max
+ x1 = x[:, :, 0, :, :]
+ x2 = x[:, :, 1, :, :]
+ x = torch.cat((x1, x2), dim=1)
+ # flownetc
+ flownetc_flow2 = self.flownetc(x)[0]
+ flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow)
+ # warp img1 to img0;
+ # magnitude of diff between img0 and and warped_img1,
+ if self.args.fp16:
+ resampled_img1 = self.resample1(tofp32()(x[:, 3:, :, :]),
+ flownetc_flow)
+ resampled_img1 = tofp16()(resampled_img1)
+ else:
+ resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow)
+ diff_img0 = x[:, :3, :, :] - resampled_img1
+ norm_diff_img0 = self.channelnorm(diff_img0)
+ # concat img0, img1, img1->img0, flow, diff-mag ;
+ concat1 = torch.cat(
+ (x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0),
+ dim=1)
+ # flownets1
+ flownets1_flow2 = self.flownets_1(concat1)[0]
+ flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow)
+ # warp img1 to img0 using flownets1;
+ # magnitude of diff between img0 and and warped_img1
+ if self.args.fp16:
+ resampled_img1 = self.resample2(tofp32()(x[:, 3:, :, :]),
+ flownets1_flow)
+ resampled_img1 = tofp16()(resampled_img1)
+ else:
+ resampled_img1 = self.resample2(x[:, 3:, :, :], flownets1_flow)
+ diff_img0 = x[:, :3, :, :] - resampled_img1
+ norm_diff_img0 = self.channelnorm(diff_img0)
+ # concat img0, img1, img1->img0, flow, diff-mag
+ concat2 = torch.cat(
+ (x,
+ resampled_img1,
+ flownets1_flow /
+ self.div_flow,
+ norm_diff_img0),
+ dim=1)
+ # flownets2
+ flownets2_flow2 = self.flownets_2(concat2)[0]
+ flownets2_flow = self.upsample4(flownets2_flow2 * self.div_flow)
+ norm_flownets2_flow = self.channelnorm(flownets2_flow)
+ if self.args.fp16:
+ diff_flownets2_flow = self.resample4(tofp32()(x[:, 3:, :, :]),
+ flownets2_flow)
+ diff_flownets2_flow = tofp16()(diff_flownets2_flow)
+ else:
+ diff_flownets2_flow = self.resample4(x[:, 3:, :, :], flownets2_flow)
+ diff_flownets2_img1 = self.channelnorm(
+ (x[:, :3, :, :] - diff_flownets2_flow))
+ # flownetsd
+ flownetsd_flow2 = self.flownets_d(x)[0]
+ flownetsd_flow = self.upsample3(flownetsd_flow2 / self.div_flow)
+ norm_flownetsd_flow = self.channelnorm(flownetsd_flow)
+ if self.args.fp16:
+ diff_flownetsd_flow = self.resample3(tofp32()(x[:, 3:, :, :]),
+ flownetsd_flow)
+ diff_flownetsd_flow = tofp16()(diff_flownetsd_flow)
+ else:
+ diff_flownetsd_flow = self.resample3(x[:, 3:, :, :], flownetsd_flow)
+ diff_flownetsd_img1 = self.channelnorm(
+ (x[:, :3, :, :] - diff_flownetsd_flow))
+ # concat img1 flownetsd, flownets2, norm_flownetsd,
+ # norm_flownets2, diff_flownetsd_img1, diff_flownets2_img1
+ concat3 = torch.cat((x[:, :3, :, :], flownetsd_flow, flownets2_flow,
+ norm_flownetsd_flow, norm_flownets2_flow,
+ diff_flownetsd_img1, diff_flownets2_img1), dim=1)
+ flownetfusion_flow = self.flownetfusion(concat3)
+ return flownetfusion_flow
+
+
+class FlowNet2C(flownet_c.FlowNetC):
+ def __init__(self, args, use_batch_norm=False, div_flow=20):
+ super(
+ FlowNet2C,
+ self).__init__(
+ args,
+ use_batch_norm=use_batch_norm,
+ div_flow=20)
+ self.rgb_max = args.rgb_max
+
+ def forward(self, inputs):
+ rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+ dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+ x = (inputs - rgb_mean) / self.rgb_max
+ x1 = x[:, :, 0, :, :]
+ x2 = x[:, :, 1, :, :]
+ # FlownetC top input stream
+ out_conv1a = self.conv1(x1)
+ out_conv2a = self.conv2(out_conv1a)
+ out_conv3a = self.conv3(out_conv2a)
+ # FlownetC bottom input stream
+ out_conv1b = self.conv1(x2)
+ out_conv2b = self.conv2(out_conv1b)
+ out_conv3b = self.conv3(out_conv2b)
+ # Merge streams
+ out_corr = self.corr(out_conv3a, out_conv3b) # False
+ out_corr = self.corr_activation(out_corr)
+ # Redirect top input stream and concatenate
+ out_conv_redir = self.conv_redir(out_conv3a)
+ in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1)
+ # Merged conv layers
+ out_conv3_1 = self.conv3_1(in_conv3_1)
+ out_conv4 = self.conv4_1(self.conv4(out_conv3_1))
+ out_conv5 = self.conv5_1(self.conv5(out_conv4))
+ out_conv6 = self.conv6_1(self.conv6(out_conv5))
+ flow6 = self.predict_flow6(out_conv6)
+ flow6_up = self.upsampled_flow6_to_5(flow6)
+ out_deconv5 = self.deconv5(out_conv6)
+ concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+ flow5 = self.predict_flow5(concat5)
+ flow5_up = self.upsampled_flow5_to_4(flow5)
+ out_deconv4 = self.deconv4(concat5)
+ concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+ flow4 = self.predict_flow4(concat4)
+ flow4_up = self.upsampled_flow4_to_3(flow4)
+ out_deconv3 = self.deconv3(concat4)
+ concat3 = torch.cat((out_conv3_1, out_deconv3, flow4_up), 1)
+ flow3 = self.predict_flow3(concat3)
+ flow3_up = self.upsampled_flow3_to_2(flow3)
+ out_deconv2 = self.deconv2(concat3)
+ concat2 = torch.cat((out_conv2a, out_deconv2, flow3_up), 1)
+ flow2 = self.predict_flow2(concat2)
+ if self.training:
+ return flow2, flow3, flow4, flow5, flow6
+ else:
+ return self.upsample1(flow2 * self.div_flow)
+
+
+class FlowNet2S(flownet_s.FlowNetS):
+ def __init__(self, args, use_batch_norm=False, div_flow=20):
+ super(FlowNet2S, self).__init__(args, input_channels=6,
+ use_batch_norm=use_batch_norm)
+ self.rgb_max = args.rgb_max
+ self.div_flow = div_flow
+
+ def forward(self, inputs):
+ rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+ dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+ x = (inputs - rgb_mean) / self.rgb_max
+ x = torch.cat((x[:, :, 0, :, :], x[:, :, 1, :, :]), dim=1)
+ out_conv1 = self.conv1(x)
+ out_conv2 = self.conv2(out_conv1)
+ out_conv3 = self.conv3_1(self.conv3(out_conv2))
+ out_conv4 = self.conv4_1(self.conv4(out_conv3))
+ out_conv5 = self.conv5_1(self.conv5(out_conv4))
+ out_conv6 = self.conv6_1(self.conv6(out_conv5))
+ flow6 = self.predict_flow6(out_conv6)
+ flow6_up = self.upsampled_flow6_to_5(flow6)
+ out_deconv5 = self.deconv5(out_conv6)
+ concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+ flow5 = self.predict_flow5(concat5)
+ flow5_up = self.upsampled_flow5_to_4(flow5)
+ out_deconv4 = self.deconv4(concat5)
+ concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+ flow4 = self.predict_flow4(concat4)
+ flow4_up = self.upsampled_flow4_to_3(flow4)
+ out_deconv3 = self.deconv3(concat4)
+ concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
+ flow3 = self.predict_flow3(concat3)
+ flow3_up = self.upsampled_flow3_to_2(flow3)
+ out_deconv2 = self.deconv2(concat3)
+ concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
+ flow2 = self.predict_flow2(concat2)
+ if self.training:
+ return flow2, flow3, flow4, flow5, flow6
+ else:
+ return self.upsample1(flow2 * self.div_flow)
+
+
+class FlowNet2SD(flownet_sd.FlowNetSD):
+ def __init__(self, args, use_batch_norm=False, div_flow=20):
+ super(FlowNet2SD, self).__init__(args, use_batch_norm=use_batch_norm)
+ self.rgb_max = args.rgb_max
+ self.div_flow = div_flow
+
+ def forward(self, inputs):
+ rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+ dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+ x = (inputs - rgb_mean) / self.rgb_max
+ x = torch.cat((x[:, :, 0, :, :], x[:, :, 1, :, :]), dim=1)
+ out_conv0 = self.conv0(x)
+ out_conv1 = self.conv1_1(self.conv1(out_conv0))
+ out_conv2 = self.conv2_1(self.conv2(out_conv1))
+ out_conv3 = self.conv3_1(self.conv3(out_conv2))
+ out_conv4 = self.conv4_1(self.conv4(out_conv3))
+ out_conv5 = self.conv5_1(self.conv5(out_conv4))
+ out_conv6 = self.conv6_1(self.conv6(out_conv5))
+ flow6 = self.predict_flow6(out_conv6)
+ flow6_up = self.upsampled_flow6_to_5(flow6)
+ out_deconv5 = self.deconv5(out_conv6)
+ concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+ out_interconv5 = self.inter_conv5(concat5)
+ flow5 = self.predict_flow5(out_interconv5)
+ flow5_up = self.upsampled_flow5_to_4(flow5)
+ out_deconv4 = self.deconv4(concat5)
+ concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+ out_interconv4 = self.inter_conv4(concat4)
+ flow4 = self.predict_flow4(out_interconv4)
+ flow4_up = self.upsampled_flow4_to_3(flow4)
+ out_deconv3 = self.deconv3(concat4)
+ concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
+ out_interconv3 = self.inter_conv3(concat3)
+ flow3 = self.predict_flow3(out_interconv3)
+ flow3_up = self.upsampled_flow3_to_2(flow3)
+ out_deconv2 = self.deconv2(concat3)
+ concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
+ out_interconv2 = self.inter_conv2(concat2)
+ flow2 = self.predict_flow2(out_interconv2)
+ if self.training:
+ return flow2, flow3, flow4, flow5, flow6
+ else:
+ return self.upsample1(flow2 * self.div_flow)
+
+
+class FlowNet2CS(nn.Module):
+ def __init__(self, args, use_batch_norm=False, div_flow=20.):
+ super(FlowNet2CS, self).__init__()
+ self.use_batch_norm = use_batch_norm
+ self.div_flow = div_flow
+ self.rgb_max = args.rgb_max
+ self.args = args
+ self.channelnorm = channelnorm.ChannelNorm()
+ # First Block (FlowNetC)
+ self.flownetc = flownet_c.FlowNetC(
+ args, use_batch_norm=self.use_batch_norm)
+ self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+ align_corners=False)
+ self.args = args
+ # if args.fp16:
+ # self.resample1 = nn.Sequential(
+ # tofp32(), resample2d.Resample2d(), tofp16())
+ # else:
+ self.resample1 = resample2d.Resample2d()
+ # Block (FlowNetS1)
+ self.flownets_1 = flownet_s.FlowNetS(
+ args, use_batch_norm=self.use_batch_norm)
+ self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear',
+ align_corners=False)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if m.bias is not None:
+ init.uniform(m.bias)
+ init.xavier_uniform(m.weight)
+ if isinstance(m, nn.ConvTranspose2d):
+ if m.bias is not None:
+ init.uniform(m.bias)
+ init.xavier_uniform(m.weight)
+
+ def forward(self, inputs):
+ rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+ dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+ x = (inputs - rgb_mean) / self.rgb_max
+ x1 = x[:, :, 0, :, :]
+ x2 = x[:, :, 1, :, :]
+ x = torch.cat((x1, x2), dim=1)
+ # flownetc
+ flownetc_flow2 = self.flownetc(x)[0]
+ flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow)
+ # warp img1 to img0;
+ # magnitude of diff between img0 and and warped_img1,
+ if self.args.fp16:
+ resampled_img1 = self.resample1(tofp32()(x[:, 3:, :, :]),
+ flownetc_flow)
+ resampled_img1 = tofp16()(resampled_img1)
+ else:
+ resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow)
+ diff_img0 = x[:, :3, :, :] - resampled_img1
+ norm_diff_img0 = self.channelnorm(diff_img0)
+ # concat img0, img1, img1->img0, flow, diff-mag ;
+ concat1 = torch.cat(
+ (x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0),
+ dim=1)
+ # flownets1
+ flownets1_flow2 = self.flownets_1(concat1)[0]
+ flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow)
+ return flownets1_flow
+
+
+class FlowNet2CSS(nn.Module):
+ def __init__(self, args, use_batch_norm=False, div_flow=20.):
+ super(FlowNet2CSS, self).__init__()
+ self.use_batch_norm = use_batch_norm
+ self.div_flow = div_flow
+ self.rgb_max = args.rgb_max
+ self.args = args
+ self.channelnorm = channelnorm.ChannelNorm()
+ # First Block (FlowNetC)
+ self.flownetc = flownet_c.FlowNetC(
+ args, use_batch_norm=self.use_batch_norm)
+ self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+ align_corners=False)
+ self.args = args
+ # if args.fp16:
+ # self.resample1 = nn.Sequential(
+ # tofp32(), resample2d.Resample2d(), tofp16())
+ # else:
+ self.resample1 = resample2d.Resample2d()
+ # Block (FlowNetS1)
+ self.flownets_1 = flownet_s.FlowNetS(
+ args, use_batch_norm=self.use_batch_norm)
+ self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear',
+ align_corners=False)
+ # if args.fp16:
+ # self.resample2 = nn.Sequential(
+ # tofp32(), resample2d.Resample2d(), tofp16())
+ # else:
+ self.resample2 = resample2d.Resample2d()
+ # Block (FlowNetS2)
+ self.flownets_2 = flownet_s.FlowNetS(
+ args, use_batch_norm=self.use_batch_norm)
+ self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest',
+ align_corners=False)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if m.bias is not None:
+ init.uniform(m.bias)
+ init.xavier_uniform(m.weight)
+ if isinstance(m, nn.ConvTranspose2d):
+ if m.bias is not None:
+ init.uniform(m.bias)
+ init.xavier_uniform(m.weight)
+
+ def forward(self, inputs):
+ rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean(
+ dim=-1).view(inputs.size()[:2] + (1, 1, 1,))
+ x = (inputs - rgb_mean) / self.rgb_max
+ x1 = x[:, :, 0, :, :]
+ x2 = x[:, :, 1, :, :]
+ x = torch.cat((x1, x2), dim=1)
+ # flownetc
+ flownetc_flow2 = self.flownetc(x)[0]
+ flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow)
+ # Warp img1 to img0;
+ # Magnitude of diff between img0 and and warped_img1,
+ if self.args.fp16:
+ resampled_img1 = self.resample1(tofp32()(x[:, 3:, :, :]),
+ flownetc_flow)
+ resampled_img1 = tofp16()(resampled_img1)
+ else:
+ resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow)
+ diff_img0 = x[:, :3, :, :] - resampled_img1
+ norm_diff_img0 = self.channelnorm(diff_img0)
+ # concat img0, img1, img1->img0, flow, diff-mag ;
+ concat1 = torch.cat(
+ (x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0),
+ dim=1)
+ # flownets1
+ flownets1_flow2 = self.flownets_1(concat1)[0]
+ flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow)
+ # Warp img1 to img0 using flownets1;
+ # magnitude of diff between img0 and and warped_img1
+ if self.args.fp16:
+ resampled_img1 = self.resample2(tofp32()(x[:, 3:, :, :]),
+ flownets1_flow)
+ resampled_img1 = tofp16()(resampled_img1)
+ else:
+ resampled_img1 = self.resample2(x[:, 3:, :, :], flownets1_flow)
+ diff_img0 = x[:, :3, :, :] - resampled_img1
+ norm_diff_img0 = self.channelnorm(diff_img0)
+ # concat img0, img1, img1->img0, flow, diff-mag
+ concat2 = torch.cat(
+ (x,
+ resampled_img1,
+ flownets1_flow /
+ self.div_flow,
+ norm_diff_img0),
+ dim=1)
+ # flownets2
+ flownets2_flow2 = self.flownets_2(concat2)[0]
+ flownets2_flow = self.upsample3(flownets2_flow2 * self.div_flow)
+ return flownets2_flow
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/__init__.py b/imaginaire/third_party/flow_net/flownet2/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb7ad34e7ce1c37ea1653d73cde323cdb5569e4
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_c.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_c.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6b3719c26eda72d61429c3c52b49707cf48c558
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_c.py
@@ -0,0 +1,160 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
+from torch.nn import init
+import correlation
+import torch
+import torch.nn as nn
+from .submodules import conv, predict_flow, deconv, tofp16, tofp32
+
+
+class FlowNetC(nn.Module):
+ def __init__(self, args, use_batch_norm=True, div_flow=20):
+ r"""FlowNet2 C module. Check out the FlowNet2 paper for more details
+ https://arxiv.org/abs/1612.01925
+
+ Args:
+ args (obj): Network initialization arguments
+ use_batch_norm (bool): Use batch norm or not. Default is true.
+ div_flow (int): Flow devision factor. Default is 20.
+ """
+ super(FlowNetC, self).__init__()
+
+ self.use_batch_norm = use_batch_norm
+ self.div_flow = div_flow
+
+ self.conv1 = conv(self.use_batch_norm, 3, 64, kernel_size=7, stride=2)
+ self.conv2 = conv(self.use_batch_norm, 64, 128, kernel_size=5, stride=2)
+ self.conv3 = conv(self.use_batch_norm, 128, 256, kernel_size=5,
+ stride=2)
+ self.conv_redir = conv(self.use_batch_norm, 256, 32,
+ kernel_size=1, stride=1)
+ self.args = args
+ # if args.fp16:
+ # self.corr = nn.Sequential(
+ # tofp32(),
+ # correlation.Correlation(pad_size=20, kernel_size=1,
+ # max_displacement=20, stride1=1,
+ # stride2=2, corr_multiply=1),
+ # tofp16())
+ # else:
+ self.corr = correlation.Correlation(pad_size=20, kernel_size=1,
+ max_displacement=20, stride1=1,
+ stride2=2, corr_multiply=1)
+
+ self.corr_activation = nn.LeakyReLU(0.1, inplace=True)
+ self.conv3_1 = conv(self.use_batch_norm, 473, 256)
+ self.conv4 = conv(self.use_batch_norm, 256, 512, stride=2)
+ self.conv4_1 = conv(self.use_batch_norm, 512, 512)
+ self.conv5 = conv(self.use_batch_norm, 512, 512, stride=2)
+ self.conv5_1 = conv(self.use_batch_norm, 512, 512)
+ self.conv6 = conv(self.use_batch_norm, 512, 1024, stride=2)
+ self.conv6_1 = conv(self.use_batch_norm, 1024, 1024)
+
+ self.deconv5 = deconv(1024, 512)
+ self.deconv4 = deconv(1026, 256)
+ self.deconv3 = deconv(770, 128)
+ self.deconv2 = deconv(386, 64)
+
+ self.predict_flow6 = predict_flow(1024)
+ self.predict_flow5 = predict_flow(1026)
+ self.predict_flow4 = predict_flow(770)
+ self.predict_flow3 = predict_flow(386)
+ self.predict_flow2 = predict_flow(194)
+
+ self.upsampled_flow6_to_5 = nn.ConvTranspose2d(
+ 2, 2, 4, 2, 1, bias=True)
+ self.upsampled_flow5_to_4 = nn.ConvTranspose2d(
+ 2, 2, 4, 2, 1, bias=True)
+ self.upsampled_flow4_to_3 = nn.ConvTranspose2d(
+ 2, 2, 4, 2, 1, bias=True)
+ self.upsampled_flow3_to_2 = nn.ConvTranspose2d(
+ 2, 2, 4, 2, 1, bias=True)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if m.bias is not None:
+ init.uniform_(m.bias)
+ init.xavier_uniform_(m.weight)
+
+ if isinstance(m, nn.ConvTranspose2d):
+ if m.bias is not None:
+ init.uniform_(m.bias)
+ init.xavier_uniform_(m.weight)
+ # init_deconv_bilinear(m.weight)
+ self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+ align_corners=False)
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input tensors of concatenated images.
+ Returns:
+ flow2 (tensor): Output flow tensors.
+ """
+ x1 = x[:, 0:3, :, :]
+ x2 = x[:, 3::, :, :]
+
+ out_conv1a = self.conv1(x1)
+ out_conv2a = self.conv2(out_conv1a)
+ out_conv3a = self.conv3(out_conv2a)
+
+ # FlownetC bottom input stream
+ out_conv1b = self.conv1(x2)
+
+ out_conv2b = self.conv2(out_conv1b)
+ out_conv3b = self.conv3(out_conv2b)
+
+ # Merge streams
+ if self.args.fp16:
+ out_corr = self.corr(tofp32()(out_conv3a),
+ tofp32()(out_conv3b)) # False
+ out_corr = tofp16()(out_corr)
+ else:
+ out_corr = self.corr(out_conv3a, out_conv3b) # False
+ out_corr = self.corr_activation(out_corr)
+
+ # Redirect top input stream and concatenate
+ out_conv_redir = self.conv_redir(out_conv3a)
+
+ in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1)
+
+ # Merged conv layers
+ out_conv3_1 = self.conv3_1(in_conv3_1)
+
+ out_conv4 = self.conv4_1(self.conv4(out_conv3_1))
+
+ out_conv5 = self.conv5_1(self.conv5(out_conv4))
+ out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+ flow6 = self.predict_flow6(out_conv6)
+ flow6_up = self.upsampled_flow6_to_5(flow6)
+ out_deconv5 = self.deconv5(out_conv6)
+
+ concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+
+ flow5 = self.predict_flow5(concat5)
+ flow5_up = self.upsampled_flow5_to_4(flow5)
+ out_deconv4 = self.deconv4(concat5)
+ concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+
+ flow4 = self.predict_flow4(concat4)
+ flow4_up = self.upsampled_flow4_to_3(flow4)
+ out_deconv3 = self.deconv3(concat4)
+ concat3 = torch.cat((out_conv3_1, out_deconv3, flow4_up), 1)
+
+ flow3 = self.predict_flow3(concat3)
+ flow3_up = self.upsampled_flow3_to_2(flow3)
+ out_deconv2 = self.deconv2(concat3)
+ concat2 = torch.cat((out_conv2a, out_deconv2, flow3_up), 1)
+
+ flow2 = self.predict_flow2(concat2)
+
+ if self.training:
+ return flow2, flow3, flow4, flow5, flow6
+ else:
+ return flow2,
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_fusion.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..edddd2446d906bfc0b93df47b6f18a45ac42bc79
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_fusion.py
@@ -0,0 +1,82 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
+from torch.nn import init
+import torch
+import torch.nn as nn
+from .submodules import conv, i_conv, predict_flow, deconv
+
+
+class FlowNetFusion(nn.Module):
+ r"""FlowNet2 Fusion module. Check out the FlowNet2 paper for more details
+ https://arxiv.org/abs/1612.01925
+
+ Args:
+ args (obj): Network initialization arguments
+ use_batch_norm (bool): Use batch norm or not. Default is true.
+ """
+ def __init__(self, args, use_batch_norm=True):
+ super(FlowNetFusion, self).__init__()
+
+ self.use_batch_norm = use_batch_norm
+ self.conv0 = conv(self.use_batch_norm, 11, 64)
+ self.conv1 = conv(self.use_batch_norm, 64, 64, stride=2)
+ self.conv1_1 = conv(self.use_batch_norm, 64, 128)
+ self.conv2 = conv(self.use_batch_norm, 128, 128, stride=2)
+ self.conv2_1 = conv(self.use_batch_norm, 128, 128)
+
+ self.deconv1 = deconv(128, 32)
+ self.deconv0 = deconv(162, 16)
+
+ self.inter_conv1 = i_conv(self.use_batch_norm, 162, 32)
+ self.inter_conv0 = i_conv(self.use_batch_norm, 82, 16)
+
+ self.predict_flow2 = predict_flow(128)
+ self.predict_flow1 = predict_flow(32)
+ self.predict_flow0 = predict_flow(16)
+
+ self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+ self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if m.bias is not None:
+ init.uniform_(m.bias)
+ init.xavier_uniform_(m.weight)
+
+ if isinstance(m, nn.ConvTranspose2d):
+ if m.bias is not None:
+ init.uniform_(m.bias)
+ init.xavier_uniform_(m.weight)
+ # init_deconv_bilinear(m.weight)
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input tensors of concatenated images.
+ Returns:
+ flow2 (tensor): Output flow tensors.
+ """
+ out_conv0 = self.conv0(x)
+ out_conv1 = self.conv1_1(self.conv1(out_conv0))
+ out_conv2 = self.conv2_1(self.conv2(out_conv1))
+
+ flow2 = self.predict_flow2(out_conv2)
+ flow2_up = self.upsampled_flow2_to_1(flow2)
+ out_deconv1 = self.deconv1(out_conv2)
+
+ concat1 = torch.cat((out_conv1, out_deconv1, flow2_up), 1)
+ out_interconv1 = self.inter_conv1(concat1)
+ flow1 = self.predict_flow1(out_interconv1)
+ flow1_up = self.upsampled_flow1_to_0(flow1)
+ out_deconv0 = self.deconv0(concat1)
+
+ concat0 = torch.cat((out_conv0, out_deconv0, flow1_up), 1)
+ out_interconv0 = self.inter_conv0(concat0)
+ flow0 = self.predict_flow0(out_interconv0)
+
+ return flow0
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_s.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba8a76c5a5d66354e07aad2522a782961a50c24c
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_s.py
@@ -0,0 +1,121 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
+'''
+Portions of this code copyright 2017, Clement Pinard
+'''
+from torch.nn import init
+import torch
+import torch.nn as nn
+from .submodules import conv, predict_flow, deconv
+
+
+class FlowNetS(nn.Module):
+ r"""FlowNet2 S module. Check out the FlowNet2 paper for more details
+ https://arxiv.org/abs/1612.01925
+
+ Args:
+ args (obj): Network initialization arguments
+ input_channels (int): Number of input channels. Default is 12.
+ use_batch_norm (bool): Use batch norm or not. Default is true.
+ """
+ def __init__(self, args, input_channels=12, use_batch_norm=True):
+ super(FlowNetS, self).__init__()
+
+ self.use_batch_norm = use_batch_norm
+ self.conv1 = conv(
+ self.use_batch_norm,
+ input_channels,
+ 64,
+ kernel_size=7,
+ stride=2)
+ self.conv2 = conv(self.use_batch_norm, 64, 128, kernel_size=5, stride=2)
+ self.conv3 = conv(self.use_batch_norm, 128, 256, kernel_size=5,
+ stride=2)
+ self.conv3_1 = conv(self.use_batch_norm, 256, 256)
+ self.conv4 = conv(self.use_batch_norm, 256, 512, stride=2)
+ self.conv4_1 = conv(self.use_batch_norm, 512, 512)
+ self.conv5 = conv(self.use_batch_norm, 512, 512, stride=2)
+ self.conv5_1 = conv(self.use_batch_norm, 512, 512)
+ self.conv6 = conv(self.use_batch_norm, 512, 1024, stride=2)
+ self.conv6_1 = conv(self.use_batch_norm, 1024, 1024)
+
+ self.deconv5 = deconv(1024, 512)
+ self.deconv4 = deconv(1026, 256)
+ self.deconv3 = deconv(770, 128)
+ self.deconv2 = deconv(386, 64)
+
+ self.predict_flow6 = predict_flow(1024)
+ self.predict_flow5 = predict_flow(1026)
+ self.predict_flow4 = predict_flow(770)
+ self.predict_flow3 = predict_flow(386)
+ self.predict_flow2 = predict_flow(194)
+
+ self.upsampled_flow6_to_5 = nn.ConvTranspose2d(
+ 2, 2, 4, 2, 1, bias=False)
+ self.upsampled_flow5_to_4 = nn.ConvTranspose2d(
+ 2, 2, 4, 2, 1, bias=False)
+ self.upsampled_flow4_to_3 = nn.ConvTranspose2d(
+ 2, 2, 4, 2, 1, bias=False)
+ self.upsampled_flow3_to_2 = nn.ConvTranspose2d(
+ 2, 2, 4, 2, 1, bias=False)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if m.bias is not None:
+ init.uniform_(m.bias)
+ init.xavier_uniform_(m.weight)
+
+ if isinstance(m, nn.ConvTranspose2d):
+ if m.bias is not None:
+ init.uniform_(m.bias)
+ init.xavier_uniform_(m.weight)
+ # init_deconv_bilinear(m.weight)
+ self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+ align_corners=False)
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input tensors of concatenated images.
+ Returns:
+ flow2 (tensor): Output flow tensors.
+ """
+ out_conv1 = self.conv1(x)
+
+ out_conv2 = self.conv2(out_conv1)
+ out_conv3 = self.conv3_1(self.conv3(out_conv2))
+ out_conv4 = self.conv4_1(self.conv4(out_conv3))
+ out_conv5 = self.conv5_1(self.conv5(out_conv4))
+ out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+ flow6 = self.predict_flow6(out_conv6)
+ flow6_up = self.upsampled_flow6_to_5(flow6)
+ out_deconv5 = self.deconv5(out_conv6)
+
+ concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+ flow5 = self.predict_flow5(concat5)
+ flow5_up = self.upsampled_flow5_to_4(flow5)
+ out_deconv4 = self.deconv4(concat5)
+
+ concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+ flow4 = self.predict_flow4(concat4)
+ flow4_up = self.upsampled_flow4_to_3(flow4)
+ out_deconv3 = self.deconv3(concat4)
+
+ concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
+ flow3 = self.predict_flow3(concat3)
+ flow3_up = self.upsampled_flow3_to_2(flow3)
+ out_deconv2 = self.deconv2(concat3)
+
+ concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
+ flow2 = self.predict_flow2(concat2)
+
+ if self.training:
+ return flow2, flow3, flow4, flow5, flow6
+ else:
+ return flow2,
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_sd.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_sd.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f4340347252a9591d7540689abaae821d759060
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_sd.py
@@ -0,0 +1,121 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
+import torch
+import torch.nn as nn
+from .submodules import conv, i_conv, predict_flow, deconv
+from torch.nn import init
+
+
+class FlowNetSD(nn.Module):
+ r"""FlowNet2 SD module. Check out the FlowNet2 paper for more details
+ https://arxiv.org/abs/1612.01925
+
+ Args:
+ args (obj): Network initialization arguments
+ use_batch_norm (bool): Use batch norm or not. Default is true.
+ """
+ def __init__(self, args, use_batch_norm=True):
+ super(FlowNetSD, self).__init__()
+
+ self.use_batch_norm = use_batch_norm
+ self.conv0 = conv(self.use_batch_norm, 6, 64)
+ self.conv1 = conv(self.use_batch_norm, 64, 64, stride=2)
+ self.conv1_1 = conv(self.use_batch_norm, 64, 128)
+ self.conv2 = conv(self.use_batch_norm, 128, 128, stride=2)
+ self.conv2_1 = conv(self.use_batch_norm, 128, 128)
+ self.conv3 = conv(self.use_batch_norm, 128, 256, stride=2)
+ self.conv3_1 = conv(self.use_batch_norm, 256, 256)
+ self.conv4 = conv(self.use_batch_norm, 256, 512, stride=2)
+ self.conv4_1 = conv(self.use_batch_norm, 512, 512)
+ self.conv5 = conv(self.use_batch_norm, 512, 512, stride=2)
+ self.conv5_1 = conv(self.use_batch_norm, 512, 512)
+ self.conv6 = conv(self.use_batch_norm, 512, 1024, stride=2)
+ self.conv6_1 = conv(self.use_batch_norm, 1024, 1024)
+
+ self.deconv5 = deconv(1024, 512)
+ self.deconv4 = deconv(1026, 256)
+ self.deconv3 = deconv(770, 128)
+ self.deconv2 = deconv(386, 64)
+
+ self.inter_conv5 = i_conv(self.use_batch_norm, 1026, 512)
+ self.inter_conv4 = i_conv(self.use_batch_norm, 770, 256)
+ self.inter_conv3 = i_conv(self.use_batch_norm, 386, 128)
+ self.inter_conv2 = i_conv(self.use_batch_norm, 194, 64)
+
+ self.predict_flow6 = predict_flow(1024)
+ self.predict_flow5 = predict_flow(512)
+ self.predict_flow4 = predict_flow(256)
+ self.predict_flow3 = predict_flow(128)
+ self.predict_flow2 = predict_flow(64)
+
+ self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+ self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+ self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+ self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ if m.bias is not None:
+ init.uniform_(m.bias)
+ init.xavier_uniform_(m.weight)
+
+ if isinstance(m, nn.ConvTranspose2d):
+ if m.bias is not None:
+ init.uniform_(m.bias)
+ init.xavier_uniform_(m.weight)
+ # init_deconv_bilinear(m.weight)
+ self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear',
+ align_corners=False)
+
+ def forward(self, x):
+ r"""
+
+ Args:
+ x (tensor): Input tensors of concatenated images.
+ Returns:
+ flow2 (tensor): Output flow tensors.
+ """
+ out_conv0 = self.conv0(x)
+ out_conv1 = self.conv1_1(self.conv1(out_conv0))
+ out_conv2 = self.conv2_1(self.conv2(out_conv1))
+
+ out_conv3 = self.conv3_1(self.conv3(out_conv2))
+ out_conv4 = self.conv4_1(self.conv4(out_conv3))
+ out_conv5 = self.conv5_1(self.conv5(out_conv4))
+ out_conv6 = self.conv6_1(self.conv6(out_conv5))
+
+ flow6 = self.predict_flow6(out_conv6)
+ flow6_up = self.upsampled_flow6_to_5(flow6)
+ out_deconv5 = self.deconv5(out_conv6)
+
+ concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
+ out_interconv5 = self.inter_conv5(concat5)
+ flow5 = self.predict_flow5(out_interconv5)
+
+ flow5_up = self.upsampled_flow5_to_4(flow5)
+ out_deconv4 = self.deconv4(concat5)
+
+ concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
+ out_interconv4 = self.inter_conv4(concat4)
+ flow4 = self.predict_flow4(out_interconv4)
+ flow4_up = self.upsampled_flow4_to_3(flow4)
+ out_deconv3 = self.deconv3(concat4)
+
+ concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
+ out_interconv3 = self.inter_conv3(concat3)
+ flow3 = self.predict_flow3(out_interconv3)
+ flow3_up = self.upsampled_flow3_to_2(flow3)
+ out_deconv2 = self.deconv2(concat3)
+
+ concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
+ out_interconv2 = self.inter_conv2(concat2)
+ flow2 = self.predict_flow2(out_interconv2)
+
+ if self.training:
+ return flow2, flow3, flow4, flow5, flow6
+ else:
+ return flow2,
diff --git a/imaginaire/third_party/flow_net/flownet2/networks/submodules.py b/imaginaire/third_party/flow_net/flownet2/networks/submodules.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ab504401c1473bcc52ae4a1029afd74eed6d11
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/networks/submodules.py
@@ -0,0 +1,113 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch
+# with some modifications.
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def conv(use_batch_norm, in_planes, out_planes, kernel_size=3, stride=1):
+ if use_batch_norm:
+ return nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
+ stride=stride, padding=(kernel_size - 1) // 2,
+ bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.LeakyReLU(0.1, inplace=True)
+ )
+ else:
+ return nn.Sequential(
+ nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(
+ kernel_size - 1) // 2,
+ bias=True),
+ nn.LeakyReLU(
+ 0.1,
+ inplace=True))
+
+
+def i_conv(use_batch_norm, in_planes, out_planes, kernel_size=3, stride=1,
+ bias=True):
+ if use_batch_norm:
+ return nn.Sequential(
+ nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(
+ kernel_size - 1) // 2,
+ bias=bias),
+ nn.BatchNorm2d(out_planes),
+ )
+ else:
+ return nn.Sequential(
+ nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(
+ kernel_size -
+ 1) //
+ 2,
+ bias=bias),
+ )
+
+
+def predict_flow(in_planes):
+ return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1,
+ bias=True)
+
+
+def deconv(in_planes, out_planes):
+ return nn.Sequential(
+ nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2,
+ padding=1, bias=True),
+ nn.LeakyReLU(0.1, inplace=True)
+ )
+
+
+class tofp16(nn.Module):
+ def __init__(self):
+ super(tofp16, self).__init__()
+
+ def forward(self, input):
+ return input.half()
+
+
+class tofp32(nn.Module):
+ def __init__(self):
+ super(tofp32, self).__init__()
+
+ def forward(self, input):
+ return input.float()
+
+
+def init_deconv_bilinear(weight):
+ f_shape = weight.size()
+ heigh, width = f_shape[-2], f_shape[-1]
+ f = np.ceil(width / 2.0)
+ c = (2 * f - 1 - f % 2) / (2.0 * f)
+ bilinear = np.zeros([heigh, width])
+ for x in range(width):
+ for y in range(heigh):
+ value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
+ bilinear[x, y] = value
+ weight.data.fill_(0.)
+ for i in range(f_shape[0]):
+ for j in range(f_shape[1]):
+ weight.data[i, j, :, :] = torch.from_numpy(bilinear)
+
+
+def save_grad(grads, name):
+ def hook(grad):
+ grads[name] = grad
+ return hook
diff --git a/imaginaire/third_party/flow_net/flownet2/utils/__init__.py b/imaginaire/third_party/flow_net/flownet2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
diff --git a/imaginaire/third_party/flow_net/flownet2/utils/flow_utils.py b/imaginaire/third_party/flow_net/flownet2/utils/flow_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bffeba58a93de4379c8e9ed54af58b56baa13eb
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/utils/flow_utils.py
@@ -0,0 +1,219 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+import matplotlib.pyplot as plt
+import os.path
+
+TAG_CHAR = np.array([202021.25], np.float32)
+
+
+def readFlow(fn):
+ """ Read .flo file in Middlebury format"""
+ # Code adapted from:
+ # http://stackoverflow.com/questions/28013200/
+ # reading-middlebury-flow-files-with-python-bytes-array-numpy
+
+ # WARNING: this will work on little-endian architectures
+ # (eg Intel x86) only!
+ # print 'fn = %s'%(fn)
+ with open(fn, 'rb') as f:
+ magic = np.fromfile(f, np.float32, count=1)
+ if 202021.25 != magic:
+ print('Magic number incorrect. Invalid .flo file')
+ return None
+ else:
+ w = np.fromfile(f, np.int32, count=1)
+ h = np.fromfile(f, np.int32, count=1)
+ # print 'Reading %d x %d flo file\n' % (w, h)
+ data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
+ # Reshape data into 3D array (columns, rows, bands)
+ # The reshape here is for visualization, the original code is
+ # (w,h,2)
+ return np.resize(data, (int(h), int(w), 2))
+
+
+def writeFlow(filename, uv, v=None):
+ """ Write optical flow to file.
+
+ If v is None, uv is assumed to contain both u and v channels,
+ stacked in deep.
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
+ """
+ nBands = 2
+
+ if v is None:
+ assert(uv.ndim == 3)
+ assert(uv.shape[2] == 2)
+ u = uv[:, :, 0]
+ v = uv[:, :, 1]
+ else:
+ u = uv
+
+ assert(u.shape == v.shape)
+ height, width = u.shape
+ f = open(filename, 'wb')
+ # write the header
+ f.write(TAG_CHAR)
+ np.array(width).astype(np.int32).tofile(f)
+ np.array(height).astype(np.int32).tofile(f)
+ # arrange into matrix form
+ tmp = np.zeros((height, width * nBands))
+ tmp[:, np.arange(width) * 2] = u
+ tmp[:, np.arange(width) * 2 + 1] = v
+ tmp.astype(np.float32).tofile(f)
+ f.close()
+
+
+# ref: https://github.com/sampepose/flownet2-tf/
+# blob/18f87081db44939414fc4a48834f9e0da3e69f4c/src/flowlib.py#L240
+def visulize_flow_file(flow_filename, save_dir=None):
+ flow_data = readFlow(flow_filename)
+ img = flow2img(flow_data)
+ # plt.imshow(img)
+ # plt.show()
+ if save_dir:
+ idx = flow_filename.rfind("/") + 1
+ plt.imsave(os.path.join(save_dir, "%s-vis.png" %
+ flow_filename[idx:-4]), img)
+
+
+def flow2img(flow_data):
+ """
+ convert optical flow into color image
+ :param flow_data:
+ :return: color image
+ """
+ # print(flow_data.shape)
+ # print(type(flow_data))
+ u = flow_data[:, :, 0]
+ v = flow_data[:, :, 1]
+
+ UNKNOW_FLOW_THRESHOLD = 1e7
+ pr1 = abs(u) > UNKNOW_FLOW_THRESHOLD
+ pr2 = abs(v) > UNKNOW_FLOW_THRESHOLD
+ idx_unknown = (pr1 | pr2)
+ u[idx_unknown] = v[idx_unknown] = 0
+
+ # get max value in each direction
+ maxu = -999.
+ maxv = -999.
+ minu = 999.
+ minv = 999.
+ maxu = max(maxu, np.max(u))
+ maxv = max(maxv, np.max(v))
+ minu = min(minu, np.min(u))
+ minv = min(minv, np.min(v))
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+ maxrad = max(-1, np.max(rad))
+ u = u / maxrad + np.finfo(float).eps
+ v = v / maxrad + np.finfo(float).eps
+
+ img = compute_color(u, v)
+
+ idx = np.repeat(idx_unknown[:, :, np.newaxis], 3, axis=2)
+ img[idx] = 0
+
+ return np.uint8(img)
+
+
+def compute_color(u, v):
+ """
+ compute optical flow color map
+ :param u: horizontal optical flow
+ :param v: vertical optical flow
+ :return:
+ """
+
+ height, width = u.shape
+ img = np.zeros((height, width, 3))
+
+ NAN_idx = np.isnan(u) | np.isnan(v)
+ u[NAN_idx] = v[NAN_idx] = 0
+
+ colorwheel = make_color_wheel()
+ ncols = np.size(colorwheel, 0)
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a + 1) / 2 * (ncols - 1) + 1
+
+ k0 = np.floor(fk).astype(int)
+
+ k1 = k0 + 1
+ k1[k1 == ncols + 1] = 1
+ f = fk - k0
+
+ for i in range(0, np.size(colorwheel, 1)):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0 - 1] / 255
+ col1 = tmp[k1 - 1] / 255
+ col = (1 - f) * col0 + f * col1
+
+ idx = rad <= 1
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ notidx = np.logical_not(idx)
+
+ col[notidx] *= 0.75
+ img[:, :, i] = np.uint8(np.floor(255 * col * (1 - NAN_idx)))
+
+ return img
+
+
+def make_color_wheel():
+ """
+ Generate color wheel according Middlebury color code
+ :return: Color wheel
+ """
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+
+ colorwheel = np.zeros([ncols, 3])
+
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
+ col += RY
+
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - \
+ np.transpose(np.floor(255 * np.arange(0, YG) / YG))
+ colorwheel[col:col + YG, 1] = 255
+ col += YG
+
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC,
+ 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
+ col += GC
+
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - \
+ np.transpose(np.floor(255 * np.arange(0, CB) / CB))
+ colorwheel[col:col + CB, 2] = 255
+ col += CB
+
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM,
+ 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
+ col += + BM
+
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - \
+ np.transpose(np.floor(255 * np.arange(0, MR) / MR))
+ colorwheel[col:col + MR, 0] = 255
+
+ return colorwheel
diff --git a/imaginaire/third_party/flow_net/flownet2/utils/frame_utils.py b/imaginaire/third_party/flow_net/flownet2/utils/frame_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac1e3f83d6179afcb266e5923af5dd54fc3dd3fc
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/utils/frame_utils.py
@@ -0,0 +1,23 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import numpy as np
+from os.path import splitext
+from scipy.misc import imread
+from . import flow_utils
+
+
+def read_gen(file_name):
+ ext = splitext(file_name)[-1]
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
+ im = imread(file_name)
+ if im.shape[2] > 3:
+ return im[:, :, :3]
+ else:
+ return im
+ elif ext == '.bin' or ext == '.raw':
+ return np.load(file_name)
+ elif ext == '.flo':
+ return flow_utils.readFlow(file_name).astype(np.float32)
+ return []
diff --git a/imaginaire/third_party/flow_net/flownet2/utils/param_utils.py b/imaginaire/third_party/flow_net/flownet2/utils/param_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b084c9c35b957888acea86987ab25073e43feac0
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/utils/param_utils.py
@@ -0,0 +1,275 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def parse_flownetc(modules, weights, biases):
+ keys = [
+ 'conv1',
+ 'conv2',
+ 'conv3',
+ 'conv_redir',
+ 'conv3_1',
+ 'conv4',
+ 'conv4_1',
+ 'conv5',
+ 'conv5_1',
+ 'conv6',
+ 'conv6_1',
+
+ 'deconv5',
+ 'deconv4',
+ 'deconv3',
+ 'deconv2',
+
+ 'Convolution1',
+ 'Convolution2',
+ 'Convolution3',
+ 'Convolution4',
+ 'Convolution5',
+
+ 'upsample_flow6to5',
+ 'upsample_flow5to4',
+ 'upsample_flow4to3',
+ 'upsample_flow3to2',
+
+ ]
+ i = 0
+ for m in modules:
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ weight = weights[keys[i]].copy()
+ bias = biases[keys[i]].copy()
+ if keys[i] == 'conv1':
+ m.weight.data[:, :, :, :] = torch.from_numpy(
+ np.flip(weight, axis=1).copy())
+ m.bias.data[:] = torch.from_numpy(bias)
+ else:
+ m.weight.data[:, :, :, :] = torch.from_numpy(weight)
+ m.bias.data[:] = torch.from_numpy(bias)
+
+ i = i + 1
+ return
+
+
+def parse_flownets(modules, weights, biases, param_prefix='net2_'):
+ keys = [
+ 'conv1',
+ 'conv2',
+ 'conv3',
+ 'conv3_1',
+ 'conv4',
+ 'conv4_1',
+ 'conv5',
+ 'conv5_1',
+ 'conv6',
+ 'conv6_1',
+
+ 'deconv5',
+ 'deconv4',
+ 'deconv3',
+ 'deconv2',
+
+ 'predict_conv6',
+ 'predict_conv5',
+ 'predict_conv4',
+ 'predict_conv3',
+ 'predict_conv2',
+
+ 'upsample_flow6to5',
+ 'upsample_flow5to4',
+ 'upsample_flow4to3',
+ 'upsample_flow3to2',
+ ]
+ for i, k in enumerate(keys):
+ if 'upsample' in k:
+ keys[i] = param_prefix + param_prefix + k
+ else:
+ keys[i] = param_prefix + k
+ i = 0
+ for m in modules:
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ weight = weights[keys[i]].copy()
+ bias = biases[keys[i]].copy()
+ if keys[i] == param_prefix + 'conv1':
+ m.weight.data[:, 0:3, :, :] = torch.from_numpy(
+ np.flip(weight[:, 0:3, :, :], axis=1).copy())
+ m.weight.data[:, 3:6, :, :] = torch.from_numpy(
+ np.flip(weight[:, 3:6, :, :], axis=1).copy())
+ m.weight.data[:, 6:9, :, :] = torch.from_numpy(
+ np.flip(weight[:, 6:9, :, :], axis=1).copy())
+ m.weight.data[:, 9::, :, :] = torch.from_numpy(
+ weight[:, 9:, :, :].copy())
+ if m.bias is not None:
+ m.bias.data[:] = torch.from_numpy(bias)
+ else:
+ m.weight.data[:, :, :, :] = torch.from_numpy(weight)
+ if m.bias is not None:
+ m.bias.data[:] = torch.from_numpy(bias)
+ i = i + 1
+ return
+
+
+def parse_flownetsonly(modules, weights, biases, param_prefix=''):
+ keys = [
+ 'conv1',
+ 'conv2',
+ 'conv3',
+ 'conv3_1',
+ 'conv4',
+ 'conv4_1',
+ 'conv5',
+ 'conv5_1',
+ 'conv6',
+ 'conv6_1',
+
+ 'deconv5',
+ 'deconv4',
+ 'deconv3',
+ 'deconv2',
+
+ 'Convolution1',
+ 'Convolution2',
+ 'Convolution3',
+ 'Convolution4',
+ 'Convolution5',
+
+ 'upsample_flow6to5',
+ 'upsample_flow5to4',
+ 'upsample_flow4to3',
+ 'upsample_flow3to2',
+ ]
+ for i, k in enumerate(keys):
+ if 'upsample' in k:
+ keys[i] = param_prefix + param_prefix + k
+ else:
+ keys[i] = param_prefix + k
+ i = 0
+ for m in modules:
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ weight = weights[keys[i]].copy()
+ bias = biases[keys[i]].copy()
+ if keys[i] == param_prefix + 'conv1':
+ # print ("%s :"%(keys[i]), m.weight.size(), m.bias.size(),
+ # tf_w[keys[i]].shape[::-1])
+ m.weight.data[:, 0:3, :, :] = torch.from_numpy(
+ np.flip(weight[:, 0:3, :, :], axis=1).copy())
+ m.weight.data[:, 3:6, :, :] = torch.from_numpy(
+ np.flip(weight[:, 3:6, :, :], axis=1).copy())
+ if m.bias is not None:
+ m.bias.data[:] = torch.from_numpy(bias)
+ else:
+ m.weight.data[:, :, :, :] = torch.from_numpy(weight)
+ if m.bias is not None:
+ m.bias.data[:] = torch.from_numpy(bias)
+ i = i + 1
+ return
+
+
+def parse_flownetsd(modules, weights, biases, param_prefix='netsd_'):
+ keys = [
+ 'conv0',
+ 'conv1',
+ 'conv1_1',
+ 'conv2',
+ 'conv2_1',
+ 'conv3',
+ 'conv3_1',
+ 'conv4',
+ 'conv4_1',
+ 'conv5',
+ 'conv5_1',
+ 'conv6',
+ 'conv6_1',
+
+ 'deconv5',
+ 'deconv4',
+ 'deconv3',
+ 'deconv2',
+
+ 'interconv5',
+ 'interconv4',
+ 'interconv3',
+ 'interconv2',
+
+ 'Convolution1',
+ 'Convolution2',
+ 'Convolution3',
+ 'Convolution4',
+ 'Convolution5',
+
+ 'upsample_flow6to5',
+ 'upsample_flow5to4',
+ 'upsample_flow4to3',
+ 'upsample_flow3to2',
+ ]
+ for i, k in enumerate(keys):
+ keys[i] = param_prefix + k
+
+ i = 0
+ for m in modules:
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ weight = weights[keys[i]].copy()
+ bias = biases[keys[i]].copy()
+ if keys[i] == param_prefix + 'conv0':
+ m.weight.data[:, 0:3, :, :] = torch.from_numpy(
+ np.flip(weight[:, 0:3, :, :], axis=1).copy())
+ m.weight.data[:, 3:6, :, :] = torch.from_numpy(
+ np.flip(weight[:, 3:6, :, :], axis=1).copy())
+ if m.bias is not None:
+ m.bias.data[:] = torch.from_numpy(bias)
+ else:
+ m.weight.data[:, :, :, :] = torch.from_numpy(weight)
+ if m.bias is not None:
+ m.bias.data[:] = torch.from_numpy(bias)
+ i = i + 1
+
+ return
+
+
+def parse_flownetfusion(modules, weights, biases, param_prefix='fuse_'):
+ keys = [
+ 'conv0',
+ 'conv1',
+ 'conv1_1',
+ 'conv2',
+ 'conv2_1',
+
+ 'deconv1',
+ 'deconv0',
+
+ 'interconv1',
+ 'interconv0',
+
+ '_Convolution5',
+ '_Convolution6',
+ '_Convolution7',
+
+ 'upsample_flow2to1',
+ 'upsample_flow1to0',
+ ]
+ for i, k in enumerate(keys):
+ keys[i] = param_prefix + k
+
+ i = 0
+ for m in modules:
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ weight = weights[keys[i]].copy()
+ bias = biases[keys[i]].copy()
+ if keys[i] == param_prefix + 'conv0':
+ m.weight.data[:, 0:3, :, :] = torch.from_numpy(
+ np.flip(weight[:, 0:3, :, :], axis=1).copy())
+ m.weight.data[:, 3::, :, :] = torch.from_numpy(
+ weight[:, 3:, :, :].copy())
+ if m.bias is not None:
+ m.bias.data[:] = torch.from_numpy(bias)
+ else:
+ m.weight.data[:, :, :, :] = torch.from_numpy(weight)
+ if m.bias is not None:
+ m.bias.data[:] = torch.from_numpy(bias)
+ i = i + 1
+
+ return
diff --git a/imaginaire/third_party/flow_net/flownet2/utils/tools.py b/imaginaire/third_party/flow_net/flownet2/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6c9208117824c712aa54e3a3871273b851de63c
--- /dev/null
+++ b/imaginaire/third_party/flow_net/flownet2/utils/tools.py
@@ -0,0 +1,194 @@
+# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# This work is made available under the Nvidia Source Code License-NC.
+# To view a copy of this license, check out LICENSE.md
+import os
+import time
+import math
+import subprocess
+import shutil
+from os.path import join
+import numpy as np
+from inspect import isclass
+from pytz import timezone
+from datetime import datetime
+import inspect
+import torch
+
+
+def datestr():
+ pacific = timezone('US/Pacific')
+ now = datetime.now(pacific)
+ return '{}{:02}{:02}_{:02}{:02}'.format(
+ now.year, now.month, now.day, now.hour, now.minute)
+
+
+def module_to_dict(module, exclude=[]):
+ return dict([(x, getattr(module, x)) for x in dir(module)
+ if isclass(getattr(module, x))
+ and x not in exclude
+ and getattr(module, x) not in exclude])
+
+
+class TimerBlock:
+ def __init__(self, title):
+ print(("{}".format(title)))
+
+ def __enter__(self):
+ self.start = time.clock()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.end = time.clock()
+ self.interval = self.end - self.start
+
+ if exc_type is not None:
+ self.log("Operation failed\n")
+ else:
+ self.log("Operation finished\n")
+
+ def log(self, string):
+ duration = time.clock() - self.start
+ units = 's'
+ if duration > 60:
+ duration = duration / 60.
+ units = 'm'
+ print((" [{:.3f}{}] {}".format(duration, units, string)))
+
+ def log2file(self, fid, string):
+ fid = open(fid, 'a')
+ fid.write("%s\n" % (string))
+ fid.close()
+
+
+def add_arguments_for_module(
+ parser,
+ module,
+ argument_for_class,
+ default,
+ skip_params=[],
+ parameter_defaults={}):
+ argument_group = parser.add_argument_group(argument_for_class.capitalize())
+
+ module_dict = module_to_dict(module)
+ argument_group.add_argument(
+ '--' + argument_for_class,
+ type=str,
+ default=default,
+ choices=list(
+ module_dict.keys()))
+
+ args, unknown_args = parser.parse_known_args()
+ class_obj = module_dict[vars(args)[argument_for_class]]
+
+ argspec = inspect.getargspec(class_obj.__init__)
+
+ defaults = argspec.defaults[::-1] if argspec.defaults else None
+
+ args = argspec.args[::-1]
+ for i, arg in enumerate(args):
+ cmd_arg = '{}_{}'.format(argument_for_class, arg)
+ if arg not in skip_params + ['self', 'args']:
+ if arg in list(parameter_defaults.keys()):
+ argument_group.add_argument(
+ '--{}'.format(cmd_arg),
+ type=type(
+ parameter_defaults[arg]),
+ default=parameter_defaults[arg])
+ elif (defaults is not None and i < len(defaults)):
+ argument_group.add_argument(
+ '--{}'.format(cmd_arg),
+ type=type(
+ defaults[i]),
+ default=defaults[i])
+ else:
+ print(("[Warning]: non-default argument '{}' "
+ "detected on class '{}'. This argument "
+ "cannot be modified via the command line"
+ .format(arg, module.__class__.__name__)))
+ # We don't have a good way of dealing with
+ # inferring the type of the argument
+ # TODO: try creating a custom action and using ast's infer type?
+ # else:
+ # argument_group.add_argument('--{}'.format(
+ # cmd_arg), required=True)
+
+
+def kwargs_from_args(args, argument_for_class):
+ argument_for_class = argument_for_class + '_'
+ return {key[len(argument_for_class):]: value for key, value in list(vars(
+ args).items()) if
+ argument_for_class in key and key != argument_for_class + 'class'}
+
+
+def format_dictionary_of_losses(labels, values):
+ try:
+ string = ', '.join([('{}: {:' +
+ ('.3f' if value >= 0.001 else '.1e') +
+ '}').format(name, value) for name, value in
+ zip(labels, values)])
+ except (TypeError, ValueError) as e:
+ print((list(zip(labels, values))))
+ string = '[Log Error] ' + str(e)
+
+ return string
+
+
+class IteratorTimer():
+ def __init__(self, iterable):
+ self.iterable = iterable
+ self.iterator = self.iterable.__iter__()
+
+ def __iter__(self):
+ return self
+
+ def __len__(self):
+ return len(self.iterable)
+
+ def __next__(self):
+ start = time.time()
+ n = next(self.iterator)
+ self.last_duration = (time.time() - start)
+ return n
+
+ next = __next__
+
+
+def gpumemusage():
+ gpu_mem = subprocess.check_output(
+ "nvidia-smi | grep MiB | cut -f 3 -d '|'",
+ shell=True).replace(
+ ' ',
+ '').replace(
+ '\n',
+ '').replace(
+ 'i',
+ '')
+ all_stat = [float(a) for a in gpu_mem.replace('/', '').split('MB')[:-1]]
+
+ gpu_mem = ''
+ for i in range(len(all_stat) / 2):
+ curr, tot = all_stat[2 * i], all_stat[2 * i + 1]
+ util = "%1.2f" % (100 * curr / tot) + '%'
+ cmem = str(int(math.ceil(curr / 1024.))) + 'GB'
+ gmem = str(int(math.ceil(tot / 1024.))) + 'GB'
+ gpu_mem += util + '--' + join(cmem, gmem) + ' '
+ return gpu_mem
+
+
+def update_hyperparameter_schedule(args, epoch, global_iteration, optimizer):
+ if args.schedule_lr_frequency > 0:
+ for param_group in optimizer.param_groups:
+ if (global_iteration + 1) % args.schedule_lr_frequency == 0:
+ param_group['lr'] /= float(args.schedule_lr_fraction)
+ param_group['lr'] = float(
+ np.maximum(param_group['lr'], 0.000001))
+
+
+def save_checkpoint(state, is_best, path, prefix,
+ filename='checkpoint.pth.tar'):
+ prefix_save = os.path.join(path, prefix)
+ name = prefix_save + '_' + filename
+ torch.save(state, name)
+ if is_best:
+ shutil.copyfile(name, prefix_save + '_model_best.pth.tar')
diff --git a/imaginaire/third_party/resample2d/resample2d.py b/imaginaire/third_party/resample2d/resample2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbdea3fa9941894090aa124adda1b62e1ea5e012
--- /dev/null
+++ b/imaginaire/third_party/resample2d/resample2d.py
@@ -0,0 +1,62 @@
+# flake8: noqa
+from torch.nn.modules.module import Module
+from torch.autograd import Function, Variable
+from torch.cuda.amp import autocast
+import resample2d_cuda
+
+
+class Resample2dFunction(Function):
+
+ @staticmethod
+ # def forward(ctx, input1, input2, kernel_size=1, bilinear=True):
+ def forward(ctx, input1, input2, kernel_size=1):
+ assert input1.is_contiguous()
+ assert input2.is_contiguous()
+
+ ctx.save_for_backward(input1, input2)
+ ctx.kernel_size = kernel_size
+ ctx.bilinear = True
+
+ _, d, _, _ = input1.size()
+ b, _, h, w = input2.size()
+ output = input1.new(b, d, h, w).zero_()
+
+ resample2d_cuda.forward(input1, input2, output, kernel_size)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_output = grad_output.contiguous()
+ assert grad_output.is_contiguous()
+
+ input1, input2 = ctx.saved_tensors
+
+ grad_input1 = Variable(input1.new(input1.size()).zero_())
+ grad_input2 = Variable(input1.new(input2.size()).zero_())
+
+ # resample2d_cuda.backward(input1, input2, grad_output.data,
+ # grad_input1.data, grad_input2.data,
+ # ctx.kernel_size, ctx.bilinear)
+ resample2d_cuda.backward(input1, input2, grad_output.data,
+ grad_input1.data, grad_input2.data,
+ ctx.kernel_size)
+
+ return grad_input1, grad_input2, None, None
+
+
+class Resample2d(Module):
+
+ def __init__(self, kernel_size=1, bilinear=True):
+ super(Resample2d, self).__init__()
+ self.kernel_size = kernel_size
+ self.bilinear = bilinear
+
+ @autocast(False)
+ def forward(self, input1, input2):
+ input1, input2 = input1.float(), input2.float()
+ input1_c = input1.contiguous()
+ # return Resample2dFunction.apply(
+ # input1_c, input2, self.kernel_size, self.bilinear)
+ return Resample2dFunction.apply(
+ input1_c, input2, self.kernel_size)
\ No newline at end of file
diff --git a/imaginaire/third_party/resample2d/setup.py b/imaginaire/third_party/resample2d/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..54c14d9743bf514e01661bd1bc307a4e3f8986fe
--- /dev/null
+++ b/imaginaire/third_party/resample2d/setup.py
@@ -0,0 +1,43 @@
+# flake8: noqa
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+import os
+
+
+cuda_version = os.getenv('CUDA_VERSION')
+print('CUDA_VERSION: {}'.format(cuda_version))
+
+nvcc_args = list()
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_50,code=sm_50')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_52,code=sm_52')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_60,code=sm_60')
+# nvcc_args.append('-gencode')
+# nvcc_args.append('arch=compute_61,code=sm_61')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_70,code=sm_70')
+nvcc_args.append('-gencode')
+nvcc_args.append('arch=compute_75,code=sm_75')
+if cuda_version is not None:
+ if cuda_version >= '11.0':
+ nvcc_args.append('-gencode')
+ nvcc_args.append('arch=compute_80,code=sm_80')
+nvcc_args.append('-Xcompiler')
+nvcc_args.append('-Wall')
+nvcc_args.append('-std=c++14')
+
+setup(
+ name='resample2d_cuda',
+ py_modules=['resample2d'],
+ ext_modules=[
+ CUDAExtension('resample2d_cuda', [
+ './src/resample2d_cuda.cc',
+ './src/resample2d_kernel.cu'
+ ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'],
+ 'nvcc': nvcc_args})
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ })
diff --git a/imaginaire/third_party/resample2d/src/resample2d_cuda.cc b/imaginaire/third_party/resample2d/src/resample2d_cuda.cc
new file mode 100644
index 0000000000000000000000000000000000000000..b330a06bc0f20fe82c275e9a784f7ed91faf7717
--- /dev/null
+++ b/imaginaire/third_party/resample2d/src/resample2d_cuda.cc
@@ -0,0 +1,34 @@
+#include
+#include
+
+#include "resample2d_kernel.cuh"
+
+int resample2d_cuda_forward(
+ at::Tensor& input1,
+ at::Tensor& input2,
+ at::Tensor& output,
+ int kernel_size/*, bool bilinear*/) {
+ resample2d_kernel_forward(input1, input2, output, kernel_size/*,
+ bilinear*/);
+ return 1;
+}
+
+int resample2d_cuda_backward(
+ at::Tensor& input1,
+ at::Tensor& input2,
+ at::Tensor& gradOutput,
+ at::Tensor& gradInput1,
+ at::Tensor& gradInput2,
+ int kernel_size/*, bool bilinear*/) {
+ resample2d_kernel_backward(input1, input2, gradOutput, gradInput1,
+ gradInput2, kernel_size/*, bilinear*/);
+ return 1;
+}
+
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)");
+ m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)");
+}
+
diff --git a/imaginaire/third_party/resample2d/src/resample2d_kernel.cu b/imaginaire/third_party/resample2d/src/resample2d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..654ca8e417b6d22ff623d72d19e24997a2db284c
--- /dev/null
+++ b/imaginaire/third_party/resample2d/src/resample2d_kernel.cu
@@ -0,0 +1,328 @@
+#include
+#include
+#include
+
+#define CUDA_NUM_THREADS 512
+#define THREADS_PER_BLOCK 64
+
+#define DIM0(TENSOR) ((TENSOR).x)
+#define DIM1(TENSOR) ((TENSOR).y)
+#define DIM2(TENSOR) ((TENSOR).z)
+#define DIM3(TENSOR) ((TENSOR).w)
+
+#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
+
+template
+__global__ void kernel_resample2d_update_output(const int n,
+ const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+ const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
+ scalar_t* __restrict__ output,
+ const long4 output_size, const
+ long4 output_stride, int
+ kernel_size/*, bool bilinear*/) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ bool bilinear = true;
+ if (index >= n) {
+ return;
+ }
+
+ scalar_t val = 0.0f;
+
+ int dim_b = DIM0(output_size);
+ int dim_c = DIM1(output_size);
+ int dim_h = DIM2(output_size);
+ int dim_w = DIM3(output_size);
+ int dim_chw = dim_c * dim_h * dim_w;
+ int dim_hw = dim_h * dim_w;
+
+ int b = ( index / dim_chw ) % dim_b;
+ int c = ( index / dim_hw ) % dim_c;
+ int y = ( index / dim_w ) % dim_h;
+ int x = ( index ) % dim_w;
+
+ scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
+ scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
+
+ scalar_t xf = static_cast(x) + dx;
+ scalar_t yf = static_cast(y) + dy;
+ scalar_t alpha = xf - floor(xf); // alpha
+ scalar_t beta = yf - floor(yf); // beta
+
+ if (bilinear) {
+ int xL = max(min( int (floor(xf)), dim_w-1), 0);
+ int xR = max(min( int (floor(xf)+1), dim_w -1), 0);
+ int yT = max(min( int (floor(yf)), dim_h-1), 0);
+ int yB = max(min( int (floor(yf)+1), dim_h-1), 0);
+
+ for (int fy = 0; fy < kernel_size; fy += 1) {
+ for (int fx = 0; fx < kernel_size; fx += 1) {
+ val += static_cast((1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx));
+ val += static_cast((alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx));
+ val += static_cast((1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx));
+ val += static_cast((alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx));
+ }
+ }
+
+ output[index] = val;
+ }
+ else {
+ int xN = max(min( int (floor(xf + 0.5)), dim_w - 1), 0);
+ int yN = max(min( int (floor(yf + 0.5)), dim_h - 1), 0);
+
+ output[index] = static_cast ( DIM3_INDEX(input1, b, c, yN, xN) );
+ }
+
+}
+
+
+template
+__global__ void kernel_resample2d_backward_input1(
+ const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+ const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
+ const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
+ scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4
+ gradInput_stride, int kernel_size/*, bool bilinear*/) {
+
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ bool bilinear = true;
+ if (index >= n) {
+ return;
+ }
+
+ int dim_b = DIM0(gradOutput_size);
+ int dim_c = DIM1(gradOutput_size);
+ int dim_h = DIM2(gradOutput_size);
+ int dim_w = DIM3(gradOutput_size);
+ int dim_chw = dim_c * dim_h * dim_w;
+ int dim_hw = dim_h * dim_w;
+
+ int b = ( index / dim_chw ) % dim_b;
+ int c = ( index / dim_hw ) % dim_c;
+ int y = ( index / dim_w ) % dim_h;
+ int x = ( index ) % dim_w;
+
+ scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
+ scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
+
+ scalar_t xf = static_cast(x) + dx;
+ scalar_t yf = static_cast(y) + dy;
+ scalar_t alpha = xf - int(xf); // alpha
+ scalar_t beta = yf - int(yf); // beta
+
+ int idim_h = DIM2(input1_size);
+ int idim_w = DIM3(input1_size);
+
+ int xL = max(min( int (floor(xf)), idim_w-1), 0);
+ int xR = max(min( int (floor(xf)+1), idim_w -1), 0);
+ int yT = max(min( int (floor(yf)), idim_h-1), 0);
+ int yB = max(min( int (floor(yf)+1), idim_h-1), 0);
+
+ for (int fy = 0; fy < kernel_size; fy += 1) {
+ for (int fx = 0; fx < kernel_size; fx += 1) {
+ atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+ atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)), (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+ atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)), (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+ atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)), (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x));
+ }
+ }
+
+}
+
+template
+__global__ void kernel_resample2d_backward_input2(
+ const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
+ const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
+ const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
+ scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4
+ gradInput_stride, int kernel_size/*, bool bilinear*/) {
+
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ bool bilinear = true;
+ if (index >= n) {
+ return;
+ }
+
+ scalar_t output = 0.0;
+ int kernel_rad = (kernel_size - 1)/2;
+
+ int dim_b = DIM0(gradInput_size);
+ int dim_c = DIM1(gradInput_size);
+ int dim_h = DIM2(gradInput_size);
+ int dim_w = DIM3(gradInput_size);
+ int dim_chw = dim_c * dim_h * dim_w;
+ int dim_hw = dim_h * dim_w;
+
+ int b = ( index / dim_chw ) % dim_b;
+ int c = ( index / dim_hw ) % dim_c;
+ int y = ( index / dim_w ) % dim_h;
+ int x = ( index ) % dim_w;
+
+ int odim_c = DIM1(gradOutput_size);
+
+ scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
+ scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
+
+ scalar_t xf = static_cast(x) + dx;
+ scalar_t yf = static_cast(y) + dy;
+
+ int xL = max(min( int (floor(xf)), dim_w-1), 0);
+ int xR = max(min( int (floor(xf)+1), dim_w -1), 0);
+ int yT = max(min( int (floor(yf)), dim_h-1), 0);
+ int yB = max(min( int (floor(yf)+1), dim_h-1), 0);
+
+ if (c % 2) {
+ float gamma = 1 - (xf - floor(xf)); // alpha
+ for (int i = 0; i <= 2*kernel_rad; ++i) {
+ for (int j = 0; j <= 2*kernel_rad; ++j) {
+ for (int ch = 0; ch < odim_c; ++ch) {
+ output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i));
+ output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i));
+ output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i));
+ output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i));
+ }
+ }
+ }
+ }
+ else {
+ float gamma = 1 - (yf - floor(yf)); // alpha
+ for (int i = 0; i <= 2*kernel_rad; ++i) {
+ for (int j = 0; j <= 2*kernel_rad; ++j) {
+ for (int ch = 0; ch < odim_c; ++ch) {
+ output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i));
+ output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i));
+ output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i));
+ output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i));
+ }
+ }
+ }
+
+ }
+
+ gradInput[index] = output;
+
+}
+
+void resample2d_kernel_forward(
+ at::Tensor& input1,
+ at::Tensor& input2,
+ at::Tensor& output,
+ int kernel_size/*,
+ bool bilinear*/) {
+
+ int n = output.numel();
+
+ const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
+ const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
+
+ const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3));
+ const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3));
+
+ const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
+ const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
+
+ // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF
+// AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] {
+
+ kernel_resample2d_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
+//at::globalContext().getCurrentCUDAStream() >>>(
+ n,
+ input1.data(),
+ input1_size,
+ input1_stride,
+ input2.data(),
+ input2_size,
+ input2_stride,
+ output.data