diff --git a/README.md b/README.md index 44644c0def69245cf1e7d20e921c928e36e0f6d9..8264706fc779a73624707592d59f3b8257ea32be 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,122 @@ ---- -title: Sat3density -emoji: 🏆 -colorFrom: green -colorTo: blue -sdk: gradio -sdk_version: 3.41.2 -app_file: app.py -pinned: false -license: other ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Sat2Density: Faithful Density Learning from Satellite-Ground Image Pairs + +> [Ming Qian](https://qianmingduowan.github.io/), Jincheng Xiong, [Gui-Song Xia](http://www.captain-whu.com/xia_En.html), [Nan Xue](https://xuenan.net) +> +> IEEE/CVF International Conference on Computer Vision (ICCV), 2023 +> +> [Project](https://sat2density.github.io/) | [Paper](https://arxiv.org/abs/2303.14672) | [Data]() | [Install.md](docs/INSTALL.md) + +>

+> drawing +> drawing +> drawing +>

+ +>

+> drawing +> drawing +> drawing +>

+ +>

+> drawing +> drawing +> drawing +>

+ +>

+> drawing +> drawing +> drawing +>

+ +## Checkpoints Downloading +> Two checkpoints for CVACT and CVUSA can be found from [thisurl](https://github.com/sat2density/checkpoints/releases). You can also run the following command to download them. +``` +bash scripts/download_weights.sh +``` + +## QuickStart Demo +### Video Synthesis + #### Example Usage + ``` + python test.py --yaml=sat2density_cvact \ + --test_ckpt_path=2u87bj8w \ + --task=test_vid \ + --demo_img=demo_img/case1/satview-input.png \ + --sty_img=demo_img/case1/groundview.image.png \ + --save_dir=results/case1 + ``` + #### + +### Illumination Interpolation + +``` +python test.py --task=test_interpolation \ +--yaml=sat2density_cvact \ +--test_ckpt_path=2u87bj8w \ +--sty_img1=demo_img/case9/groundview.image.png \ +--sty_img2=demo_img/case7/groundview.image.png \ +--demo_img=demo_img/case3/satview-input.png \ +--save_dir=results/case2 +``` + +## Train & Inference +- *We trained our model using 1 V100 32GB GPU. The training phase will take about 20 hours.* +- *For data preparation, please check out [data.md](dataset/INSTALL.md).* + + + + +### Inference + +To test Center Ground-View Synthesis setting +If you want save results, please add --task=vis_test +```bash +# CVACT +python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w +# CVUSA +python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4 +``` + +To test inference with different illumination +```bash +# CVACT +bash inference/single_style_test_cvact.sh +# CVUSA +bash inference/single_style_test_cvusa.sh +``` + +To test synthesis ground videos +```bash +bash inference/synthesis_video.sh +``` + +## Training + +### Training command + +```bash +# CVACT +CUDA_VISIBLE_DEVICES=X python train.py --yaml=sat2density_cvact +# CVUSA +CUDA_VISIBLE_DEVICES=X python train.py --yaml=sat2density_cvusa +``` + +## Citation +If you use this code for your research, please cite + +``` +@inproceedings{qian2021sat2density, + title={Sat2Density: Faithful Density Learning from Satellite-Ground Image Pairs}, + author={Qian, Ming and Xiong, Jincheng and Xia, Gui-Song and Xue, Nan}, + booktitle={ICCV}, + year={2023} +} +``` + +## License +This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License. +For commercial use, please contact [mingqian@whu.edu.cn]. diff --git a/__pycache__/options.cpython-38.pyc b/__pycache__/options.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae86524e6a0305cbf3f915a831f72c3d7e62ade3 Binary files /dev/null and b/__pycache__/options.cpython-38.pyc differ diff --git a/__pycache__/test.cpython-38.pyc b/__pycache__/test.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c26acf14978dd41cdcadaf20b1fec863948c7f3 Binary files /dev/null and b/__pycache__/test.cpython-38.pyc differ diff --git a/__pycache__/utils.cpython-38.pyc b/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c490c388bbf3dd8d18683faeebf085e3623d0b76 Binary files /dev/null and b/__pycache__/utils.cpython-38.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e56f1de13f4647636798b3e51522d193342b9ea4 --- /dev/null +++ b/app.py @@ -0,0 +1,244 @@ +import gradio as gr +import numpy as np +import os +from PIL import Image +import torch +import torchvision.transforms as transforms +import options +import test +import importlib +from scipy.interpolate import interp1d, splev, splprep +import cv2 + + +def get_single(sat_img, style_img, x_offset, y_offset): + name = '' + for i in [name for name in os.listdir('demo_img') if 'case' in name]: + style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB') + style =np.array(style) + if (style == style_img).all(): + name = i + break + + input_dict = {} + trans = transforms.ToTensor() + input_dict['sat'] = trans(sat_img) + input_dict['pano'] = trans(style_img) + input_dict['paths'] = "demo.png" + sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L")) + input_a = input_dict['pano']*sky + sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) + input_dict['sky_histc'] = sky_histc + input_dict['sky_mask'] = sky + + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].unsqueeze(0) + + args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png", + "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"] + opt_cmd = options.parse_arguments(args=args) + opt = options.set(opt_cmd=opt_cmd) + opt.isTrain = False + opt.name = opt.yaml if opt.name is None else opt.name + opt.batch_size = 1 + + m = importlib.import_module("model.{}".format(opt.model)) + model = m.Model(opt) + + # m.load_dataset(opt) + model.build_networks(opt) + ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') + model.netG.load_state_dict(ckpt['netG']) + model.netG.eval() + + model.set_input(input_dict) + + model.style_temp = model.sky_histc + opt.origin_H_W = [-(y_offset*256-128)/128, (x_offset*256-128)/128] # TODO: hard code should be removed in the future + + model.forward(opt) + + rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0)) + rgb = np.array(rgb*255, dtype=np.uint8) + return rgb + +def get_video(sat_img, style_img, positions): + name = '' + for i in [name for name in os.listdir('demo_img') if 'case' in name]: + style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB') + style =np.array(style) + if (style == style_img).all(): + name = i + break + + input_dict = {} + trans = transforms.ToTensor() + input_dict['sat'] = trans(sat_img) + input_dict['pano'] = trans(style_img) + input_dict['paths'] = "demo.png" + sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L")) + input_a = input_dict['pano']*sky + sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) + input_dict['sky_histc'] = sky_histc + input_dict['sky_mask'] = sky + + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].unsqueeze(0) + + args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png", + "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"] + opt_cmd = options.parse_arguments(args=args) + opt = options.set(opt_cmd=opt_cmd) + opt.isTrain = False + opt.name = opt.yaml if opt.name is None else opt.name + opt.batch_size = 1 + + m = importlib.import_module("model.{}".format(opt.model)) + model = m.Model(opt) + + # m.load_dataset(opt) + model.build_networks(opt) + ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') + model.netG.load_state_dict(ckpt['netG']) + model.netG.eval() + + model.set_input(input_dict) + + model.style_temp = model.sky_histc + + unique_lst = list(dict.fromkeys(positions)) + pixels = [] + for x in positions: + if x in unique_lst: + if x not in pixels: + pixels.append(x) + pixels = np.array(pixels) + tck, u = splprep(pixels.T, s=25, per=0) + u_new = np.linspace(u.min(), u.max(), 80) + x_new, y_new = splev(u_new, tck) + smooth_path = np.array([x_new,y_new]).T + + rendered_image_list = [] + rendered_depth_list = [] + + + for i, (x,y) in enumerate(smooth_path): + opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future + print('Rendering at ({}, {})'.format(x,y)) + model.forward(opt) + + rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0)) + rgb = np.array(rgb*255, dtype=np.uint8) + rendered_image_list.append(rgb) + + rendered_depth_list.append( + model.out_put.depth[0,0].cpu().detach().numpy() + ) + + output_video_path = 'output_video.mp4' + + # 设置视频的帧率、宽度和高度 + frame_rate = 15 + frame_width = 512 + frame_height = 128 + + # 使用OpenCV创建视频写入对象,选择H.264编码器 + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height)) + + # 遍历图像列表并将它们写入视频 + for image_np in rendered_image_list: + image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) + out.write(image_np) + + # 释放视频写入对象 + out.release() + + return "output_video.mp4" + +def copy_image(image): + return image + +def show_image_and_point(image, x, y): + x = int(x*image.shape[1]) + y = image.shape[0]-int(y*image.shape[0]) + mask = np.zeros(image.shape[:2]) + radius = min(image.shape[0], image.shape[1])//60 + for i in range(x-radius-2, x+radius+2): + for j in range(y-radius-2, y+radius+2): + if (i-x)**2+(j-y)**2<=radius**2: + mask[j, i] = 1 + return (image, [(mask, 'render point')]) + +def add_select_point(image, evt: gr.SelectData, state1): + if state1 == None: + state1 = [] + x, y = evt.index + state1.append((x, y)) + print(state1) + radius = min(image.shape[0], image.shape[1])//60 + for i in range(x-radius-2, x+radius+2): + for j in range(y-radius-2, y+radius+2): + if (i-x)**2+(j-y)**2<=radius**2: + image[j, i, :] = 0 + return image, state1 + +def reset_select_points(image): + return image, [] + + + + + + +with gr.Blocks() as demo: + gr.Markdown("# Sat2Density Demos") + gr.Markdown("### select/upload the satllite image and select the style image") + with gr.Row(): + with gr.Column(): + sat_img = gr.Image(source='upload', shape=[256, 256], interactive=True) + img_examples = gr.Examples(examples=['demo_img/{}/satview-input.png'.format(i) for i in os.listdir('demo_img') if 'case' in i], + inputs=sat_img, outputs=None, examples_per_page=20) + with gr.Column(): + style_img = gr.Image() + style_examples = gr.Examples(examples=['demo_img/{}/groundview.image.png'.format(i) for i in os.listdir('demo_img') if 'case' in i], + inputs=style_img, outputs=None, examples_per_page=20) + + + gr.Markdown("### select a certain point to generate single groundview image") + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + slider_x = gr.Slider(0.2, 0.8, 0.5, label="x-axis position") + slider_y = gr.Slider(0.2, 0.8, 0.5, label="y-axis position") + btn_single = gr.Button(label="demo1") + + annotation_image = gr.AnnotatedImage() + + out_single = gr.Image() + + gr.Markdown("### draw a trajectory on the map to generate video") + state_select_points = gr.State() + with gr.Row(): + with gr.Column(): + draw_img = gr.Image(shape=[256, 256], interactive=True) + with gr.Column(): + out_video = gr.Video() + reset_btn =gr.Button(value="Reset") + btn_video = gr.Button(label="demo1") + + sat_img.change(copy_image, inputs = sat_img, outputs=draw_img) + + draw_img.select(add_select_point, [draw_img, state_select_points], [draw_img, state_select_points]) + sat_img.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image) + slider_x.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden') + slider_y.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden') + btn_single.click(get_single, inputs = [sat_img, style_img, slider_x, slider_y], outputs=out_single) + reset_btn.click(reset_select_points, [sat_img], [draw_img, state_select_points]) + btn_video.click(get_video, inputs=[sat_img, style_img, state_select_points], outputs=out_video) # 触发 + + +demo.launch() \ No newline at end of file diff --git a/data/CVACT_Shi.py b/data/CVACT_Shi.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4588ccbe61427e4bec55e7d32bb59428c1228f --- /dev/null +++ b/data/CVACT_Shi.py @@ -0,0 +1,119 @@ +import torch,os +from torch.utils.data.dataset import Dataset +from PIL import Image +import scipy.io as sio +import torchvision.transforms as transforms + +def data_list(img_root,mode): + exist_aer_list = os.listdir(os.path.join(img_root , 'satview_correct')) + exist_grd_list = os.listdir(os.path.join(img_root , 'streetview')) + allDataList = os.path.join(img_root, 'ACT_data.mat') + anuData = sio.loadmat(allDataList) + + all_data_list = [] + for i in range(0, len(anuData['panoIds'])): + grd_id_align = anuData['panoIds'][i] + '_grdView.png' + sat_id_ori = anuData['panoIds'][i] + '_satView_polish.png' + all_data_list.append([grd_id_align, sat_id_ori]) + + data_list = [] + + if mode=='train': + training_inds = anuData['trainSet']['trainInd'][0][0] - 1 + trainNum = len(training_inds) + for k in range(trainNum): + data_list.append(all_data_list[training_inds[k][0]]) + else: + val_inds = anuData['valSet']['valInd'][0][0] - 1 + valNum = len(val_inds) + for k in range(valNum): + data_list.append(all_data_list[val_inds[k][0]]) + + + pano_list = [img_root + 'streetview/' + item[0] for item in data_list if item[0] in exist_grd_list and item[1] in exist_aer_list] + + return pano_list + +def img_read(img,size=None,datatype='RGB'): + img = Image.open(img).convert('RGB' if datatype=='RGB' else "L") + if size: + if type(size) is int: + size = (size,size) + img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST) + img = transforms.ToTensor()(img) + return img + + +class Dataset(Dataset): + def __init__(self, opt,split='train',sub=None,sty_img=None): + if sty_img: + assert sty_img.endswith('grdView.png') + demo_img_path = os.path.join(opt.data.root,'streetview',sty_img) + self.pano_list = [demo_img_path] + + elif opt.task in ['test_vid','test_interpolation'] : + demo_img_path = os.path.join(opt.data.root,'streetview',opt.demo_img.replace('satView_polish.png','grdView.png')) + self.pano_list = [demo_img_path] + + else: + self.pano_list = data_list(img_root=opt.data.root,mode=split) + if sub: + self.pano_list = self.pano_list[:sub] + + # select some ground images to test the influence of different skys. + # different skys guide different illumination intensity, colors, and etc. + if opt.task == 'test_sty': + demo_name = [ + 'dataset/CVACT/streetview/pPfo7qQ1fP_24rXrJ2Uxog_grdView.png', + 'dataset/CVACT/streetview/YL81FiK9PucIvAkr1FHkpA_grdView.png', + 'dataset/CVACT/streetview/Tzis1jBKHjbXiVB2oRYwAQ_grdView.png', + 'dataset/CVACT/streetview/eqGgeBLGXRhSj6c-0h0KoQ_grdView.png', + 'dataset/CVACT/streetview/pdZmLHYEhe2PHj_8-WHMhw_grdView.png', + 'dataset/CVACT/streetview/ehsu9Q3iTin5t52DM-MwyQ_grdView.png', + 'dataset/CVACT/streetview/agLEcuq3_-qFj7wwGbktVg_grdView.png', + 'dataset/CVACT/streetview/HwQIDdMI3GfHyPGtCSo6aA_grdView.png', + 'dataset/CVACT/streetview/hV8svb3ZVXcQ0AtTRFE1dQ_grdView.png', + 'dataset/CVACT/streetview/fzq2mBfKP3UIczAd9KpMMg_grdView.png', + 'dataset/CVACT/streetview/acRP98sACUIlwl2ZIsEyiQ_grdView.png', + 'dataset/CVACT/streetview/WSh9tNVryLdupUlU0ri2tQ_grdView.png', + 'dataset/CVACT/streetview/FhEuB9NA5o08VJ_TBCbHjw_grdView.png', + 'dataset/CVACT/streetview/YHfpn2Mgu1lqgT2OUeBpOg_grdView.png', + 'dataset/CVACT/streetview/vNhv7ZP1dUkJ93UwFXagJw_grdView.png', + ] + self.pano_list = demo_name + + self.opt = opt + + def __len__(self): + return len(self.pano_list) + + def __getitem__(self, index): + pano = self.pano_list[index] + aer = pano.replace('streetview','satview_correct').replace('_grdView','_satView_polish') + if self.opt.data.sky_mask: + sky = pano.replace('streetview','pano_sky_mask') + name = pano + aer = img_read(aer, size = self.opt.data.sat_size) + pano = img_read(pano,size = self.opt.data.pano_size) + if self.opt.data.sky_mask: + sky = img_read(sky,size=self.opt.data.pano_size,datatype='L') + + input = {} + input['sat']=aer + input['pano']=pano + input['paths']=name + if self.opt.data.sky_mask: + input['sky_mask']=sky + black_ground = torch.zeros_like(pano) + if self.opt.data.histo_mode =='grey': + input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:] + elif self.opt.data.histo_mode in ['rgb','RGB']: + input_a = (pano*sky+black_ground*(1-sky)) + for idx in range(len(input_a)): + if idx == 0: + sky_histc = input_a[idx].histc()[10:] + else: + sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0) + input['sky_histc'] = sky_histc + return input + diff --git a/data/CVUSA.py b/data/CVUSA.py new file mode 100644 index 0000000000000000000000000000000000000000..2179427bf165335df33ab0761ea28975f7672a45 --- /dev/null +++ b/data/CVUSA.py @@ -0,0 +1,86 @@ +import torch,os +from torch.utils.data.dataset import Dataset +from PIL import Image +import torchvision.transforms as transforms +import re +from easydict import EasyDict as edict + +def data_list(img_root,mode): + data_list=[] + if mode=='train': + split_file=os.path.join(img_root, 'splits/train-19zl.csv') + with open(split_file) as f: + list = f.readlines() + for i in list: + aerial_name=re.split(r',', re.split('\n', i)[0])[0] + panorama_name = re.split(r',', re.split('\n', i)[0])[1] + data_list.append([aerial_name, panorama_name]) + else: + split_file=os.path.join(img_root+'splits/val-19zl.csv') + with open(split_file) as f: + list = f.readlines() + for i in list: + aerial_name=re.split(r',', re.split('\n', i)[0])[0] + panorama_name = re.split(r',', re.split('\n', i)[0])[1] + data_list.append([aerial_name, panorama_name]) + print('length of dataset is: ', len(data_list)) + return [os.path.join(img_root, i[1]) for i in data_list] + +def img_read(img,size=None,datatype='RGB'): + img = Image.open(img).convert('RGB' if datatype=='RGB' else "L") + if size: + if type(size) is int: + size = (size,size) + img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST) + img = transforms.ToTensor()(img) + return img + + +class Dataset(Dataset): + def __init__(self, opt,split='train',sub=None,sty_img=None): + self.pano_list = data_list(img_root=opt.data.root,mode=split) + if sub: + self.pano_list = self.pano_list[:sub] + if opt.task == 'test_vid': + demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.demo_img) + self.pano_list = [demo_img_path] + if sty_img: + assert opt.sty_img.split('.')[-1] == 'jpg' + demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.sty_img) + self.pano_list = [demo_img_path] + + self.opt = opt + + def __len__(self): + return len(self.pano_list) + + def __getitem__(self, index): + pano = self.pano_list[index] + aer = pano.replace('streetview/panos', 'bingmap/19') + if self.opt.data.sky_mask: + sky = pano.replace('streetview/panos','sky_mask').replace('jpg', 'png') + name = pano + aer = img_read(aer, size = self.opt.data.sat_size) + pano = img_read(pano,size = self.opt.data.pano_size) + if self.opt.data.sky_mask: + sky = img_read(sky,size=self.opt.data.pano_size,datatype='L') + + input = {} + input['sat']=aer + input['pano']=pano + input['paths']=name + if self.opt.data.sky_mask: + input['sky_mask']=sky + black_ground = torch.zeros_like(pano) + if self.opt.data.histo_mode =='grey': + input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:] + elif self.opt.data.histo_mode in ['rgb','RGB']: + input_a = (pano*sky+black_ground*(1-sky)) + for idx in range(len(input_a)): + if idx == 0: + sky_histc = input_a[idx].histc()[10:] + else: + sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0) + input['sky_histc'] = sky_histc + return input + diff --git a/dataset/INSTALL.md b/dataset/INSTALL.md new file mode 100644 index 0000000000000000000000000000000000000000..1d0784b008b078f757c910b16b08edc08a7ed6f8 --- /dev/null +++ b/dataset/INSTALL.md @@ -0,0 +1,32 @@ +For reproduce our paper, + +you should first download 4 zip file: + +` +CVACT/satview_correct.zip , +CVACT/streetview.zip , +CVUSA/bingmap/19.zip , +CVUSA/streetview/panos.zip +` + from [here](https://anu365-my.sharepoint.com/:f:/g/personal/u6293587_anu_edu_au/EuOBUDUQNClJvCpQ8bD1hnoBjdRBWxsHOVp946YVahiMGg?e=F4yRAC), the project page is [Sat2StrPanoramaSynthesis](https://github.com/shiyujiao/Sat2StrPanoramaSynthesis). + +Then download the sky mask from [here](https://drive.google.com/drive/folders/1pfzwONg4P-Mzvxvzb2HoCpuZFynElPCk?usp=sharing) + +Last,the users should organize the dataset just like: +``` +├dataset +├── CVACT +│ ├── streetview +│ ├── satview_correct +│ ├── pano_sky_mask +│ ├── ACT_data.mat +└── CVUSA +│ ├── bingmap +│ │ ├── 19 +│ └── streetview +│ │ ├── panos +│ ├── sky_mask +│ ├── splits +``` + +Tip: The sky masks are processed with [Trans4PASS](https://github.com/jamycheung/Trans4PASS). diff --git a/demo_img/case1/groundview.image.png b/demo_img/case1/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..66f6050ec9e635a04e58242826af667f894adde0 Binary files /dev/null and b/demo_img/case1/groundview.image.png differ diff --git a/demo_img/case1/groundview.sky.png b/demo_img/case1/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..413a82e27fd781030a1314513e859610ab48f652 Binary files /dev/null and b/demo_img/case1/groundview.sky.png differ diff --git a/demo_img/case1/satview-input.png b/demo_img/case1/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..375575a45a1d1b33a87c0133e9237d0b89af9f3e Binary files /dev/null and b/demo_img/case1/satview-input.png differ diff --git a/demo_img/case10/groundview.image.png b/demo_img/case10/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..66f6050ec9e635a04e58242826af667f894adde0 Binary files /dev/null and b/demo_img/case10/groundview.image.png differ diff --git a/demo_img/case10/groundview.sky.png b/demo_img/case10/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..5252275f364573f59f2f3b57a177327c297f523a Binary files /dev/null and b/demo_img/case10/groundview.sky.png differ diff --git a/demo_img/case10/satview-input.png b/demo_img/case10/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..efe5929d06fb9be3eeda5deb84733614ce571b80 Binary files /dev/null and b/demo_img/case10/satview-input.png differ diff --git a/demo_img/case11/groundview.image.png b/demo_img/case11/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..a4c7ad7b5281dd57dc161173699a570592978f6a Binary files /dev/null and b/demo_img/case11/groundview.image.png differ diff --git a/demo_img/case11/groundview.sky.png b/demo_img/case11/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..31d697c17321a98ad5838420419a2c8028da20b4 Binary files /dev/null and b/demo_img/case11/groundview.sky.png differ diff --git a/demo_img/case11/satview-input.png b/demo_img/case11/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..78c725dc36987de0b11c1f044294deeb86d7e20c Binary files /dev/null and b/demo_img/case11/satview-input.png differ diff --git a/demo_img/case12/groundview.image.png b/demo_img/case12/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..3e74053765745892dc5847955d14fc8210efcc94 Binary files /dev/null and b/demo_img/case12/groundview.image.png differ diff --git a/demo_img/case12/groundview.sky.png b/demo_img/case12/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..70c7652da7c488dccc1393b3abf7f916e00662bb Binary files /dev/null and b/demo_img/case12/groundview.sky.png differ diff --git a/demo_img/case12/satview-input.png b/demo_img/case12/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..28858ab35e3c67162cf717e297867e44f545c2ab Binary files /dev/null and b/demo_img/case12/satview-input.png differ diff --git a/demo_img/case13/groundview.image.png b/demo_img/case13/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..9fd2f5da5ba5e7947074df32bb723d2631115cf5 Binary files /dev/null and b/demo_img/case13/groundview.image.png differ diff --git a/demo_img/case13/groundview.sky.png b/demo_img/case13/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..b0c9f795448f2c9a4cb4253c4036a31f66356ae2 Binary files /dev/null and b/demo_img/case13/groundview.sky.png differ diff --git a/demo_img/case13/satview-input.png b/demo_img/case13/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..53b78fd4a1252d9c5f730ce27bd485810ced4db1 Binary files /dev/null and b/demo_img/case13/satview-input.png differ diff --git a/demo_img/case2/groundview.image.png b/demo_img/case2/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..9fd2f5da5ba5e7947074df32bb723d2631115cf5 Binary files /dev/null and b/demo_img/case2/groundview.image.png differ diff --git a/demo_img/case2/groundview.sky.png b/demo_img/case2/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..b0c9f795448f2c9a4cb4253c4036a31f66356ae2 Binary files /dev/null and b/demo_img/case2/groundview.sky.png differ diff --git a/demo_img/case2/satview-input.png b/demo_img/case2/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..53b78fd4a1252d9c5f730ce27bd485810ced4db1 Binary files /dev/null and b/demo_img/case2/satview-input.png differ diff --git a/demo_img/case3/groundview.image.png b/demo_img/case3/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..ee07e66c811c49961b5b4cf83e2bde17c1c4ab96 Binary files /dev/null and b/demo_img/case3/groundview.image.png differ diff --git a/demo_img/case3/groundview.sky.png b/demo_img/case3/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..bf1580c092c0e7b80f8199d845f71acc8bf23ee3 Binary files /dev/null and b/demo_img/case3/groundview.sky.png differ diff --git a/demo_img/case3/satview-input.png b/demo_img/case3/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..d270d4f7ae9a135b0df74e35c7f5ba5440bd07cd Binary files /dev/null and b/demo_img/case3/satview-input.png differ diff --git a/demo_img/case4/groundview.image.png b/demo_img/case4/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..bdd86213b5efdf40e457a5fbead6d2592a40dc0d Binary files /dev/null and b/demo_img/case4/groundview.image.png differ diff --git a/demo_img/case4/groundview.sky.png b/demo_img/case4/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..996b7120fbb55c5987c9b58ba1ccf28d16747960 Binary files /dev/null and b/demo_img/case4/groundview.sky.png differ diff --git a/demo_img/case4/satview-input.png b/demo_img/case4/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..e16179fe548be5115eb5906ae7408c56fdae62f8 Binary files /dev/null and b/demo_img/case4/satview-input.png differ diff --git a/demo_img/case5/groundview.image.png b/demo_img/case5/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..a9266ebdaead7935d8ea0f834cf33e351da0dd5a Binary files /dev/null and b/demo_img/case5/groundview.image.png differ diff --git a/demo_img/case5/groundview.sky.png b/demo_img/case5/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..8598b9e878bcbc54d1b5eaa90e729d2dd77308b0 Binary files /dev/null and b/demo_img/case5/groundview.sky.png differ diff --git a/demo_img/case5/satview-input.png b/demo_img/case5/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..4c12242298ef0f2721ec629e780dc75d3af8e810 Binary files /dev/null and b/demo_img/case5/satview-input.png differ diff --git a/demo_img/case6/groundview.image.png b/demo_img/case6/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..71aab5b4aeca63f04b8d0f4eb8e29f5075612c48 Binary files /dev/null and b/demo_img/case6/groundview.image.png differ diff --git a/demo_img/case6/groundview.sky.png b/demo_img/case6/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..6d00e485d6cf4ce46aba8274c527598c54e9b970 Binary files /dev/null and b/demo_img/case6/groundview.sky.png differ diff --git a/demo_img/case6/satview-input.png b/demo_img/case6/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..f0837bce0790d1795f0079f8c0c76b53a3fa36b9 Binary files /dev/null and b/demo_img/case6/satview-input.png differ diff --git a/demo_img/case7/groundview.image.png b/demo_img/case7/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..fcbfd8e23a6411d0cbbac43c15efcea4d6a6dd20 Binary files /dev/null and b/demo_img/case7/groundview.image.png differ diff --git a/demo_img/case7/groundview.sky.png b/demo_img/case7/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..60ce35217b3faf8894d9cbd3e8ce58c77c29d28b Binary files /dev/null and b/demo_img/case7/groundview.sky.png differ diff --git a/demo_img/case7/satview-input.png b/demo_img/case7/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..708349179ddfdbeff7fae9ce516a5e49323bc2f8 Binary files /dev/null and b/demo_img/case7/satview-input.png differ diff --git a/demo_img/case8/groundview.image.png b/demo_img/case8/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..18d52022a5becfb48728b8ea23679dffeee870ea Binary files /dev/null and b/demo_img/case8/groundview.image.png differ diff --git a/demo_img/case8/groundview.sky.png b/demo_img/case8/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..94731c15c751e9df883806ac0d43fbfedb270ab7 Binary files /dev/null and b/demo_img/case8/groundview.sky.png differ diff --git a/demo_img/case8/satview-input.png b/demo_img/case8/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..981ce9f97d7d182ff441a3f925fb1d03598f9ae2 Binary files /dev/null and b/demo_img/case8/satview-input.png differ diff --git a/demo_img/case9/groundview.image.png b/demo_img/case9/groundview.image.png new file mode 100644 index 0000000000000000000000000000000000000000..0084c60d4618d184db5e0b383403a20ce1644555 Binary files /dev/null and b/demo_img/case9/groundview.image.png differ diff --git a/demo_img/case9/groundview.sky.png b/demo_img/case9/groundview.sky.png new file mode 100644 index 0000000000000000000000000000000000000000..649db36fe97c73ea18f7712941de6c6b0f7bad19 Binary files /dev/null and b/demo_img/case9/groundview.sky.png differ diff --git a/demo_img/case9/satview-input.png b/demo_img/case9/satview-input.png new file mode 100644 index 0000000000000000000000000000000000000000..18d0484df8b32f566532c6345abb98eec6e9195e Binary files /dev/null and b/demo_img/case9/satview-input.png differ diff --git a/demo_img/runall.sh b/demo_img/runall.sh new file mode 100644 index 0000000000000000000000000000000000000000..975924bab8dab46fb3e42c0454985aa93afbdb1e --- /dev/null +++ b/demo_img/runall.sh @@ -0,0 +1,30 @@ +# for case in `ls -d demo_img/case*` +for case_id in 1 2 3 4 +do + case=demo_img/case$case_id + echo $case + python test.py --yaml=sat2density_cvact \ + --test_ckpt_path=2u87bj8w \ + --task=test_vid \ + --demo_img=$case/satview-input.png \ + --sty_img=$case/groundview.image.png \ + --save_dir=results/$case + # ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png results/$case/render.gif + ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png -vf "palettegen" results/$case-palette.png + ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png -i results/$case-palette.png -filter_complex "paletteuse" results/$case/render.gif + + ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png -vf "palettegen" results/$case-palette.png + ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png -i results/$case-palette.png -filter_complex "paletteuse" results/$case/sat.gif + # ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png results/$case/sat.gif +done + +# for case in `ls -d demo_img/case*` +for case_id in 1 2 3 4 +do + case=demo_img/case$case_id + sat_gif=results/$case/sat.gif + render_gif=results/$case/render.gif + # echo $sat_gif + cp $sat_gif docs/figures/demo/case$case_id.sat.gif + cp $render_gif docs/figures/demo/case$case_id.render.gif +done \ No newline at end of file diff --git a/imaginaire/__init__.py b/imaginaire/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/__pycache__/__init__.cpython-38.pyc b/imaginaire/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e39ab8b1c6c47a307883db29f98f32adce2f4ea6 Binary files /dev/null and b/imaginaire/__pycache__/__init__.cpython-38.pyc differ diff --git a/imaginaire/config.py b/imaginaire/config.py new file mode 100644 index 0000000000000000000000000000000000000000..3a728a5aaee8d040288ff9ffd17a4fa83a7e2ca7 --- /dev/null +++ b/imaginaire/config.py @@ -0,0 +1,238 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +"""Config utilities for yml file.""" + +import collections +import functools +import os +import re + +import yaml +from imaginaire.utils.distributed import master_only_print as print + +DEBUG = False +USE_JIT = False + + +class AttrDict(dict): + """Dict as attribute trick.""" + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + for key, value in self.__dict__.items(): + if isinstance(value, dict): + self.__dict__[key] = AttrDict(value) + elif isinstance(value, (list, tuple)): + if isinstance(value[0], dict): + self.__dict__[key] = [AttrDict(item) for item in value] + else: + self.__dict__[key] = value + + def yaml(self): + """Convert object to yaml dict and return.""" + yaml_dict = {} + for key, value in self.__dict__.items(): + if isinstance(value, AttrDict): + yaml_dict[key] = value.yaml() + elif isinstance(value, list): + if isinstance(value[0], AttrDict): + new_l = [] + for item in value: + new_l.append(item.yaml()) + yaml_dict[key] = new_l + else: + yaml_dict[key] = value + else: + yaml_dict[key] = value + return yaml_dict + + def __repr__(self): + """Print all variables.""" + ret_str = [] + for key, value in self.__dict__.items(): + if isinstance(value, AttrDict): + ret_str.append('{}:'.format(key)) + child_ret_str = value.__repr__().split('\n') + for item in child_ret_str: + ret_str.append(' ' + item) + elif isinstance(value, list): + if isinstance(value[0], AttrDict): + ret_str.append('{}:'.format(key)) + for item in value: + # Treat as AttrDict above. + child_ret_str = item.__repr__().split('\n') + for item in child_ret_str: + ret_str.append(' ' + item) + else: + ret_str.append('{}: {}'.format(key, value)) + else: + ret_str.append('{}: {}'.format(key, value)) + return '\n'.join(ret_str) + + +class Config(AttrDict): + r"""Configuration class. This should include every human specifiable + hyperparameter values for your training.""" + + def __init__(self, filename=None, verbose=False): + super(Config, self).__init__() + self.source_filename = filename + # Set default parameters. + # Logging. + large_number = 1000000000 + self.snapshot_save_iter = large_number + self.snapshot_save_epoch = large_number + self.metrics_iter = None + self.metrics_epoch = None + self.snapshot_save_start_iter = 0 + self.snapshot_save_start_epoch = 0 + self.image_save_iter = large_number + self.image_display_iter = large_number + self.max_epoch = large_number + self.max_iter = large_number + self.logging_iter = 100 + self.speed_benchmark = False + + # Trainer. + self.trainer = AttrDict( + model_average_config=AttrDict(enabled=False, + beta=0.9999, + start_iteration=1000, + num_batch_norm_estimation_iterations=30, + remove_sn=True), + # model_average=False, + # model_average_beta=0.9999, + # model_average_start_iteration=1000, + # model_average_batch_norm_estimation_iteration=30, + # model_average_remove_sn=True, + image_to_tensorboard=False, + hparam_to_tensorboard=False, + distributed_data_parallel='pytorch', + distributed_data_parallel_params=AttrDict( + find_unused_parameters=False), + delay_allreduce=True, + gan_relativistic=False, + gen_step=1, + dis_step=1, + gan_decay_k=1., + gan_min_k=1., + gan_separate_topk=False, + aug_policy='', + channels_last=False, + strict_resume=True, + amp_gp=False, + amp_config=AttrDict(init_scale=65536.0, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=False)) + + # Networks. + self.gen = AttrDict(type='imaginaire.generators.dummy') + self.dis = AttrDict(type='imaginaire.discriminators.dummy') + + # Optimizers. + self.gen_opt = AttrDict(type='adam', + fused_opt=False, + lr=0.0001, + adam_beta1=0.0, + adam_beta2=0.999, + eps=1e-8, + lr_policy=AttrDict(iteration_mode=False, + type='step', + step_size=large_number, + gamma=1)) + self.dis_opt = AttrDict(type='adam', + fused_opt=False, + lr=0.0001, + adam_beta1=0.0, + adam_beta2=0.999, + eps=1e-8, + lr_policy=AttrDict(iteration_mode=False, + type='step', + step_size=large_number, + gamma=1)) + # Data. + self.data = AttrDict(name='dummy', + type='imaginaire.datasets.images', + num_workers=0) + self.test_data = AttrDict(name='dummy', + type='imaginaire.datasets.images', + num_workers=0, + test=AttrDict(is_lmdb=False, + roots='', + batch_size=1)) + + +# Cudnn. + self.cudnn = AttrDict(deterministic=False, + benchmark=True) + + # Others. + self.pretrained_weight = '' + self.inference_args = AttrDict() + + # Update with given configurations. + assert os.path.exists(filename), 'File {} not exist.'.format(filename) + loader = yaml.SafeLoader + loader.add_implicit_resolver( + u'tag:yaml.org,2002:float', + re.compile(u'''^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$''', re.X), + list(u'-+0123456789.')) + try: + with open(filename, 'r') as f: + cfg_dict = yaml.load(f, Loader=loader) + except EnvironmentError: + print('Please check the file with name of "%s"', filename) + recursive_update(self, cfg_dict) + + # Put common opts in both gen and dis. + if 'common' in cfg_dict: + self.common = AttrDict(**cfg_dict['common']) + self.gen.common = self.common + self.dis.common = self.common + + if verbose: + print(' imaginaire config '.center(80, '-')) + print(self.__repr__()) + print(''.center(80, '-')) + + +def rsetattr(obj, attr, val): + """Recursively find object and set value""" + pre, _, post = attr.rpartition('.') + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + """Recursively find object and return value""" + + def _getattr(obj, attr): + r"""Get attribute.""" + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split('.')) + + +def recursive_update(d, u): + """Recursively update AttrDict d with AttrDict u""" + for key, value in u.items(): + if isinstance(value, collections.abc.Mapping): + d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) + elif isinstance(value, (list, tuple)): + if isinstance(value[0], dict): + d.__dict__[key] = [AttrDict(item) for item in value] + else: + d.__dict__[key] = value + else: + d.__dict__[key] = value + return d diff --git a/imaginaire/datasets/__init__.py b/imaginaire/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/datasets/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/datasets/base.py b/imaginaire/datasets/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c9192f78d4c3cfb65ac73751b632f897893d8288 --- /dev/null +++ b/imaginaire/datasets/base.py @@ -0,0 +1,596 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +"""All datasets are inherited from this class.""" + +import importlib +import json +import os +import pickle +from collections import OrderedDict +from functools import partial +from inspect import signature + +import numpy as np +import torch +import torch.utils.data as data +import torchvision.transforms as transforms + +from imaginaire.datasets.folder import FolderDataset +from imaginaire.datasets.lmdb import \ + IMG_EXTENSIONS, HDR_IMG_EXTENSIONS, LMDBDataset +from imaginaire.datasets.object_store import ObjectStoreDataset +from imaginaire.utils.data import \ + (VIDEO_EXTENSIONS, Augmentor, + load_from_folder, load_from_lmdb, load_from_object_store) +from imaginaire.utils.lmdb import create_metadata + + +DATASET_TYPES = ['lmdb', 'folder', 'object_store'] + + +class BaseDataset(data.Dataset): + r"""Base class for image/video datasets. + + Args: + cfg (Config object): Input config. + is_inference (bool): Training if False, else validation. + is_test (bool): Final test set after training and validation. + """ + + def __init__(self, cfg, is_inference, is_test): + super(BaseDataset, self).__init__() + + self.cfg = cfg + self.is_inference = is_inference + self.is_test = is_test + if self.is_test: + self.cfgdata = self.cfg.test_data + data_info = self.cfgdata.test + else: + self.cfgdata = self.cfg.data + if self.is_inference: + data_info = self.cfgdata.val + else: + data_info = self.cfgdata.train + self.name = self.cfgdata.name + self.lmdb_roots = data_info.roots + self.dataset_type = getattr(data_info, 'dataset_type', None) + self.cache = getattr(self.cfgdata, 'cache', None) + self.interpolator = getattr(self.cfgdata, 'interpolator', "INTER_LINEAR") + + # Get AWS secret keys. + if self.dataset_type == 'object_store': + assert hasattr(cfg, 'aws_credentials_file') + self.aws_credentials_file = cfg.aws_credentials_file + + # Legacy lmdb/folder only support. + if self.dataset_type is None: + self.dataset_is_lmdb = getattr(data_info, 'is_lmdb', False) + if self.dataset_is_lmdb: + self.dataset_type = 'lmdb' + else: + self.dataset_type = 'folder' + # Legacy support ends. + + assert self.dataset_type in DATASET_TYPES + if self.dataset_type == 'lmdb': + # Add handle to function to load data from LMDB. + self.load_from_dataset = load_from_lmdb + elif self.dataset_type == 'folder': + # For some unpaired experiments, we would like the dataset to be presented in a paired way + + if hasattr(self.cfgdata, 'paired') is False: + self.cfgdata.paired = self.paired + # Add handle to function to load data from folder. + self.load_from_dataset = load_from_folder + # Create metadata for folders. + print('Creating metadata') + all_filenames, all_metadata = [], [] + if self.is_test: + cfg.data_backup = cfg.data + cfg.data = cfg.test_data + for root in self.lmdb_roots: + filenames, metadata = create_metadata( + data_root=root, cfg=cfg, paired=self.cfgdata['paired']) + all_filenames.append(filenames) + all_metadata.append(metadata) + if self.is_test: + cfg.data = cfg.data_backup + elif self.dataset_type == 'object_store': + # Add handle to function to load data from AWS S3. + self.load_from_dataset = load_from_object_store + + # Get the types of data stored in dataset, and their extensions. + self.data_types = [] # Names of data types. + self.dataset_data_types = [] # These data types are in the dataset. + self.image_data_types = [] # These types are images. + self.hdr_image_data_types = [] # These types are HDR images. + self.normalize = {} # Does this data type need normalization? + self.extensions = {} # What is this data type's file extension. + self.is_mask = {} # Whether this data type is discrete masks? + self.num_channels = {} # How many channels does this data type have? + self.pre_aug_ops = {} # Ops on data type before augmentation. + self.post_aug_ops = {} # Ops on data type after augmentation. + + # Extract info from data types. + for data_type in self.cfgdata.input_types: + name = list(data_type.keys()) + assert len(name) == 1 + name = name[0] + info = data_type[name] + + if 'ext' not in info: + info['ext'] = None + if 'normalize' not in info: + info['normalize'] = False + if 'is_mask' not in info: + info['is_mask'] = False + if 'pre_aug_ops' not in info: + info['pre_aug_ops'] = 'None' + if 'post_aug_ops' not in info: + info['post_aug_ops'] = 'None' + if 'computed_on_the_fly' not in info: + info['computed_on_the_fly'] = False + if 'num_channels' not in info: + info['num_channels'] = None + + self.data_types.append(name) + if not info['computed_on_the_fly']: + self.dataset_data_types.append(name) + + self.extensions[name] = info['ext'] + self.normalize[name] = info['normalize'] + self.num_channels[name] = info['num_channels'] + self.pre_aug_ops[name] = [op.strip() for op in + info['pre_aug_ops'].split(',')] + self.post_aug_ops[name] = [op.strip() for op in + info['post_aug_ops'].split(',')] + self.is_mask[name] = info['is_mask'] + if info['ext'] is not None and (info['ext'] in IMG_EXTENSIONS or info['ext'] in VIDEO_EXTENSIONS): + self.image_data_types.append(name) + if info['ext'] is not None and info['ext'] in HDR_IMG_EXTENSIONS: + self.hdr_image_data_types.append(name) + + # Add some info into cfgdata for legacy support. + self.cfgdata.data_types = self.data_types + self.cfgdata.num_channels = [self.num_channels[name] + for name in self.data_types] + + # Augmentations which need full dict. + self.full_data_post_aug_ops, self.full_data_ops = [], [] + if hasattr(self.cfgdata, 'full_data_ops'): + ops = self.cfgdata.full_data_ops + self.full_data_ops.extend([op.strip() for op in ops.split(',')]) + if hasattr(self.cfgdata, 'full_data_post_aug_ops'): + ops = self.cfgdata.full_data_post_aug_ops + self.full_data_post_aug_ops.extend( + [op.strip() for op in ops.split(',')]) + + # These are the labels which will be concatenated for generator input. + self.input_labels = [] + if hasattr(self.cfgdata, 'input_labels'): + self.input_labels = self.cfgdata.input_labels + + # These are the keypoints which also need to be augmented. + self.keypoint_data_types = [] + if hasattr(self.cfgdata, 'keypoint_data_types'): + self.keypoint_data_types = self.cfgdata.keypoint_data_types + + # Create augmentation operations. + aug_list = data_info.augmentations + individual_video_frame_aug_list = getattr(data_info, 'individual_video_frame_augmentations', dict()) + self.augmentor = Augmentor( + aug_list, individual_video_frame_aug_list, self.image_data_types, self.is_mask, + self.keypoint_data_types, self.interpolator) + self.augmentable_types = self.image_data_types + \ + self.keypoint_data_types + + # Create torch transformations. + self.transform = {} + for data_type in self.image_data_types: + normalize = self.normalize[data_type] + self.transform[data_type] = self._get_transform( + normalize, self.num_channels[data_type]) + + # Create torch transformations for HDR images. + for data_type in self.hdr_image_data_types: + normalize = self.normalize[data_type] + self.transform[data_type] = self._get_transform( + normalize, self.num_channels[data_type]) + + # Initialize handles. + self.sequence_lists = [] # List of sequences per dataset root. + self.lmdbs = {} # Dict for list of lmdb handles per data type. + for data_type in self.dataset_data_types: + self.lmdbs[data_type] = [] + self.dataset_probability = None + self.additional_lists = [] + + # Load each dataset. + for idx, root in enumerate(self.lmdb_roots): + if self.dataset_type == 'lmdb': + self._add_dataset(root) + elif self.dataset_type == 'folder': + self._add_dataset(root, filenames=all_filenames[idx], + metadata=all_metadata[idx]) + elif self.dataset_type == 'object_store': + self._add_dataset( + root, aws_credentials_file=self.aws_credentials_file) + + # Compute dataset statistics and create whatever self.variables required + # for the specific dataloader. + self._compute_dataset_stats() + + # Build index of data to sample. + self.mapping, self.epoch_length = self._create_mapping() + + def _create_mapping(self): + r"""Creates mapping from data sample idx to actual LMDB keys. + All children need to implement their own. + + Returns: + self.mapping (list): List of LMDB keys. + """ + raise NotImplementedError + + def _compute_dataset_stats(self): + r"""Computes required statistics about dataset. + All children need to implement their own. + """ + pass + + def __getitem__(self, index): + r"""Entry function for dataset.""" + raise NotImplementedError + + def _get_transform(self, normalize, num_channels): + r"""Convert numpy to torch tensor. + + Args: + normalize (bool): Normalize image i.e. (x - 0.5) * 2. + Goes from [0, 1] -> [-1, 1]. + Returns: + Composed list of torch transforms. + """ + transform_list = [transforms.ToTensor()] + if normalize: + transform_list.append( + transforms.Normalize((0.5, ) * num_channels, + (0.5, ) * num_channels, inplace=True)) + return transforms.Compose(transform_list) + + def _add_dataset(self, root, filenames=None, metadata=None, + aws_credentials_file=None): + r"""Adds an LMDB dataset to a list of datasets. + + Args: + root (str): Path to LMDB or folder dataset. + filenames: List of filenames for folder dataset. + metadata: Metadata for folder dataset. + aws_credentials_file: Path to file containing AWS credentials. + """ + if aws_credentials_file and self.dataset_type == 'object_store': + object_store_dataset = ObjectStoreDataset( + root, aws_credentials_file, cache=self.cache) + sequence_list = object_store_dataset.sequence_list + else: + # Get sequences associated with this dataset. + if filenames is None: + list_path = 'all_filenames.json' + with open(os.path.join(root, list_path)) as fin: + sequence_list = OrderedDict(json.load(fin)) + else: + sequence_list = filenames + + additional_path = 'all_indices.json' + if os.path.exists(os.path.join(root, additional_path)): + print('Using additional list for object indices.') + with open(os.path.join(root, additional_path)) as fin: + additional_list = OrderedDict(json.load(fin)) + self.additional_lists.append(additional_list) + self.sequence_lists.append(sequence_list) + + # Get LMDB dataset handles. + for data_type in self.dataset_data_types: + if self.dataset_type == 'lmdb': + self.lmdbs[data_type].append( + LMDBDataset(os.path.join(root, data_type))) + elif self.dataset_type == 'folder': + self.lmdbs[data_type].append( + FolderDataset(os.path.join(root, data_type), metadata)) + elif self.dataset_type == 'object_store': + # All data types use the same handle. + self.lmdbs[data_type].append(object_store_dataset) + + def perform_individual_video_frame(self, data, augment_ops): + r"""Perform data augmentation on images only. + + Args: + data (dict): Keys are from data types. Values can be numpy.ndarray + or list of numpy.ndarray (image or list of images). + augment_ops (list): The augmentation operations for individual frames. + Returns: + (tuple): + - data (dict): Augmented data, with same keys as input data. + - is_flipped (bool): Flag which tells if images have been + left-right flipped. + """ + if augment_ops: + all_data = dict() + for ix, key in enumerate(data.keys()): + if ix == 0: + num = len(data[key]) + for j in range(num): + all_data['%d' % j] = dict() + for j in range(num): + all_data['%d' % j][key] = data[key][j:(j+1)] + for j in range(num): + all_data['%d' % j], _ = self.perform_augmentation( + all_data['%d' % j], paired=True, augment_ops=augment_ops) + for key in data.keys(): + tmp = [] + for j in range(num): + tmp += all_data['%d' % j][key] + data[key] = tmp + return data + + def perform_augmentation(self, data, paired, augment_ops=None): + r"""Perform data augmentation on images only. + + Args: + data (dict): Keys are from data types. Values can be numpy.ndarray + or list of numpy.ndarray (image or list of images). + paired (bool): Apply same augmentation to all input keys? + augment_ops (list): The augmentation operations. + Returns: + (tuple): + - data (dict): Augmented data, with same keys as input data. + - is_flipped (bool): Flag which tells if images have been + left-right flipped. + """ + aug_inputs = {} + for data_type in self.augmentable_types: + aug_inputs[data_type] = data[data_type] + + augmented, is_flipped = self.augmentor.perform_augmentation( + aug_inputs, paired=paired, augment_ops=augment_ops) + + for data_type in self.augmentable_types: + data[data_type] = augmented[data_type] + + return data, is_flipped + + def flip_hdr(self, data, is_flipped=False): + r"""Flip hdr images. + + Args: + data (dict): Keys are from data types. Values can be numpy.ndarray + or list of numpy.ndarray (image or list of images). + is_flipped (bool): Applying left-right flip to the hdr images + Returns: + (tuple): + - data (dict): Augmented data, with same keys as input data. + """ + if is_flipped is False: + return data + + for data_type in self.hdr_image_data_types: + # print('Length of data: {}'.format(len(data[data_type]))) + data[data_type][0] = data[data_type][0][:, ::-1, :].copy() + return data + + def to_tensor(self, data): + r"""Convert all images to tensor. + + Args: + data (dict): Dict containing data_type as key, with each value + as a list of numpy.ndarrays. + Returns: + data (dict): Dict containing data_type as key, with each value + as a list of torch.Tensors. + """ + for data_type in self.image_data_types: + for idx in range(len(data[data_type])): + if data[data_type][idx].dtype == np.uint16: + data[data_type][idx] = data[data_type][idx].astype( + np.float32) + data[data_type][idx] = self.transform[data_type]( + data[data_type][idx]) + for data_type in self.hdr_image_data_types: + for idx in range(len(data[data_type])): + data[data_type][idx] = self.transform[data_type]( + data[data_type][idx]) + return data + + def apply_ops(self, data, op_dict, full_data=False): + r"""Apply any ops from op_dict to data types. + + Args: + data (dict): Dict containing data_type as key, with each value + as a list of numpy.ndarrays. + op_dict (dict): Dict containing data_type as key, with each value + containing string of operations to apply. + full_data (bool): Do these ops require access to the full data? + Returns: + data (dict): Dict containing data_type as key, with each value + modified by the op if any. + """ + if full_data: + # op needs entire data dict. + for op in op_dict: + if op == 'None': + continue + op, op_type = self.get_op(op) + assert op_type == 'full_data' + data = op(data) + else: + # op per data type. + if not op_dict: + return data + for data_type in data: + for op in op_dict[data_type]: + if op == 'None': + continue + op, op_type = self.get_op(op) + data[data_type] = op(data[data_type]) + + if op_type == 'vis': + # We have converted this data type to an image. Enter it + # in self.image_data_types and give it a torch + # transform. + if data_type not in self.image_data_types: + self.image_data_types.append(data_type) + normalize = self.normalize[data_type] + num_channels = self.num_channels[data_type] + self.transform[data_type] = \ + self._get_transform(normalize, num_channels) + elif op_type == 'convert': + continue + elif op_type is None: + continue + else: + raise NotImplementedError + return data + + def get_op(self, op): + r"""Get function to apply for specific op. + + Args: + op (str): Name of the op. + Returns: + function handle. + """ + def list_to_tensor(data): + r"""Convert list of numeric values to tensor.""" + assert isinstance(data, list) + return torch.from_numpy(np.array(data, dtype=np.float32)) + + def decode_json_list(data): + r"""Decode list of strings in json to objects.""" + assert isinstance(data, list) + return [json.loads(item) for item in data] + + def decode_pkl_list(data): + r"""Decode list of pickled strings to objects.""" + assert isinstance(data, list) + return [pickle.loads(item) for item in data] + + def list_to_numpy(data): + r"""Convert list of numeric values to numpy array.""" + assert isinstance(data, list) + return np.array(data) + + def l2_normalize(data): + r"""L2 normalization.""" + assert isinstance(data, torch.Tensor) + import torch.nn.functional as F + return F.normalize(data, dim=1) + + if op == 'to_tensor': + return list_to_tensor, None + elif op == 'decode_json': + return decode_json_list, None + elif op == 'decode_pkl': + return decode_pkl_list, None + elif op == 'to_numpy': + return list_to_numpy, None + elif op == 'l2_norm': + return l2_normalize, None + elif '::' in op: + parts = op.split('::') + if len(parts) == 2: + module, function = parts + module = importlib.import_module(module) + function = getattr(module, function) + sig = signature(function) + num_params = len(sig.parameters) + assert num_params in [3, 4], \ + 'Full data functions take in (cfgdata, is_inference, ' \ + 'full_data) or (cfgdata, is_inference, self, full_data) ' \ + 'as input.' + if num_params == 3: + function = partial( + function, self.cfgdata, self.is_inference) + elif num_params == 4: + function = partial( + function, self.cfgdata, self.is_inference, self) + function_type = 'full_data' + elif len(parts) == 3: + function_type, module, function = parts + module = importlib.import_module(module) + + # Get function inputs, if provided. + partial_fn = False + if '(' in function and ')' in function: + partial_fn = True + function, params = self._get_fn_params(function) + + function = getattr(module, function) + + # Create partial function. + if partial_fn: + function = partial(function, **params) + + # Get function signature. + sig = signature(function) + num_params = 0 + for param in sig.parameters.values(): + if param.kind == param.POSITIONAL_OR_KEYWORD: + num_params += 1 + + if function_type == 'vis': + if num_params != 9: + raise ValueError( + 'vis function type needs to take ' + + '(resize_h, resize_w, crop_h, crop_w, ' + + 'original_h, original_w, is_flipped, cfgdata, ' + + 'data) as input.') + function = partial(function, + self.augmentor.resize_h, + self.augmentor.resize_w, + self.augmentor.crop_h, + self.augmentor.crop_w, + self.augmentor.original_h, + self.augmentor.original_w, + self.augmentor.is_flipped, + self.cfgdata) + elif function_type == 'convert': + if num_params != 1: + raise ValueError( + 'convert function type needs to take ' + + '(data) as input.') + else: + raise ValueError('Unknown op: %s' % (op)) + else: + raise ValueError('Unknown op: %s' % (op)) + return function, function_type + else: + raise ValueError('Unknown op: %s' % (op)) + + def _get_fn_params(self, function_string): + r"""Find key-value inputs to function from string definition. + + Args: + function_string (str): String with function name and args. e.g. + my_function(a=10, b=20). + Returns: + function (str): Name of function. + params (dict): Key-value params for function. + """ + start = function_string.find('(') + end = function_string.find(')') + function = function_string[:start] + params_str = function_string[start+1:end] + params = {} + for item in params_str.split(':'): + key, value = item.split('=') + try: + params[key] = float(value) + except: # noqa + params[key] = value + return function, params + + def __len__(self): + return self.epoch_length diff --git a/imaginaire/datasets/cache.py b/imaginaire/datasets/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6c28752dc290c9cbf15ffcf6ca2093415082f93e --- /dev/null +++ b/imaginaire/datasets/cache.py @@ -0,0 +1,40 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import diskcache + +""" +INFO: +Cache objects are thread-safe and may be shared between threads. +Two Cache objects may also reference the same directory from separate +threads or processes. In this way, they are also process-safe and support +cross-process communication. +""" + + +class Cache(object): + r"""This creates an on disk cache, which saves files as bytes. + Args: + root (str): Path to the cache dir. + size_MB (float): Size of cache in MB. + """ + + def __init__(self, root, size_GB): + self.root = root + self.size_limit_B = size_GB * 1024 * 1024 * 1024 + self.cache = diskcache.Cache(root, size_limit=self.size_limit_B) + print('Created cache of max size %d GB at %s' % + (size_GB, self.cache.directory)) + + def read(self, key): + if key in self.cache: + return self.cache[key] + return False + + def write(self, key, value): + try: + self.cache[key] = value + except Exception as e: # noqa + print(e) + return False diff --git a/imaginaire/datasets/dummy.py b/imaginaire/datasets/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..9783eb3f38007652b84dde33f2da9202491686e1 --- /dev/null +++ b/imaginaire/datasets/dummy.py @@ -0,0 +1,18 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch + + +class Dataset(torch.utils.data.Dataset): + r"""Dummy dataset, returns nothing.""" + + def __init__(self, cfg, is_inference=False, is_test=False): + super(Dataset, self).__init__() + + def __getitem__(self, index): + return {} + + def __len__(self): + return 65535 diff --git a/imaginaire/datasets/folder.py b/imaginaire/datasets/folder.py new file mode 100644 index 0000000000000000000000000000000000000000..fd3fcc679834044ac9bf11afe81a9e9fe8697aa8 --- /dev/null +++ b/imaginaire/datasets/folder.py @@ -0,0 +1,86 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os + +import cv2 +import numpy as np +import torch.utils.data as data +from PIL import Image + +from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS +import imageio + + +class FolderDataset(data.Dataset): + r"""This deals with opening, and reading from an Folder dataset. + + Args: + root (str): Path to the folder. + metadata (dict): Containing extensions. + """ + + def __init__(self, root, metadata): + self.root = os.path.expanduser(root) + self.extensions = metadata + + print('Folder at %s opened.' % (root)) + + def getitem_by_path(self, path, data_type): + r"""Load data item stored for key = path. + + Args: + path (str): Key into Folder dataset. + data_type (str): Key into self.extensions e.g. data/data_segmaps/... + Returns: + img (PIL.Image) or buf (str): Contents of file for this key. + """ + # Figure out decoding params. + ext = self.extensions[data_type] + is_image = False + is_hdr = False + if ext in IMG_EXTENSIONS: + is_image = True + if 'tif' in ext: + dtype, mode = np.uint16, -1 + elif 'JPEG' in ext or 'JPG' in ext \ + or 'jpeg' in ext or 'jpg' in ext: + dtype, mode = np.uint8, 3 + else: + dtype, mode = np.uint8, -1 + elif ext in HDR_IMG_EXTENSIONS: + is_hdr = True + else: + is_image = False + + # Get value from key. + filepath = os.path.join(self.root, path.decode() + '.' + ext) + assert os.path.exists(filepath), '%s does not exist' % (filepath) + with open(filepath, 'rb') as f: + buf = f.read() + + # Decode and return. + if is_image: + try: + img = cv2.imdecode(np.fromstring(buf, dtype=dtype), mode) + except Exception: + print(path) + # BGR to RGB if 3 channels. + if img.ndim == 3 and img.shape[-1] == 3: + img = img[:, :, ::-1] + img = Image.fromarray(img) + return img + elif is_hdr: + try: + imageio.plugins.freeimage.download() + img = imageio.imread(buf) + except Exception: + print(path) + return img # Return a numpy array + else: + return buf + + def __len__(self): + r"""Return number of keys in Folder dataset.""" + return self.length diff --git a/imaginaire/datasets/images.py b/imaginaire/datasets/images.py new file mode 100644 index 0000000000000000000000000000000000000000..943752be11d823025f33956dcdcaf6a35c1fb899 --- /dev/null +++ b/imaginaire/datasets/images.py @@ -0,0 +1,168 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import random + +from imaginaire.datasets.base import BaseDataset + + +class Dataset(BaseDataset): + r"""Image dataset for use in class conditional GAN. + + Args: + cfg (Config): Loaded config object. + is_inference (bool): In train or inference mode? + """ + + def __init__(self, cfg, is_inference=False, is_test=False): + self.paired = False + super(Dataset, self).__init__(cfg, is_inference, is_test) + self.num_classes = len(self.class_name_to_idx['images']) + self.sample_class_idx = None + + def set_sample_class_idx(self, class_idx): + r"""Set sample class idx. This is not used in this class... + + Args: + class_idx (int): Which class idx to sample from. + """ + self.sample_class_idx = class_idx + self.epoch_length = \ + max([len(lmdb_keys) for _, lmdb_keys in self.mapping.items()]) + + def _create_mapping(self): + r"""Creates mapping from idx to key in LMDB. + + Returns: + (tuple): + - self.mapping (dict): Dict with data type as key mapping idx to + LMDB key. + - self.epoch_length (int): Number of samples in an epoch. + """ + idx_to_key, class_names = {}, {} + for lmdb_idx, sequence_list in enumerate(self.sequence_lists): + for data_type, data_type_sequence_list in sequence_list.items(): + class_names[data_type] = [] + if data_type not in idx_to_key: + idx_to_key[data_type] = [] + for sequence_name, filenames in data_type_sequence_list.items(): + class_name = sequence_name.split('/')[0] + for filename in filenames: + idx_to_key[data_type].append({ + 'lmdb_root': self.lmdb_roots[lmdb_idx], + 'lmdb_idx': lmdb_idx, + 'sequence_name': sequence_name, + 'filename': filename, + 'class_name': class_name + }) + class_names[data_type].append(class_name) + self.mapping = idx_to_key + self.epoch_length = max([len(lmdb_keys) + for _, lmdb_keys in self.mapping.items()]) + + # Create mapping from class name to class idx. + self.class_name_to_idx = {} + for data_type, class_names_data_type in class_names.items(): + self.class_name_to_idx[data_type] = {} + class_names_data_type = sorted(list(set(class_names_data_type))) + for class_idx, class_name in enumerate(class_names_data_type): + self.class_name_to_idx[data_type][class_name] = class_idx + + # Add class idx to mapping. + for data_type in self.mapping: + for key in self.mapping[data_type]: + key['class_idx'] = \ + self.class_name_to_idx[data_type][key['class_name']] + + # Create a mapping from index to lmdb key for each class. + idx_to_key_class = {} + for data_type in self.mapping: + idx_to_key_class[data_type] = {} + for class_idx, class_name in enumerate(class_names[data_type]): + idx_to_key_class[data_type][class_idx] = [] + for key in self.mapping[data_type]: + idx_to_key_class[data_type][key['class_idx']].append(key) + self.mapping_class = idx_to_key_class + + return self.mapping, self.epoch_length + + def _sample_keys(self, index): + r"""Gets files to load for this sample. + + Args: + index (int): Index in [0, len(dataset)]. + Returns: + keys (dict): Each key of this dict is a data type. + - lmdb_key (dict): + - lmdb_idx (int): Chosen LMDB dataset root. + - sequence_name (str): Chosen sequence in chosen dataset. + - filename (str): Chosen filename in chosen sequence. + """ + + keys = {} + if self.is_inference: # evaluation mode + lmdb_keys = self.mapping['images'] + keys['images'] = lmdb_keys[index % len(lmdb_keys)] + else: + lmdb_keys = self.mapping['images'] + keys['images'] = random.choice(lmdb_keys) + return keys + + def __getitem__(self, index): + r"""Gets selected files. + + Args: + index (int): Index into dataset. + concat (bool): Concatenate all items in labels? + Returns: + data (dict): Dict with all chosen data_types. + """ + # Select a sample from the available data. + keys_per_data_type = self._sample_keys(index) + + # Get class idx into a list. + class_idxs = [] + for data_type in keys_per_data_type: + class_idxs.append(keys_per_data_type[data_type]['class_idx']) + + # Get keys and lmdbs. + keys, lmdbs = {}, {} + for data_type in self.dataset_data_types: + # Unpack keys. + lmdb_idx = keys_per_data_type[data_type]['lmdb_idx'] + sequence_name = keys_per_data_type[data_type]['sequence_name'] + filename = keys_per_data_type[data_type]['filename'] + keys[data_type] = '%s/%s' % (sequence_name, filename) + lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx] + + # Load all data for this index. + data = self.load_from_dataset(keys, lmdbs) + + # Apply ops pre augmentation. + data = self.apply_ops(data, self.pre_aug_ops) + + # Do augmentations for images. + data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops) + + # Apply ops post augmentation. + data = self.apply_ops(data, self.post_aug_ops) + data = self.apply_ops(data, self.full_data_post_aug_ops, + full_data=True) + + # Convert images to tensor. + for data_type in self.image_data_types: + for idx in range(len(data[data_type])): + data[data_type][idx] = \ + data[data_type][idx][:, :, :self.num_channels[data_type]] + data = self.to_tensor(data) + + # Remove any extra dimensions. + for data_type in self.image_data_types: + data[data_type] = data[data_type][0] + + # Package output. + data['is_flipped'] = is_flipped + data['key'] = keys_per_data_type + data['labels'] = class_idxs[0] + return data diff --git a/imaginaire/datasets/lmdb.py b/imaginaire/datasets/lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..136642c1e624b886b05aaffe46f14694b0eaa29a --- /dev/null +++ b/imaginaire/datasets/lmdb.py @@ -0,0 +1,92 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import json +import os + +import cv2 +import lmdb +import numpy as np +import torch.utils.data as data +from PIL import Image + +from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS +from imaginaire.utils.distributed import master_only_print as print +import imageio + + +class LMDBDataset(data.Dataset): + r"""This deals with opening, and reading from an LMDB dataset. + Args: + root (str): Path to the LMDB file. + """ + + def __init__(self, root): + self.root = os.path.expanduser(root) + self.env = lmdb.open(root, max_readers=126, readonly=True, lock=False, + readahead=False, meminit=False) + with self.env.begin(write=False) as txn: + self.length = txn.stat()['entries'] + + # Read metadata. + with open(os.path.join(self.root, '..', 'metadata.json')) as fin: + self.extensions = json.load(fin) + + print('LMDB file at %s opened.' % (root)) + + def getitem_by_path(self, path, data_type): + r"""Load data item stored for key = path. + + Args: + path (str): Key into LMDB dataset. + data_type (str): Key into self.extensions e.g. data/data_segmaps/... + Returns: + img (PIL.Image) or buf (str): Contents of LMDB value for this key. + """ + # Figure out decoding params. + ext = self.extensions[data_type] + is_image = False + is_hdr = False + if ext in IMG_EXTENSIONS: + is_image = True + if 'tif' in ext: + dtype, mode = np.uint16, -1 + elif 'JPEG' in ext or 'JPG' in ext \ + or 'jpeg' in ext or 'jpg' in ext: + dtype, mode = np.uint8, 3 + else: + dtype, mode = np.uint8, -1 + elif ext in HDR_IMG_EXTENSIONS: + is_hdr = True + else: + is_image = False + + # Get value from key. + with self.env.begin(write=False) as txn: + buf = txn.get(path) + + # Decode and return. + if is_image: + try: + img = cv2.imdecode(np.fromstring(buf, dtype=dtype), mode) + except Exception: + print(path) + # BGR to RGB if 3 channels. + if img.ndim == 3 and img.shape[-1] == 3: + img = img[:, :, ::-1] + img = Image.fromarray(img) + return img + elif is_hdr: + try: + imageio.plugins.freeimage.download() + img = imageio.imread(buf) + except Exception: + print(path) + return img # Return a numpy array + else: + return buf + + def __len__(self): + r"""Return number of keys in LMDB dataset.""" + return self.length diff --git a/imaginaire/datasets/object_store.py b/imaginaire/datasets/object_store.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd4f2d765def17ed52d5a442cc3bb87d31b9ded --- /dev/null +++ b/imaginaire/datasets/object_store.py @@ -0,0 +1,142 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import io +import json + +# import cv2 +import boto3 +from botocore.config import Config +import numpy as np +import torch.utils.data as data +from PIL import Image +import imageio +from botocore.exceptions import ClientError + +from imaginaire.datasets.cache import Cache +from imaginaire.utils.data import IMG_EXTENSIONS, HDR_IMG_EXTENSIONS + +Image.MAX_IMAGE_PIXELS = None + + +class ObjectStoreDataset(data.Dataset): + r"""This deals with opening, and reading from an AWS S3 bucket. + Args: + + root (str): Path to the AWS S3 bucket. + aws_credentials_file (str): Path to file containing AWS credentials. + data_type (str): Which data type should this dataset load? + """ + + def __init__(self, root, aws_credentials_file, data_type='', cache=None): + # Cache. + self.cache = False + if cache is not None: + # raise NotImplementedError + self.cache = Cache(cache.root, cache.size_GB) + + # Get bucket info, and keys to info about dataset. + with open(aws_credentials_file) as fin: + self.credentials = json.load(fin) + + parts = root.split('/') + self.bucket = parts[0] + self.all_filenames_key = '/'.join(parts[1:]) + '/all_filenames.json' + self.metadata_key = '/'.join(parts[1:]) + '/metadata.json' + + # Get list of filenames. + filename_info = self._get_object(self.all_filenames_key) + self.sequence_list = json.loads(filename_info.decode('utf-8')) + + # Get length. + length = 0 + for _, value in self.sequence_list.items(): + length += len(value) + self.length = length + + # Read metadata. + metadata_info = self._get_object(self.metadata_key) + self.extensions = json.loads(metadata_info.decode('utf-8')) + self.data_type = data_type + + print('AWS S3 bucket at %s opened.' % (root + '/' + self.data_type)) + + def _get_object(self, key): + r"""Download object from bucket. + + Args: + key (str): Key inside bucket. + """ + # Look up value in cache. + object_content = self.cache.read(key) if self.cache else False + if not object_content: + # Either no cache used or key not found in cache. + config = Config(connect_timeout=30, + signature_version="s3", + retries={"max_attempts": 999999}) + s3 = boto3.client('s3', **self.credentials, config=config) + try: + s3_response_object = s3.get_object(Bucket=self.bucket, Key=key) + object_content = s3_response_object['Body'].read() + except Exception as e: + print('%s not found' % (key)) + print(e) + # Save content to cache. + if self.cache: + self.cache.write(key, object_content) + return object_content + + def getitem_by_path(self, path, data_type): + r"""Load data item stored for key = path. + + Args: + path (str): Path into AWS S3 bucket, without data_type prefix. + data_type (str): Key into self.extensions e.g. data/data_segmaps/... + Returns: + img (PIL.Image) or buf (str): Contents of LMDB value for this key. + """ + # Figure out decoding params. + ext = self.extensions[data_type] + is_image = False + is_hdr = False + parts = path.split('/') + key = parts[0] + '/' + data_type + '/' + '/'.join(parts[1:]) + '.' + ext + if ext in IMG_EXTENSIONS: + is_image = True + if 'tif' in ext: + _, mode = np.uint16, -1 + elif 'JPEG' in ext or 'JPG' in ext \ + or 'jpeg' in ext or 'jpg' in ext: + _, mode = np.uint8, 3 + else: + _, mode = np.uint8, -1 + elif ext in HDR_IMG_EXTENSIONS: + is_hdr = True + else: + is_image = False + + # Get value from key. + buf = self._get_object(key) + + # Decode and return. + if is_image: + # This is totally a hack. + # We should have a better way to handle grayscale images. + img = Image.open(io.BytesIO(buf)) + if mode == 3: + img = img.convert('RGB') + return img + elif is_hdr: + try: + imageio.plugins.freeimage.download() + img = imageio.imread(buf) + except Exception: + print(path) + return img # Return a numpy array + else: + return buf + + def __len__(self): + r"""Return number of keys in LMDB dataset.""" + return self.length diff --git a/imaginaire/datasets/paired_few_shot_videos.py b/imaginaire/datasets/paired_few_shot_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b311bc36fbc05aceaf552a47d3d89b118728be --- /dev/null +++ b/imaginaire/datasets/paired_few_shot_videos.py @@ -0,0 +1,308 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import copy +import random +import torch + +from imaginaire.datasets.paired_videos import Dataset as VideoDataset +from imaginaire.model_utils.fs_vid2vid import select_object +from imaginaire.utils.distributed import master_only_print as print + + +class Dataset(VideoDataset): + r"""Paired video dataset for use in few-shot vid2vid. + + Args: + cfg (Config): Loaded config object. + is_inference (bool): In train or inference mode? + sequence_length (int): What sequence of images to provide? + few_shot_K (int): How many images to provide for few-shot? + """ + + def __init__(self, cfg, is_inference=False, sequence_length=None, + few_shot_K=None, is_test=False): + self.paired = True + # Get initial few shot K. + if few_shot_K is None: + self.few_shot_K = cfg.data.initial_few_shot_K + else: + self.few_shot_K = few_shot_K + # Initialize. + super(Dataset, self).__init__( + cfg, is_inference, sequence_length=sequence_length, is_test=is_test) + + def set_inference_sequence_idx(self, index, k_shot_index, + k_shot_frame_index): + r"""Get frames from this sequence during inference. + + Args: + index (int): Index of inference sequence. + k_shot_index (int): Index of sequence from which k_shot is sampled. + k_shot_frame_index (int): Index of frame to sample. + """ + assert self.is_inference + assert index < len(self.mapping) + assert k_shot_index < len(self.mapping) + assert k_shot_frame_index < len(self.mapping[k_shot_index]) + + self.inference_sequence_idx = index + self.inference_k_shot_sequence_index = k_shot_index + self.inference_k_shot_frame_index = k_shot_frame_index + self.epoch_length = len( + self.mapping[self.inference_sequence_idx]['filenames']) + + def set_sequence_length(self, sequence_length, few_shot_K=None): + r"""Set the length of sequence you want as output from dataloader. + + Args: + sequence_length (int): Length of output sequences. + few_shot_K (int): Number of few-shot frames. + """ + if few_shot_K is None: + few_shot_K = self.few_shot_K + assert isinstance(sequence_length, int) + assert isinstance(few_shot_K, int) + if (sequence_length + few_shot_K) > self.sequence_length_max: + error_message = \ + 'Requested sequence length (%d) ' % (sequence_length) + \ + '+ few shot K (%d) > ' % (few_shot_K) + \ + 'max sequence length (%d). ' % (self.sequence_length_max) + print(error_message) + sequence_length = self.sequence_length_max - few_shot_K + print('Reduced sequence length to %s' % (sequence_length)) + self.sequence_length = sequence_length + self.few_shot_K = few_shot_K + # Recalculate mapping as some sequences might no longer be useful. + self.mapping, self.epoch_length = self._create_mapping() + print('Epoch length:', self.epoch_length) + + def _create_mapping(self): + r"""Creates mapping from idx to key in LMDB. + + Returns: + (tuple): + - self.mapping (dict): Dict of seq_len to list of sequences. + - self.epoch_length (int): Number of samples in an epoch. + """ + # Create dict mapping length to sequence. + length_to_key, num_selected_seq = {}, 0 + has_additional_lists = len(self.additional_lists) > 0 + for lmdb_idx, sequence_list in enumerate(self.sequence_lists): + for sequence_name, filenames in sequence_list.items(): + if len(filenames) >= (self.sequence_length + self.few_shot_K): + if len(filenames) not in length_to_key: + length_to_key[len(filenames)] = [] + if has_additional_lists: + obj_indices = self.additional_lists[lmdb_idx][ + sequence_name] + else: + obj_indices = [0 for _ in range(len(filenames))] + length_to_key[len(filenames)].append({ + 'lmdb_root': self.lmdb_roots[lmdb_idx], + 'lmdb_idx': lmdb_idx, + 'sequence_name': sequence_name, + 'filenames': filenames, + 'obj_indices': obj_indices, + }) + num_selected_seq += 1 + self.mapping = length_to_key + self.epoch_length = num_selected_seq + + # At inference time, we want to use all sequences, + # irrespective of length. + if self.is_inference: + sequence_list = [] + for key, sequences in self.mapping.items(): + sequence_list.extend(sequences) + self.mapping = sequence_list + + return self.mapping, self.epoch_length + + def _sample_keys(self, index): + r"""Gets files to load for this sample. + + Args: + index (int): Index in [0, len(dataset)]. + Returns: + key (dict): + - lmdb_idx (int): Chosen LMDB dataset root. + - sequence_name (str): Chosen sequence in chosen dataset. + - filenames (list of str): Chosen filenames in chosen sequence. + """ + if self.is_inference: + assert index < self.epoch_length + chosen_sequence = self.mapping[self.inference_sequence_idx] + chosen_filenames = [chosen_sequence['filenames'][index]] + chosen_obj_indices = [chosen_sequence['obj_indices'][index]] + k_shot_chosen_sequence = self.mapping[ + self.inference_k_shot_sequence_index] + k_shot_chosen_filenames = [k_shot_chosen_sequence['filenames'][ + self.inference_k_shot_frame_index]] + k_shot_chosen_obj_indices = [k_shot_chosen_sequence['obj_indices'][ + self.inference_k_shot_frame_index]] + # Prepare few shot key. + few_shot_key = copy.deepcopy(k_shot_chosen_sequence) + few_shot_key['filenames'] = k_shot_chosen_filenames + few_shot_key['obj_indices'] = k_shot_chosen_obj_indices + else: + # Pick a time step for temporal augmentation. + time_step = random.randint(1, self.augmentor.max_time_step) + required_sequence_length = 1 + \ + (self.sequence_length - 1) * time_step + + # If step is too large, default to step size of 1. + if required_sequence_length + self.few_shot_K > \ + self.sequence_length_max: + required_sequence_length = self.sequence_length + time_step = 1 + + # Find valid sequences. + valid_sequences = [] + for sequence_length, sequences in self.mapping.items(): + if sequence_length >= required_sequence_length + \ + self.few_shot_K: + valid_sequences.extend(sequences) + + # Pick a sequence. + chosen_sequence = random.choice(valid_sequences) + + # Choose filenames. + max_start_idx = len(chosen_sequence['filenames']) - \ + required_sequence_length + start_idx = random.randint(0, max_start_idx) + end_idx = start_idx + required_sequence_length + chosen_filenames = chosen_sequence['filenames'][ + start_idx:end_idx:time_step] + chosen_obj_indices = chosen_sequence['obj_indices'][ + start_idx:end_idx:time_step] + + # Find the K few shot filenames. + valid_range = list(range(start_idx)) + \ + list(range(end_idx, len(chosen_sequence['filenames']))) + k_shot_chosen = sorted(random.sample(valid_range, self.few_shot_K)) + k_shot_chosen_filenames = [chosen_sequence['filenames'][idx] + for idx in k_shot_chosen] + k_shot_chosen_obj_indices = [chosen_sequence['obj_indices'][idx] + for idx in k_shot_chosen] + assert not (set(chosen_filenames) & set(k_shot_chosen_filenames)) + + assert len(chosen_filenames) == self.sequence_length + assert len(k_shot_chosen_filenames) == self.few_shot_K + + # Prepare few shot key. + few_shot_key = copy.deepcopy(chosen_sequence) + few_shot_key['filenames'] = k_shot_chosen_filenames + few_shot_key['obj_indices'] = k_shot_chosen_obj_indices + + # Prepre output key. + key = copy.deepcopy(chosen_sequence) + key['filenames'] = chosen_filenames + key['obj_indices'] = chosen_obj_indices + return key, few_shot_key + + def _prepare_data(self, keys): + r"""Load data and perform augmentation. + + Args: + keys (dict): Key into LMDB/folder dataset for this item. + Returns: + data (dict): Dict with all chosen data_types. + """ + # Unpack keys. + lmdb_idx = keys['lmdb_idx'] + sequence_name = keys['sequence_name'] + filenames = keys['filenames'] + obj_indices = keys['obj_indices'] + + # Get key and lmdbs. + keys, lmdbs = {}, {} + for data_type in self.dataset_data_types: + keys[data_type] = self._create_sequence_keys( + sequence_name, filenames) + lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx] + + # Load all data for this index. + data = self.load_from_dataset(keys, lmdbs) + + # Apply ops pre augmentation. + data = self.apply_ops(data, self.pre_aug_ops) + + # Select the object in data using the object indices. + data = select_object(data, obj_indices) + + # Do augmentations for images. + data, is_flipped = self.perform_augmentation(data, paired=True, augment_ops=self.augmentor.augment_ops) + + # Create copy of keypoint data types before post aug. + # kp_data = {} + # for data_type in self.keypoint_data_types: + # new_key = data_type + '_xy' + # kp_data[new_key] = copy.deepcopy(data[data_type]) + + # Create copy of keypoint data types before post aug. + kp_data = {} + for data_type in self.keypoint_data_types: + new_key = data_type + '_xy' + kp_data[new_key] = copy.deepcopy(data[data_type]) + + # Apply ops post augmentation. + data = self.apply_ops(data, self.post_aug_ops) + + data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True) + + # Convert images to tensor. + data = self.to_tensor(data) + + # Pack the sequence of images. + for data_type in self.image_data_types: + for idx in range(len(data[data_type])): + data[data_type][idx] = data[data_type][idx].unsqueeze(0) + data[data_type] = torch.cat(data[data_type], dim=0) + + # Add keypoint xy to data. + data.update(kp_data) + + data['is_flipped'] = is_flipped + data['key'] = keys + + return data + + def _getitem(self, index): + r"""Gets selected files. + + Args: + index (int): Index into dataset. + Returns: + data (dict): Dict with all chosen data_types. + """ + # Select a sample from the available data. + keys, few_shot_keys = self._sample_keys(index) + + data = self._prepare_data(keys) + few_shot_data = self._prepare_data(few_shot_keys) + + # Add few shot data into data. + for key, value in few_shot_data.items(): + data['few_shot_' + key] = few_shot_data[key] + + # Apply full data ops. + if self.is_inference: + if index == 0: + pass + elif index < self.cfg.data.num_workers: + data_0 = self._getitem(0) + if 'common_attr' in data_0: + self.common_attr = data['common_attr'] = \ + data_0['common_attr'] + else: + if hasattr(self, 'common_attr'): + data['common_attr'] = self.common_attr + + data = self.apply_ops(data, self.full_data_ops, full_data=True) + + if self.is_inference and index == 0 and 'common_attr' in data: + self.common_attr = data['common_attr'] + + return data diff --git a/imaginaire/datasets/paired_few_shot_videos_native.py b/imaginaire/datasets/paired_few_shot_videos_native.py new file mode 100644 index 0000000000000000000000000000000000000000..17bd23d8b1bd7b8d007f9a9bcf9f13ae6a16bdf0 --- /dev/null +++ b/imaginaire/datasets/paired_few_shot_videos_native.py @@ -0,0 +1,233 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import random +import tempfile +from collections import OrderedDict +import warnings +import numpy as np +import torch +# import torchvision.io as io +import cv2 +from PIL import Image + +from imaginaire.datasets.base import BaseDataset + + +class Dataset(BaseDataset): + r"""Dataset for paired few shot videos. + + Args: + cfg (Config): Loaded config object. + is_inference (bool): In train or inference mode? + """ + + def __init__(self, cfg, is_inference=False, is_test=False): + self.paired = True + super(Dataset, self).__init__(cfg, is_inference, is_test) + self.is_video_dataset = True + self.few_shot_K = 1 + self.first_last_only = getattr(cfg.data, 'first_last_only', False) + self.sample_far_frames_more = getattr(cfg.data, 'sample_far_frames_more', False) + + def get_label_lengths(self): + r"""Get num channels of all labels to be concated. + + Returns: + label_lengths (OrderedDict): Dict mapping image data_type to num + channels. + """ + label_lengths = OrderedDict() + for data_type in self.input_labels: + data_cfg = self.cfgdata + if hasattr(data_cfg, 'one_hot_num_classes') and \ + data_type in data_cfg.one_hot_num_classes: + label_lengths[data_type] = data_cfg.one_hot_num_classes[data_type] + if getattr(data_cfg, 'use_dont_care', False): + label_lengths[data_type] += 1 + else: + label_lengths[data_type] = self.num_channels[data_type] + return label_lengths + + def num_inference_sequences(self): + r"""Number of sequences available for inference. + + Returns: + (int) + """ + assert self.is_inference + return len(self.mapping) + + def _create_mapping(self): + r"""Creates mapping from idx to key in LMDB. + + Returns: + (tuple): + - self.mapping (dict): Dict of seq_len to list of sequences. + - self.epoch_length (int): Number of samples in an epoch. + """ + # Create dict mapping length to sequence. + mapping = [] + for lmdb_idx, sequence_list in enumerate(self.sequence_lists): + for sequence_name, filenames in sequence_list.items(): + for filename in filenames: + # This file is corrupt. + if filename == 'z-KziTO_5so_0019_start0_end85_h596_w596': + continue + mapping.append({ + 'lmdb_root': self.lmdb_roots[lmdb_idx], + 'lmdb_idx': lmdb_idx, + 'sequence_name': sequence_name, + 'filenames': [filename], + }) + self.mapping = mapping + self.epoch_length = len(mapping) + + return self.mapping, self.epoch_length + + def _sample_keys(self, index): + r"""Gets files to load for this sample. + + Args: + index (int): Index in [0, len(dataset)]. + Returns: + (tuple): + - key (dict): + - lmdb_idx (int): Chosen LMDB dataset root. + - sequence_name (str): Chosen sequence in chosen dataset. + - filenames (list of str): Chosen filenames in chosen sequence. + """ + if self.is_inference: + assert index < self.epoch_length + raise NotImplementedError + else: + # Select a video at random. + key = random.choice(self.mapping) + return key + + def _create_sequence_keys(self, sequence_name, filenames): + r"""Create the LMDB key for this piece of information. + + Args: + sequence_name (str): Which sequence from the chosen dataset. + filenames (list of str): List of filenames in this sequence. + Returns: + keys (list): List of full keys. + """ + assert isinstance(filenames, list), 'Filenames should be a list.' + keys = [] + for filename in filenames: + keys.append('%s/%s' % (sequence_name, filename)) + return keys + + def _getitem(self, index): + r"""Gets selected files. + + Args: + index (int): Index into dataset. + concat (bool): Concatenate all items in labels? + Returns: + data (dict): Dict with all chosen data_types. + """ + # Select a sample from the available data. + keys = self._sample_keys(index) + + # Unpack keys. + lmdb_idx = keys['lmdb_idx'] + sequence_name = keys['sequence_name'] + filenames = keys['filenames'] + + # Get key and lmdbs. + keys, lmdbs = {}, {} + for data_type in self.dataset_data_types: + keys[data_type] = self._create_sequence_keys( + sequence_name, filenames) + lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx] + + # Load all data for this index. + data = self.load_from_dataset(keys, lmdbs) + + # Get frames from video. + try: + temp = tempfile.NamedTemporaryFile() + temp.write(data['videos'][0]) + temp.seek(0) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # frames, _, info = io.read_video(temp) + # num_frames = frames.size(0) + cap = cv2.VideoCapture(temp.name) + num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if self.first_last_only: + chosen_idxs = [0, num_frames - 1] + else: + # chosen_idxs = random.sample(range(frames.size(0)), 2) + + chosen_idx = random.sample(range(num_frames), 1)[0] + few_shot_choose_range = list(range(chosen_idx)) + list(range(chosen_idx + 1, num_frames)) + if self.sample_far_frames_more: + choose_weight = list(reversed(range(chosen_idx))) + list(range(num_frames - chosen_idx - 1)) + few_shot_idx = random.choices(few_shot_choose_range, choose_weight, k=self.few_shot_K) + else: + few_shot_idx = random.sample(few_shot_choose_range, k=self.few_shot_K) + chosen_idxs = few_shot_idx + [chosen_idx] + + chosen_images = [] + for idx in chosen_idxs: + # chosen_images.append(Image.fromarray(frames[idx].numpy())) + cap.set(1, idx) + _, frame = cap.read() + chosen_images.append(Image.fromarray(frame[:, :, ::-1])) + except Exception: + print('Issue with file:', sequence_name, filenames) + blank = np.zeros((512, 512, 3), dtype=np.uint8) + chosen_images = [Image.fromarray(blank), Image.fromarray(blank)] + + data['videos'] = chosen_images + + # Apply ops pre augmentation. + data = self.apply_ops(data, self.pre_aug_ops) + + # Do augmentations for images. + data, is_flipped = self.perform_augmentation( + data, paired=True, augment_ops=self.augmentor.augment_ops) + # Individual video frame augmentation is used in face-vid2vid. + data = self.perform_individual_video_frame( + data, self.augmentor.individual_video_frame_augmentation_ops) + + # Apply ops post augmentation. + data = self.apply_ops(data, self.post_aug_ops) + + # Convert images to tensor. + data = self.to_tensor(data) + + # Pack the sequence of images. + for data_type in self.image_data_types: + for idx in range(len(data[data_type])): + data[data_type][idx] = data[data_type][idx].unsqueeze(0) + data[data_type] = torch.cat(data[data_type], dim=0) + + if not self.is_video_dataset: + # Remove any extra dimensions. + for data_type in self.image_data_types: + if data_type in data: + data[data_type] = data[data_type].squeeze(0) + + # Prepare output. + data['driving_images'] = data['videos'][self.few_shot_K:] + data['source_images'] = data['videos'][:self.few_shot_K] + data.pop('videos') + data['is_flipped'] = is_flipped + data['key'] = keys + data['original_h_w'] = torch.IntTensor([ + self.augmentor.original_h, self.augmentor.original_w]) + + # Apply full data ops. + data = self.apply_ops(data, self.full_data_ops, full_data=True) + + return data + + def __getitem__(self, index): + return self._getitem(index) diff --git a/imaginaire/datasets/paired_images.py b/imaginaire/datasets/paired_images.py new file mode 100644 index 0000000000000000000000000000000000000000..b3112a68de57e7c85d031ee690b62ac1fbc4f96d --- /dev/null +++ b/imaginaire/datasets/paired_images.py @@ -0,0 +1,87 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md + +from imaginaire.datasets.paired_videos import Dataset as VideoDataset + + +class Dataset(VideoDataset): + r"""Paired image dataset for use in pix2pixHD, SPADE. + + Args: + cfg (Config): Loaded config object. + is_inference (bool): In train or inference mode? + """ + + def __init__(self, cfg, is_inference=False, is_test=False): + self.paired = True + super(Dataset, self).__init__(cfg, is_inference, + sequence_length=1, + is_test=is_test) + self.is_video_dataset = False + + def _create_mapping(self): + r"""Creates mapping from idx to key in LMDB. + + Returns: + (tuple): + - self.mapping (list): List mapping idx to key. + - self.epoch_length (int): Number of samples in an epoch. + """ + idx_to_key = [] + for lmdb_idx, sequence_list in enumerate(self.sequence_lists): + for sequence_name, filenames in sequence_list.items(): + for filename in filenames: + idx_to_key.append({ + 'lmdb_root': self.lmdb_roots[lmdb_idx], + 'lmdb_idx': lmdb_idx, + 'sequence_name': sequence_name, + 'filenames': [filename], + }) + self.mapping = idx_to_key + self.epoch_length = len(self.mapping) + return self.mapping, self.epoch_length + + def _sample_keys(self, index): + r"""Gets files to load for this sample. + + Args: + index (int): Index in [0, len(dataset)]. + Returns: + key (dict): + - lmdb_idx (int): Chosen LMDB dataset root. + - sequence_name (str): Chosen sequence in chosen dataset. + - filenames (list of str): Chosen filenames in chosen sequence. + """ + assert self.sequence_length == 1, \ + 'Image dataset can only have sequence length = 1, not %d' % ( + self.sequence_length) + return self.mapping[index] + + def set_sequence_length(self, sequence_length): + r"""Set the length of sequence you want as output from dataloader. + Ignore this as this is an image loader. + + Args: + sequence_length (int): Length of output sequences. + """ + pass + + def set_inference_sequence_idx(self, index): + r"""Get frames from this sequence during inference. + Overriden from super as this is not applicable for images. + + Args: + index (int): Index of inference sequence. + """ + raise RuntimeError('Image dataset does not have sequences.') + + def num_inference_sequences(self): + r"""Number of sequences available for inference. + Overriden from super as this is not applicable for images. + + Returns: + (int) + """ + raise RuntimeError('Image dataset does not have sequences.') diff --git a/imaginaire/datasets/paired_videos.py b/imaginaire/datasets/paired_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..38e5c645f9de00184fa6cf7cc0b6b910c5815417 --- /dev/null +++ b/imaginaire/datasets/paired_videos.py @@ -0,0 +1,288 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import copy +import random +from collections import OrderedDict + +import torch + +from imaginaire.datasets.base import BaseDataset +from imaginaire.model_utils.fs_vid2vid import select_object +from imaginaire.utils.distributed import master_only_print as print + + +class Dataset(BaseDataset): + r"""Paired video dataset for use in vid2vid, wc_vid2vid. + + Args: + cfg (Config): Loaded config object. + is_inference (bool): In train or inference mode? + sequence_length (int): What sequence of images to provide? + """ + + def __init__(self, cfg, + is_inference=False, + sequence_length=None, + is_test=False): + self.paired = True + # Get initial sequence length. + if sequence_length is None and not is_inference: + self.sequence_length = cfg.data.train.initial_sequence_length + elif sequence_length is None and is_inference: + self.sequence_length = 2 + else: + self.sequence_length = sequence_length + super(Dataset, self).__init__(cfg, is_inference, is_test) + self.set_sequence_length(self.sequence_length) + self.is_video_dataset = True + + def get_label_lengths(self): + r"""Get num channels of all labels to be concated. + + Returns: + label_lengths (OrderedDict): Dict mapping image data_type to num + channels. + """ + label_lengths = OrderedDict() + for data_type in self.input_labels: + data_cfg = self.cfgdata + if hasattr(data_cfg, 'one_hot_num_classes') and data_type in data_cfg.one_hot_num_classes: + label_lengths[data_type] = data_cfg.one_hot_num_classes[data_type] + if getattr(data_cfg, 'use_dont_care', False): + label_lengths[data_type] += 1 + else: + label_lengths[data_type] = self.num_channels[data_type] + return label_lengths + + def num_inference_sequences(self): + r"""Number of sequences available for inference. + + Returns: + (int) + """ + assert self.is_inference + return len(self.mapping) + + def set_inference_sequence_idx(self, index): + r"""Get frames from this sequence during inference. + + Args: + index (int): Index of inference sequence. + """ + assert self.is_inference + assert index < len(self.mapping) + self.inference_sequence_idx = index + self.epoch_length = len( + self.mapping[self.inference_sequence_idx]['filenames']) + + def set_sequence_length(self, sequence_length): + r"""Set the length of sequence you want as output from dataloader. + + Args: + sequence_length (int): Length of output sequences. + """ + assert isinstance(sequence_length, int) + if sequence_length > self.sequence_length_max: + print('Requested sequence length (%d) > ' % (sequence_length) + + 'max sequence length (%d). ' % (self.sequence_length_max) + + 'Limiting sequence length to max sequence length.') + sequence_length = self.sequence_length_max + self.sequence_length = sequence_length + # Recalculate mapping as some sequences might no longer be useful. + self.mapping, self.epoch_length = self._create_mapping() + print('Epoch length:', self.epoch_length) + + def _compute_dataset_stats(self): + r"""Compute statistics of video sequence dataset. + + Returns: + sequence_length_max (int): Maximum sequence length. + """ + print('Num datasets:', len(self.sequence_lists)) + + if self.sequence_length >= 1: + num_sequences, sequence_length_max = 0, 0 + for sequence in self.sequence_lists: + for _, filenames in sequence.items(): + sequence_length_max = max( + sequence_length_max, len(filenames)) + num_sequences += 1 + print('Num sequences:', num_sequences) + print('Max sequence length:', sequence_length_max) + self.sequence_length_max = sequence_length_max + + def _create_mapping(self): + r"""Creates mapping from idx to key in LMDB. + + Returns: + (tuple): + - self.mapping (dict): Dict of seq_len to list of sequences. + - self.epoch_length (int): Number of samples in an epoch. + """ + # Create dict mapping length to sequence. + length_to_key, num_selected_seq = {}, 0 + total_num_of_frames = 0 + for lmdb_idx, sequence_list in enumerate(self.sequence_lists): + for sequence_name, filenames in sequence_list.items(): + if len(filenames) >= self.sequence_length: + total_num_of_frames += len(filenames) + if len(filenames) not in length_to_key: + length_to_key[len(filenames)] = [] + length_to_key[len(filenames)].append({ + 'lmdb_root': self.lmdb_roots[lmdb_idx], + 'lmdb_idx': lmdb_idx, + 'sequence_name': sequence_name, + 'filenames': filenames, + }) + num_selected_seq += 1 + self.mapping = length_to_key + self.epoch_length = num_selected_seq + if not self.is_inference and self.epoch_length < \ + self.cfgdata.train.batch_size * 8: + self.epoch_length = total_num_of_frames + + # At inference time, we want to use all sequences, + # irrespective of length. + if self.is_inference: + sequence_list = [] + for key, sequences in self.mapping.items(): + sequence_list.extend(sequences) + self.mapping = sequence_list + + return self.mapping, self.epoch_length + + def _sample_keys(self, index): + r"""Gets files to load for this sample. + + Args: + index (int): Index in [0, len(dataset)]. + Returns: + key (dict): + - lmdb_idx (int): Chosen LMDB dataset root. + - sequence_name (str): Chosen sequence in chosen dataset. + - filenames (list of str): Chosen filenames in chosen sequence. + """ + if self.is_inference: + assert index < self.epoch_length + chosen_sequence = self.mapping[self.inference_sequence_idx] + chosen_filenames = [chosen_sequence['filenames'][index]] + else: + # Pick a time step for temporal augmentation. + time_step = random.randint(1, self.augmentor.max_time_step) + required_sequence_length = 1 + \ + (self.sequence_length - 1) * time_step + + # If step is too large, default to step size of 1. + if required_sequence_length > self.sequence_length_max: + required_sequence_length = self.sequence_length + time_step = 1 + + # Find valid sequences. + valid_sequences = [] + for sequence_length, sequences in self.mapping.items(): + if sequence_length >= required_sequence_length: + valid_sequences.extend(sequences) + + # Pick a sequence. + chosen_sequence = random.choice(valid_sequences) + + # Choose filenames. + max_start_idx = len(chosen_sequence['filenames']) - \ + required_sequence_length + start_idx = random.randint(0, max_start_idx) + + chosen_filenames = chosen_sequence['filenames'][ + start_idx:start_idx + required_sequence_length:time_step] + assert len(chosen_filenames) == self.sequence_length + + # Prepre output key. + key = copy.deepcopy(chosen_sequence) + key['filenames'] = chosen_filenames + return key + + def _create_sequence_keys(self, sequence_name, filenames): + r"""Create the LMDB key for this piece of information. + + Args: + sequence_name (str): Which sequence from the chosen dataset. + filenames (list of str): List of filenames in this sequence. + Returns: + keys (list): List of full keys. + """ + assert isinstance(filenames, list), 'Filenames should be a list.' + keys = [] + if sequence_name.endswith('___') and sequence_name[-9:-6] == '___': + sequence_name = sequence_name[:-9] + for filename in filenames: + keys.append('%s/%s' % (sequence_name, filename)) + return keys + + def _getitem(self, index): + r"""Gets selected files. + + Args: + index (int): Index into dataset. + concat (bool): Concatenate all items in labels? + Returns: + data (dict): Dict with all chosen data_types. + """ + # Select a sample from the available data. + keys = self._sample_keys(index) + + # Unpack keys. + lmdb_idx = keys['lmdb_idx'] + sequence_name = keys['sequence_name'] + filenames = keys['filenames'] + + # Get key and lmdbs. + keys, lmdbs = {}, {} + for data_type in self.dataset_data_types: + keys[data_type] = self._create_sequence_keys( + sequence_name, filenames) + lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx] + + # Load all data for this index. + data = self.load_from_dataset(keys, lmdbs) + + # Apply ops pre augmentation. + data = self.apply_ops(data, self.pre_aug_ops) + + # If multiple subjects exist in the data, only pick one to synthesize. + data = select_object(data, obj_indices=None) + + # Do augmentations for images. + data, is_flipped = self.perform_augmentation(data, paired=True, augment_ops=self.augmentor.augment_ops) + + # Apply ops post augmentation. + data = self.apply_ops(data, self.post_aug_ops) + data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True) + + # Convert images to tensor. + data = self.to_tensor(data) + + # Pack the sequence of images. + for data_type in self.image_data_types + self.hdr_image_data_types: + for idx in range(len(data[data_type])): + data[data_type][idx] = data[data_type][idx].unsqueeze(0) + data[data_type] = torch.cat(data[data_type], dim=0) + + if not self.is_video_dataset: + # Remove any extra dimensions. + for data_type in self.data_types: + if data_type in data: + data[data_type] = data[data_type].squeeze(0) + + data['is_flipped'] = is_flipped + data['key'] = keys + data['original_h_w'] = torch.IntTensor([ + self.augmentor.original_h, self.augmentor.original_w]) + + # Apply full data ops. + data = self.apply_ops(data, self.full_data_ops, full_data=True) + + return data + + def __getitem__(self, index): + return self._getitem(index) diff --git a/imaginaire/datasets/unpaired_few_shot_images.py b/imaginaire/datasets/unpaired_few_shot_images.py new file mode 100644 index 0000000000000000000000000000000000000000..e8fa1e5ec2cd528c72effd065338fa693d1319c8 --- /dev/null +++ b/imaginaire/datasets/unpaired_few_shot_images.py @@ -0,0 +1,182 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import random + +from imaginaire.datasets.base import BaseDataset + + +class Dataset(BaseDataset): + r"""Image dataset for use in FUNIT. + + Args: + cfg (Config): Loaded config object. + is_inference (bool): In train or inference mode? + """ + + def __init__(self, cfg, is_inference=False, is_test=False): + self.paired = False + super(Dataset, self).__init__(cfg, is_inference, is_test) + self.num_content_classes = len(self.class_name_to_idx['images_content']) + self.num_style_classes = len(self.class_name_to_idx['images_style']) + self.sample_class_idx = None + self.content_offset = 8888 + self.content_interval = 100 + + def set_sample_class_idx(self, class_idx=None): + r"""Set sample class idx. + + Args: + class_idx (int): Which class idx to sample from. + """ + self.sample_class_idx = class_idx + if class_idx is None: + self.epoch_length = \ + max([len(lmdb_keys) for _, lmdb_keys in self.mapping.items()]) + else: + self.epoch_length = \ + len(self.mapping_class['images_style'][class_idx]) + + def _create_mapping(self): + r"""Creates mapping from idx to key in LMDB. + + Returns: + (tuple): + - self.mapping (dict): Dict with data type as key mapping idx to + LMDB key. + - self.epoch_length (int): Number of samples in an epoch. + """ + idx_to_key, class_names = {}, {} + for lmdb_idx, sequence_list in enumerate(self.sequence_lists): + for data_type, data_type_sequence_list in sequence_list.items(): + class_names[data_type] = [] + if data_type not in idx_to_key: + idx_to_key[data_type] = [] + for sequence_name, filenames in data_type_sequence_list.items(): + class_name = sequence_name.split('/')[0] + for filename in filenames: + idx_to_key[data_type].append({ + 'lmdb_root': self.lmdb_roots[lmdb_idx], + 'lmdb_idx': lmdb_idx, + 'sequence_name': sequence_name, + 'filename': filename, + 'class_name': class_name + }) + class_names[data_type].append(class_name) + self.mapping = idx_to_key + self.epoch_length = max([len(lmdb_keys) + for _, lmdb_keys in self.mapping.items()]) + + # Create mapping from class name to class idx. + self.class_name_to_idx = {} + for data_type, class_names_data_type in class_names.items(): + self.class_name_to_idx[data_type] = {} + class_names_data_type = sorted(list(set(class_names_data_type))) + for class_idx, class_name in enumerate(class_names_data_type): + self.class_name_to_idx[data_type][class_name] = class_idx + + # Add class idx to mapping. + for data_type in self.mapping: + for key in self.mapping[data_type]: + key['class_idx'] = \ + self.class_name_to_idx[data_type][key['class_name']] + + # Create a mapping from index to lmdb key for each class. + idx_to_key_class = {} + for data_type in self.mapping: + idx_to_key_class[data_type] = {} + for class_idx, class_name in enumerate(class_names[data_type]): + idx_to_key_class[data_type][class_idx] = [] + for key in self.mapping[data_type]: + idx_to_key_class[data_type][key['class_idx']].append(key) + self.mapping_class = idx_to_key_class + + return self.mapping, self.epoch_length + + def _sample_keys(self, index): + r"""Gets files to load for this sample. + + Args: + index (int): Index in [0, len(dataset)]. + Returns: + (tuple): + - keys (dict): Each key of this dict is a data type. + - lmdb_key (dict): + - lmdb_idx (int): Chosen LMDB dataset root. + - sequence_name (str): Chosen sequence in chosen dataset. + - filename (str): Chosen filename in chosen sequence. + """ + + keys = {} + if self.is_inference: # evaluation mode + lmdb_keys_content = self.mapping['images_content'] + keys['images_content'] = \ + lmdb_keys_content[ + ((index + self.content_offset * self.sample_class_idx) * + self.content_interval) % len(lmdb_keys_content)] + + lmdb_keys_style = \ + self.mapping_class['images_style'][self.sample_class_idx] + keys['images_style'] = lmdb_keys_style[index] + else: + lmdb_keys_content = self.mapping['images_content'] + lmdb_keys_style = self.mapping['images_style'] + keys['images_content'] = random.choice(lmdb_keys_content) + keys['images_style'] = random.choice(lmdb_keys_style) + return keys + + def __getitem__(self, index): + r"""Gets selected files. + + Args: + index (int): Index into dataset. + concat (bool): Concatenate all items in labels? + Returns: + data (dict): Dict with all chosen data_types. + """ + # Select a sample from the available data. + keys_per_data_type = self._sample_keys(index) + + # Get class idx into a list. + class_idxs = [] + for data_type in keys_per_data_type: + class_idxs.append(keys_per_data_type[data_type]['class_idx']) + + # Get keys and lmdbs. + keys, lmdbs = {}, {} + for data_type in self.dataset_data_types: + # Unpack keys. + lmdb_idx = keys_per_data_type[data_type]['lmdb_idx'] + sequence_name = keys_per_data_type[data_type]['sequence_name'] + filename = keys_per_data_type[data_type]['filename'] + keys[data_type] = '%s/%s' % (sequence_name, filename) + lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx] + + # Load all data for this index. + data = self.load_from_dataset(keys, lmdbs) + + # Apply ops pre augmentation. + data = self.apply_ops(data, self.pre_aug_ops) + + # Do augmentations for images. + data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops) + + # Apply ops post augmentation. + data = self.apply_ops(data, self.post_aug_ops) + data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True) + + # Convert images to tensor. + data = self.to_tensor(data) + + # Remove any extra dimensions. + for data_type in self.image_data_types: + data[data_type] = data[data_type][0] + + # Package output. + data['is_flipped'] = is_flipped + data['key'] = keys_per_data_type + data['labels_content'] = class_idxs[0] + data['labels_style'] = class_idxs[1] + + return data diff --git a/imaginaire/datasets/unpaired_images.py b/imaginaire/datasets/unpaired_images.py new file mode 100644 index 0000000000000000000000000000000000000000..68a49a876705771a2d7a3f836bdbee2cdd328c10 --- /dev/null +++ b/imaginaire/datasets/unpaired_images.py @@ -0,0 +1,118 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import random + +from imaginaire.datasets.base import BaseDataset + + +class Dataset(BaseDataset): + r"""Unpaired image dataset for use in MUNIT. + + Args: + cfg (Config): Loaded config object. + is_inference (bool): In train or inference mode? + """ + + def __init__(self, cfg, is_inference=False, is_test=False): + self.paired = False + super(Dataset, self).__init__(cfg, is_inference, is_test) + + def _create_mapping(self): + r"""Creates mapping from idx to key in LMDB. + + Returns: + (tuple): + - self.mapping (dict): Dict with data type as key mapping idx to + LMDB key. + - self.epoch_length (int): Number of samples in an epoch. + """ + idx_to_key = {} + for lmdb_idx, sequence_list in enumerate(self.sequence_lists): + for data_type, data_type_sequence_list in sequence_list.items(): + if data_type not in idx_to_key: + idx_to_key[data_type] = [] + for sequence_name, filenames in data_type_sequence_list.items(): + for filename in filenames: + idx_to_key[data_type].append({ + 'lmdb_root': self.lmdb_roots[lmdb_idx], + 'lmdb_idx': lmdb_idx, + 'sequence_name': sequence_name, + 'filename': filename, + }) + self.mapping = idx_to_key + self.epoch_length = max([len(lmdb_keys) + for _, lmdb_keys in self.mapping.items()]) + return self.mapping, self.epoch_length + + def _sample_keys(self, index): + r"""Gets files to load for this sample. + + Args: + index (int): Index in [0, len(dataset)]. + Returns: + keys (dict): Each key of this dict is a data type. + lmdb_key (dict): + lmdb_idx (int): Chosen LMDB dataset root. + sequence_name (str): Chosen sequence in chosen dataset. + filename (str): Chosen filename in chosen sequence. + """ + keys = {} + for data_type in self.data_types: + lmdb_keys = self.mapping[data_type] + if self.is_inference: + # Modulo ensures valid indexing in case A and B have different + # number of files. + keys[data_type] = lmdb_keys[index % len(lmdb_keys)] + else: + keys[data_type] = random.choice(lmdb_keys) + return keys + + def __getitem__(self, index): + r"""Gets selected files. + + Args: + index (int): Index into dataset. + concat (bool): Concatenate all items in labels? + Returns: + data (dict): Dict with all chosen data_types. + """ + # Select a sample from the available data. + keys_per_data_type = self._sample_keys(index) + + # Get keys and lmdbs. + keys, lmdbs = {}, {} + for data_type in self.dataset_data_types: + # Unpack keys. + lmdb_idx = keys_per_data_type[data_type]['lmdb_idx'] + sequence_name = keys_per_data_type[data_type]['sequence_name'] + filename = keys_per_data_type[data_type]['filename'] + keys[data_type] = '%s/%s' % (sequence_name, filename) + lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx] + + # Load all data for this index. + data = self.load_from_dataset(keys, lmdbs) + + # Apply ops pre augmentation. + data = self.apply_ops(data, self.pre_aug_ops) + + # Do augmentations for images. + data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops) + + # Apply ops post augmentation. + data = self.apply_ops(data, self.post_aug_ops) + data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True) + + # Convert images to tensor. + data = self.to_tensor(data) + + # Remove any extra dimensions. + for data_type in self.image_data_types: + data[data_type] = data[data_type][0] + + # Package output. + data['is_flipped'] = is_flipped + data['key'] = keys_per_data_type + + return data diff --git a/imaginaire/discriminators/__init__.py b/imaginaire/discriminators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/discriminators/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/discriminators/dummy.py b/imaginaire/discriminators/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..a345806f6f844fbcf0c9da1915f4db5b2fa3d587 --- /dev/null +++ b/imaginaire/discriminators/dummy.py @@ -0,0 +1,29 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch.nn as nn + +from imaginaire.layers import LinearBlock + + +class Discriminator(nn.Module): + """Dummy Discriminator constructor. + + Args: + dis_cfg (obj): Discriminator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file + """ + + def __init__(self, dis_cfg, data_cfg): + super(Discriminator, self).__init__() + self.dummy_layer = LinearBlock(1, 1) + pass + + def forward(self, data): + """Dummy discriminator forward. + + Args: + data (dict): + """ + return diff --git a/imaginaire/discriminators/fpse.py b/imaginaire/discriminators/fpse.py new file mode 100644 index 0000000000000000000000000000000000000000..231b666bfd3970c760df7ce5d1174193ad9a7708 --- /dev/null +++ b/imaginaire/discriminators/fpse.py @@ -0,0 +1,132 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from imaginaire.layers import Conv2dBlock + + +class FPSEDiscriminator(nn.Module): + r"""# Feature-Pyramid Semantics Embedding Discriminator. This is a copy + of the discriminator in https://arxiv.org/pdf/1910.06809.pdf + """ + + def __init__(self, + num_input_channels, + num_labels, + num_filters, + kernel_size, + weight_norm_type, + activation_norm_type): + super().__init__() + padding = int(np.ceil((kernel_size - 1.0) / 2)) + nonlinearity = 'leakyrelu' + stride1_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=1, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + down_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=2, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + latent_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=1, + stride=1, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + # bottom-up pathway + + self.enc1 = down_conv2d_block(num_input_channels, num_filters) + self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) + self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) + self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) + self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) + + # top-down pathway + self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters) + self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters) + self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters) + self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters) + + # upsampling + self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear', + align_corners=False) + + # final layers + self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) + self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) + self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) + + # true/false prediction and semantic alignment prediction + self.output = Conv2dBlock(num_filters * 2, 1, kernel_size=1) + self.seg = Conv2dBlock(num_filters * 2, num_filters * 2, kernel_size=1) + self.embedding = Conv2dBlock(num_labels, num_filters * 2, kernel_size=1) + + def forward(self, images, segmaps): + r""" + + Args: + images: image tensors. + segmaps: segmentation map tensors. + """ + # bottom-up pathway + feat11 = self.enc1(images) + feat12 = self.enc2(feat11) + feat13 = self.enc3(feat12) + feat14 = self.enc4(feat13) + feat15 = self.enc5(feat14) + # top-down pathway and lateral connections + feat25 = self.lat5(feat15) + feat24 = self.upsample2x(feat25) + self.lat4(feat14) + feat23 = self.upsample2x(feat24) + self.lat3(feat13) + feat22 = self.upsample2x(feat23) + self.lat2(feat12) + # final prediction layers + feat32 = self.final2(feat22) + feat33 = self.final3(feat23) + feat34 = self.final4(feat24) + # Patch-based True/False prediction + pred2 = self.output(feat32) + pred3 = self.output(feat33) + pred4 = self.output(feat34) + seg2 = self.seg(feat32) + seg3 = self.seg(feat33) + seg4 = self.seg(feat34) + + # # segmentation map embedding + segembs = self.embedding(segmaps) + segembs = F.avg_pool2d(segembs, kernel_size=2, stride=2) + segembs2 = F.avg_pool2d(segembs, kernel_size=2, stride=2) + segembs3 = F.avg_pool2d(segembs2, kernel_size=2, stride=2) + segembs4 = F.avg_pool2d(segembs3, kernel_size=2, stride=2) + + # semantics embedding discriminator score + pred2 += torch.mul(segembs2, seg2).sum(dim=1, keepdim=True) + pred3 += torch.mul(segembs3, seg3).sum(dim=1, keepdim=True) + pred4 += torch.mul(segembs4, seg4).sum(dim=1, keepdim=True) + + # concat results from multiple resolutions + # results = [pred2, pred3, pred4] + + return pred2, pred3, pred4 diff --git a/imaginaire/discriminators/fs_vid2vid.py b/imaginaire/discriminators/fs_vid2vid.py new file mode 100644 index 0000000000000000000000000000000000000000..78e29f64a301864f03238db7bd1ad444bee9773e --- /dev/null +++ b/imaginaire/discriminators/fs_vid2vid.py @@ -0,0 +1,318 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import importlib + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator +from imaginaire.model_utils.fs_vid2vid import get_fg_mask, pick_image +from imaginaire.utils.data import (get_paired_input_image_channel_number, + get_paired_input_label_channel_number) +from imaginaire.utils.misc import get_nested_attr + + +class Discriminator(nn.Module): + r"""Image and video discriminator constructor. + + Args: + dis_cfg (obj): Discriminator part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file + """ + + def __init__(self, dis_cfg, data_cfg): + super().__init__() + self.data_cfg = data_cfg + num_input_channels = get_paired_input_label_channel_number(data_cfg) + if num_input_channels == 0: + num_input_channels = getattr(data_cfg, 'label_channels', 1) + num_img_channels = get_paired_input_image_channel_number(data_cfg) + self.num_frames_D = data_cfg.num_frames_D + self.num_scales = get_nested_attr(dis_cfg, 'temporal.num_scales', 0) + num_netD_input_channels = (num_input_channels + num_img_channels) + self.use_few_shot = 'few_shot' in data_cfg.type + if self.use_few_shot: + num_netD_input_channels *= 2 + self.net_D = MultiPatchDiscriminator(dis_cfg.image, + num_netD_input_channels) + + self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None) + if self.add_dis_cfg is not None: + for name in self.add_dis_cfg: + add_dis_cfg = self.add_dis_cfg[name] + num_ch = num_img_channels * (2 if self.use_few_shot else 1) + setattr(self, 'net_D_' + name, + MultiPatchDiscriminator(add_dis_cfg, num_ch)) + + # Temporal discriminator. + self.num_netDT_input_channels = num_img_channels * self.num_frames_D + for n in range(self.num_scales): + setattr(self, 'net_DT%d' % n, + MultiPatchDiscriminator(dis_cfg.temporal, + self.num_netDT_input_channels)) + self.has_fg = getattr(data_cfg, 'has_foreground', False) + + def forward(self, data, net_G_output, past_frames): + r"""Discriminator forward. + + Args: + data (dict): Input data. + net_G_output (dict): Generator output. + past_frames (list of tensors): Past real frames / generator outputs. + Returns: + (tuple): + - output (dict): Discriminator output. + - past_frames (list of tensors): New past frames by adding + current outputs. + """ + label, real_image = data['label'], data['image'] + # Only operate on the latest output frame. + if label.dim() == 5: + label = label[:, -1] + if self.use_few_shot: + # Pick only one reference image to concat with. + ref_idx = net_G_output['ref_idx'] \ + if 'ref_idx' in net_G_output else 0 + ref_label = pick_image(data['ref_labels'], ref_idx) + ref_image = pick_image(data['ref_images'], ref_idx) + # Concat references with label map as discriminator input. + label = torch.cat([label, ref_label, ref_image], dim=1) + fake_image = net_G_output['fake_images'] + output = dict() + + # Individual frame loss. + pred_real, pred_fake = self.discrminate_image(self.net_D, label, + real_image, fake_image) + output['indv'] = dict() + output['indv']['pred_real'] = pred_real + output['indv']['pred_fake'] = pred_fake + + if 'fake_raw_images' in net_G_output and \ + net_G_output['fake_raw_images'] is not None: + # Raw generator output loss. + fake_raw_image = net_G_output['fake_raw_images'] + fg_mask = get_fg_mask(data['label'], self.has_fg) + pred_real, pred_fake = self.discrminate_image( + self.net_D, label, + real_image * fg_mask, + fake_raw_image * fg_mask) + output['raw'] = dict() + output['raw']['pred_real'] = pred_real + output['raw']['pred_fake'] = pred_fake + + # Additional GAN loss on specific regions. + if self.add_dis_cfg is not None: + for name in self.add_dis_cfg: + # Crop corresponding regions in the image according to the + # crop function. + add_dis_cfg = self.add_dis_cfg[name] + file, crop_func = add_dis_cfg.crop_func.split('::') + file = importlib.import_module(file) + crop_func = getattr(file, crop_func) + + real_crop = crop_func(self.data_cfg, real_image, label) + fake_crop = crop_func(self.data_cfg, fake_image, label) + if self.use_few_shot: + ref_crop = crop_func(self.data_cfg, ref_image, label) + if ref_crop is not None: + real_crop = torch.cat([real_crop, ref_crop], dim=1) + fake_crop = torch.cat([fake_crop, ref_crop], dim=1) + + # Feed the crops to specific discriminator. + if fake_crop is not None: + net_D = getattr(self, 'net_D_' + name) + pred_real, pred_fake = \ + self.discrminate_image(net_D, None, + real_crop, fake_crop) + else: + pred_real = pred_fake = None + output[name] = dict() + output[name]['pred_real'] = pred_real + output[name]['pred_fake'] = pred_fake + + # Temporal loss. + past_frames, skipped_frames = \ + get_all_skipped_frames(past_frames, [real_image, fake_image], + self.num_scales, self.num_frames_D) + + for scale in range(self.num_scales): + real_image, fake_image = \ + [skipped_frame[scale] for skipped_frame in skipped_frames] + pred_real, pred_fake = self.discriminate_video(real_image, + fake_image, scale) + output['temporal_%d' % scale] = dict() + output['temporal_%d' % scale]['pred_real'] = pred_real + output['temporal_%d' % scale]['pred_fake'] = pred_fake + + return output, past_frames + + def discrminate_image(self, net_D, real_A, real_B, fake_B): + r"""Discriminate individual images. + + Args: + net_D (obj): Discriminator network. + real_A (NxC1xHxW tensor): Input label map. + real_B (NxC2xHxW tensor): Real image. + fake_B (NxC2xHxW tensor): Fake image. + Returns: + (tuple): + - pred_real (NxC3xH2xW2 tensor): Output of net_D for real images. + - pred_fake (NxC3xH2xW2 tensor): Output of net_D for fake images. + """ + if real_A is not None: + real_AB = torch.cat([real_A, real_B], dim=1) + fake_AB = torch.cat([real_A, fake_B], dim=1) + else: + real_AB, fake_AB = real_B, fake_B + + pred_real = net_D.forward(real_AB) + pred_fake = net_D.forward(fake_AB) + return pred_real, pred_fake + + def discriminate_video(self, real_B, fake_B, scale): + r"""Discriminate a sequence of images. + + Args: + real_B (NxCxHxW tensor): Real image. + fake_B (NxCxHxW tensor): Fake image. + scale (int): Temporal scale. + Returns: + (tuple): + - pred_real (NxC2xH2xW2 tensor): Output of net_D for real images. + - pred_fake (NxC2xH2xW2 tensor): Output of net_D for fake images. + """ + if real_B is None: + return None, None + net_DT = getattr(self, 'net_DT%d' % scale) + height, width = real_B.shape[-2:] + real_B = real_B.view(-1, self.num_netDT_input_channels, height, width) + fake_B = fake_B.view(-1, self.num_netDT_input_channels, height, width) + + pred_real = net_DT.forward(real_B) + pred_fake = net_DT.forward(fake_B) + return pred_real, pred_fake + + +def get_all_skipped_frames(past_frames, new_frames, t_scales, tD): + r"""Get temporally skipped frames from the input frames. + + Args: + past_frames (list of tensors): Past real frames / generator outputs. + new_frames (list of tensors): Current real frame / generated output. + t_scales (int): Temporal scale. + tD (int): Number of frames as input to the temporal discriminator. + Returns: + (tuple): + - new_past_frames (list of tensors): Past + current frames. + - skipped_frames (list of tensors): Temporally skipped frames using + the given t_scales. + """ + new_past_frames, skipped_frames = [], [] + for past_frame, new_frame in zip(past_frames, new_frames): + skipped_frame = None + if t_scales > 0: + past_frame, skipped_frame = \ + get_skipped_frames(past_frame, new_frame.unsqueeze(1), + t_scales, tD) + new_past_frames.append(past_frame) + skipped_frames.append(skipped_frame) + return new_past_frames, skipped_frames + + +def get_skipped_frames(all_frames, frame, t_scales, tD): + r"""Get temporally skipped frames from the input frames. + + Args: + all_frames (NxTxCxHxW tensor): All past frames. + frame (Nx1xCxHxW tensor): Current frame. + t_scales (int): Temporal scale. + tD (int): Number of frames as input to the temporal discriminator. + Returns: + (tuple): + - all_frames (NxTxCxHxW tensor): Past + current frames. + - skipped_frames (list of NxTxCxHxW tensors): Temporally skipped + frames. + """ + all_frames = torch.cat([all_frames.detach(), frame], dim=1) \ + if all_frames is not None else frame + skipped_frames = [None] * t_scales + for s in range(t_scales): + # Number of skipped frames between neighboring frames (e.g. 1, 3, 9,...) + t_step = tD ** s + # Number of frames the final triplet frames span before skipping + # (e.g., 2, 6, 18, ...). + t_span = t_step * (tD-1) + if all_frames.size(1) > t_span: + skipped_frames[s] = all_frames[:, -(t_span+1)::t_step].contiguous() + + # Maximum number of past frames we need to keep track of. + max_num_prev_frames = (tD ** (t_scales-1)) * (tD-1) + # Remove past frames that are older than this number. + if all_frames.size()[1] > max_num_prev_frames: + all_frames = all_frames[:, -max_num_prev_frames:] + return all_frames, skipped_frames + + +class MultiPatchDiscriminator(nn.Module): + r"""Multi-resolution patch discriminator. + + Args: + dis_cfg (obj): Discriminator part of the yaml config file. + num_input_channels (int): Number of input channels. + """ + + def __init__(self, dis_cfg, num_input_channels): + super(MultiPatchDiscriminator, self).__init__() + kernel_size = getattr(dis_cfg, 'kernel_size', 4) + num_filters = getattr(dis_cfg, 'num_filters', 64) + max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) + num_discriminators = getattr(dis_cfg, 'num_discriminators', 3) + num_layers = getattr(dis_cfg, 'num_layers', 3) + activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') + weight_norm_type = getattr(dis_cfg, 'weight_norm_type', + 'spectral_norm') + self.nets_discriminator = [] + for i in range(num_discriminators): + net_discriminator = NLayerPatchDiscriminator( + kernel_size, + num_input_channels, + num_filters, + num_layers, + max_num_filters, + activation_norm_type, + weight_norm_type) + self.add_module('discriminator_%d' % i, net_discriminator) + self.nets_discriminator.append(net_discriminator) + + def forward(self, input_x): + r"""Multi-resolution patch discriminator forward. + + Args: + input_x (N x C x H x W tensor) : Concatenation of images and + semantic representations. + Returns: + (dict): + - output (list): list of output tensors produced by individual + patch discriminators. + - features (list): list of lists of features produced by + individual patch discriminators. + """ + output_list = [] + features_list = [] + input_downsampled = input_x + for name, net_discriminator in self.named_children(): + if not name.startswith('discriminator_'): + continue + output, features = net_discriminator(input_downsampled) + output_list.append(output) + features_list.append(features) + input_downsampled = F.interpolate( + input_downsampled, scale_factor=0.5, mode='bilinear', + align_corners=True, recompute_scale_factor=True) + output_x = dict() + output_x['output'] = output_list + output_x['features'] = features_list + return output_x diff --git a/imaginaire/discriminators/funit.py b/imaginaire/discriminators/funit.py new file mode 100644 index 0000000000000000000000000000000000000000..8e74238bf723058dfb680989a8839ce54bf98e61 --- /dev/null +++ b/imaginaire/discriminators/funit.py @@ -0,0 +1,117 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import warnings + +import torch +from torch import nn + +from imaginaire.layers import Conv2dBlock, Res2dBlock + + +class Discriminator(nn.Module): + r"""Discriminator in the improved FUNIT baseline in the COCO-FUNIT paper. + + Args: + dis_cfg (obj): Discriminator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, dis_cfg, data_cfg): + super().__init__() + self.model = ResDiscriminator(**vars(dis_cfg)) + + def forward(self, data, net_G_output, recon=True): + r"""Improved FUNIT discriminator forward function. + + Args: + data (dict): Training data at the current iteration. + net_G_output (dict): Fake data generated at the current iteration. + recon (bool): If ``True``, also classifies reconstructed images. + """ + source_labels = data['labels_content'] + target_labels = data['labels_style'] + fake_out_trans, fake_features_trans = \ + self.model(net_G_output['images_trans'], target_labels) + output = dict(fake_out_trans=fake_out_trans, + fake_features_trans=fake_features_trans) + + real_out_style, real_features_style = \ + self.model(data['images_style'], target_labels) + output.update(dict(real_out_style=real_out_style, + real_features_style=real_features_style)) + if recon: + fake_out_recon, fake_features_recon = \ + self.model(net_G_output['images_recon'], source_labels) + output.update(dict(fake_out_recon=fake_out_recon, + fake_features_recon=fake_features_recon)) + return output + + +class ResDiscriminator(nn.Module): + r"""Residual discriminator architecture used in the FUNIT paper.""" + + def __init__(self, + image_channels=3, + num_classes=119, + num_filters=64, + max_num_filters=1024, + num_layers=6, + padding_mode='reflect', + weight_norm_type='', + **kwargs): + super().__init__() + for key in kwargs: + if key != 'type': + warnings.warn( + "Discriminator argument {} is not used".format(key)) + + conv_params = dict(padding_mode=padding_mode, + activation_norm_type='none', + weight_norm_type=weight_norm_type, + bias=[True, True, True], + nonlinearity='leakyrelu', + order='NACNAC') + + first_kernel_size = 7 + first_padding = (first_kernel_size - 1) // 2 + model = [Conv2dBlock(image_channels, num_filters, + first_kernel_size, 1, first_padding, + padding_mode=padding_mode, + weight_norm_type=weight_norm_type)] + for i in range(num_layers): + num_filters_prev = num_filters + num_filters = min(num_filters * 2, max_num_filters) + model += [Res2dBlock(num_filters_prev, num_filters_prev, + **conv_params), + Res2dBlock(num_filters_prev, num_filters, + **conv_params)] + if i != num_layers - 1: + model += [nn.ReflectionPad2d(1), + nn.AvgPool2d(3, stride=2)] + self.model = nn.Sequential(*model) + self.classifier = Conv2dBlock(num_filters, 1, 1, 1, 0, + nonlinearity='leakyrelu', + weight_norm_type=weight_norm_type, + order='NACNAC') + + self.embedder = nn.Embedding(num_classes, num_filters) + + def forward(self, images, labels=None): + r"""Forward function of the projection discriminator. + + Args: + images (image tensor): Images inputted to the discriminator. + labels (long int tensor): Class labels of the images. + """ + assert (images.size(0) == labels.size(0)) + features = self.model(images) + outputs = self.classifier(features) + features_1x1 = features.mean(3).mean(2) + if labels is None: + return features_1x1 + embeddings = self.embedder(labels) + outputs += torch.sum(embeddings * features_1x1, dim=1, + keepdim=True).view(images.size(0), 1, 1, 1) + return outputs, features_1x1 diff --git a/imaginaire/discriminators/gancraft.py b/imaginaire/discriminators/gancraft.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc070cb46ac5c6ae287231ddd0144bedd6d55a2 --- /dev/null +++ b/imaginaire/discriminators/gancraft.py @@ -0,0 +1,278 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import functools +from imaginaire.layers import Conv2dBlock + +from imaginaire.utils.data import get_paired_input_label_channel_number, get_paired_input_image_channel_number +from imaginaire.utils.distributed import master_only_print as print + + +class Discriminator(nn.Module): + r"""Multi-resolution patch discriminator. Based on FPSE discriminator but with N+1 labels. + + Args: + dis_cfg (obj): Discriminator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, dis_cfg, data_cfg): + super(Discriminator, self).__init__() + # We assume the first datum is the ground truth image. + image_channels = get_paired_input_image_channel_number(data_cfg) + # Calculate number of channels in the input label. + num_labels = get_paired_input_label_channel_number(data_cfg) + + self.use_label = getattr(dis_cfg, 'use_label', True) + # Override number of input channels + if hasattr(dis_cfg, 'image_channels'): + image_channels = dis_cfg.image_channels + if hasattr(dis_cfg, 'num_labels'): + num_labels = dis_cfg.num_labels + else: + # We assume the first datum is the ground truth image. + image_channels = get_paired_input_image_channel_number(data_cfg) + # Calculate number of channels in the input label. + num_labels = get_paired_input_label_channel_number(data_cfg) + + if not self.use_label: + num_labels = 2 # ignore + true + + # Build the discriminator. + num_filters = getattr(dis_cfg, 'num_filters', 128) + weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') + + fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3) + fpse_activation_norm_type = getattr(dis_cfg, + 'fpse_activation_norm_type', + 'none') + do_multiscale = getattr(dis_cfg, 'do_multiscale', False) + smooth_resample = getattr(dis_cfg, 'smooth_resample', False) + no_label_except_largest_scale = getattr(dis_cfg, 'no_label_except_largest_scale', False) + + self.fpse_discriminator = FPSEDiscriminator( + image_channels, + num_labels, + num_filters, + fpse_kernel_size, + weight_norm_type, + fpse_activation_norm_type, + do_multiscale, + smooth_resample, + no_label_except_largest_scale) + + def _single_forward(self, input_label, input_image, weights): + output_list, features_list = self.fpse_discriminator(input_image, input_label, weights) + return output_list, [features_list] + + def forward(self, data, net_G_output, weights=None, incl_real=False, incl_pseudo_real=False): + r"""GANcraft discriminator forward. + + Args: + data (dict): + - data (N x C1 x H x W tensor) : Ground truth images. + - label (N x C2 x H x W tensor) : Semantic representations. + - z (N x style_dims tensor): Gaussian random noise. + net_G_output (dict): + - fake_images (N x C1 x H x W tensor) : Fake images. + Returns: + output_x (dict): + - real_outputs (list): list of output tensors produced by + individual patch discriminators for real images. + - real_features (list): list of lists of features produced by + individual patch discriminators for real images. + - fake_outputs (list): list of output tensors produced by + individual patch discriminators for fake images. + - fake_features (list): list of lists of features produced by + individual patch discriminators for fake images. + """ + output_x = dict() + + # Fake. + fake_images = net_G_output['fake_images'] + if self.use_label: + fake_labels = data['fake_masks'] + else: + fake_labels = torch.zeros([fake_images.size(0), 2, fake_images.size( + 2), fake_images.size(3)], device=fake_images.device, dtype=fake_images.dtype) + fake_labels[:, 1, :, :] = 1 + output_x['fake_outputs'], output_x['fake_features'] = \ + self._single_forward(fake_labels, fake_images, None) + + # Real. + if incl_real: + real_images = data['images'] + if self.use_label: + real_labels = data['real_masks'] + else: + real_labels = torch.zeros([real_images.size(0), 2, real_images.size( + 2), real_images.size(3)], device=real_images.device, dtype=real_images.dtype) + real_labels[:, 1, :, :] = 1 + output_x['real_outputs'], output_x['real_features'] = \ + self._single_forward(real_labels, real_images, None) + + # pseudo-Real. + if incl_pseudo_real: + preal_images = data['pseudo_real_img'] + preal_labels = data['fake_masks'] + if not self.use_label: + preal_labels = torch.zeros([preal_images.size(0), 2, preal_images.size( + 2), preal_images.size(3)], device=preal_images.device, dtype=preal_images.dtype) + preal_labels[:, 1, :, :] = 1 + output_x['pseudo_real_outputs'], output_x['pseudo_real_features'] = \ + self._single_forward(preal_labels, preal_images, None) + + return output_x + + +class FPSEDiscriminator(nn.Module): + def __init__(self, + num_input_channels, + num_labels, + num_filters, + kernel_size, + weight_norm_type, + activation_norm_type, + do_multiscale, + smooth_resample, + no_label_except_largest_scale): + super().__init__() + + self.do_multiscale = do_multiscale + self.no_label_except_largest_scale = no_label_except_largest_scale + + padding = int(np.ceil((kernel_size - 1.0) / 2)) + nonlinearity = 'leakyrelu' + stride1_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=1, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + down_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=2, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + latent_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=1, + stride=1, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + # bottom-up pathway + self.enc1 = down_conv2d_block(num_input_channels, num_filters) # 3 + self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) # 7 + self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) # 15 + self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) # 31 + self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) # 63 + + # top-down pathway + # self.lat1 = latent_conv2d_block(num_filters, 2 * num_filters) # Zekun + self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters) + self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters) + self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters) + self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters) + + # upsampling + self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear', + align_corners=False) + + # final layers + self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) + self.output = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1) + + if self.do_multiscale: + self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) + self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) + if self.no_label_except_largest_scale: + self.output3 = Conv2dBlock(num_filters * 2, 2, kernel_size=1) + self.output4 = Conv2dBlock(num_filters * 2, 2, kernel_size=1) + else: + self.output3 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1) + self.output4 = Conv2dBlock(num_filters * 2, num_labels+1, kernel_size=1) + + self.interpolator = functools.partial(F.interpolate, mode='nearest') + if smooth_resample: + self.interpolator = self.smooth_interp + + @staticmethod + def smooth_interp(x, size): + r"""Smooth interpolation of segmentation maps. + + Args: + x (4D tensor): Segmentation maps. + size(2D list): Target size (H, W). + """ + x = F.interpolate(x, size=size, mode='area') + onehot_idx = torch.argmax(x, dim=-3, keepdims=True) + x.fill_(0.0) + x.scatter_(1, onehot_idx, 1.0) + return x + + # Weights: [N C] + def forward(self, images, segmaps, weights=None): + # Assume images 256x256 + # bottom-up pathway + feat11 = self.enc1(images) # 128 + feat12 = self.enc2(feat11) # 64 + feat13 = self.enc3(feat12) # 32 + feat14 = self.enc4(feat13) # 16 + feat15 = self.enc5(feat14) # 8 + # top-down pathway and lateral connections + feat25 = self.lat5(feat15) # 8 + feat24 = self.upsample2x(feat25) + self.lat4(feat14) # 16 + feat23 = self.upsample2x(feat24) + self.lat3(feat13) # 32 + feat22 = self.upsample2x(feat23) + self.lat2(feat12) # 64 + + # final prediction layers + feat32 = self.final2(feat22) + + results = [] + label_map = self.interpolator(segmaps, size=feat32.size()[2:]) + pred2 = self.output(feat32) # N, num_labels+1, H//4, W//4 + + features = [feat11, feat12, feat13, feat14, feat15, feat25, feat24, feat23, feat22] + if weights is not None: + label_map = label_map * weights[..., None, None] + results.append({'pred': pred2, 'label': label_map}) + + if self.do_multiscale: + feat33 = self.final3(feat23) + pred3 = self.output3(feat33) + + feat34 = self.final4(feat24) + pred4 = self.output4(feat34) + + if self.no_label_except_largest_scale: + label_map3 = torch.ones([pred3.size(0), 1, pred3.size(2), pred3.size(3)], device=pred3.device) + label_map4 = torch.ones([pred4.size(0), 1, pred4.size(2), pred4.size(3)], device=pred4.device) + else: + label_map3 = self.interpolator(segmaps, size=pred3.size()[2:]) + label_map4 = self.interpolator(segmaps, size=pred4.size()[2:]) + + if weights is not None: + label_map3 = label_map3 * weights[..., None, None] + label_map4 = label_map4 * weights[..., None, None] + + results.append({'pred': pred3, 'label': label_map3}) + results.append({'pred': pred4, 'label': label_map4}) + + return results, features diff --git a/imaginaire/discriminators/mlp_multiclass.py b/imaginaire/discriminators/mlp_multiclass.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7d1d27e2d21a7b4fd5f23646b545a44a47a783 --- /dev/null +++ b/imaginaire/discriminators/mlp_multiclass.py @@ -0,0 +1,63 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools + +import numpy as np +import torch.nn as nn + +from imaginaire.layers import LinearBlock + + +class Discriminator(nn.Module): + r"""Multi-layer Perceptron Classifier constructor. + + Args: + dis_cfg (obj): Discriminator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file + """ + + def __init__(self, dis_cfg, data_cfg): + super(Discriminator, self).__init__() + num_input_channels = dis_cfg.input_dims + num_labels = dis_cfg.num_labels + num_layers = getattr(dis_cfg, 'num_layers', 5) + num_filters = getattr(dis_cfg, 'num_filters', 512) + activation_norm_type = getattr(dis_cfg, + 'activation_norm_type', + 'batch_norm') + nonlinearity = getattr(dis_cfg, 'nonlinearity', 'leakyrelu') + base_linear_block = \ + functools.partial(LinearBlock, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + order='CNA') + dropout_ratio = 0.1 + layers = [base_linear_block(num_input_channels, num_filters), + nn.Dropout(dropout_ratio)] + for n in range(num_layers): + dropout_ratio *= 1.5 + dropout_ratio = np.min([dropout_ratio, 0.5]) + layers += [base_linear_block(num_filters, num_filters), + nn.Dropout(dropout_ratio)] + layers += [LinearBlock(num_filters, num_labels)] + self.model = nn.Sequential(*layers) + + def forward(self, data): + r"""Patch Discriminator forward. + + Args: + data (dict): + - data (N x -1 tensor): We will reshape the tensor to this format. + Returns: + (dict): + - results (N x C tensor): Output scores before softmax. + """ + input_x = data['data'] + bs = input_x.size()[0] + input_x = input_x.view(bs, -1) + pre_softmax_scores = self.model(input_x) + outputs = dict() + outputs['results'] = pre_softmax_scores + return outputs diff --git a/imaginaire/discriminators/multires_patch.py b/imaginaire/discriminators/multires_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..7a31d25c689a3d5b26a3b92216715d3ebd62dc91 --- /dev/null +++ b/imaginaire/discriminators/multires_patch.py @@ -0,0 +1,313 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# Copyright (C) 2020 NVIDIA Corporation. All rights reserved +import functools +import warnings + +import numpy as np +import torch +import torch.nn as nn + +from imaginaire.layers import Conv2dBlock +from imaginaire.utils.data import (get_paired_input_image_channel_number, + get_paired_input_label_channel_number) +from imaginaire.utils.distributed import master_only_print as print + + +class Discriminator(nn.Module): + r"""Multi-resolution patch discriminator. + + Args: + dis_cfg (obj): Discriminator definition part of the yaml config + file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, dis_cfg, data_cfg): + super(Discriminator, self).__init__() + print('Multi-resolution patch discriminator initialization.') + # We assume the first datum is the ground truth image. + image_channels = get_paired_input_image_channel_number(data_cfg) + # Calculate number of channels in the input label. + num_labels = get_paired_input_label_channel_number(data_cfg) + + # Build the discriminator. + kernel_size = getattr(dis_cfg, 'kernel_size', 3) + num_filters = getattr(dis_cfg, 'num_filters', 128) + max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) + num_discriminators = getattr(dis_cfg, 'num_discriminators', 2) + num_layers = getattr(dis_cfg, 'num_layers', 5) + activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') + weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') + print('\tBase filter number: %d' % num_filters) + print('\tNumber of discriminators: %d' % num_discriminators) + print('\tNumber of layers in a discriminator: %d' % num_layers) + print('\tWeight norm type: %s' % weight_norm_type) + num_input_channels = image_channels + num_labels + self.model = MultiResPatchDiscriminator(num_discriminators, + kernel_size, + num_input_channels, + num_filters, + num_layers, + max_num_filters, + activation_norm_type, + weight_norm_type) + print('Done with the Multi-resolution patch ' + 'discriminator initialization.') + + def forward(self, data, net_G_output, real=True): + r"""SPADE Generator forward. + + Args: + data (dict): + - data (N x C1 x H x W tensor) : Ground truth images. + - label (N x C2 x H x W tensor) : Semantic representations. + - z (N x style_dims tensor): Gaussian random noise. + net_G_output (dict): + fake_images (N x C1 x H x W tensor) : Fake images. + real (bool): If ``True``, also classifies real images. Otherwise it + only classifies generated images to save computation during the + generator update. + Returns: + (tuple): + - real_outputs (list): list of output tensors produced by + - individual patch discriminators for real images. + - real_features (list): list of lists of features produced by + individual patch discriminators for real images. + - fake_outputs (list): list of output tensors produced by + individual patch discriminators for fake images. + - fake_features (list): list of lists of features produced by + individual patch discriminators for fake images. + """ + output_x = dict() + if 'label' in data: + fake_input_x = torch.cat( + (data['label'], net_G_output['fake_images']), 1) + else: + fake_input_x = net_G_output['fake_images'] + output_x['fake_outputs'], output_x['fake_features'], _ = \ + self.model.forward(fake_input_x) + if real: + if 'label' in data: + real_input_x = torch.cat( + (data['label'], data['images']), 1) + else: + real_input_x = data['images'] + output_x['real_outputs'], output_x['real_features'], _ = \ + self.model.forward(real_input_x) + return output_x + + +class MultiResPatchDiscriminator(nn.Module): + r"""Multi-resolution patch discriminator. + + Args: + num_discriminators (int): Num. of discriminators (one per scale). + kernel_size (int): Convolution kernel size. + num_image_channels (int): Num. of channels in the real/fake image. + num_filters (int): Num. of base filters in a layer. + num_layers (int): Num. of layers for the patch discriminator. + max_num_filters (int): Maximum num. of filters in a layer. + activation_norm_type (str): batch_norm/instance_norm/none/.... + weight_norm_type (str): none/spectral_norm/weight_norm + """ + + def __init__(self, + num_discriminators=3, + kernel_size=3, + num_image_channels=3, + num_filters=64, + num_layers=4, + max_num_filters=512, + activation_norm_type='', + weight_norm_type='', + **kwargs): + super().__init__() + for key in kwargs: + if key != 'type' and key != 'patch_wise': + warnings.warn( + "Discriminator argument {} is not used".format(key)) + + self.discriminators = nn.ModuleList() + for i in range(num_discriminators): + net_discriminator = NLayerPatchDiscriminator( + kernel_size, + num_image_channels, + num_filters, + num_layers, + max_num_filters, + activation_norm_type, + weight_norm_type) + self.discriminators.append(net_discriminator) + print('Done with the Multi-resolution patch ' + 'discriminator initialization.') + + def forward(self, input_x): + r"""Multi-resolution patch discriminator forward. + + Args: + input_x (tensor) : Input images. + Returns: + (tuple): + - output_list (list): list of output tensors produced by + individual patch discriminators. + - features_list (list): list of lists of features produced by + individual patch discriminators. + - input_list (list): list of downsampled input images. + """ + input_list = [] + output_list = [] + features_list = [] + input_downsampled = input_x + for net_discriminator in self.discriminators: + input_list.append(input_downsampled) + output, features = net_discriminator(input_downsampled) + output_list.append(output) + features_list.append(features) + input_downsampled = nn.functional.interpolate( + input_downsampled, scale_factor=0.5, mode='bilinear', + align_corners=True, recompute_scale_factor=True) + return output_list, features_list, input_list + + +class WeightSharedMultiResPatchDiscriminator(nn.Module): + r"""Multi-resolution patch discriminator with shared weights. + + Args: + num_discriminators (int): Num. of discriminators (one per scale). + kernel_size (int): Convolution kernel size. + num_image_channels (int): Num. of channels in the real/fake image. + num_filters (int): Num. of base filters in a layer. + num_layers (int): Num. of layers for the patch discriminator. + max_num_filters (int): Maximum num. of filters in a layer. + activation_norm_type (str): batch_norm/instance_norm/none/.... + weight_norm_type (str): none/spectral_norm/weight_norm + """ + + def __init__(self, + num_discriminators=3, + kernel_size=3, + num_image_channels=3, + num_filters=64, + num_layers=4, + max_num_filters=512, + activation_norm_type='', + weight_norm_type='', + **kwargs): + super().__init__() + for key in kwargs: + if key != 'type' and key != 'patch_wise': + warnings.warn( + "Discriminator argument {} is not used".format(key)) + self.num_discriminators = num_discriminators + self.discriminator = NLayerPatchDiscriminator( + kernel_size, + num_image_channels, + num_filters, + num_layers, + max_num_filters, + activation_norm_type, + weight_norm_type) + print('Done with the Weight-Shared Multi-resolution patch ' + 'discriminator initialization.') + + def forward(self, input_x): + r"""Multi-resolution patch discriminator forward. + + Args: + input_x (tensor) : Input images. + Returns: + (tuple): + - output_list (list): list of output tensors produced by + individual patch discriminators. + - features_list (list): list of lists of features produced by + individual patch discriminators. + - input_list (list): list of downsampled input images. + """ + input_list = [] + output_list = [] + features_list = [] + input_downsampled = input_x + for i in range(self.num_discriminators): + input_list.append(input_downsampled) + output, features = self.discriminator(input_downsampled) + output_list.append(output) + features_list.append(features) + input_downsampled = nn.functional.interpolate( + input_downsampled, scale_factor=0.5, mode='bilinear', + align_corners=True) + return output_list, features_list, input_list + + +class NLayerPatchDiscriminator(nn.Module): + r"""Patch Discriminator constructor. + + Args: + kernel_size (int): Convolution kernel size. + num_input_channels (int): Num. of channels in the real/fake image. + num_filters (int): Num. of base filters in a layer. + num_layers (int): Num. of layers for the patch discriminator. + max_num_filters (int): Maximum num. of filters in a layer. + activation_norm_type (str): batch_norm/instance_norm/none/.... + weight_norm_type (str): none/spectral_norm/weight_norm + """ + + def __init__(self, + kernel_size, + num_input_channels, + num_filters, + num_layers, + max_num_filters, + activation_norm_type, + weight_norm_type): + super(NLayerPatchDiscriminator, self).__init__() + self.num_layers = num_layers + padding = int(np.floor((kernel_size - 1.0) / 2)) + nonlinearity = 'leakyrelu' + base_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + layers = [[base_conv2d_block( + num_input_channels, num_filters, stride=2)]] + for n in range(num_layers): + num_filters_prev = num_filters + num_filters = min(num_filters * 2, max_num_filters) + stride = 2 if n < (num_layers - 1) else 1 + layers += [[base_conv2d_block(num_filters_prev, num_filters, + stride=stride)]] + layers += [[Conv2dBlock(num_filters, 1, + 3, 1, + padding, + weight_norm_type=weight_norm_type)]] + for n in range(len(layers)): + setattr(self, 'layer' + str(n), nn.Sequential(*layers[n])) + + def forward(self, input_x): + r"""Patch Discriminator forward. + + Args: + input_x (N x C x H1 x W2 tensor): Concatenation of images and + semantic representations. + Returns: + (tuple): + - output (N x 1 x H2 x W2 tensor): Discriminator output value. + Before the sigmoid when using NSGAN. + - features (list): lists of tensors of the intermediate + activations. + """ + res = [input_x] + for n in range(self.num_layers + 2): + layer = getattr(self, 'layer' + str(n)) + x = res[-1] + res.append(layer(x)) + output = res[-1] + features = res[1:-1] + return output, features diff --git a/imaginaire/discriminators/multires_patch_pano.py b/imaginaire/discriminators/multires_patch_pano.py new file mode 100644 index 0000000000000000000000000000000000000000..97763da3bbfc990be1755bbb62383869fd6708da --- /dev/null +++ b/imaginaire/discriminators/multires_patch_pano.py @@ -0,0 +1,247 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# Copyright (C) 2020 NVIDIA Corporation. All rights reserved +import functools +import warnings + +import numpy as np +import torch +import torch.nn as nn + +from imaginaire.layers import Conv2dBlock +from imaginaire.utils.data import (get_paired_input_image_channel_number, + get_paired_input_label_channel_number) +from imaginaire.utils.distributed import master_only_print as print +from model.sample import Equirectangular + +class Discriminator(nn.Module): + r"""Multi-resolution patch discriminator. + + Args: + dis_cfg (obj): Discriminator definition part of the yaml config + file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, dis_cfg): + super(Discriminator, self).__init__() + print('Multi-resolution patch discriminator initialization.') + # We assume the first datum is the ground truth image. + num_input_channels = getattr(dis_cfg, 'input_channels', 3) + # Calculate number of channels in the input label. + + # Build the discriminator. + kernel_size = getattr(dis_cfg, 'kernel_size', 3) + num_filters = getattr(dis_cfg, 'num_filters', 128) + max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) + num_discriminators = getattr(dis_cfg, 'num_discriminators', 2) + num_layers = getattr(dis_cfg, 'num_layers', 5) + activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') + weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') + print('\tBase filter number: %d' % num_filters) + print('\tNumber of discriminators: %d' % num_discriminators) + print('\tNumber of layers in a discriminator: %d' % num_layers) + print('\tWeight norm type: %s' % weight_norm_type) + self.condition = getattr(dis_cfg, 'condition', None) + # self.condition = dis_cfg.condition + self.model = MultiResPatchDiscriminator(num_discriminators, + kernel_size, + num_input_channels, + num_filters, + num_layers, + max_num_filters, + activation_norm_type, + weight_norm_type) + print('Done with the Multi-resolution patch ' + 'discriminator initialization.') + + def forward(self, data, net_G_output, real=True): + r"""SPADE Generator forward. + + Args: + data (N x C1 x H x W tensor) : Ground truth images. + net_G_output (dict): + fake_images (N x C1 x H x W tensor) : Fake images. + real (bool): If ``True``, also classifies real images. Otherwise it + only classifies generated images to save computation during the + generator update. + Returns: + (tuple): + - real_outputs (list): list of output tensors produced by + - individual patch discriminators for real images. + - real_features (list): list of lists of features produced by + individual patch discriminators for real images. + - fake_outputs (list): list of output tensors produced by + individual patch discriminators for fake images. + - fake_features (list): list of lists of features produced by + individual patch discriminators for fake images. + """ + output_x = dict() + if self.condition: + fake_input_x = torch.cat([net_G_output['pred'],net_G_output['generator_inputs']],dim=1) + else: + fake_input_x = net_G_output['pred'] + output_x['fake_outputs'], output_x['fake_features'], _ = \ + self.model.forward(fake_input_x) + if real: + if self.condition: + real_input_x = torch.cat([net_G_output['pred'],net_G_output['generator_inputs']],dim=1) + else: + real_input_x = data + output_x['real_outputs'], output_x['real_features'], _ = \ + self.model.forward(real_input_x) + return output_x + + +class MultiResPatchDiscriminator(nn.Module): + r"""Multi-resolution patch discriminator. + + Args: + num_discriminators (int): Num. of discriminators (one per scale). + kernel_size (int): Convolution kernel size. + num_image_channels (int): Num. of channels in the real/fake image. + num_filters (int): Num. of base filters in a layer. + num_layers (int): Num. of layers for the patch discriminator. + max_num_filters (int): Maximum num. of filters in a layer. + activation_norm_type (str): batch_norm/instance_norm/none/.... + weight_norm_type (str): none/spectral_norm/weight_norm + """ + + def __init__(self, + num_discriminators=3, + kernel_size=3, + num_image_channels=3, + num_filters=64, + num_layers=4, + max_num_filters=512, + activation_norm_type='', + weight_norm_type='', + **kwargs): + super().__init__() + for key in kwargs: + if key != 'type' and key != 'patch_wise': + warnings.warn( + "Discriminator argument {} is not used".format(key)) + + self.discriminators = nn.ModuleList() + for i in range(num_discriminators): + net_discriminator = NLayerPatchDiscriminator( + kernel_size, + num_image_channels, + num_filters, + num_layers, + max_num_filters, + activation_norm_type, + weight_norm_type) + self.discriminators.append(net_discriminator) + print('Done with the Multi-resolution patch ' + 'discriminator initialization.') + self.e = Equirectangular(theta=[-40., 40.],width = 128, height = 128,FovX = 100) + + def forward(self, input_x): + r"""Multi-resolution patch discriminator forward. + + Args: + input_x (tensor) : Input images. + Returns: + (tuple): + - output_list (list): list of output tensors produced by + individual patch discriminators. + - features_list (list): list of lists of features produced by + individual patch discriminators. + - input_list (list): list of downsampled input images. + """ + input_list = [] + output_list = [] + features_list = [] + input_N = nn.functional.interpolate( + input_x, scale_factor=0.5, mode='bilinear', + align_corners=True, recompute_scale_factor=True) + equ= self.e(input_x) + for i, net_discriminator in enumerate(self.discriminators): + input_list.append(input_N) + output, features = net_discriminator(input_N) + output_list.append(output) + features_list.append(features) + if i == 0: + input_N = torch.nn.functional.grid_sample(input_x, equ.float(), align_corners = True)*0.99 + elif i == 1: + input_N = nn.functional.interpolate( + input_N, scale_factor=0.5, mode='bilinear', + align_corners=True, recompute_scale_factor=True) + + return output_list, features_list, input_list + +class NLayerPatchDiscriminator(nn.Module): + r"""Patch Discriminator constructor. + + Args: + kernel_size (int): Convolution kernel size. + num_input_channels (int): Num. of channels in the real/fake image. + num_filters (int): Num. of base filters in a layer. + num_layers (int): Num. of layers for the patch discriminator. + max_num_filters (int): Maximum num. of filters in a layer. + activation_norm_type (str): batch_norm/instance_norm/none/.... + weight_norm_type (str): none/spectral_norm/weight_norm + """ + + def __init__(self, + kernel_size, + num_input_channels, + num_filters, + num_layers, + max_num_filters, + activation_norm_type, + weight_norm_type): + super(NLayerPatchDiscriminator, self).__init__() + self.num_layers = num_layers + padding = int(np.floor((kernel_size - 1.0) / 2)) + nonlinearity = 'leakyrelu' + base_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity, + # inplace_nonlinearity=True, + order='CNA') + layers = [[base_conv2d_block( + num_input_channels, num_filters, stride=2)]] + for n in range(num_layers): + num_filters_prev = num_filters + num_filters = min(num_filters * 2, max_num_filters) + stride = 2 if n < (num_layers - 1) else 1 + layers += [[base_conv2d_block(num_filters_prev, num_filters, + stride=stride)]] + layers += [[Conv2dBlock(num_filters, 1, + 3, 1, + padding, + weight_norm_type=weight_norm_type)]] + for n in range(len(layers)): + setattr(self, 'layer' + str(n), nn.Sequential(*layers[n])) + + + def forward(self, input_x): + r"""Patch Discriminator forward. + + Args: + input_x (N x C x H1 x W2 tensor): Concatenation of images and + semantic representations. + Returns: + (tuple): + - output (N x 1 x H2 x W2 tensor): Discriminator output value. + Before the sigmoid when using NSGAN. + - features (list): lists of tensors of the intermediate + activations. + """ + res = [input_x] + for n in range(self.num_layers + 2): + layer = getattr(self, 'layer' + str(n)) + x = res[-1] + res.append(layer(x)) + output = res[-1] + features = res[1:-1] + return output, features diff --git a/imaginaire/discriminators/munit.py b/imaginaire/discriminators/munit.py new file mode 100644 index 0000000000000000000000000000000000000000..4e407569764dced0962e66c56a1a0b2e9106c683 --- /dev/null +++ b/imaginaire/discriminators/munit.py @@ -0,0 +1,99 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from torch import nn + +from imaginaire.discriminators.multires_patch import MultiResPatchDiscriminator +from imaginaire.discriminators.residual import ResDiscriminator + + +class Discriminator(nn.Module): + r"""MUNIT discriminator. It can be either a multi-resolution patch + discriminator like in the original implementation, or a + global residual discriminator. + + Args: + dis_cfg (obj): Discriminator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file + """ + + def __init__(self, dis_cfg, data_cfg): + super().__init__() + if getattr(dis_cfg, 'patch_wise', True): + # Use the multi-resolution patch discriminator. It works better for + # scene images and when you want to preserve pixel-wise + # correspondence during translation. + self.discriminator_a = \ + MultiResPatchDiscriminator(**vars(dis_cfg)) + self.discriminator_b = \ + MultiResPatchDiscriminator(**vars(dis_cfg)) + else: + # Use the global residual discriminator. It works better if images + # have a single centered object (e.g., animal faces, shoes). + self.discriminator_a = ResDiscriminator(**vars(dis_cfg)) + self.discriminator_b = ResDiscriminator(**vars(dis_cfg)) + + def forward(self, data, net_G_output, gan_recon=False, real=True): + r"""Returns the output of the discriminator. + + Args: + data (dict): + - images_a (tensor) : Images in domain A. + - images_b (tensor) : Images in domain B. + net_G_output (dict): + - images_ab (tensor) : Images translated from domain A to B by + the generator. + - images_ba (tensor) : Images translated from domain B to A by + the generator. + - images_aa (tensor) : Reconstructed images in domain A. + - images_bb (tensor) : Reconstructed images in domain B. + gan_recon (bool): If ``True``, also classifies reconstructed images. + real (bool): If ``True``, also classifies real images. Otherwise it + only classifies generated images to save computation during the + generator update. + + Returns: + (dict): + - out_ab (tensor): Output of the discriminator for images + translated from domain A to B by the generator. + - out_ab (tensor): Output of the discriminator for images + translated from domain B to A by the generator. + - fea_ab (tensor): Intermediate features of the discriminator + for images translated from domain B to A by the generator. + - fea_ba (tensor): Intermediate features of the discriminator + for images translated from domain A to B by the generator. + + - out_a (tensor): Output of the discriminator for images + in domain A. + - out_b (tensor): Output of the discriminator for images + in domain B. + - fea_a (tensor): Intermediate features of the discriminator + for images in domain A. + - fea_b (tensor): Intermediate features of the discriminator + for images in domain B. + + - out_aa (tensor): Output of the discriminator for + reconstructed images in domain A. + - out_bb (tensor): Output of the discriminator for + reconstructed images in domain B. + - fea_aa (tensor): Intermediate features of the discriminator + for reconstructed images in domain A. + - fea_bb (tensor): Intermediate features of the discriminator + for reconstructed images in domain B. + """ + out_ab, fea_ab, _ = self.discriminator_b(net_G_output['images_ab']) + out_ba, fea_ba, _ = self.discriminator_a(net_G_output['images_ba']) + output = dict(out_ba=out_ba, out_ab=out_ab, + fea_ba=fea_ba, fea_ab=fea_ab) + if real: + out_a, fea_a, _ = self.discriminator_a(data['images_a']) + out_b, fea_b, _ = self.discriminator_b(data['images_b']) + output.update(dict(out_a=out_a, out_b=out_b, + fea_a=fea_a, fea_b=fea_b)) + if gan_recon: + out_aa, fea_aa, _ = self.discriminator_a(net_G_output['images_aa']) + out_bb, fea_bb, _ = self.discriminator_b(net_G_output['images_bb']) + output.update(dict(out_aa=out_aa, out_bb=out_bb, + fea_aa=fea_aa, fea_bb=fea_bb)) + return output diff --git a/imaginaire/discriminators/residual.py b/imaginaire/discriminators/residual.py new file mode 100644 index 0000000000000000000000000000000000000000..f65b41df96f8ea25e82d18adf33cabef746f863c --- /dev/null +++ b/imaginaire/discriminators/residual.py @@ -0,0 +1,96 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import warnings + +import torch +import torch.nn as nn + +from imaginaire.layers import Conv2dBlock, Res2dBlock +from imaginaire.third_party.upfirdn2d import BlurDownsample + + +class ResDiscriminator(nn.Module): + r"""Global residual discriminator. + + Args: + image_channels (int): Num. of channels in the real/fake image. + num_filters (int): Num. of base filters in a layer. + max_num_filters (int): Maximum num. of filters in a layer. + first_kernel_size (int): Kernel size in the first layer. + num_layers (int): Num. of layers in discriminator. + padding_mode (str): Padding mode. + activation_norm_type (str): Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + aggregation (str): Method to aggregate features across different + locations in the final layer. ``'conv'``, or ``'pool'``. + order (str): Order of operations in the residual link. + anti_aliased (bool): If ``True``, uses anti-aliased pooling. + """ + + def __init__(self, + image_channels=3, + num_filters=64, + max_num_filters=512, + first_kernel_size=1, + num_layers=4, + padding_mode='zeros', + activation_norm_type='', + weight_norm_type='', + aggregation='conv', + order='pre_act', + anti_aliased=False, + **kwargs): + super().__init__() + for key in kwargs: + if key != 'type' and key != 'patch_wise': + warnings.warn( + "Discriminator argument {} is not used".format(key)) + + conv_params = dict(padding_mode=padding_mode, + activation_norm_type=activation_norm_type, + weight_norm_type=weight_norm_type, + nonlinearity='leakyrelu') + + first_padding = (first_kernel_size - 1) // 2 + model = [Conv2dBlock(image_channels, num_filters, + first_kernel_size, 1, first_padding, + **conv_params)] + for _ in range(num_layers): + num_filters_prev = num_filters + num_filters = min(num_filters * 2, max_num_filters) + model.append(Res2dBlock(num_filters_prev, num_filters, order=order, + **conv_params)) + if anti_aliased: + model.append(BlurDownsample()) + else: + model.append(nn.AvgPool2d(2, stride=2)) + if aggregation == 'pool': + model += [torch.nn.AdaptiveAvgPool2d(1)] + elif aggregation == 'conv': + model += [Conv2dBlock(num_filters, num_filters, 4, 1, 0, + nonlinearity='leakyrelu')] + else: + raise ValueError('The aggregation mode is not recognized' + % self.aggregation) + self.model = nn.Sequential(*model) + self.classifier = nn.Linear(num_filters, 1) + + def forward(self, images): + r"""Multi-resolution patch discriminator forward. + + Args: + images (tensor) : Input images. + Returns: + (tuple): + - outputs (tensor): Output of the discriminator. + - features (tensor): Intermediate features of the discriminator. + - images (tensor): Input images. + """ + batch_size = images.size(0) + features = self.model(images) + outputs = self.classifier(features.view(batch_size, -1)) + return outputs, features, images diff --git a/imaginaire/discriminators/spade.py b/imaginaire/discriminators/spade.py new file mode 100644 index 0000000000000000000000000000000000000000..d85d1c5926ce92e948501c7afaf1d1324e1ea38c --- /dev/null +++ b/imaginaire/discriminators/spade.py @@ -0,0 +1,119 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn as nn + +from imaginaire.discriminators.fpse import FPSEDiscriminator +from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator +from imaginaire.utils.data import (get_paired_input_image_channel_number, + get_paired_input_label_channel_number) +from imaginaire.utils.distributed import master_only_print as print + + +class Discriminator(nn.Module): + r"""Multi-resolution patch discriminator. + + Args: + dis_cfg (obj): Discriminator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, dis_cfg, data_cfg): + super(Discriminator, self).__init__() + print('Multi-resolution patch discriminator initialization.') + image_channels = getattr(dis_cfg, 'image_channels', None) + if image_channels is None: + image_channels = get_paired_input_image_channel_number(data_cfg) + num_labels = getattr(dis_cfg, 'num_labels', None) + if num_labels is None: + # Calculate number of channels in the input label when not specified. + num_labels = get_paired_input_label_channel_number(data_cfg) + + # Build the discriminator. + kernel_size = getattr(dis_cfg, 'kernel_size', 3) + num_filters = getattr(dis_cfg, 'num_filters', 128) + max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) + num_discriminators = getattr(dis_cfg, 'num_discriminators', 2) + num_layers = getattr(dis_cfg, 'num_layers', 5) + activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') + weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') + print('\tBase filter number: %d' % num_filters) + print('\tNumber of discriminators: %d' % num_discriminators) + print('\tNumber of layers in a discriminator: %d' % num_layers) + print('\tWeight norm type: %s' % weight_norm_type) + num_input_channels = image_channels + num_labels + self.discriminators = nn.ModuleList() + for i in range(num_discriminators): + net_discriminator = NLayerPatchDiscriminator( + kernel_size, + num_input_channels, + num_filters, + num_layers, + max_num_filters, + activation_norm_type, + weight_norm_type) + self.discriminators.append(net_discriminator) + print('Done with the Multi-resolution patch discriminator initialization.') + self.use_fpse = getattr(dis_cfg, 'use_fpse', True) + if self.use_fpse: + fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3) + fpse_activation_norm_type = getattr(dis_cfg, + 'fpse_activation_norm_type', + 'none') + self.fpse_discriminator = FPSEDiscriminator( + image_channels, + num_labels, + num_filters, + fpse_kernel_size, + weight_norm_type, + fpse_activation_norm_type) + + def _single_forward(self, input_label, input_image): + # Compute discriminator outputs and intermediate features from input + # images and semantic labels. + input_x = torch.cat( + (input_label, input_image), 1) + output_list = [] + features_list = [] + if self.use_fpse: + pred2, pred3, pred4 = self.fpse_discriminator(input_image, input_label) + output_list = [pred2, pred3, pred4] + input_downsampled = input_x + for net_discriminator in self.discriminators: + output, features = net_discriminator(input_downsampled) + output_list.append(output) + features_list.append(features) + input_downsampled = nn.functional.interpolate( + input_downsampled, scale_factor=0.5, mode='bilinear', + align_corners=True) + return output_list, features_list + + def forward(self, data, net_G_output): + r"""SPADE discriminator forward. + + Args: + data (dict): + - data (N x C1 x H x W tensor) : Ground truth images. + - label (N x C2 x H x W tensor) : Semantic representations. + - z (N x style_dims tensor): Gaussian random noise. + net_G_output (dict): + fake_images (N x C1 x H x W tensor) : Fake images. + Returns: + (dict): + - real_outputs (list): list of output tensors produced by + individual patch discriminators for real images. + - real_features (list): list of lists of features produced by + individual patch discriminators for real images. + - fake_outputs (list): list of output tensors produced by + individual patch discriminators for fake images. + - fake_features (list): list of lists of features produced by + individual patch discriminators for fake images. + """ + output_x = dict() + output_x['real_outputs'], output_x['real_features'] = \ + self._single_forward(data['label'], data['images']) + output_x['fake_outputs'], output_x['fake_features'] = \ + self._single_forward(data['label'], net_G_output['fake_images']) + return output_x diff --git a/imaginaire/discriminators/unit.py b/imaginaire/discriminators/unit.py new file mode 100644 index 0000000000000000000000000000000000000000..fb537fe47ea17051a012f7e647c5ac6b6ea1b7d9 --- /dev/null +++ b/imaginaire/discriminators/unit.py @@ -0,0 +1,99 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from torch import nn + +from imaginaire.discriminators.multires_patch import \ + WeightSharedMultiResPatchDiscriminator +from imaginaire.discriminators.residual import ResDiscriminator + + +class Discriminator(nn.Module): + r"""UNIT discriminator. It can be either a multi-resolution patch + discriminator like in the original implementation, or a + global residual discriminator. + + Args: + dis_cfg (obj): Discriminator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file + """ + + def __init__(self, dis_cfg, data_cfg): + super().__init__() + if getattr(dis_cfg, 'patch_dis', True): + # Use the multi-resolution patch discriminator. It works better for + # scene images and when you want to preserve pixel-wise + # correspondence during translation. + self.discriminator_a = \ + WeightSharedMultiResPatchDiscriminator(**vars(dis_cfg)) + self.discriminator_b = \ + WeightSharedMultiResPatchDiscriminator(**vars(dis_cfg)) + else: + # Use the global residual discriminator. It works better if images + # have a single centered object (e.g., animal faces, shoes). + self.discriminator_a = ResDiscriminator(**vars(dis_cfg)) + self.discriminator_b = ResDiscriminator(**vars(dis_cfg)) + + def forward(self, data, net_G_output, gan_recon=False, real=True): + r"""Returns the output of the discriminator. + + Args: + data (dict): + - images_a (tensor) : Images in domain A. + - images_b (tensor) : Images in domain B. + net_G_output (dict): + - images_ab (tensor) : Images translated from domain A to B by + the generator. + - images_ba (tensor) : Images translated from domain B to A by + the generator. + - images_aa (tensor) : Reconstructed images in domain A. + - images_bb (tensor) : Reconstructed images in domain B. + gan_recon (bool): If ``True``, also classifies reconstructed images. + real (bool): If ``True``, also classifies real images. Otherwise it + only classifies generated images to save computation during the + generator update. + Returns: + (dict): + - out_ab (tensor): Output of the discriminator for images + translated from domain A to B by the generator. + - out_ab (tensor): Output of the discriminator for images + translated from domain B to A by the generator. + - fea_ab (tensor): Intermediate features of the discriminator + for images translated from domain B to A by the generator. + - fea_ba (tensor): Intermediate features of the discriminator + for images translated from domain A to B by the generator. + + - out_a (tensor): Output of the discriminator for images + in domain A. + - out_b (tensor): Output of the discriminator for images + in domain B. + - fea_a (tensor): Intermediate features of the discriminator + for images in domain A. + - fea_b (tensor): Intermediate features of the discriminator + for images in domain B. + + - out_aa (tensor): Output of the discriminator for + reconstructed images in domain A. + - out_bb (tensor): Output of the discriminator for + reconstructed images in domain B. + - fea_aa (tensor): Intermediate features of the discriminator + for reconstructed images in domain A. + - fea_bb (tensor): Intermediate features of the discriminator + for reconstructed images in domain B. + """ + out_ab, fea_ab, _ = self.discriminator_b(net_G_output['images_ab']) + out_ba, fea_ba, _ = self.discriminator_a(net_G_output['images_ba']) + output = dict(out_ba=out_ba, out_ab=out_ab, + fea_ba=fea_ba, fea_ab=fea_ab) + if real: + out_a, fea_a, _ = self.discriminator_a(data['images_a']) + out_b, fea_b, _ = self.discriminator_b(data['images_b']) + output.update(dict(out_a=out_a, out_b=out_b, + fea_a=fea_a, fea_b=fea_b)) + if gan_recon: + out_aa, fea_aa, _ = self.discriminator_a(net_G_output['images_aa']) + out_bb, fea_bb, _ = self.discriminator_b(net_G_output['images_bb']) + output.update(dict(out_aa=out_aa, out_bb=out_bb, + fea_aa=fea_aa, fea_bb=fea_bb)) + return output diff --git a/imaginaire/evaluation/__init__.py b/imaginaire/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a379a7be63550921dcd5802e7621196df9de9e1 --- /dev/null +++ b/imaginaire/evaluation/__init__.py @@ -0,0 +1,11 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .fid import compute_fid, compute_fid_data +from .kid import compute_kid, compute_kid_data +from .prdc import compute_prdc +from .common import compute_all_metrics, compute_all_metrics_data + +__all__ = ['compute_fid', 'compute_fid_data', 'compute_kid', 'compute_kid_data', + 'compute_prdc', 'compute_all_metrics', 'compute_all_metrics_data'] diff --git a/imaginaire/evaluation/caption/__init__.py b/imaginaire/evaluation/caption/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3197a51d0318c393a6659495d97914b1192f587a --- /dev/null +++ b/imaginaire/evaluation/caption/__init__.py @@ -0,0 +1,8 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .r_precision import get_r_precision +from .common import get_image_encoder + +__all__ = ['get_image_encoder', 'get_r_precision'] diff --git a/imaginaire/evaluation/caption/clip.py b/imaginaire/evaluation/caption/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..1cdcd94e80afcb310564d9d670c3f73f5e065707 --- /dev/null +++ b/imaginaire/evaluation/caption/clip.py @@ -0,0 +1,576 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# flake8: noqa +# https://github.com/openai/CLIP +import hashlib +import os +import urllib +import warnings +from time import sleep +from typing import Union, List + + +from collections import OrderedDict +from typing import Tuple, Union + +import torch +import numpy as np +import torch.nn.functional as F +from torch import nn +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, \ + Normalize +from tqdm import tqdm + +__all__ = ["available_models", "load", 'build_model'] + +from imaginaire.utils.io import download_file_from_google_drive + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", +} + + +def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError( + f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, + "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, + "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, + unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256( + open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError( + f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=Image.BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(model_path): + if not os.path.exists(model_path): + downloaded = False + while not downloaded: + try: + download_file_from_google_drive("1Ri5APYM34A_IjG4F3Admutsf2oUwDjfW", model_path) + downloaded = True + except Exception as e: + print(e) + sleep(30) + continue + model = torch.load(model_path, map_location='cpu') + model = build_model(model).cuda() + return model, _transform(model.visual.input_resolution) + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, + bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, + width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, + padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, + bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, + heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), + (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[ + 0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, + attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential( + *[ResidualAttentionBlock(width, heads, attn_mask) for _ in + range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisualTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, + kernel_size=patch_size, stride=patch_size, + bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((input_resolution // patch_size) ** 2 + 1, + width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], + dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter( + torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, + self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ( + (2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, + std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type( + self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax( + dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, + keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if + k.startswith("visual.") and k.endswith( + ".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round( + (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if + k.startswith(f"visual.layer{b}"))) for b in + [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict[ + "visual.attnpool.positional_embedding"].shape[ + 0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == \ + state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if + k.startswith(f"transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, + transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/imaginaire/evaluation/caption/common.py b/imaginaire/evaluation/caption/common.py new file mode 100644 index 0000000000000000000000000000000000000000..ed54e55335b9b9684eab2fc58cbe14e810597dd6 --- /dev/null +++ b/imaginaire/evaluation/caption/common.py @@ -0,0 +1,57 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os + +import boto3 +import torch +from torch import nn, distributed as dist +from torch.nn import functional as F +from torch.distributed import barrier + +from imaginaire.utils.distributed import is_local_master +from .clip import build_model +from ...utils.io import download_file_from_google_drive + + +def get_image_encoder(aws_credentials=None): + if dist.is_initialized() and not is_local_master(): + # Make sure only the first process in distributed training downloads the model, and the others use the cache. + barrier() + + # Load the CLIP image encoder. + print("Loading CLIP image encoder.") + model_path = os.path.join(torch.hub.get_dir(), 'checkpoints', 'ViT-B-32.pt') + if not os.path.exists(model_path): + if aws_credentials is not None: + s3 = boto3.client('s3', **aws_credentials) + s3.download_file('lpi-poe', 'model_zoo/ViT-B-32.pt', model_path) + else: + download_file_from_google_drive("1Ri5APYM34A_IjG4F3Admutsf2oUwDjfW", model_path) + model = torch.load(model_path, map_location='cpu') + + if dist.is_initialized() and is_local_master(): + # Make sure only the first process in distributed training downloads the model, and the others use the cache. + barrier() + + encoder = build_model(model).cuda() + return ImageEncoder(encoder) + + +class ImageEncoder(nn.Module): + def __init__(self, encoder): + super().__init__() + self.model = encoder + self.image_size = self.model.visual.input_resolution + self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda") + self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda") + + @torch.no_grad() + def forward(self, data, fake_images, align_corners=True): + images = 0.5 * (1 + fake_images) + images = F.interpolate(images, (self.image_size, self.image_size), mode='bicubic', align_corners=align_corners) + images.clamp_(0, 1) + images = (images - self.mean[None, :, None, None]) / (self.std[None, :, None, None]) + image_code = self.model.encode_image(images) + return torch.cat((image_code, data['captions-clip']), dim=1) diff --git a/imaginaire/evaluation/caption/r_precision.py b/imaginaire/evaluation/caption/r_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..ec192e67cf38d24f4f238d151661367f9a5fa753 --- /dev/null +++ b/imaginaire/evaluation/caption/r_precision.py @@ -0,0 +1,27 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# flake8: noqa + +import torch +import torch.nn.functional as F + + +def get_r_precision(image_text_code, eps=1e-5): + all_image_code, all_text_code = torch.chunk(image_text_code, 2, dim=1) + P_rates = [] + num_samples = len(all_image_code) + assert num_samples >= 100 + for i in range(0, num_samples, 100): + if i + 100 <= num_samples: + cur_image_code = all_image_code[i:i + 100] + cur_text_code = all_text_code[i:i + 100] + cur_image_code = F.normalize(cur_image_code, dim=1, eps=eps) + cur_text_code = F.normalize(cur_text_code, dim=1, eps=eps) + cosine_similarities = cur_image_code @ cur_text_code.T + top1_indices = torch.topk(cosine_similarities, dim=1, k=1)[1][:, 0] + P_rate = torch.sum(top1_indices == torch.arange(100, device=top1_indices.device)).item() + P_rates.append(P_rate) + A_precision = sum(P_rates) * 1.0 / len(P_rates) + return {"caption_rprec": A_precision} diff --git a/imaginaire/evaluation/common.py b/imaginaire/evaluation/common.py new file mode 100644 index 0000000000000000000000000000000000000000..046c47da6692f5bd24543cc0e72da2a86b7ef2a3 --- /dev/null +++ b/imaginaire/evaluation/common.py @@ -0,0 +1,651 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import math +import os +from functools import partial +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F +from torchvision.models import inception_v3 +from cleanfid.features import feature_extractor +from cleanfid.resize import build_resizer + +from imaginaire.evaluation.lpips import get_lpips_model +from imaginaire.evaluation.segmentation import get_segmentation_hist_model, get_miou +from imaginaire.evaluation.caption import get_image_encoder, get_r_precision +from imaginaire.evaluation.pretrained import TFInceptionV3, InceptionV3, Vgg16, SwAV +from imaginaire.utils.distributed import (dist_all_gather_tensor, get_rank, + get_world_size, is_master, + is_local_master) +from imaginaire.utils.distributed import master_only_print +from imaginaire.utils.misc import apply_imagenet_normalization, to_cuda + + +@torch.no_grad() +def compute_all_metrics(act_dir, + data_loader, + net_G, + key_real='images', + key_fake='fake_images', + sample_size=None, + preprocess=None, + is_video=False, + few_shot_video=False, + kid_num_subsets=1, + kid_subset_size=None, + key_prefix='', + prdc_k=5, + metrics=None, + dataset_name='', + aws_credentials=None, + **kwargs): + r""" + Args: + act_dir (string): Path to a directory to temporarily save feature activations. + data_loader (obj): PyTorch dataloader object. + net_G (obj): The generator module. + key_real (str): Dictionary key value for the real data. + key_fake (str): Dictionary key value for the fake data. + sample_size (int or None): How many samples to use for FID. + preprocess (func or None): Pre-processing function to use. + is_video (bool): Whether we are handling video sequences. + few_shot_video (bool): If ``True``, uses few-shot video synthesis. + kid_num_subsets (int): Number of subsets for KID evaluation. + kid_subset_size (int or None): The number of samples in each subset for KID evaluation. + key_prefix (string): Add this string before all keys of the output dictionary. + prdc_k (int): The K used for computing K-NN when evaluating precision/recall/density/coverage. + metrics (list of strings): Which metrics we want to evaluate. + dataset_name (string): The name of the dataset, currently only used to determine which segmentation network to + use for segmentation evaluation. + Returns: + batch_y (tensor): Inception features of the current batch. Note that + only the master gpu will get it. + """ + + from imaginaire.evaluation.fid import _calculate_frechet_distance + from imaginaire.evaluation.kid import _polynomial_mmd_averages + from imaginaire.evaluation.prdc import _get_prdc + from imaginaire.evaluation.msid import _get_msid + from imaginaire.evaluation.knn import _get_1nn_acc + if metrics is None: + metrics = [] + act_path = os.path.join(act_dir, 'activations_real.pt') + + # Get feature activations and other outputs computed from fake images. + output_module_dict = nn.ModuleDict() + if "seg_mIOU" in metrics: + output_module_dict["seg_mIOU"] = get_segmentation_hist_model(dataset_name, aws_credentials) + if "caption_rprec" in metrics: + output_module_dict["caption_rprec"] = get_image_encoder(aws_credentials) + if "LPIPS" in metrics: + output_module_dict["LPIPS"] = get_lpips_model() + + fake_outputs = get_outputs( + data_loader, key_real, key_fake, net_G, sample_size, preprocess, + output_module_dict=output_module_dict, **kwargs + ) + fake_act = fake_outputs["activations"] + + # Get feature activations computed from real images. + real_act = load_or_compute_activations( + act_path, data_loader, key_real, key_fake, None, + sample_size, preprocess, is_video=is_video, + few_shot_video=few_shot_video, **kwargs + ) + + metrics_from_activations = { + "1NN": _get_1nn_acc, + "MSID": _get_msid, + "FID": _calculate_frechet_distance, + "KID": partial(_polynomial_mmd_averages, + n_subsets=kid_num_subsets, + subset_size=kid_subset_size, + ret_var=True), + "PRDC": partial(_get_prdc, nearest_k=prdc_k) + } + + other_metrics = { + "seg_mIOU": get_miou, + "caption_rprec": get_r_precision, + "LPIPS": lambda x: {"LPIPS": torch.mean(x).item()} + } + + all_metrics = {} + if is_master(): + for metric in metrics: + if metric in metrics_from_activations: + metric_function = metrics_from_activations[metric] + metric_dict = metric_function(real_act, fake_act) + elif metric in other_metrics: + metric_function = other_metrics[metric] + if fake_outputs[metric] is not None: + metric_dict = metric_function(fake_outputs[metric]) + else: + print(f"{metric} is not implemented!") + raise NotImplementedError + for k, v in metric_dict.items(): + all_metrics.update({key_prefix + k: v}) + if dist.is_initialized(): + dist.barrier() + return all_metrics + + +@torch.no_grad() +def compute_all_metrics_data(data_loader_a, + data_loader_b, + key_a='images', + key_b='images', + sample_size=None, + preprocess=None, + kid_num_subsets=1, + kid_subset_size=None, + key_prefix='', + prdc_k=5, + metrics=None, + dataset_name='', + aws_credentials=None, + **kwargs): + r""" + Args: + act_dir (string): Path to a directory to temporarily save feature activations. + data_loader (obj): PyTorch dataloader object. + net_G (obj): The generator module. + key_a (str): Dictionary key value for the real data. + key_b (str): Dictionary key value for the fake data. + sample_size (int or None): How many samples to use for FID. + preprocess (func or None): Pre-processing function to use. + is_video (bool): Whether we are handling video sequences. + few_shot_video (bool): If ``True``, uses few-shot video synthesis. + kid_num_subsets (int): Number of subsets for KID evaluation. + kid_subset_size (int or None): The number of samples in each subset for KID evaluation. + key_prefix (string): Add this string before all keys of the output dictionary. + prdc_k (int): The K used for computing K-NN when evaluating precision/recall/density/coverage. + metrics (list of strings): Which metrics we want to evaluate. + dataset_name (string): The name of the dataset, currently only used to determine which segmentation network to + use for segmentation evaluation. + Returns: + batch_y (tensor): Inception features of the current batch. Note that + only the master gpu will get it. + """ + + from imaginaire.evaluation.fid import _calculate_frechet_distance + from imaginaire.evaluation.kid import _polynomial_mmd_averages + from imaginaire.evaluation.prdc import _get_prdc + from imaginaire.evaluation.msid import _get_msid + from imaginaire.evaluation.knn import _get_1nn_acc + if metrics is None: + metrics = [] + + min_data_size = min(len(data_loader_a.dataset), + len(data_loader_b.dataset)) + if sample_size is None: + sample_size = min_data_size + else: + sample_size = min(sample_size, min_data_size) + + # Get feature activations and other outputs computed from fake images. + output_module_dict = nn.ModuleDict() + if "seg_mIOU" in metrics: + output_module_dict["seg_mIOU"] = get_segmentation_hist_model(dataset_name, aws_credentials) + if "caption_rprec" in metrics: + output_module_dict["caption_rprec"] = get_image_encoder(aws_credentials) + if "LPIPS" in metrics: + output_module_dict["LPIPS"] = get_lpips_model() + + fake_outputs = get_outputs( + data_loader_b, key_a, key_b, None, sample_size, preprocess, + output_module_dict=output_module_dict, **kwargs + ) + act_b = fake_outputs["activations"] + + act_a = load_or_compute_activations( + None, data_loader_a, key_a, key_b, None, sample_size, preprocess, + output_module_dict=output_module_dict, **kwargs + ) + + # act_b = load_or_compute_activations( + # None, data_loader_b, key_a, key_b, None, sample_size, preprocess, + # output_module_dict=output_module_dict, generate_twice=generate_twice, **kwargs + # ) + + metrics_from_activations = { + "1NN": _get_1nn_acc, + "MSID": _get_msid, + "FID": _calculate_frechet_distance, + "KID": partial(_polynomial_mmd_averages, + n_subsets=kid_num_subsets, + subset_size=kid_subset_size, + ret_var=True), + "PRDC": partial(_get_prdc, nearest_k=prdc_k) + } + + other_metrics = { + "seg_mIOU": get_miou, + "caption_rprec": get_r_precision, + "LPIPS": lambda x: {"LPIPS": torch.mean(x).item()} + } + + all_metrics = {} + if is_master(): + for metric in metrics: + if metric in metrics_from_activations: + metric_function = metrics_from_activations[metric] + metric_dict = metric_function(act_a, act_b) + elif metric in other_metrics: + metric_function = other_metrics[metric] + if fake_outputs[metric] is not None: + metric_dict = metric_function(fake_outputs[metric]) + else: + print(f"{metric} is not implemented!") + raise NotImplementedError + for k, v in metric_dict.items(): + all_metrics.update({key_prefix + k: v}) + if dist.is_initialized(): + dist.barrier() + return all_metrics + + +@torch.no_grad() +def get_activations(data_loader, key_real, key_fake, + generator=None, sample_size=None, preprocess=None, + align_corners=True, network='inception', **kwargs): + r"""Compute activation values and pack them in a list. + + Args: + data_loader (obj): PyTorch dataloader object. + key_real (str): Dictionary key value for the real data. + key_fake (str): Dictionary key value for the fake data. + generator (obj): PyTorch trainer network. + sample_size (int): How many samples to use for FID. + preprocess (func): Pre-processing function to use. + align_corners (bool): The ``'align_corners'`` parameter to be used for + `torch.nn.functional.interpolate`. + Returns: + batch_y (tensor): Inception features of the current batch. Note that + only the master gpu will get it. + """ + if dist.is_initialized() and not is_local_master(): + # Make sure only the first process in distributed training downloads + # the model, and the others will use the cache + # noinspection PyUnresolvedReferences + torch.distributed.barrier() + + if network == 'tf_inception': + model = TFInceptionV3() + elif network == 'inception': + model = InceptionV3() + elif network == 'vgg16': + model = Vgg16() + elif network == 'swav': + model = SwAV() + elif network == 'clean_inception': + model = CleanInceptionV3() + else: + raise NotImplementedError(f'Network "{network}" is not supported!') + + if dist.is_initialized() and is_local_master(): + # Make sure only the first process in distributed training downloads + # the model, and the others will use the cache + # noinspection PyUnresolvedReferences + dist.barrier() + + model = model.to('cuda').eval() + world_size = get_world_size() + batch_y = [] + + # Iterate through the dataset to compute the activation. + for it, data in enumerate(data_loader): + data = to_cuda(data) + # Preprocess the data. + if preprocess is not None: + data = preprocess(data) + # Load real data if the generator is not specified. + if generator is None: + images = data[key_real] + else: + # Compute the generated image. + net_G_output = generator(data, **kwargs) + images = net_G_output[key_fake] + # Clamp the image for models that do not set the output to between + # -1, 1. For models that employ tanh, this has no effect. + images.clamp_(-1, 1) + y = model(images, align_corners=align_corners) + batch_y.append(y) + if sample_size is not None and \ + data_loader.batch_size * world_size * (it + 1) >= sample_size: + # Reach the number of samples we need. + break + + batch_y = torch.cat(dist_all_gather_tensor(torch.cat(batch_y))) + if sample_size is not None: + batch_y = batch_y[:sample_size] + print(f"Computed feature activations of size {batch_y.shape}") + return batch_y + + +class CleanInceptionV3(nn.Module): + def __init__(self): + super().__init__() + self.model = feature_extractor(name="torchscript_inception", resize_inside=False) + + def forward(self, img_batch, transform=True, **_kwargs): + if transform: + # Assume the input is (-1, 1). We transform it to (0, 255) and round it to the closest integer. + img_batch = torch.round(255 * (0.5 * img_batch + 0.5)) + resized_batch = clean_resize(img_batch) + return self.model(resized_batch) + + +def clean_resize(img_batch): + # Resize images from arbitrary resolutions to 299x299. + batch_size = img_batch.size(0) + img_batch = img_batch.cpu().numpy() + fn_resize = build_resizer('clean') + resized_batch = torch.zeros(batch_size, 3, 299, 299, device='cuda') + for idx in range(batch_size): + curr_img = img_batch[idx] + img_np = curr_img.transpose((1, 2, 0)) + img_resize = fn_resize(img_np) + resized_batch[idx] = torch.tensor(img_resize.transpose((2, 0, 1)), device='cuda') + resized_batch = resized_batch.cuda() + return resized_batch + + +@torch.no_grad() +def get_outputs(data_loader, key_real, key_fake, + generator=None, sample_size=None, preprocess=None, + align_corners=True, network='inception', + output_module_dict=None, **kwargs): + r"""Compute activation values and pack them in a list. + + Args: + data_loader (obj): PyTorch dataloader object. + key_real (str): Dictionary key value for the real data. + key_fake (str): Dictionary key value for the fake data. + generator (obj): PyTorch trainer network. + sample_size (int): How many samples to use for FID. + preprocess (func): Pre-processing function to use. + align_corners (bool): The ``'align_corners'`` parameter to be used for `torch.nn.functional.interpolate`. + Returns: + batch_y (tensor): Inception features of the current batch. Note that + only the master gpu will get it. + """ + if output_module_dict is None: + output_module_dict = nn.ModuleDict() + if dist.is_initialized() and not is_local_master(): + # Make sure only the first process in distributed training downloads + # the model, and the others will use the cache + # noinspection PyUnresolvedReferences + torch.distributed.barrier() + + if network == 'tf_inception': + model = TFInceptionV3() + elif network == 'inception': + model = InceptionV3() + elif network == 'vgg16': + model = Vgg16() + elif network == 'swav': + model = SwAV() + elif network == 'clean_inception': + model = CleanInceptionV3() + else: + raise NotImplementedError(f'Network "{network}" is not supported!') + + if dist.is_initialized() and is_local_master(): + # Make sure only the first process in distributed training downloads + # the model, and the others will use the cache + # noinspection PyUnresolvedReferences + dist.barrier() + + model = model.to('cuda').eval() + world_size = get_world_size() + output = {} + for k in output_module_dict.keys(): + output[k] = [] + output["activations"] = [] + + # Iterate through the dataset to compute the activation. + for it, data in enumerate(data_loader): + data = to_cuda(data) + # Preprocess the data. + if preprocess is not None: + data = preprocess(data) + # Load real data if the generator is not specified. + if generator is None: + images = data[key_real] + else: + # Compute the generated image. + net_G_output = generator(data, **kwargs) + images = net_G_output[key_fake] + for metric_name, metric_module in output_module_dict.items(): + if metric_module is not None: + if metric_name == 'LPIPS': + assert generator is not None + net_G_output_another = generator(data, **kwargs) + images_another = net_G_output_another[key_fake] + output[metric_name].append(metric_module(images, images_another)) + else: + output[metric_name].append(metric_module(data, images, align_corners=align_corners)) + # Clamp the image for models that do not set the output to between + # -1, 1. For models that employ tanh, this has no effect. + images.clamp_(-1, 1) + y = model(images, align_corners=align_corners) + output["activations"].append(y) + if sample_size is not None and data_loader.batch_size * world_size * (it + 1) >= sample_size: + # Reach the number of samples we need. + break + + for k, v in output.items(): + if len(v) > 0: + output[k] = torch.cat(dist_all_gather_tensor(torch.cat(v)))[:sample_size] + else: + output[k] = None + return output + + +@torch.no_grad() +def get_video_activations(data_loader, key_real, key_fake, trainer=None, + sample_size=None, preprocess=None, few_shot=False): + r"""Compute activation values and pack them in a list. We do not do all + reduce here. + + Args: + data_loader (obj): PyTorch dataloader object. + key_real (str): Dictionary key value for the real data. + key_fake (str): Dictionary key value for the fake data. + trainer (obj): Trainer. Video generation is more involved, we rely on + the "reset" and "test" function to conduct the evaluation. + sample_size (int): For computing video activation, we will use . + preprocess (func): The preprocess function to be applied to the data. + few_shot (bool): If ``True``, uses the few-shot setting. + Returns: + batch_y (tensor): Inception features of the current batch. Note that + only the master gpu will get it. + """ + inception = inception_init() + batch_y = [] + + # We divide video sequences to different GPUs for testing. + num_sequences = data_loader.dataset.num_inference_sequences() + if sample_size is None: + num_videos_to_test = 10 + num_frames_per_video = 5 + else: + num_videos_to_test, num_frames_per_video = sample_size + if num_videos_to_test == -1: + num_videos_to_test = num_sequences + else: + num_videos_to_test = min(num_videos_to_test, num_sequences) + master_only_print('Number of videos used for evaluation: {}'.format(num_videos_to_test)) + master_only_print('Number of frames per video used for evaluation: {}'.format(num_frames_per_video)) + + world_size = get_world_size() + if num_videos_to_test < world_size: + seq_to_run = [get_rank() % num_videos_to_test] + else: + num_videos_to_test = num_videos_to_test // world_size * world_size + seq_to_run = range(get_rank(), num_videos_to_test, world_size) + + for sequence_idx in seq_to_run: + data_loader = set_sequence_idx(few_shot, data_loader, sequence_idx) + if trainer is not None: + trainer.reset() + for it, data in enumerate(data_loader): + if few_shot and it == 0: + continue + if it >= num_frames_per_video: + break + + # preprocess the data is preprocess is not none. + if trainer is not None: + data = trainer.pre_process(data) + elif preprocess is not None: + data = preprocess(data) + data = to_cuda(data) + + if trainer is None: + images = data[key_real][:, -1] + else: + net_G_output = trainer.test_single(data) + images = net_G_output[key_fake] + y = inception_forward(inception, images) + batch_y += [y] + + batch_y = torch.cat(batch_y) + batch_y = dist_all_gather_tensor(batch_y) + if is_local_master(): + batch_y = torch.cat(batch_y) + return batch_y + + +def inception_init(): + inception = inception_v3(pretrained=True, transform_input=False) + inception = inception.to('cuda') + inception.eval() + inception.fc = torch.nn.Sequential() + return inception + + +def inception_forward(inception, images): + images.clamp_(-1, 1) + images = apply_imagenet_normalization(images) + images = F.interpolate(images, size=(299, 299), + mode='bicubic', align_corners=True) + return inception(images) + + +def gather_tensors(batch_y): + batch_y = torch.cat(batch_y) + batch_y = dist_all_gather_tensor(batch_y) + if is_local_master(): + batch_y = torch.cat(batch_y) + return batch_y + + +def set_sequence_idx(few_shot, data_loader, sequence_idx): + r"""Get sequence index + + Args: + few_shot (bool): If ``True``, uses the few-shot setting. + data_loader: dataloader object + sequence_idx (int): which sequence to use. + """ + if few_shot: + data_loader.dataset.set_inference_sequence_idx(sequence_idx, + sequence_idx, + 0) + else: + data_loader.dataset.set_inference_sequence_idx(sequence_idx) + return data_loader + + +def load_or_compute_activations(act_path, data_loader, key_real, key_fake, + generator=None, sample_size=None, + preprocess=None, + is_video=False, few_shot_video=False, + **kwargs): + r"""Load mean and covariance from saved npy file if exists. Otherwise, + compute the mean and covariance. + + Args: + act_path (str or None): Location for the numpy file to store or to load + the activations. + data_loader (obj): PyTorch dataloader object. + key_real (str): Dictionary key value for the real data. + key_fake (str): Dictionary key value for the fake data. + generator (obj): PyTorch trainer network. + sample_size (int): How many samples to be used for computing the KID. + preprocess (func): The preprocess function to be applied to the data. + is_video (bool): Whether we are handling video sequences. + few_shot_video (bool): If ``True``, uses few-shot video synthesis. + Returns: + (torch.Tensor) Feature activations. + """ + if act_path is not None and os.path.exists(act_path): + # Loading precomputed activations. + print('Load activations from {}'.format(act_path)) + act = torch.load(act_path, map_location='cpu').cuda() + else: + # Compute activations. + if is_video: + act = get_video_activations( + data_loader, key_real, key_fake, generator, + sample_size, preprocess, few_shot_video, **kwargs + ) + else: + act = get_activations( + data_loader, key_real, key_fake, generator, + sample_size, preprocess, **kwargs + ) + if act_path is not None and is_local_master(): + print('Save activations to {}'.format(act_path)) + if not os.path.exists(os.path.dirname(act_path)): + os.makedirs(os.path.dirname(act_path), exist_ok=True) + torch.save(act, act_path) + return act + + +def compute_pairwise_distance(data_x, data_y=None, num_splits=10): + r""" + + Args: + data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) + data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) + Returns: + numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. + """ + if data_y is None: + data_y = data_x + num_samples = data_x.shape[0] + assert data_x.shape[0] == data_y.shape[0] + dists = [] + for i in range(num_splits): + batch_size = math.ceil(num_samples / num_splits) + start_idx = i * batch_size + end_idx = min((i + 1) * batch_size, num_samples) + dists.append(torch.cdist(data_x[start_idx:end_idx], + data_y).cpu()) + dists = torch.cat(dists, dim=0) + return dists + + +def compute_nn(input_features, k, num_splits=50): + num_samples = input_features.shape[0] + all_indices = [] + all_values = [] + for i in range(num_splits): + batch_size = math.ceil(num_samples / num_splits) + start_idx = i * batch_size + end_idx = min((i + 1) * batch_size, num_samples) + dist = torch.cdist(input_features[start_idx:end_idx], + input_features) + dist[:, start_idx:end_idx] += torch.diag( + float('inf') * torch.ones(dist.size(0), device=dist.device) + ) + k_smallests, indices = torch.topk(dist, k, dim=-1, largest=False) + all_indices.append(indices) + all_values.append(k_smallests) + return torch.cat(all_values, dim=0), torch.cat(all_indices, dim=0) diff --git a/imaginaire/evaluation/fid.py b/imaginaire/evaluation/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..793f4b77dd50381abffc82d03fee8bc36e746dab --- /dev/null +++ b/imaginaire/evaluation/fid.py @@ -0,0 +1,143 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os +import numpy as np +import torch +from scipy import linalg + +from imaginaire.evaluation.common import load_or_compute_activations +from imaginaire.utils.distributed import is_master +from imaginaire.utils.distributed import master_only_print as print + + +@torch.no_grad() +def compute_fid(fid_path, data_loader, net_G, + key_real='images', key_fake='fake_images', + sample_size=None, preprocess=None, return_act=False, + is_video=False, few_shot_video=False, **kwargs): + r"""Compute the fid score. + + Args: + fid_path (str): Location for the numpy file to store or to load the + statistics. + data_loader (obj): PyTorch dataloader object. + net_G (obj): For image generation modes, net_G is the generator network. + For video generation models, net_G is the trainer. + key_real (str): Dictionary key value for the real data. + key_fake (str): Dictionary key value for the fake data. + sample_size (int or tuple): How many samples to be used. + preprocess (func): The preprocess function to be applied to the data. + return_act (bool): If ``True``, also returns feature activations of + real and fake data. + is_video (bool): Whether we are handling video sequences. + few_shot_video (bool): If ``True``, uses few-shot video synthesis. + Returns: + (float): FID value. + """ + print('Computing FID.') + act_path = os.path.join(os.path.dirname(fid_path), + 'activations_real.npy') + # Get the fake mean and covariance. + fake_act = load_or_compute_activations( + None, data_loader, key_real, key_fake, net_G, + sample_size, preprocess, is_video=is_video, + few_shot_video=few_shot_video, **kwargs + ) + + # Get the ground truth mean and covariance. + real_act = load_or_compute_activations( + act_path, data_loader, key_real, key_fake, None, + sample_size, preprocess, is_video=is_video, + few_shot_video=few_shot_video, **kwargs + ) + + if is_master(): + fid = _calculate_frechet_distance( + fake_act, real_act)["FID"] + if return_act: + return fid, real_act, fake_act + else: + return fid + elif return_act: + return None, None, None + else: + return None + + +@torch.no_grad() +def compute_fid_data(fid_path, data_loader_a, data_loader_b, + key_a='images', key_b='images', sample_size=None, + is_video=False, few_shot_video=False, **kwargs): + r"""Compute the fid score between two datasets. + + Args: + fid_path (str): Location for the numpy file to store or to load the + statistics. + data_loader_a (obj): PyTorch dataloader object for dataset a. + data_loader_b (obj): PyTorch dataloader object for dataset b. + key_a (str): Dictionary key value for images in the dataset a. + key_b (str): Dictionary key value for images in the dataset b. + sample_size (int): How many samples to be used for computing the FID. + is_video (bool): Whether we are handling video sequences. + few_shot_video (bool): If ``True``, uses few-shot video synthesis. + Returns: + (float): FID value. + """ + print('Computing FID.') + path_a = os.path.join(os.path.dirname(fid_path), + 'activations_a.npy') + min_data_size = min(len(data_loader_a.dataset), + len(data_loader_b.dataset)) + if sample_size is None: + sample_size = min_data_size + else: + sample_size = min(sample_size, min_data_size) + + act_a = load_or_compute_activations( + path_a, data_loader_a, key_a, key_b, None, + sample_size=sample_size, is_video=is_video, + few_shot_video=few_shot_video, **kwargs + ) + act_b = load_or_compute_activations( + None, data_loader_b, key_a, key_b, None, + sample_size=sample_size, is_video=is_video, + few_shot_video=few_shot_video, **kwargs + ) + + if is_master(): + return _calculate_frechet_distance(act_a, act_b)["FID"] + + +def _calculate_frechet_distance(act_1, act_2, eps=1e-6): + mu1 = np.mean(act_1.cpu().numpy(), axis=0) + sigma1 = np.cov(act_1.cpu().numpy(), rowvar=False) + mu2 = np.mean(act_2.cpu().numpy(), axis=0) + sigma2 = np.cov(act_2.cpu().numpy(), rowvar=False) + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions' + diff = mu1 - mu2 + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + print('Imaginary component {}'.format(m)) + # raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + tr_covmean = np.trace(covmean) + return {"FID": (diff.dot(diff) + np.trace(sigma1) + np.trace( + sigma2) - 2 * tr_covmean)} diff --git a/imaginaire/evaluation/kid.py b/imaginaire/evaluation/kid.py new file mode 100644 index 0000000000000000000000000000000000000000..675b93015b03a8c1f4557e5687de0887b1f5a0a4 --- /dev/null +++ b/imaginaire/evaluation/kid.py @@ -0,0 +1,317 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md + +""" +Modified from https://github.com/abdulfatir/gan-metrics-pytorch +Copyright 2018 Institute of Bioinformatics, JKU Linz +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import warnings + +import numpy as np +import torch + +from imaginaire.evaluation.common import get_activations, \ + load_or_compute_activations +from imaginaire.utils.distributed import is_master +from imaginaire.utils.distributed import master_only_print as print + + +@torch.no_grad() +def compute_kid(kid_path, data_loader, net_G, + key_real='images', key_fake='fake_images', + real_act=None, fake_act=None, + sample_size=None, preprocess=None, is_video=False, + save_act=True, num_subsets=1, subset_size=None, **kwargs): + r"""Compute the kid score. + + Args: + kid_path (str): Location for store feature activations. + data_loader (obj): PyTorch dataloader object. + net_G (obj): For image generation modes, net_G is the PyTorch trainer + network. For video generation models, net_G is the trainer + because video generation requires more complicated processing. + key_real (str): Dictionary key value for the real data. + key_fake (str): Dictionary key value for the fake data. + real_act (torch.Tensor or None): Feature activations of real data. + fake_act (torch.Tensor or None): Feature activations of fake data. + sample_size (int): How many samples to be used for computing feature + activations. + preprocess (func): The preprocess function to be applied to the data. + is_video (bool): Whether we are handling video sequences. + save_act (bool): If ``True``, saves real activations to the disk and + reload them in the future. It might save some computation but will + cost storage. + num_subsets (int): Number of subsets to sample from all the samples. + subset_size (int): Number of samples in each subset. + Returns: + kid (float): KID value. + """ + print('Computing KID.') + act_path = os.path.join( + os.path.dirname(kid_path), 'activations_real.npy' + ) if save_act else None + + # Get the fake activations. + if fake_act is None: + fake_act = load_or_compute_activations(None, + data_loader, + key_real, key_fake, net_G, + sample_size, preprocess, + is_video=is_video, **kwargs) + else: + print(f"Using precomputed activations of size {fake_act.shape}.") + + # Get the ground truth activations. + if real_act is None: + real_act = load_or_compute_activations(act_path, + data_loader, + key_real, key_fake, None, + sample_size, preprocess, + is_video=is_video, **kwargs) + else: + print(f"Using precomputed activations of size {real_act.shape}.") + + if is_master(): + return _polynomial_mmd_averages(fake_act, real_act, + num_subsets, + subset_size, + ret_var=True)["KID"] + + +@torch.no_grad() +def compute_kid_data(kid_path, data_loader_a, data_loader_b, + key_a='images', key_b='images', sample_size=None, + is_video=False, num_subsets=1, subset_size=None, + **kwargs): + r"""Compute the kid score between two datasets. + + Args: + kid_path (str): Location for store feature activations. + data_loader_a (obj): PyTorch dataloader object for dataset a. + data_loader_b (obj): PyTorch dataloader object for dataset b. + key_a (str): Dictionary key value for images in the dataset a. + key_b (str): Dictionary key value for images in the dataset b. + sample_size (int): How many samples to be used for computing the KID. + is_video (bool): Whether we are handling video sequences. + num_subsets (int): Number of subsets to sample from the whole data. + subset_size (int): Number of samples in each subset. + Returns: + kid (float): KID value. + """ + min_data_size = min(len(data_loader_a.dataset), + len(data_loader_b.dataset)) + if sample_size is None: + sample_size = min_data_size + else: + sample_size = min(sample_size, min_data_size) + print('Computing KID using {} images from both distributions.'. + format(sample_size)) + path_a = os.path.join(os.path.dirname(kid_path), + 'activations_a.npy') + act_a = load_or_compute_activations(path_a, data_loader_a, + key_a, key_a, + sample_size=sample_size, + is_video=is_video, **kwargs) + act_b = get_activations(data_loader_b, key_b, key_b, + None, sample_size, None, **kwargs) + + if is_master(): + return _polynomial_mmd_averages(act_a, act_b, + num_subsets, + subset_size, + ret_var=True)["KID"] + + +def _polynomial_mmd_averages(codes_g, codes_r, n_subsets, subset_size, + ret_var=True, **kernel_args): + r"""Computes MMD between two sets of features using polynomial kernels. It + performs a number of repetitions of subset sampling without replacement. + + Args: + codes_g (Tensor): Feature activations of generated images. + codes_r (Tensor): Feature activations of real images. + n_subsets (int): The number of subsets. + subset_size (int): The number of samples in each subset. + ret_var (bool): If ``True``, returns both mean and variance of MMDs, + otherwise only returns the mean. + Returns: + (tuple): + - mmds (Tensor): Mean of MMDs. + - mmd_vars (Tensor): Variance of MMDs. + """ + mmds = np.zeros(n_subsets) + if ret_var: + mmd_vars = np.zeros(n_subsets) + choice = np.random.choice + + if subset_size is None: + subset_size = min(len(codes_r), len(codes_r)) + print("Subset size not provided, " + "setting it to the data size ({}).".format(subset_size)) + if subset_size > len(codes_g) or subset_size > len(codes_r): + subset_size = min(len(codes_r), len(codes_r)) + warnings.warn( + "Subset size is large than the actual data size, " + "setting it to the data size ({}).".format(subset_size)) + + for i in range(n_subsets): + g = codes_g[choice(len(codes_g), subset_size, replace=False)] + r = codes_r[choice(len(codes_r), subset_size, replace=False)] + o = _polynomial_mmd(g, r, **kernel_args, ret_var=ret_var) + if ret_var: + # noinspection PyUnboundLocalVariable + mmds[i], mmd_vars[i] = o + else: + mmds[i] = o + return {'KID': mmds.mean()} + + +def _polynomial_kernel(X, Y=None, degree=3, gamma=None, coef0=1.): + r"""Compute the polynomial kernel between X and Y""" + if gamma is None: + gamma = 1.0 / X.shape[1] + if Y is None: + Y = X + + # K = safe_sparse_dot(X, Y.T, dense_output=True) + K = torch.matmul(X, Y.t()) + K *= gamma + K += coef0 + K = K ** degree + return K + + +def _polynomial_mmd(codes_g, codes_r, degree=3, gamma=None, coef0=1, + ret_var=True): + r"""Computes MMD between two sets of features using polynomial kernels. It + performs a number of repetitions of subset sampling without replacement. + + Args: + codes_g (torch.Tensor): Feature activations of generated images. + codes_r (torch.Tensor): Feature activations of real images. + degree (int): The degree of the polynomial kernel. + gamma (float or None): Scale of the polynomial kernel. + coef0 (float or None): Bias of the polynomial kernel. + ret_var (bool): If ``True``, returns both mean and variance of MMDs, + otherwise only returns the mean. + Returns: + (tuple): + - mmds (torch.Tensor): Mean of MMDs. + - mmd_vars (torch.Tensor): Variance of MMDs. + """ + # use k(x, y) = (gamma + coef0)^degree + # default gamma is 1 / dim + X = codes_g + Y = codes_r + + # with warnings.catch_warnings(): + # warnings.simplefilter('ignore') + K_XX = _polynomial_kernel(X, degree=degree, gamma=gamma, coef0=coef0) + K_YY = _polynomial_kernel(Y, degree=degree, gamma=gamma, coef0=coef0) + K_XY = _polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=coef0) + + return _mmd2_and_variance(K_XX, K_XY, K_YY, ret_var=ret_var) + + +def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False, + mmd_est='unbiased', ret_var=True): + r"""Based on + https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py + but changed to not compute the full kernel matrix at once + """ + + m = K_XX.shape[0] + assert K_XX.shape == (m, m) + assert K_XY.shape == (m, m) + assert K_YY.shape == (m, m) + var_at_m = m + + # Get the various sums of kernels that we'll use + # Kts drop the diagonal, but we don't need to compute them explicitly + if unit_diagonal: + diag_X = diag_Y = 1 + sum_diag_X = sum_diag_Y = m + sum_diag2_X = sum_diag2_Y = m + else: + diag_X = torch.diagonal(K_XX) + diag_Y = torch.diagonal(K_YY) + + sum_diag_X = diag_X.sum() + sum_diag_Y = diag_Y.sum() + + sum_diag2_X = _sqn(diag_X) + sum_diag2_Y = _sqn(diag_Y) + + Kt_XX_sums = K_XX.sum(dim=1) - diag_X + Kt_YY_sums = K_YY.sum(dim=1) - diag_Y + K_XY_sums_0 = K_XY.sum(dim=0) + K_XY_sums_1 = K_XY.sum(dim=1) + + Kt_XX_sum = Kt_XX_sums.sum() + Kt_YY_sum = Kt_YY_sums.sum() + K_XY_sum = K_XY_sums_0.sum() + + if mmd_est == 'biased': + mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) + + (Kt_YY_sum + sum_diag_Y) / (m * m) + - 2 * K_XY_sum / (m * m)) + else: + assert mmd_est in {'unbiased', 'u-statistic'} + mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1)) + if mmd_est == 'unbiased': + mmd2 -= 2 * K_XY_sum / (m * m) + else: + mmd2 -= 2 * (K_XY_sum - torch.trace(K_XY)) / (m * (m - 1)) + + if not ret_var: + return mmd2 + + Kt_XX_2_sum = _sqn(K_XX) - sum_diag2_X + Kt_YY_2_sum = _sqn(K_YY) - sum_diag2_Y + K_XY_2_sum = _sqn(K_XY) + + dot_XX_XY = Kt_XX_sums.dot(K_XY_sums_1) + dot_YY_YX = Kt_YY_sums.dot(K_XY_sums_0) + + m1 = m - 1 + m2 = m - 2 + zeta1_est = ( + 1 / (m * m1 * m2) * + (_sqn(Kt_XX_sums) - Kt_XX_2_sum + _sqn(Kt_YY_sums) - Kt_YY_2_sum) + - 1 / (m * m1) ** 2 * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2) + + 1 / (m * m * m1) * ( + _sqn(K_XY_sums_1) + _sqn(K_XY_sums_0) - 2 * K_XY_2_sum) + - 2 / m ** 4 * K_XY_sum ** 2 + - 2 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) + + 2 / (m ** 3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum + ) + zeta2_est = ( + 1 / (m * m1) * (Kt_XX_2_sum + Kt_YY_2_sum) + - 1 / (m * m1) ** 2 * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2) + + 2 / (m * m) * K_XY_2_sum + - 2 / m ** 4 * K_XY_sum ** 2 + - 4 / (m * m * m1) * (dot_XX_XY + dot_YY_YX) + + 4 / (m ** 3 * m1) * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum + ) + var_est = (4 * (var_at_m - 2) / (var_at_m * (var_at_m - 1)) * zeta1_est + + 2 / (var_at_m * (var_at_m - 1)) * zeta2_est) + + return mmd2.cpu().numpy(), var_est.cpu().numpy() + + +def _sqn(arr): + r"""Squared norm.""" + flat = arr.view(-1) + return flat.dot(flat) diff --git a/imaginaire/evaluation/knn.py b/imaginaire/evaluation/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..48b1cf502dc01e89e67f021572df9e94c71938ef --- /dev/null +++ b/imaginaire/evaluation/knn.py @@ -0,0 +1,35 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch + +from imaginaire.evaluation.common import compute_nn + + +def _get_1nn_acc(data_x, data_y, k=1): + device = data_x.device + n0 = data_x.size(0) + n1 = data_y.size(0) + data_all = torch.cat((data_x, data_y), dim=0) + val, idx = compute_nn(data_all, k) + label = torch.cat((torch.ones(n0, device=device), + torch.zeros(n1, device=device))) + + count = torch.zeros(n0 + n1, device=device) + for i in range(0, k): + count = count + label.index_select(0, idx[:, i]) + pred = torch.ge(count, (float(k) / 2) * + torch.ones(n0 + n1, device=device)).float() + + tp = (pred * label).sum() + fp = (pred * (1 - label)).sum() + fn = ((1 - pred) * label).sum() + tn = ((1 - pred) * (1 - label)).sum() + acc_r = (tp / (tp + fn)).item() + acc_f = (tn / (tn + fp)).item() + acc = torch.eq(label, pred).float().mean().item() + + return {'1NN_acc': acc, + '1NN_acc_real': acc_r, + '1NN_acc_fake': acc_f} diff --git a/imaginaire/evaluation/lpips.py b/imaginaire/evaluation/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..37b10d77ad739151347ee2c45e6417145206cdbc --- /dev/null +++ b/imaginaire/evaluation/lpips.py @@ -0,0 +1,153 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from collections import namedtuple + +import torch +from torch import nn, distributed as dist +import torchvision.models as tv +from torch.distributed import barrier + +from imaginaire.utils.distributed import is_local_master + + +def get_lpips_model(): + if dist.is_initialized() and not is_local_master(): + # Make sure only the first process in distributed training downloads the model, and the others use the cache. + barrier() + + model = LPIPSNet().cuda() + + if dist.is_initialized() and is_local_master(): + # Make sure only the first process in distributed training downloads the model, and the others use the cache. + barrier() + return model + + +# Learned perceptual network, modified from https://github.com/richzhang/PerceptualSimilarity + +def normalize_tensor(in_feat, eps=1e-5): + norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True) + eps) + return in_feat / (norm_factor + eps) + + +class NetLinLayer(nn.Module): + """ A single linear layer used as placeholder for LPIPS learnt weights """ + + def __init__(self, dim): + super(NetLinLayer, self).__init__() + self.weight = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, inp): + out = self.weight * inp + return out + + +class ScalingLayer(nn.Module): + # For rescaling the input to vgg16 + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class LPIPSNet(nn.Module): + def __init__(self): + super().__init__() + self.model = LPNet() + + @torch.no_grad() + def forward(self, fake_images, fake_images_another, align_corners=True): + features, shape = self._forward_single(fake_images) + features_another, _ = self._forward_single(fake_images_another) + result = 0 + for i, g_feat in enumerate(features): + cur_diff = torch.sum((g_feat - features_another[i]) ** 2, dim=1) / (shape[i] ** 2) + result += cur_diff + return result + + def _forward_single(self, images): + return self.model(torch.clamp(images, 0, 1)) + + +class LPNet(nn.Module): + def __init__(self): + super(LPNet, self).__init__() + + self.scaling_layer = ScalingLayer() + self.net = vgg16(pretrained=True, requires_grad=False) + self.L = 5 + dims = [64, 128, 256, 512, 512] + self.lins = nn.ModuleList([NetLinLayer(dims[i]) for i in range(self.L)]) + + weights = torch.hub.load_state_dict_from_url( + 'https://github.com/niopeng/CAM-Net/raw/main/code/models/weights/v0.1/vgg.pth' + ) + for i in range(self.L): + self.lins[i].weight.data = torch.sqrt(weights["lin%d.model.1.weight" % i]) + + def forward(self, in0, avg=False): + in0 = 2 * in0 - 1 + in0_input = self.scaling_layer(in0) + outs0 = self.net.forward(in0_input) + feats0 = {} + shapes = [] + res = [] + + for kk in range(self.L): + feats0[kk] = normalize_tensor(outs0[kk]) + + if avg: + res = [self.lins[kk](feats0[kk]).mean([2, 3], keepdim=False) for kk in range(self.L)] + else: + for kk in range(self.L): + cur_res = self.lins[kk](feats0[kk]) + shapes.append(cur_res.shape[-1]) + res.append(cur_res.reshape(cur_res.shape[0], -1)) + + return res, shapes + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + h = self.slice1(x) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out diff --git a/imaginaire/evaluation/msid.py b/imaginaire/evaluation/msid.py new file mode 100644 index 0000000000000000000000000000000000000000..c59a39dd38fe29d8d463de8348eb3749189c6b2b --- /dev/null +++ b/imaginaire/evaluation/msid.py @@ -0,0 +1,375 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md + +# flake8: noqa +from scipy.sparse import lil_matrix, diags, eye +import math + +import numpy as np +import torch + +EPSILON = 1e-6 +NORMALIZATION = 1e6 + + +def _get_msid(x, y, ts=np.logspace(-1, 1, 256), k=5, m=10, niters=100, + rademacher=False, graph_builder='full', + msid_mode='max', normalized_laplacian=True, normalize='empty'): + """ + Compute the msid score between two samples, x and y + Arguments: + x: x samples + y: y samples + ts: temperature values + k: number of neighbours for graph construction + m: Lanczos steps in SLQ + niters: number of starting random vectors for SLQ + rademacher: if True, sample random vectors from Rademacher + distributions, else sample standard normal distribution + graph_builder: if 'kgraph', uses faster graph construction (options: + 'sparse', 'kgraph') + msid_mode: 'l2' to compute the l2 norm of the distance between `msid1` + and `msid2`; 'max' to find the maximum abosulute difference between + two descriptors over temperature + normalized_laplacian: if True, use normalized Laplacian + normalize: 'empty' for average heat kernel (corresponds to the empty + graph normalization of NetLSD), 'complete' for the complete, 'er' + for erdos-renyi normalization, 'none' for no normalization + Returns: + msid_score: the scalar value of the distance between discriptors + """ + normed_msidx = msid_descriptor(x, ts, k, m, niters, rademacher, + graph_builder, normalized_laplacian, + normalize) + normed_msidy = msid_descriptor(y, ts, k, m, niters, rademacher, + graph_builder, normalized_laplacian, + normalize) + + c = np.exp(-2 * (ts + 1 / ts)) + + if msid_mode == 'l2': + score = np.linalg.norm(normed_msidx - normed_msidy) + elif msid_mode == 'max': + score = np.amax(c * np.abs(normed_msidx - normed_msidy)) + else: + raise Exception('Use either l2 or max mode.') + + return {'IMD': score} + + +def msid_descriptor(x, ts=np.logspace(-1, 1, 256), k=5, m=10, niters=100, + rademacher=False, graph_builder='full', + normalized_laplacian=True, normalize='empty'): + """ + Compute the msid descriptor for a single sample x + Arguments: + x: x samples + ts: temperature values + k: number of neighbours for graph construction + m: Lanczos steps in SLQ + niters: number of starting random vectors for SLQ + rademacher: if True, sample random vectors from Rademacher + distributions, else sample standard normal distribution + graph_builder: if 'kgraph', uses faster graph construction + (options: ' + rse', 'kgraph') + normalized_laplacian: if True, use normalized Laplacian + normalize: 'empty' for average heat kernel (corresponds to the empty + graph normalization of NetLSD), 'complete' for the complete, 'er' + for erdos-renyi normalization, 'none' for no normalization + Returns: + normed_msidx: normalized msid descriptor + """ + Lx = _build_graph(x, k, graph_builder, normalized_laplacian) + + nx = Lx.shape[0] + msidx = slq_red_var(Lx, m, niters, ts, rademacher) + + normed_msidx = _normalize_msid(msidx, normalize, nx, k, ts) * NORMALIZATION + + return normed_msidx + + +def _build_graph(data, k=5, graph_builder='full', normalized=True): + """ + Return Laplacian from data or load preconstructed from path + + Arguments: + data: samples + k: number of neighbours for graph construction + graph_builder: if 'kgraph', use faster graph construction + normalized: if True, use nnormalized Laplacian + Returns: + L: Laplacian of the graph constructed with data + """ + if graph_builder == 'sparse': + A = construct_graph_sparse(data.cpu().numpy(), k) + elif graph_builder == 'kgraph': + A = construct_graph_kgraph(data.cpu().numpy(), k) + elif graph_builder == 'full': + A = lil_matrix(construct_graph(data, k).cpu().numpy()).tocsr() + else: + raise Exception('Please specify graph builder: sparse or kgraph.') + A = (A + A.T) / 2 + A.data = np.ones(A.data.shape) + L = _laplacian_sparse(A, normalized) + return L + + +def _normalize_msid(msid, normalization, n, k, ts): + normed_msid = msid.copy() + if normalization == 'empty': + normed_msid /= n + elif normalization == 'complete': + normed_msid /= (1 + (n - 1) * np.exp(-(1 + 1 / (n - 1)) * ts)) + elif normalization == 'er': + xs = np.linspace(0, 1, n) + er_spectrum = 4 / np.sqrt(k) * xs + 1 - 2 / np.sqrt(k) + er_msid = np.exp(-np.outer(ts, er_spectrum)).sum(-1) + normed_msid = normed_msid / (er_msid + EPSILON) + elif normalization == 'none' or normalization is None: + pass + else: + raise ValueError('Unknown normalization parameter!') + return normed_msid + + +def _lanczos_m(A, m, nv, rademacher, SV=None): + """ + Lanczos algorithm computes symmetric m x m tridiagonal matrix T and + matrix V with orthogonal rows constituting the basis of the Krylov + subspace K_m(A, x), where x is an arbitrary starting unit vector. This + implementation parallelizes `nv` starting vectors. + + Arguments: + m: number of Lanczos steps + nv: number of random vectors + rademacher: True to use Rademacher distribution, + False - standard normal for random vectors + SV: specified starting vectors + + Returns: T: a nv x m x m tensor, T[i, :, :] is the ith symmetric + tridiagonal matrix V: a n x m x nv tensor, V[:, :, i] is the ith matrix + with orthogonal rows + """ + orthtol = 1e-5 + if type(SV) != np.ndarray: + if rademacher: + SV = np.sign(np.random.randn(A.shape[0], nv)) + else: + SV = np.random.randn(A.shape[0], + nv) # init random vectors in columns: n x nv + V = np.zeros((SV.shape[0], m, nv)) + T = np.zeros((nv, m, m)) + + np.divide(SV, np.linalg.norm(SV, axis=0), out=SV) # normalize each column + V[:, 0, :] = SV + + w = A.dot(SV) + alpha = np.einsum('ij,ij->j', w, SV) + w -= alpha[None, :] * SV + beta = np.einsum('ij,ij->j', w, w) + np.sqrt(beta, beta) + + T[:, 0, 0] = alpha + T[:, 0, 1] = beta + T[:, 1, 0] = beta + + np.divide(w, beta[None, :], out=w) + V[:, 1, :] = w + t = np.zeros((m, nv)) + + for i in range(1, m): + SVold = V[:, i - 1, :] + SV = V[:, i, :] + + w = A.dot(SV) # sparse @ dense + w -= beta[None, :] * SVold # n x nv + np.einsum('ij,ij->j', w, SV, out=alpha) + + T[:, i, i] = alpha + + if i < m - 1: + w -= alpha[None, :] * SV # n x nv + # reortho + np.einsum('ijk,ik->jk', V, w, out=t) + w -= np.einsum('ijk,jk->ik', V, t) + np.einsum('ij,ij->j', w, w, out=beta) + np.sqrt(beta, beta) + np.divide(w, beta[None, :], out=w) + + T[:, i, i + 1] = beta + T[:, i + 1, i] = beta + + # more reotho + innerprod = np.einsum('ijk,ik->jk', V, w) + reortho = False + for _ in range(100): + if not (innerprod > orthtol).sum(): + reortho = True + break + np.einsum('ijk,ik->jk', V, w, out=t) + w -= np.einsum('ijk,jk->ik', V, t) + np.divide(w, np.linalg.norm(w, axis=0)[None, :], out=w) + innerprod = np.einsum('ijk,ik->jk', V, w) + + V[:, i + 1, :] = w + + if (np.abs(beta) > 1e-6).sum() == 0 or not reortho: + break + return T, V + + +def _slq(A, m, niters, rademacher): + """ + Compute the trace of matrix exponential + + Arguments: + A: square matrix in trace(exp(A)) + m: number of Lanczos steps + niters: number of quadratures (also, the number of random vectors in the + hutchinson trace estimator) + rademacher: True to use Rademacher distribution, False - standard normal + for random vectors in Hutchinson + Returns: trace: estimate of trace of matrix exponential + """ + T, _ = _lanczos_m(A, m, niters, rademacher) + eigvals, eigvecs = np.linalg.eigh(T) + expeig = np.exp(eigvals) + sqeigv1 = np.power(eigvecs[:, 0, :], 2) + trace = A.shape[-1] * (expeig * sqeigv1).sum() / niters + return trace + + +def _slq_ts(A, m, niters, ts, rademacher): + """ + Compute the trace of matrix exponential + + Arguments: + A: square matrix in trace(exp(-t*A)), where t is temperature + m: number of Lanczos steps + niters: number of quadratures (also, the number of random vectors in the + hutchinson trace estimator) + ts: an array with temperatures + rademacher: True to use Rademacher distribution, False - standard normal + for random vectors in Hutchinson + Returns: + trace: estimate of trace of matrix exponential across temperatures `ts` + """ + T, _ = _lanczos_m(A, m, niters, rademacher) + eigvals, eigvecs = np.linalg.eigh(T) + expeig = np.exp(-np.outer(ts, eigvals)).reshape(ts.shape[0], niters, m) + sqeigv1 = np.power(eigvecs[:, 0, :], 2) + traces = A.shape[-1] * (expeig * sqeigv1).sum(-1).mean(-1) + return traces + + +def _slq_ts_fs(A, m, niters, ts, rademacher, fs): + """ + Compute the trace of matrix functions + + Arguments: + A: square matrix in trace(exp(-t*A)), where t is temperature + m: number of Lanczos steps + niters: number of quadratures (also, the number of random vectors in the + hutchinson trace estimator) + ts: an array with temperatures + rademacher: True to use Rademacher distribution, else - standard normal + for random vectors in Hutchinson + fs: a list of functions + Returns: + traces: estimate of traces for each of the functions in fs + """ + T, _ = _lanczos_m(A, m, niters, rademacher) + eigvals, eigvecs = np.linalg.eigh(T) + traces = np.zeros((len(fs), len(ts))) + for i, f in enumerate(fs): + expeig = f(-np.outer(ts, eigvals)).reshape(ts.shape[0], niters, m) + sqeigv1 = np.power(eigvecs[:, 0, :], 2) + traces[i, :] = A.shape[-1] * (expeig * sqeigv1).sum(-1).mean(-1) + return traces + + +def slq_red_var(A, m, niters, ts, rademacher): + """ + Compute the trace of matrix exponential with reduced variance + + Arguments: + A: square matrix in trace(exp(-t*A)), where t is temperature + m: number of Lanczos steps + niters: number of quadratures (also, the number of random vectors in the + hutchinson trace estimator) + ts: an array with temperatures + Returns: + traces: estimate of trace for each temperature value in `ts` + """ + fs = [np.exp, lambda x: x] + + traces = _slq_ts_fs(A, m, niters, ts, rademacher, fs) + subee = traces[0, :] - traces[1, :] / np.exp(ts) + sub = - ts * A.shape[0] / np.exp(ts) + return subee + sub + + +def np_euc_cdist(data): + dd = np.sum(data * data, axis=1) + dist = -2 * np.dot(data, data.T) + dist += dd + dd[:, np.newaxis] + np.fill_diagonal(dist, 0) + np.sqrt(dist, dist) + return dist + + +def construct_graph_sparse(data, k): + n = len(data) + spmat = lil_matrix((n, n)) + dd = np.sum(data * data, axis=1) + + for i in range(n): + dists = dd - 2 * data[i, :].dot(data.T) + inds = np.argpartition(dists, k + 1)[:k + 1] + inds = inds[inds != i] + spmat[i, inds] = 1 + + return spmat.tocsr() + + +def construct_graph_kgraph(data, k): + raise NotImplementedError + # + # n = len(data) + # spmat = lil_matrix((n, n)) + # index = pykgraph.KGraph(data, 'euclidean') + # index.build(reverse=0, K=2 * k + 1, L=2 * k + 50) + # result = index.search(data, K=k + 1)[:, 1:] + # spmat[np.repeat(np.arange(n), k, 0), result.ravel()] = 1 + # return spmat.tocsr() + + +def construct_graph(input_features, k, num_splits=10): + num_samples = input_features.shape[0] + indices = [] + for i in range(num_splits): + batch_size = math.ceil(num_samples / num_splits) + start_idx = i * batch_size + end_idx = min((i + 1) * batch_size, num_samples) + dist = torch.cdist(input_features[start_idx:end_idx], + input_features) + indices.append(torch.topk(dist, k + 1, dim=-1, largest=False)[1].cpu()) + indices = torch.cat(indices, dim=0) + graph = torch.zeros(num_samples, num_samples, device=indices.device) + graph.scatter_(1, indices, 1) + graph -= torch.eye(num_samples, device=graph.device) + return graph + + +def _laplacian_sparse(A, normalized=True): + D = A.sum(1).A1 + if normalized: + Dsqrt = diags(1 / np.sqrt(D)) + L = eye(A.shape[0]) - Dsqrt.dot(A).dot(Dsqrt) + else: + L = diags(D) - A + return L diff --git a/imaginaire/evaluation/prdc.py b/imaginaire/evaluation/prdc.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba2abba01f6e4b8e49b4b979ebe0aa93e4f9b00 --- /dev/null +++ b/imaginaire/evaluation/prdc.py @@ -0,0 +1,124 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +""" +Modified from https://github.com/clovaai/generative-evaluation-prdc +Copyright (c) 2020-present NAVER Corp. +MIT license +""" +import os + +import torch + +from imaginaire.utils.distributed import is_master +from imaginaire.utils.distributed import master_only_print as print + +from .common import load_or_compute_activations, compute_pairwise_distance, \ + compute_nn + + +@torch.no_grad() +def compute_prdc(prdc_path, data_loader, net_G, + key_real='images', key_fake='fake_images', + real_act=None, fake_act=None, + sample_size=None, save_act=True, k=10, **kwargs): + r"""Compute precision diversity curve + + Args: + + """ + print('Computing PRDC.') + act_path = os.path.join( + os.path.dirname(prdc_path), 'activations_real.npy' + ) if save_act else None + + # Get the fake activations. + if fake_act is None: + fake_act = load_or_compute_activations(None, + data_loader, + key_real, key_fake, net_G, + sample_size=sample_size, + **kwargs) + else: + print(f"Using precomputed activations of size {fake_act.shape}.") + + # Get the ground truth activations. + if real_act is None: + real_act = load_or_compute_activations(act_path, + data_loader, + key_real, key_fake, None, + sample_size=sample_size, + **kwargs) + else: + print(f"Using precomputed activations of size {real_act.shape}.") + + if is_master(): + prdc_data = _get_prdc(real_act, fake_act, k) + return \ + prdc_data['precision'], prdc_data['recall'], \ + prdc_data['density'], prdc_data['coverage'] + else: + return None, None, None, None + + +def get_kth_value(unsorted, k, dim=-1): + r""" + + Args: + unsorted: numpy.ndarray of any dimensionality. + k: int + Returns: + kth values along the designated axis. + """ + indices = torch.topk(unsorted, k, dim=dim, largest=False)[1] + k_smallests = torch.gather(unsorted, dim=dim, index=indices) + kth_values = k_smallests.max(dim=dim)[0] + return kth_values + + +def _get_prdc(real_features, fake_features, nearest_k): + r""" + Computes precision, recall, density, and coverage given two manifolds. + + Args: + real_features: numpy.ndarray([N, feature_dim], dtype=np.float32) + fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32) + nearest_k: int. + Returns: + dict of precision, recall, density, and coverage. + """ + real_nearest_neighbour_distances, _ = compute_nn( + real_features, nearest_k) + real_nearest_neighbour_distances = \ + real_nearest_neighbour_distances.max(dim=-1)[0].cpu() + fake_nearest_neighbour_distances, _ = compute_nn( + fake_features, nearest_k) + fake_nearest_neighbour_distances = \ + fake_nearest_neighbour_distances.max(dim=-1)[0].cpu() + distance_real_fake = compute_pairwise_distance( + real_features, fake_features) + + precision = ( + distance_real_fake < + torch.unsqueeze(real_nearest_neighbour_distances, dim=1) + ).any(dim=0).float().mean().item() + + recall = ( + distance_real_fake < + torch.unsqueeze(fake_nearest_neighbour_distances, dim=0) + ).any(dim=1).float().mean().item() + + density = (1. / float(nearest_k)) * ( + distance_real_fake < + torch.unsqueeze(real_nearest_neighbour_distances, dim=1) + ).sum(dim=0).float().mean().item() + + # noinspection PyUnresolvedReferences + coverage = ( + distance_real_fake.min(dim=1)[0] < + real_nearest_neighbour_distances + ).float().mean().item() + + return dict(precision=precision, recall=recall, + density=density, coverage=coverage) diff --git a/imaginaire/evaluation/pretrained.py b/imaginaire/evaluation/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..b99c8f19cbb7687ac5e88b28939e83475e458971 --- /dev/null +++ b/imaginaire/evaluation/pretrained.py @@ -0,0 +1,232 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md + +""" +Modified from +https://github.com/mseitzer/pytorch-fid + +Code adapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead +of Tensorflow +Copyright 2018 Institute of Bioinformatics, JKU Linz +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import torch +import torch.nn.functional as F +from torch import nn + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +from torchvision.models import inception, inception_v3, vgg16 + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases' \ + '/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' + + +class SwAV(nn.Module): + def __init__(self): + super().__init__() + self.model = torch.hub.load('facebookresearch/swav', 'resnet50', + pretrained=True) + self.model.fc = torch.nn.Sequential() + + def forward(self, x, align_corners=True): + y = self.model(F.interpolate( + x, size=(224, 224), mode='bicubic', align_corners=align_corners)) + return y + + +class Vgg16(nn.Module): + def __init__(self): + super().__init__() + self.model = vgg16(pretrained=True, init_weights=False) + self.model.classifier = torch.nn.Sequential( + *[self.model.classifier[i] for i in range(4)] + ) + + def forward(self, x, align_corners=True): + y = self.model(F.interpolate( + x, size=(224, 224), mode='bicubic', align_corners=align_corners)) + return y + + +class InceptionV3(nn.Module): + def __init__(self): + super().__init__() + self.model = inception_v3(transform_input=False, + pretrained=True, + init_weights=False) + self.model.fc = torch.nn.Sequential() + + def forward(self, x, align_corners=True): + y = self.model(F.interpolate( + x, size=(299, 299), mode='bicubic', align_corners=align_corners)) + return y + + +class TFInceptionV3(nn.Module): + def __init__(self): + super().__init__() + self.model = inception_v3(transform_input=False, + num_classes=1008, + aux_logits=False, + pretrained=False, + init_weights=False) + self.model.Mixed_5b = FIDInceptionA(192, pool_features=32) + self.model.Mixed_5c = FIDInceptionA(256, pool_features=64) + self.model.Mixed_5d = FIDInceptionA(288, pool_features=64) + self.model.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + self.model.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + self.model.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + self.model.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + self.model.Mixed_7b = FIDInceptionE_1(1280) + self.model.Mixed_7c = FIDInceptionE_2(2048) + + state_dict = load_state_dict_from_url( + FID_WEIGHTS_URL, progress=True, map_location='cpu' + ) + self.model.load_state_dict(state_dict) + self.model.fc = torch.nn.Sequential() + + def forward(self, x, align_corners=True): + y = self.model(F.interpolate( + x, size=(299, 299), mode='bicubic', align_corners=align_corners)) + return y + + +class FIDInceptionA(inception.InceptionA): + """InceptionA block patched for FID computation""" + + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(inception.InceptionC): + """InceptionC block patched for FID computation""" + + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(inception.InceptionE): + """First InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(inception.InceptionE): + """Second InceptionE block patched for FID computation""" + + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/imaginaire/evaluation/segmentation/__init__.py b/imaginaire/evaluation/segmentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63a71a20ebb0ac2940c091f064de5a029126af42 --- /dev/null +++ b/imaginaire/evaluation/segmentation/__init__.py @@ -0,0 +1,7 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .common import get_segmentation_hist_model, get_miou,compute_hist + +__all__ = ['get_segmentation_hist_model', 'get_miou','compute_hist'] diff --git a/imaginaire/evaluation/segmentation/celebamask_hq.py b/imaginaire/evaluation/segmentation/celebamask_hq.py new file mode 100644 index 0000000000000000000000000000000000000000..ab95f325eb681bb0d80bac2c626494ee9e81b51b --- /dev/null +++ b/imaginaire/evaluation/segmentation/celebamask_hq.py @@ -0,0 +1,130 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# https://github.com/switchablenorms/CelebAMask-HQ/tree/master/face_parsing +import torch +from torch import nn +from torch.nn import functional as F + + +class Unet(nn.Module): + def __init__( + self, + feature_scale=4, + n_classes=19, + is_deconv=True, + in_channels=3, + is_batchnorm=True, + image_size=512, + use_dont_care=False + ): + super(Unet, self).__init__() + self.is_deconv = is_deconv + self.in_channels = in_channels + self.is_batchnorm = is_batchnorm + self.feature_scale = feature_scale + self.image_size = image_size + self.n_classes = n_classes + self.use_dont_care = use_dont_care + + filters = [64, 128, 256, 512, 1024] + filters = [int(x / self.feature_scale) for x in filters] + + # downsampling + self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) + self.maxpool1 = nn.MaxPool2d(kernel_size=2) + + self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) + self.maxpool2 = nn.MaxPool2d(kernel_size=2) + + self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) + self.maxpool3 = nn.MaxPool2d(kernel_size=2) + + self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) + self.maxpool4 = nn.MaxPool2d(kernel_size=2) + + self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) + + # upsampling + self.up_concat4 = unetUp( + filters[4], filters[3], self.is_deconv, self.is_batchnorm) + self.up_concat3 = unetUp( + filters[3], filters[2], self.is_deconv, self.is_batchnorm) + self.up_concat2 = unetUp( + filters[2], filters[1], self.is_deconv, self.is_batchnorm) + self.up_concat1 = unetUp( + filters[1], filters[0], self.is_deconv, self.is_batchnorm) + + # final conv (without any concat) + self.final = nn.Conv2d(filters[0], n_classes, 1) + + def forward(self, images, align_corners=True): + images = F.interpolate( + images, size=(self.image_size, self.image_size), mode='bicubic', + align_corners=align_corners + ) + + conv1 = self.conv1(images) + maxpool1 = self.maxpool1(conv1) + conv2 = self.conv2(maxpool1) + maxpool2 = self.maxpool2(conv2) + conv3 = self.conv3(maxpool2) + maxpool3 = self.maxpool3(conv3) + conv4 = self.conv4(maxpool3) + maxpool4 = self.maxpool4(conv4) + center = self.center(maxpool4) + up4 = self.up_concat4(conv4, center) + up3 = self.up_concat3(conv3, up4) + up2 = self.up_concat2(conv2, up3) + up1 = self.up_concat1(conv1, up2) + probs = self.final(up1) + pred = torch.argmax(probs, dim=1) + return pred + + +class unetConv2(nn.Module): + def __init__(self, in_size, out_size, is_batchnorm): + super(unetConv2, self).__init__() + + if is_batchnorm: + self.conv1 = nn.Sequential( + nn.Conv2d(in_size, out_size, 3, 1, 1), + nn.BatchNorm2d(out_size), + nn.ReLU(), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(out_size, out_size, 3, 1, 1), + nn.BatchNorm2d(out_size), + nn.ReLU(), + ) + else: + self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1), + nn.ReLU()) + self.conv2 = nn.Sequential( + nn.Conv2d(out_size, out_size, 3, 1, 1), nn.ReLU() + ) + + def forward(self, inputs): + outputs = self.conv1(inputs) + outputs = self.conv2(outputs) + return outputs + + +class unetUp(nn.Module): + def __init__(self, in_size, out_size, is_deconv, is_batchnorm): + super(unetUp, self).__init__() + self.conv = unetConv2(in_size, out_size, is_batchnorm) + if is_deconv: + self.up = nn.ConvTranspose2d( + in_size, out_size, kernel_size=2, stride=2) + else: + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, inputs1, inputs2): + outputs2 = self.up(inputs2) + offset = outputs2.size()[2] - inputs1.size()[2] + padding = 2 * [offset // 2, offset // 2] + outputs1 = F.pad(inputs1, padding) + + return self.conv(torch.cat([outputs1, outputs2], 1)) diff --git a/imaginaire/evaluation/segmentation/cocostuff.py b/imaginaire/evaluation/segmentation/cocostuff.py new file mode 100644 index 0000000000000000000000000000000000000000..601f0c28a3811f994df9a9375bd6c5fc08509d9d --- /dev/null +++ b/imaginaire/evaluation/segmentation/cocostuff.py @@ -0,0 +1,48 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from torch import nn +from torch.nn import functional as F +import torch.hub + + +class DeepLabV2(nn.Module): + def __init__(self, n_classes=182, image_size=512, use_dont_care=True): + super(DeepLabV2, self).__init__() + self.model = torch.hub.load( + "kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", + pretrained=False, n_classes=182 + ) + state_dict = torch.hub.load_state_dict_from_url( + 'https://github.com/kazuto1011/deeplab-pytorch/releases/download/' + 'v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth', + map_location="cpu" + ) + self.model.load_state_dict(state_dict) + + self.image_size = image_size + # self.mean = torch.tensor([122.675, 116.669, 104.008], device="cuda") + self.mean = torch.tensor([104.008, 116.669, 122.675], device="cuda") + self.n_classes = n_classes + self.use_dont_care = use_dont_care + + def forward(self, images, align_corners=True): + scale = self.image_size / max(images.shape[2:]) + images = F.interpolate( + images, scale_factor=scale, mode='bilinear', + align_corners=align_corners + ) + images = 255 * 0.5 * (images + 1) # (-1, 1) -> (0, 255) + images = images.flip(1) # RGB to BGR + images -= self.mean[None, :, None, None] + _, _, H, W = images.shape + + logits = self.model(images) + logits = F.interpolate( + logits, size=(H, W), mode="bilinear", + align_corners=align_corners + ) + probs = F.softmax(logits, dim=1) + pred = torch.argmax(probs, dim=1) + return pred diff --git a/imaginaire/evaluation/segmentation/common.py b/imaginaire/evaluation/segmentation/common.py new file mode 100644 index 0000000000000000000000000000000000000000..78d90ae26db4747e5be859a8c1237a288e3146b7 --- /dev/null +++ b/imaginaire/evaluation/segmentation/common.py @@ -0,0 +1,92 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os + +import boto3 +import torch +from torch import nn, distributed as dist +from torch.nn import functional as F + +from imaginaire.utils.distributed import is_local_master +from imaginaire.utils.io import download_file_from_google_drive + + +def get_segmentation_hist_model(dataset_name, aws_credentials=None): + if dist.is_initialized() and not is_local_master(): + # Make sure only the first process in distributed training downloads + # the model, and the others will use the cache + # noinspection PyUnresolvedReferences + torch.distributed.barrier() + + # Load the segmentation network. + if dataset_name == "celebamask_hq": + from imaginaire.evaluation.segmentation.celebamask_hq import Unet + seg_network = Unet() + os.makedirs(os.path.join(torch.hub.get_dir(), 'checkpoints'), exist_ok=True) + model_path = os.path.join(os.path.join(torch.hub.get_dir(), 'checkpoints'), "celebamask_hq.pt") + if not os.path.exists(model_path): + if aws_credentials is not None: + s3 = boto3.client('s3', **aws_credentials) + s3.download_file('lpi-poe', 'model_zoo/celebamask_hq.pt', model_path) + else: + download_file_from_google_drive("1o1m-eT38zNCIFldcRaoWcLvvBtY8S4W3", model_path) + state_dict = torch.load(model_path, map_location='cpu') + seg_network.load_state_dict(state_dict) + elif dataset_name == "cocostuff" or dataset_name == "getty": + from imaginaire.evaluation.segmentation.cocostuff import DeepLabV2 + seg_network = DeepLabV2() + else: + print(f"No segmentation network for {dataset_name} was found.") + return None + + if dist.is_initialized() and is_local_master(): + # Make sure only the first process in distributed training downloads + # the model, and the others will use the cache + # noinspection PyUnresolvedReferences + torch.distributed.barrier() + + if seg_network is not None: + seg_network = seg_network.to('cuda').eval() + + return SegmentationHistModel(seg_network) + + +class SegmentationHistModel(nn.Module): + def __init__(self, seg_network): + super().__init__() + self.seg_network = seg_network + + def forward(self, data, fake_images, align_corners=True): + pred = self.seg_network(fake_images, align_corners=align_corners) + gt = data["segmaps"] + gt = gt * 255.0 + gt = gt.long() + # print(fake_images.shape, fake_images.min(), fake_images.max()) + # print(gt.shape, gt.min(), gt.max()) + # exit() + return compute_hist(pred, gt, self.seg_network.n_classes, self.seg_network.use_dont_care) + + +def compute_hist(pred, gt, n_classes, use_dont_care): + _, H, W = pred.size() + gt = F.interpolate(gt.float(), (H, W), mode="nearest").long().squeeze(1) + ignore_idx = n_classes if use_dont_care else -1 + all_hist = [] + for cur_pred, cur_gt in zip(pred, gt): + keep = torch.logical_not(cur_gt == ignore_idx) + merge = cur_pred[keep] * n_classes + cur_gt[keep] + hist = torch.bincount(merge, minlength=n_classes ** 2) + hist = hist.view((n_classes, n_classes)) + all_hist.append(hist) + all_hist = torch.stack(all_hist) + return all_hist + + +def get_miou(hist, eps=1e-8): + hist = hist.sum(0) + IOUs = torch.diag(hist) / ( + torch.sum(hist, dim=0, keepdim=False) + torch.sum(hist, dim=1, keepdim=False) - torch.diag(hist) + eps) + mIOU = 100 * torch.mean(IOUs).item() + return {"seg_mIOU": mIOU} diff --git a/imaginaire/generators/__init__.py b/imaginaire/generators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imaginaire/generators/__pycache__/__init__.cpython-38.pyc b/imaginaire/generators/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c29051fe61ab456c65828f39f1e1130d389be734 Binary files /dev/null and b/imaginaire/generators/__pycache__/__init__.cpython-38.pyc differ diff --git a/imaginaire/generators/__pycache__/craft_2stage_add_style.cpython-38.pyc b/imaginaire/generators/__pycache__/craft_2stage_add_style.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fa5fc572212d24f822b7bcdc1d936b848ea4adb Binary files /dev/null and b/imaginaire/generators/__pycache__/craft_2stage_add_style.cpython-38.pyc differ diff --git a/imaginaire/generators/__pycache__/craft_base.cpython-38.pyc b/imaginaire/generators/__pycache__/craft_base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..205291ec52f52992a82b459886ef1b3455ee39f0 Binary files /dev/null and b/imaginaire/generators/__pycache__/craft_base.cpython-38.pyc differ diff --git a/imaginaire/generators/coco_funit.py b/imaginaire/generators/coco_funit.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f555fb91c9d8580ccf69e1e785c5b6c5a54aef --- /dev/null +++ b/imaginaire/generators/coco_funit.py @@ -0,0 +1,194 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch import nn + +from imaginaire.generators.funit import (MLP, ContentEncoder, Decoder, + StyleEncoder) + + +class Generator(nn.Module): + r"""COCO-FUNIT Generator. + """ + + def __init__(self, gen_cfg, data_cfg): + r"""COCO-FUNIT Generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + super().__init__() + self.generator = COCOFUNITTranslator(**vars(gen_cfg)) + + def forward(self, data): + r"""In the FUNIT's forward pass, it generates a content embedding and + a style code from the content image, and a style code from the style + image. By mixing the content code and the style code from the content + image, we reconstruct the input image. By mixing the content code and + the style code from the style image, we have a translation output. + + Args: + data (dict): Training data at the current iteration. + """ + content_a = self.generator.content_encoder(data['images_content']) + style_a = self.generator.style_encoder(data['images_content']) + style_b = self.generator.style_encoder(data['images_style']) + images_trans = self.generator.decode(content_a, style_b) + images_recon = self.generator.decode(content_a, style_a) + + net_G_output = dict(images_trans=images_trans, + images_recon=images_recon) + return net_G_output + + def inference(self, data, keep_original_size=True): + r"""COCO-FUNIT inference. + + Args: + data (dict): Training data at the current iteration. + - images_content (tensor): Content images. + - images_style (tensor): Style images. + a2b (bool): If ``True``, translates images from domain A to B, + otherwise from B to A. + keep_original_size (bool): If ``True``, output image is resized + to the input content image size. + """ + content_a = self.generator.content_encoder(data['images_content']) + style_b = self.generator.style_encoder(data['images_style']) + output_images = self.generator.decode(content_a, style_b) + if keep_original_size: + height = data['original_h_w'][0][0] + width = data['original_h_w'][0][1] + # print('( H, W) = ( %d, %d)' % (height, width)) + output_images = torch.nn.functional.interpolate( + output_images, size=[height, width]) + file_names = data['key']['images_content'][0] + return output_images, file_names + + +class COCOFUNITTranslator(nn.Module): + r"""COCO-FUNIT Generator architecture. + + Args: + num_filters (int): Base filter numbers. + num_filters_mlp (int): Base filter number in the MLP module. + style_dims (int): Dimension of the style code. + usb_dims (int): Dimension of the universal style bias code. + num_res_blocks (int): Number of residual blocks at the end of the + content encoder. + num_mlp_blocks (int): Number of layers in the MLP module. + num_downsamples_content (int): Number of times we reduce + resolution by 2x2 for the content image. + num_downsamples_style (int): Number of times we reduce + resolution by 2x2 for the style image. + num_image_channels (int): Number of input image channels. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + """ + + def __init__(self, + num_filters=64, + num_filters_mlp=256, + style_dims=64, + usb_dims=1024, + num_res_blocks=2, + num_mlp_blocks=3, + num_downsamples_style=4, + num_downsamples_content=2, + num_image_channels=3, + weight_norm_type='', + **kwargs): + super().__init__() + + self.style_encoder = StyleEncoder(num_downsamples_style, + num_image_channels, + num_filters, + style_dims, + 'reflect', + 'none', + weight_norm_type, + 'relu') + + self.content_encoder = ContentEncoder(num_downsamples_content, + num_res_blocks, + num_image_channels, + num_filters, + 'reflect', + 'instance', + weight_norm_type, + 'relu') + + self.decoder = Decoder(self.content_encoder.output_dim, + num_filters_mlp, + num_image_channels, + num_downsamples_content, + 'reflect', + weight_norm_type, + 'relu') + + self.usb = torch.nn.Parameter(torch.randn(1, usb_dims)) + + self.mlp = MLP(style_dims, + num_filters_mlp, + num_filters_mlp, + num_mlp_blocks, + 'none', + 'relu') + + num_content_mlp_blocks = 2 + num_style_mlp_blocks = 2 + self.mlp_content = MLP(self.content_encoder.output_dim, + style_dims, + num_filters_mlp, + num_content_mlp_blocks, + 'none', + 'relu') + + self.mlp_style = MLP(style_dims + usb_dims, + style_dims, + num_filters_mlp, + num_style_mlp_blocks, + 'none', + 'relu') + + def forward(self, images): + r"""Reconstruct the input image by combining the computer content and + style code. + + Args: + images (tensor): Input image tensor. + """ + # reconstruct an image + content, style = self.encode(images) + images_recon = self.decode(content, style) + return images_recon + + def encode(self, images): + r"""Encoder images to get their content and style codes. + + Args: + images (tensor): Input image tensor. + """ + style = self.style_encoder(images) + content = self.content_encoder(images) + return content, style + + def decode(self, content, style): + r"""Generate images by combining their content and style codes. + + Args: + content (tensor): Content code tensor. + style (tensor): Style code tensor. + """ + content_style_code = content.mean(3).mean(2) + content_style_code = self.mlp_content(content_style_code) + batch_size = style.size(0) + usb = self.usb.repeat(batch_size, 1) + style = style.view(batch_size, -1) + style_in = self.mlp_style(torch.cat([style, usb], 1)) + coco_style = style_in * content_style_code + coco_style = self.mlp(coco_style) + images = self.decoder(content, coco_style) + return images diff --git a/imaginaire/generators/craft_2stage.py b/imaginaire/generators/craft_2stage.py new file mode 100644 index 0000000000000000000000000000000000000000..ee9c8bd42feb3fae32a96625f885a6752e5254d1 --- /dev/null +++ b/imaginaire/generators/craft_2stage.py @@ -0,0 +1,65 @@ +import torch +from torch import nn +import functools +import torch.nn.functional as F + +import sys +sys.path.append(".") +from model import geometry_transform +from imaginaire.utils.distributed import master_only_print as print +from model.graphs.decoder import DeepLab +from imaginaire.generators.craft_base import * + + + +class Generator(nn.Module): + def __init__(self,opt): + super(Generator, self).__init__() + gen_cfg = opt.arch.gen + data_cfg = opt.data + # self.gen_model = gen_model + self.gen_cfg = opt.arch.gen + if gen_cfg.transform_mode in ['project_RGB','volum_rendering','proj_like_radus']: + self.pano_direction = torch.from_numpy(geometry_transform.get_original_coord(opt)).unsqueeze(0).to(opt.device) + if gen_cfg.transform_mode == 'volum_rendering': + last_act = 'relu' + else: + last_act = 'softmax' + self.depth_model = inner_Generator(gen_cfg,gen_cfg.depth_arch,data_cfg,num_input_channels=3,last_act=last_act) + render_input_channel = 3 + if gen_cfg.cat_opa: + render_input_channel = render_input_channel+1 + self.denoise_model = inner_Generator(gen_cfg,gen_cfg.render_arch,data_cfg,render_input_channel,last_act='sigmoid') + + self.PE = None + + + + def forward(self, inputs, style_img=None,opt=None): + estimated_height = self.depth_model(inputs) + + if self.gen_cfg.transform_mode in ['project_RGB','volum_rendering','proj_like_radus']: + geo_outputs = geometry_transform.render(opt,inputs,estimated_height,self.pano_direction,PE=self.PE) + generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth'] + if 'voxel' in geo_outputs.keys(): + voxel = geo_outputs['voxel'] + # mu, logvar, z = self.style_encode(style_img) + # z = self.style_model(z) + if self.gen_cfg.cat_opa: + generator_inputs = torch.cat((generator_inputs,opacity),dim=1) + output_RGB = self.denoise_model(generator_inputs) + out_put = { + 'pred': output_RGB, + # 'inter_RGB': generator_inputs, ### out_feature not for show + # 'mu' :mu, + # 'logvar' : logvar, + } + if self.gen_cfg.transform_mode in ['volum_rendering']: + out_put['opacity'] = opacity + if self.gen_cfg.transform_mode: + out_put['estimated_height'] = estimated_height + out_put['generator_inputs'] = generator_inputs + out_put['voxel'] = voxel + out_put['depth'] = depth + return out_put + diff --git a/imaginaire/generators/craft_2stage_add_style.py b/imaginaire/generators/craft_2stage_add_style.py new file mode 100644 index 0000000000000000000000000000000000000000..7920abdf110783ea2921e5f8faa98c3a7c284ec9 --- /dev/null +++ b/imaginaire/generators/craft_2stage_add_style.py @@ -0,0 +1,75 @@ +import torch +from torch import nn +import sys +sys.path.append(".") +from model import geometry_transform +from imaginaire.generators.craft_base import * + + + +class Generator(nn.Module): + def __init__(self,opt): + super(Generator, self).__init__() + gen_cfg = opt.arch.gen + data_cfg = opt.data + style_enc_cfg = opt.arch.gen.style_enc_cfg + # self.gen_model = gen_model + self.style_inject = getattr(gen_cfg, 'style_inject', + None) + self.gen_cfg = opt.arch.gen + self.pano_direction = torch.from_numpy(geometry_transform.get_original_coord(opt)).unsqueeze(0).to(opt.device) + last_act = 'relu' + self.depth_model = inner_Generator_split(gen_cfg,gen_cfg.depth_arch,data_cfg,num_input_channels=3,last_act=last_act) + + + render_input_channel = 3 + if gen_cfg.cat_opa: + render_input_channel +=1 + if gen_cfg.cat_depth: + render_input_channel +=1 + + self.denoise_model = inner_Generator_split(gen_cfg,gen_cfg.render_arch,data_cfg,render_input_channel,last_act='sigmoid') + if self.style_inject: + if self.style_inject=='histo': + self.style_encode = histo_process(style_enc_cfg) + elif self.style_inject=='perspective': + self.style_encode = StyleEncoder(style_enc_cfg) + else: + raise Exception('Unknown style inject') + self.style_model = StyleMLP(style_dim=style_enc_cfg.style_dims, out_dim=style_enc_cfg.interm_style_dims, hidden_channels=style_enc_cfg.hidden_channel, leaky_relu=True, num_layers=5, normalize_input=True, + output_act=True) + + self.PE = geometry_transform.position_produce(opt) if gen_cfg.cat_PE else None + + + + def forward(self, inputs, style_img=None,opt=None): + # predicted height of satellite images + estimated_height = self.depth_model(inputs) + geo_outputs = geometry_transform.render(opt,inputs,estimated_height,self.pano_direction,PE=self.PE) + generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth'] + if 'voxel' in geo_outputs.keys(): + voxel = geo_outputs['voxel'] + + if self.gen_cfg.cat_opa: + generator_inputs = torch.cat((generator_inputs,opacity),dim=1) + if self.gen_cfg.cat_depth: + generator_inputs = torch.cat((generator_inputs,depth),dim=1) + if self.style_inject: + mu, logvar, z = self.style_encode(style_img) + z = self.style_model(z) + else: + z = None + # merge multiple sources(rgb,opacity,depth and sky) and denoise redundancy + output_RGB = self.denoise_model(generator_inputs,z) + out_put = {'pred': output_RGB} + if self.style_inject: + out_put['mu'] = mu + out_put['logvar'] = logvar + out_put['estimated_height'] = estimated_height + out_put['generator_inputs'] = generator_inputs + out_put['voxel'] = voxel + out_put['depth'] = depth + out_put['opacity'] = opacity + return out_put + diff --git a/imaginaire/generators/craft_base.py b/imaginaire/generators/craft_base.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f283c701269ad8eaa656036ade597d37393ebc --- /dev/null +++ b/imaginaire/generators/craft_base.py @@ -0,0 +1,483 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.nn import Upsample as NearestUpsample +import torch.nn.functional as F +from functools import partial + +import sys +sys.path.append(".") +from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock + + +class StyleMLP(nn.Module): + r"""MLP converting style code to intermediate style representation.""" + + def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True, + output_act=True): + super(StyleMLP, self).__init__() + + self.normalize_input = normalize_input + self.output_act = output_act + fc_layers = [] + fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True)) + for i in range(num_layers-1): + fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True)) + self.fc_layers = nn.ModuleList(fc_layers) + + self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True) + + if leaky_relu: + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + self.act = partial(F.relu, inplace=True) + + def forward(self, z): + r""" Forward network + + Args: + z (N x style_dim tensor): Style codes. + """ + if self.normalize_input: + z = F.normalize(z, p=2, dim=-1,eps=1e-6) + for fc_layer in self.fc_layers: + z = self.act(fc_layer(z)) + z = self.fc_out(z) + if self.output_act: + z = self.act(z) + return z + +class histo_process(nn.Module): + r"""Histo process to replace Style Encoder constructor. + + Args: + style_enc_cfg (obj): Style encoder definition file. + """ + def __init__(self,style_enc_cfg): + super().__init__() + # if style_enc_cfg.histo.mode in ['RGB','rgb']: + input_channel=270 + # else: + # input_channel=90 + style_dims = style_enc_cfg.style_dims + self.no_vae = getattr(style_enc_cfg, 'no_vae', False) + num_filters = getattr(style_enc_cfg, 'num_filters', 180) + self.process_model = nn.ModuleList() + self.layer1 = LinearBlock(input_channel,num_filters) + self.layer4 = LinearBlock(num_filters, num_filters) + self.fc_mu = LinearBlock(num_filters, style_dims,nonlinearity='tanh') + if not self.no_vae: + self.fc_var = LinearBlock(num_filters, style_dims,nonlinearity='tanh') + + + def forward(self,histo): + x = self.layer1(histo) + x = self.layer4(x) + mu = self.fc_mu(x) #[-1,1] + if not self.no_vae: + logvar = self.fc_var(x) # [-1,1] + std = torch.exp(0.5 * logvar) # [0.607,1.624] + eps = torch.randn_like(std) + z = eps.mul(std) + mu + else: + z = mu + logvar = torch.zeros_like(mu) + return mu, logvar, z + + + +class StyleEncoder(nn.Module): + r"""Style Encoder constructor. + + Args: + style_enc_cfg (obj): Style encoder definition file. + """ + + def __init__(self, style_enc_cfg): + super(StyleEncoder, self).__init__() + input_image_channels = style_enc_cfg.input_image_channels + num_filters = style_enc_cfg.num_filters + kernel_size = style_enc_cfg.kernel_size + padding = int(np.ceil((kernel_size - 1.0) / 2)) + style_dims = style_enc_cfg.style_dims + weight_norm_type = style_enc_cfg.weight_norm_type + self.no_vae = getattr(style_enc_cfg, 'no_vae', False) + activation_norm_type = 'none' + nonlinearity = 'leakyrelu' + base_conv2d_block = \ + partial(Conv2dBlock, + kernel_size=kernel_size, + stride=2, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + # inplace_nonlinearity=True, + nonlinearity=nonlinearity) + self.layer1 = base_conv2d_block(input_image_channels, num_filters) + self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) + self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) + self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) + self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) + self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) + self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims,nonlinearity='tanh') + if not self.no_vae: + self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims,nonlinearity='tanh') + + def forward(self, input_x): + r"""SPADE Style Encoder forward. + + Args: + input_x (N x 3 x H x W tensor): input images. + Returns: + mu (N x C tensor): Mean vectors. + logvar (N x C tensor): Log-variance vectors. + z (N x C tensor): Style code vectors. + """ + if input_x.size(2) != 256 or input_x.size(3) != 256: + input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') + x = self.layer1(input_x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + x = self.layer6(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + if not self.no_vae: + logvar = self.fc_var(x) + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + z = eps.mul(std) + mu + else: + z = mu + logvar = torch.zeros_like(mu) + return mu, logvar, z + + +class RenderCNN(nn.Module): + r"""CNN converting intermediate feature map to final image.""" + + def __init__(self, in_channels, style_dim, hidden_channels=256, + leaky_relu=True): + super(RenderCNN, self).__init__() + self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels) + + self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0) + self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) + self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) + + self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) + self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) + + self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) + self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) + + self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0) + + if leaky_relu: + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + self.act = partial(F.relu, inplace=True) + + def modulate(self, x, w_, b_): + w_ = w_[..., None, None] + b_ = b_[..., None, None] + return x * (w_+1) + b_ +1e-9 + + def forward(self, x, z): + r"""Forward network. + + Args: + x (N x in_channels x H x W tensor): Intermediate feature map + z (N x style_dim tensor): Style codes. + """ + z = self.fc_z_cond(z) + adapt = torch.chunk(z, 2 * 2, dim=-1) + y = self.act(self.conv1(x)) + + y = y + self.conv2b(self.act(self.conv2a(y))) + y = self.act(self.modulate(y, adapt[0], adapt[1])) + + y = y + self.conv3b(self.act(self.conv3a(y))) + y = self.act(self.modulate(y, adapt[2], adapt[3])) + + y = y + self.conv4b(self.act(self.conv4a(y))) + y = self.act(y) + + y = self.conv4(y) + y = torch.sigmoid(y) + return y + + +class inner_Generator(nn.Module): + r"""Pix2pixHD coarse-to-fine generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + last_act: ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``,default is 'relu'. + """ + + def __init__(self, gen_cfg,inner_cfg, data_cfg,num_input_channels=3,last_act='relu'): + super().__init__() + assert last_act in ['none', 'relu', 'leakyrelu', 'prelu', + 'tanh' , 'sigmoid' , 'softmax'] + # pix2pixHD has a global generator. + global_gen_cfg = inner_cfg + # By default, pix2pixHD using instance normalization. + activation_norm_type = getattr(gen_cfg, 'activation_norm_type', + 'instance') + activation_norm_params = getattr(gen_cfg, 'activation_norm_params', + None) + weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '') + padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect') + base_conv_block = partial(Conv2dBlock, + padding_mode=padding_mode, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + nonlinearity='relu') + base_res_block = partial(Res2dBlock, + padding_mode=padding_mode, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + nonlinearity='relu', order='CNACN') + # Know what is the number of available segmentation labels. + + # Global generator model. + global_model = GlobalGenerator(global_gen_cfg, data_cfg, + num_input_channels, padding_mode, + base_conv_block, base_res_block,last_act=last_act) + self.global_model = global_model + + + def forward(self, input, random_style=False): + r"""Coarse-to-fine generator forward. + + Args: + data (dict) : Dictionary of input data. + random_style (bool): Always set to false for the pix2pixHD model. + Returns: + output (dict) : Dictionary of output data. + """ + return self.global_model(input) + + + + def load_pretrained_network(self, pretrained_dict): + r"""Load a pretrained network.""" + # print(pretrained_dict.keys()) + model_dict = self.state_dict() + print('Pretrained network has fewer layers; The following are ' + 'not initialized:') + + not_initialized = set() + for k, v in model_dict.items(): + kp = 'module.' + k.replace('global_model.', 'global_model.model.') + if kp in pretrained_dict and v.size() == pretrained_dict[kp].size(): + model_dict[k] = pretrained_dict[kp] + else: + not_initialized.add('.'.join(k.split('.')[:2])) + print(sorted(not_initialized)) + self.load_state_dict(model_dict) + + def inference(self, data, **kwargs): + r"""Generator inference. + + Args: + data (dict) : Dictionary of input data. + Returns: + fake_images (tensor): Output fake images. + file_names (str): Data file name. + """ + output = self.forward(data, **kwargs) + return output['fake_images'], data['key']['seg_maps'][0] + + +class GlobalGenerator(nn.Module): + r"""Coarse generator constructor. This is the main generator in the + pix2pixHD architecture. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + num_input_channels (int): Number of segmentation labels. + padding_mode (str): zero | reflect | ... + base_conv_block (obj): Conv block with preset attributes. + base_res_block (obj): Residual block with preset attributes. + last_act (str, optional, default='relu'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + """ + + def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode, + base_conv_block, base_res_block,last_act='relu'): + super(GlobalGenerator, self).__init__() + + # num_img_channels = get_paired_input_image_channel_number(data_cfg) + num_out_put_channels = getattr(gen_cfg, 'output_nc', 64) + num_filters = getattr(gen_cfg, 'num_filters', 64) + num_downsamples = getattr(gen_cfg, 'num_downsamples', 4) + num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9) + # First layer. + model = [base_conv_block(num_input_channels, num_filters, + kernel_size=7, padding=3)] + # Downsample. + for i in range(num_downsamples): + ch = num_filters * (2 ** i) + model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)] + # ResNet blocks. + ch = num_filters * (2 ** num_downsamples) + for i in range(num_res_blocks): + model += [base_res_block(ch, ch, 3, padding=1)] + # Upsample. + num_upsamples = num_downsamples + for i in reversed(range(num_upsamples)): + ch = num_filters * (2 ** i) + model += \ + [NearestUpsample(scale_factor=2), + base_conv_block(ch * 2, ch, 3, padding=1)] + model += [Conv2dBlock(num_filters, num_out_put_channels, 7, padding=3, + padding_mode=padding_mode, nonlinearity=last_act)] + + self.model = nn.Sequential(*model) + + def forward(self, input): + r"""Coarse-to-fine generator forward. + + Args: + input (4D tensor) : Input semantic representations. + Returns: + output (4D tensor) : Synthesized image by generator. + """ + return self.model(input) + +class inner_Generator_split(nn.Module): + r"""Pix2pixHD coarse-to-fine generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + last_act: ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``,default is 'relu'. + """ + + def __init__(self, gen_cfg,inner_cfg, data_cfg,num_input_channels=3,last_act='relu'): + super().__init__() + assert last_act in ['none', 'relu', 'leakyrelu', 'prelu', + 'tanh' , 'sigmoid' , 'softmax'] + # pix2pixHD has a global generator. + # By default, pix2pixHD using instance normalization. + print(inner_cfg) + style_dim = gen_cfg.style_enc_cfg.interm_style_dims + activation_norm_type = getattr(gen_cfg, 'activation_norm_type', + 'instance') + activation_norm_params = getattr(gen_cfg, 'activation_norm_params', + None) + weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '') + padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect') + # num_input_channels = get_paired_input_label_channel_number(data_cfg) + # num_input_channels = 3 + base_conv_block = partial(Conv2dBlock, + padding_mode=padding_mode, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + ) + base_res_block = partial(Res2dBlock, + padding_mode=padding_mode, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + nonlinearity='relu', order='CNACN') + # Know what is the number of available segmentation labels. + + # Global generator model. + + num_out_put_channels = getattr(inner_cfg, 'output_nc', 64) + num_filters = getattr(inner_cfg, 'num_filters', 64) + num_downsamples = 4 + num_res_blocks = getattr(inner_cfg, 'num_res_blocks', 9) + # First layer. + model = [base_conv_block(num_input_channels, num_filters, + kernel_size=7, padding=3)] + model += [nn.PReLU()] + # Downsample. + for i in range(num_downsamples): + ch = num_filters * (2 ** i) + model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)] + model += [nn.PReLU()] + # ResNet blocks. + ch = num_filters * (2 ** num_downsamples) + for i in range(num_res_blocks): + model += [base_res_block(ch, ch, 3, padding=1)] + self.model = nn.Sequential(*model) + # Upsample. + assert num_downsamples == 4 + if not (inner_cfg.name =='render' and gen_cfg.style_inject): + list = [16,8,4,2] + else: + list = [16,6,6,6] + + self.up0_a = NearestUpsample(scale_factor=2) + self.up0_b = base_conv_block(num_filters * list[0], num_filters*list[1], 3, padding=1) + self.up1_a = NearestUpsample(scale_factor=2) + self.up1_b = base_conv_block(num_filters * list[1], num_filters*list[2], 3, padding=1) + self.up2_a = NearestUpsample(scale_factor=2) + self.up2_b = base_conv_block(num_filters * list[2], num_filters*list[3], 3, padding=1) + self.up3_a = NearestUpsample(scale_factor=2) + self.up3_b = base_conv_block(num_filters * list[3], num_filters, 3, padding=1) + self.up_end = Conv2dBlock(num_filters, num_out_put_channels, 7, padding=3, + padding_mode=padding_mode, nonlinearity=last_act) + if inner_cfg.name =='render' and gen_cfg.style_inject: + self.fc_z_cond = nn.Linear(style_dim, 4* list[-1] * num_filters) + + def modulate(self, x, w, b): + w = w[..., None, None] + b = b[..., None, None] + return x * (w+1) + b + + def forward(self, input,z=None): + r"""Coarse-to-fine generator forward. + + Args: + input (4D tensor) : Input semantic representations. + Returns: + output (4D tensor) : Synthesized image by generator. + """ + if z is not None: + z = self.fc_z_cond(z) + adapt = torch.chunk(z, 2 * 2, dim=-1) + input = self.model(input) + input = self.up0_a(input) + input = self.up0_b(input) + input = F.leaky_relu(input,negative_slope=0.2, inplace=True) + input = self.up1_a(input) + input = self.up1_b(input) + if z is not None: + input = self.modulate(input, adapt[0], adapt[1]) + input = F.leaky_relu(input,negative_slope=0.2, inplace=True) + + input = self.up2_a(input) + input = self.up2_b(input) + if z is not None: + input = self.modulate(input, adapt[2], adapt[3]) + input = F.leaky_relu(input,negative_slope=0.2, inplace=True) + + input = self.up3_a(input) + input = self.up3_b(input) + input = F.leaky_relu(input,negative_slope=0.2, inplace=True) + + input = self.up_end(input) + + return input + +if __name__=='__main__': + from easydict import EasyDict as edict + style_enc_cfg = edict() + style_enc_cfg.histo.mode = 'RGB' + style_enc_cfg.histo.num_filters = 180 + model = histo_process(style_enc_cfg) \ No newline at end of file diff --git a/imaginaire/generators/dummy.py b/imaginaire/generators/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9a2f1edec286be8751d6a188ebc4f47875c437 --- /dev/null +++ b/imaginaire/generators/dummy.py @@ -0,0 +1,29 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch.nn as nn + +from imaginaire.layers import LinearBlock + + +class Generator(nn.Module): + r"""Dummy generator. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super(Generator, self).__init__() + self.dummy_layer = LinearBlock(1, 1) + pass + + def forward(self, data): + r"""Dummy Generator forward. + + Args: + data (dict): + """ + return diff --git a/imaginaire/generators/fs_vid2vid.py b/imaginaire/generators/fs_vid2vid.py new file mode 100644 index 0000000000000000000000000000000000000000..93c2c20048d0e371d9393302cc858fe738a4d53b --- /dev/null +++ b/imaginaire/generators/fs_vid2vid.py @@ -0,0 +1,1176 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import copy +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from imaginaire.layers import (Conv2dBlock, HyperConv2dBlock, HyperRes2dBlock, + LinearBlock, Res2dBlock) +from imaginaire.model_utils.fs_vid2vid import (extract_valid_pose_labels, + pick_image, resample) +from imaginaire.utils.data import (get_paired_input_image_channel_number, + get_paired_input_label_channel_number) +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.init_weight import weights_init +from imaginaire.utils.misc import get_and_setattr, get_nested_attr + + +class Generator(nn.Module): + r"""Few-shot vid2vid generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super().__init__() + self.gen_cfg = gen_cfg + self.data_cfg = data_cfg + self.num_frames_G = data_cfg.num_frames_G + self.flow_cfg = flow_cfg = gen_cfg.flow + + # For pose dataset. + self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset') + if self.is_pose_data: + pose_cfg = data_cfg.for_pose_dataset + self.pose_type = getattr(pose_cfg, 'pose_type', 'both') + self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels', + False) + + num_img_channels = get_paired_input_image_channel_number(data_cfg) + self.num_downsamples = num_downsamples = \ + get_and_setattr(gen_cfg, 'num_downsamples', 5) + conv_kernel_size = get_and_setattr(gen_cfg, 'kernel_size', 3) + num_filters = get_and_setattr(gen_cfg, 'num_filters', 32) + + max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024) + self.max_num_filters = gen_cfg.max_num_filters = \ + min(max_num_filters, num_filters * (2 ** num_downsamples)) + # Get number of filters at each layer in the main branch. + num_filters_each_layer = [min(self.max_num_filters, + num_filters * (2 ** i)) + for i in range(num_downsamples + 2)] + + # Hyper normalization / convolution. + hyper_cfg = gen_cfg.hyper + # Use adaptive weight generation for SPADE. + self.use_hyper_spade = hyper_cfg.is_hyper_spade + # Use adaptive for convolutional layers in the main branch. + self.use_hyper_conv = hyper_cfg.is_hyper_conv + # Number of hyper layers. + self.num_hyper_layers = getattr(hyper_cfg, 'num_hyper_layers', 4) + if self.num_hyper_layers == -1: + self.num_hyper_layers = num_downsamples + gen_cfg.hyper.num_hyper_layers = self.num_hyper_layers + # Network weight generator. + self.weight_generator = WeightGenerator(gen_cfg, data_cfg) + + # Number of layers to perform multi-spade combine. + self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine, + 'num_layers', 3) + # Whether to generate raw output for additional losses. + self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output', + False) + + # Main branch image generation. + padding = conv_kernel_size // 2 + activation_norm_type = get_and_setattr(gen_cfg, 'activation_norm_type', + 'sync_batch') + weight_norm_type = get_and_setattr(gen_cfg, 'weight_norm_type', + 'spectral') + activation_norm_params = get_and_setattr(gen_cfg, + 'activation_norm_params', + None) + spade_in_channels = [] # Input channel size in SPADE module. + for i in range(num_downsamples + 1): + spade_in_channels += [[num_filters_each_layer[i]]] \ + if i >= self.num_multi_spade_layers \ + else [[num_filters_each_layer[i]] * 3] + + order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC') + for i in reversed(range(num_downsamples + 1)): + activation_norm_params.cond_dims = spade_in_channels[i] + is_hyper_conv = self.use_hyper_conv and i < self.num_hyper_layers + is_hyper_norm = self.use_hyper_spade and i < self.num_hyper_layers + setattr(self, 'up_%d' % i, HyperRes2dBlock( + num_filters_each_layer[i + 1], num_filters_each_layer[i], + conv_kernel_size, padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + order=order * 2, + is_hyper_conv=is_hyper_conv, is_hyper_norm=is_hyper_norm)) + + self.conv_img = Conv2dBlock(num_filters, num_img_channels, + conv_kernel_size, padding=padding, + nonlinearity='leakyrelu', order='AC') + self.upsample = partial(F.interpolate, scale_factor=2) + + # Flow estimation module. + # Whether to warp reference image and combine with the synthesized. + self.warp_ref = getattr(flow_cfg, 'warp_ref', True) + if self.warp_ref: + self.flow_network_ref = FlowGenerator(flow_cfg, data_cfg, 2) + self.ref_image_embedding = \ + LabelEmbedder(flow_cfg.multi_spade_combine.embed, + num_img_channels + 1) + # At beginning of training, only train an image generator. + self.temporal_initialized = False + if getattr(gen_cfg, 'init_temporal', True): + self.init_temporal_network() + + def forward(self, data): + r"""few-shot vid2vid generator forward. + + Args: + data (dict) : Dictionary of input data. + Returns: + output (dict) : Dictionary of output data. + """ + label = data['label'] + ref_labels, ref_images = data['ref_labels'], data['ref_images'] + prev_labels, prev_images = data['prev_labels'], data['prev_images'] + is_first_frame = prev_labels is None + + if self.is_pose_data: + label, prev_labels = extract_valid_pose_labels( + [label, prev_labels], self.pose_type, self.remove_face_labels) + ref_labels = extract_valid_pose_labels( + ref_labels, self.pose_type, self.remove_face_labels, + do_remove=False) + + # Weight generation. + x, encoded_label, conv_weights, norm_weights, atn, atn_vis, ref_idx = \ + self.weight_generator(ref_images, ref_labels, label, is_first_frame) + + # Flow estimation. + flow, flow_mask, img_warp, cond_inputs = \ + self.flow_generation(label, ref_labels, ref_images, + prev_labels, prev_images, ref_idx) + + for i in range(len(encoded_label)): + encoded_label[i] = [encoded_label[i]] + if self.generate_raw_output: + encoded_label_raw = [encoded_label[i] for i in + range(self.num_multi_spade_layers)] + x_raw = None + encoded_label = self.SPADE_combine(encoded_label, cond_inputs) + + # Main branch image generation. + for i in range(self.num_downsamples, -1, -1): + conv_weight = norm_weight = [None] * 3 + if self.use_hyper_conv and i < self.num_hyper_layers: + conv_weight = conv_weights[i] + if self.use_hyper_spade and i < self.num_hyper_layers: + norm_weight = norm_weights[i] + + # Main branch residual blocks. + x = self.one_up_conv_layer(x, encoded_label, + conv_weight, norm_weight, i) + + # For raw output generation. + if self.generate_raw_output and i < self.num_multi_spade_layers: + x_raw = self.one_up_conv_layer(x_raw, encoded_label_raw, + conv_weight, norm_weight, i) + else: + x_raw = x + + # Final conv layer. + if self.generate_raw_output: + img_raw = torch.tanh(self.conv_img(x_raw)) + else: + img_raw = None + img_final = torch.tanh(self.conv_img(x)) + + output = dict() + output['fake_images'] = img_final + output['fake_flow_maps'] = flow + output['fake_occlusion_masks'] = flow_mask + output['fake_raw_images'] = img_raw + output['warped_images'] = img_warp + output['attention_visualization'] = atn_vis + output['ref_idx'] = ref_idx + return output + + def one_up_conv_layer(self, x, encoded_label, conv_weight, norm_weight, i): + r"""One residual block layer in the main branch. + + Args: + x (4D tensor) : Current feature map. + encoded_label (list of tensors) : Encoded input label maps. + conv_weight (list of tensors) : Hyper conv weights. + norm_weight (list of tensors) : Hyper norm weights. + i (int) : Layer index. + Returns: + x (4D tensor) : Output feature map. + """ + layer = getattr(self, 'up_' + str(i)) + x = layer(x, *encoded_label[i], conv_weights=conv_weight, + norm_weights=norm_weight) + if i != 0: + x = self.upsample(x) + return x + + def init_temporal_network(self, cfg_init=None): + r"""When starting training multiple frames, initialize the flow network. + + Args: + cfg_init (dict) : Weight initialization config. + """ + flow_cfg = self.flow_cfg + emb_cfg = flow_cfg.multi_spade_combine.embed + num_frames_G = self.num_frames_G + self.temporal_initialized = True + + self.sep_prev_flownet = flow_cfg.sep_prev_flow or (num_frames_G != 2) \ + or not flow_cfg.warp_ref + if self.sep_prev_flownet: + self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg, + num_frames_G) + if cfg_init is not None: + self.flow_network_temp.apply(weights_init(cfg_init.type, + cfg_init.gain)) + else: + self.flow_network_temp = self.flow_network_ref + + self.sep_prev_embedding = emb_cfg.sep_warp_embed or \ + not flow_cfg.warp_ref + if self.sep_prev_embedding: + num_img_channels = get_paired_input_image_channel_number( + self.data_cfg) + self.prev_image_embedding = \ + LabelEmbedder(emb_cfg, num_img_channels + 1) + if cfg_init is not None: + self.prev_image_embedding.apply( + weights_init(cfg_init.type, cfg_init.gain)) + else: + self.prev_image_embedding = self.ref_image_embedding + + if self.warp_ref: + if self.sep_prev_flownet: + self.init_network_weights(self.flow_network_ref, + self.flow_network_temp) + print('Initialized temporal flow network with the reference ' + 'one.') + if self.sep_prev_embedding: + self.init_network_weights(self.ref_image_embedding, + self.prev_image_embedding) + print('Initialized temporal embedding network with the ' + 'reference one.') + self.flow_temp_is_initalized = True + + def init_network_weights(self, net_src, net_dst): + r"""Initialize weights in net_dst with those in net_src.""" + source_weights = net_src.state_dict() + target_weights = net_dst.state_dict() + + for k, v in source_weights.items(): + if k in target_weights and target_weights[k].size() == v.size(): + target_weights[k] = v + net_dst.load_state_dict(target_weights) + + def load_pretrained_network(self, pretrained_dict, prefix='module.'): + r"""Load the pretrained network into self network. + + Args: + pretrained_dict (dict): Pretrained network weights. + prefix (str): Prefix to the network weights name. + """ + # print(pretrained_dict.keys()) + model_dict = self.state_dict() + print('Pretrained network has fewer layers; The following are ' + 'not initialized:') + + not_initialized = set() + for k, v in model_dict.items(): + kp = prefix + k + if kp in pretrained_dict and v.size() == pretrained_dict[kp].size(): + model_dict[k] = pretrained_dict[kp] + else: + not_initialized.add('.'.join(k.split('.')[:2])) + print(sorted(not_initialized)) + self.load_state_dict(model_dict) + + def reset(self): + r"""Reset the network at the beginning of a sequence.""" + self.weight_generator.reset() + + def flow_generation(self, label, ref_labels, ref_images, prev_labels, + prev_images, ref_idx): + r"""Generates flows and masks for warping reference / previous images. + + Args: + label (NxCxHxW tensor): Target label map. + ref_labels (NxKxCxHxW tensor): Reference label maps. + ref_images (NxKx3xHxW tensor): Reference images. + prev_labels (NxTxCxHxW tensor): Previous label maps. + prev_images (NxTx3xHxW tensor): Previous images. + ref_idx (Nx1 tensor): Index for which image to use from the + reference images. + Returns: + (tuple): + - flow (list of Nx2xHxW tensor): Optical flows. + - occ_mask (list of Nx1xHxW tensor): Occlusion masks. + - img_warp (list of Nx3xHxW tensor): Warped reference / previous + images. + - cond_inputs (list of Nx4xHxW tensor): Conditional inputs for + SPADE combination. + """ + # Pick an image in the reference images using ref_idx. + ref_label, ref_image = pick_image([ref_labels, ref_images], ref_idx) + # Only start using prev frames when enough prev frames are generated. + has_prev = prev_labels is not None and \ + prev_labels.shape[1] == (self.num_frames_G - 1) + flow, occ_mask, img_warp, cond_inputs = [None] * 2, [None] * 2, \ + [None] * 2, [None] * 2 + if self.warp_ref: + # Generate flows/masks for warping the reference image. + flow_ref, occ_mask_ref = \ + self.flow_network_ref(label, ref_label, ref_image) + ref_image_warp = resample(ref_image, flow_ref) + flow[0], occ_mask[0], img_warp[0] = \ + flow_ref, occ_mask_ref, ref_image_warp[:, :3] + # Concat warped image and occlusion mask to form the conditional + # input. + cond_inputs[0] = torch.cat([img_warp[0], occ_mask[0]], dim=1) + + if self.temporal_initialized and has_prev: + # Generate flows/masks for warping the previous image. + b, t, c, h, w = prev_labels.shape + prev_labels_concat = prev_labels.view(b, -1, h, w) + prev_images_concat = prev_images.view(b, -1, h, w) + flow_prev, occ_mask_prev = \ + self.flow_network_temp(label, prev_labels_concat, + prev_images_concat) + img_prev_warp = resample(prev_images[:, -1], flow_prev) + flow[1], occ_mask[1], img_warp[1] = \ + flow_prev, occ_mask_prev, img_prev_warp + cond_inputs[1] = torch.cat([img_warp[1], occ_mask[1]], dim=1) + + return flow, occ_mask, img_warp, cond_inputs + + def SPADE_combine(self, encoded_label, cond_inputs): + r"""Using Multi-SPADE to combine raw synthesized image with warped + images. + + Args: + encoded_label (list of tensors): Original label map embeddings. + cond_inputs (list of tensors): New SPADE conditional inputs from the + warped images. + Returns: + encoded_label (list of tensors): Combined conditional inputs. + """ + # Generate the conditional embeddings from inputs. + embedded_img_feat = [None, None] + if cond_inputs[0] is not None: + embedded_img_feat[0] = self.ref_image_embedding(cond_inputs[0]) + if cond_inputs[1] is not None: + embedded_img_feat[1] = self.prev_image_embedding(cond_inputs[1]) + + # Combine the original encoded label maps with new conditional + # embeddings. + for i in range(self.num_multi_spade_layers): + encoded_label[i] += [w[i] if w is not None else None + for w in embedded_img_feat] + return encoded_label + + def custom_init(self): + r"""This function is for dealing with the numerical issue that might + occur when doing mixed precision training. + """ + print('Use custom initialization for the generator.') + for k, m in self.named_modules(): + if 'weight_generator.ref_label_' in k and 'norm' in k: + m.eps = 1e-1 + + +class WeightGenerator(nn.Module): + r"""Weight generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file + """ + + def __init__(self, gen_cfg, data_cfg): + super().__init__() + self.data_cfg = data_cfg + self.embed_cfg = embed_cfg = gen_cfg.embed + self.embed_arch = embed_cfg.arch + + num_filters = gen_cfg.num_filters + self.max_num_filters = gen_cfg.max_num_filters + self.num_downsamples = num_downsamples = gen_cfg.num_downsamples + self.num_filters_each_layer = num_filters_each_layer = \ + [min(self.max_num_filters, num_filters * (2 ** i)) + for i in range(num_downsamples + 2)] + if getattr(embed_cfg, 'num_filters', 32) != num_filters: + raise ValueError('Embedding network must have the same number of ' + 'filters as generator.') + + # Normalization params. + hyper_cfg = gen_cfg.hyper + kernel_size = getattr(hyper_cfg, 'kernel_size', 3) + activation_norm_type = getattr(hyper_cfg, 'activation_norm_type', + 'sync_batch') + weight_norm_type = getattr(hyper_cfg, 'weight_norm_type', 'spectral') + # Conv kernel size in main branch. + self.conv_kernel_size = conv_kernel_size = gen_cfg.kernel_size + # Conv kernel size in embedding network. + self.embed_kernel_size = embed_kernel_size = \ + getattr(gen_cfg.embed, 'kernel_size', 3) + # Conv kernel size in SPADE. + self.kernel_size = kernel_size = \ + getattr(gen_cfg.activation_norm_params, 'kernel_size', 1) + # Input channel size in SPADE module. + self.spade_in_channels = [] + for i in range(num_downsamples + 1): + self.spade_in_channels += [num_filters_each_layer[i]] + + # Hyper normalization / convolution. + # Use adaptive weight generation for SPADE. + self.use_hyper_spade = hyper_cfg.is_hyper_spade + # Use adaptive for the label embedding network. + self.use_hyper_embed = hyper_cfg.is_hyper_embed + # Use adaptive for convolutional layers in the main branch. + self.use_hyper_conv = hyper_cfg.is_hyper_conv + # Number of hyper layers. + self.num_hyper_layers = hyper_cfg.num_hyper_layers + # Order of operations in the conv block. + order = getattr(gen_cfg.hyper, 'hyper_block_order', 'NAC') + self.conv_before_norm = order.find('C') < order.find('N') + + # For reference image encoding. + # How to utilize the reference label map: concat | mul. + self.concat_ref_label = 'concat' in hyper_cfg.method_to_use_ref_labels + self.mul_ref_label = 'mul' in hyper_cfg.method_to_use_ref_labels + # Output spatial size for adaptive pooling layer. + self.sh_fix = self.sw_fix = 32 + # Number of fc layers in weight generation. + self.num_fc_layers = getattr(hyper_cfg, 'num_fc_layers', 2) + + # Reference image encoding network. + num_input_channels = get_paired_input_label_channel_number(data_cfg) + if num_input_channels == 0: + num_input_channels = getattr(data_cfg, 'label_channels', 1) + elif get_nested_attr(data_cfg, 'for_pose_dataset.pose_type', + 'both') == 'open': + num_input_channels -= 3 + data_cfg.num_input_channels = num_input_channels + num_img_channels = get_paired_input_image_channel_number(data_cfg) + num_ref_channels = num_img_channels + (num_input_channels + if self.concat_ref_label else 0) + conv_2d_block = partial( + Conv2dBlock, kernel_size=kernel_size, + padding=(kernel_size // 2), weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity='leakyrelu') + + self.ref_img_first = conv_2d_block(num_ref_channels, num_filters) + if self.mul_ref_label: + self.ref_label_first = conv_2d_block(num_input_channels, + num_filters) + + for i in range(num_downsamples): + in_ch, out_ch = num_filters_each_layer[i], \ + num_filters_each_layer[i + 1] + setattr(self, 'ref_img_down_%d' % i, + conv_2d_block(in_ch, out_ch, stride=2)) + setattr(self, 'ref_img_up_%d' % i, conv_2d_block(out_ch, in_ch)) + if self.mul_ref_label: + setattr(self, 'ref_label_down_%d' % i, + conv_2d_block(in_ch, out_ch, stride=2)) + setattr(self, 'ref_label_up_%d' % i, + conv_2d_block(out_ch, in_ch)) + + # Normalization / main branch conv weight generation. + if self.use_hyper_spade or self.use_hyper_conv: + for i in range(self.num_hyper_layers): + ch_in, ch_out = num_filters_each_layer[i], \ + num_filters_each_layer[i + 1] + conv_ks2 = conv_kernel_size ** 2 + embed_ks2 = embed_kernel_size ** 2 + spade_ks2 = kernel_size ** 2 + spade_in_ch = self.spade_in_channels[i] + + fc_names, fc_ins, fc_outs = [], [], [] + if self.use_hyper_spade: + fc0_out = fcs_out = (spade_in_ch * spade_ks2 + 1) * ( + 1 if self.conv_before_norm else 2) + fc1_out = (spade_in_ch * spade_ks2 + 1) * ( + 1 if ch_in != ch_out else 2) + fc_names += ['fc_spade_0', 'fc_spade_1', 'fc_spade_s'] + fc_ins += [ch_out] * 3 + fc_outs += [fc0_out, fc1_out, fcs_out] + if self.use_hyper_embed: + fc_names += ['fc_spade_e'] + fc_ins += [ch_out] + fc_outs += [ch_in * embed_ks2 + 1] + if self.use_hyper_conv: + fc0_out = ch_out * conv_ks2 + 1 + fc1_out = ch_in * conv_ks2 + 1 + fcs_out = ch_out + 1 + fc_names += ['fc_conv_0', 'fc_conv_1', 'fc_conv_s'] + fc_ins += [ch_in] * 3 + fc_outs += [fc0_out, fc1_out, fcs_out] + + linear_block = partial(LinearBlock, + weight_norm_type='spectral', + nonlinearity='leakyrelu') + for n, l in enumerate(fc_names): + fc_in = fc_ins[n] if self.mul_ref_label \ + else self.sh_fix * self.sw_fix + fc_layer = [linear_block(fc_in, ch_out)] + for k in range(1, self.num_fc_layers): + fc_layer += [linear_block(ch_out, ch_out)] + fc_layer += [LinearBlock(ch_out, fc_outs[n], + weight_norm_type='spectral')] + setattr(self, '%s_%d' % (l, i), nn.Sequential(*fc_layer)) + + # Label embedding network. + num_hyper_layers = self.num_hyper_layers if self.use_hyper_embed else 0 + self.label_embedding = LabelEmbedder(self.embed_cfg, + num_input_channels, + num_hyper_layers=num_hyper_layers) + + # For multiple reference images. + if hasattr(hyper_cfg, 'attention'): + self.num_downsample_atn = get_and_setattr(hyper_cfg.attention, + 'num_downsamples', 2) + if data_cfg.initial_few_shot_K > 1: + self.attention_module = AttentionModule(hyper_cfg, data_cfg, + conv_2d_block, + num_filters_each_layer) + else: + self.num_downsample_atn = 0 + + def forward(self, ref_image, ref_label, label, is_first_frame): + r"""Generate network weights based on the reference images. + + Args: + ref_image (NxKx3xHxW tensor): Reference images. + ref_label (NxKxCxHxW tensor): Reference labels. + label (NxCxHxW tensor): Target label. + is_first_frame (bool): Whether the current frame is the first frame. + + Returns: + (tuple): + - x (NxC2xH2xW2 tensor): Encoded features from reference images + for the main branch (as input to the decoder). + - encoded_label (list of tensors): Encoded target label map for + SPADE. + - conv_weights (list of tensors): Network weights for conv + layers in the main network. + - norm_weights (list of tensors): Network weights for SPADE + layers in the main network. + - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. + - atn_vis (1x1xH1xW1 tensor): Visualization for attention + scores. + - ref_idx (Nx1 tensor): Index for which image to use from the + reference images. + """ + b, k, c, h, w = ref_image.size() + ref_image = ref_image.view(b * k, -1, h, w) + if ref_label is not None: + ref_label = ref_label.view(b * k, -1, h, w) + + # Encode the reference images to get the features. + x, encoded_ref, atn, atn_vis, ref_idx = \ + self.encode_reference(ref_image, ref_label, label, k) + + # If the reference image has changed, recompute the network weights. + if self.training or is_first_frame or k > 1: + embedding_weights, norm_weights, conv_weights = [], [], [] + for i in range(self.num_hyper_layers): + if self.use_hyper_spade: + feat = encoded_ref[min(len(encoded_ref) - 1, i + 1)] + embedding_weight, norm_weight = \ + self.get_norm_weights(feat, i) + embedding_weights.append(embedding_weight) + norm_weights.append(norm_weight) + if self.use_hyper_conv: + feat = encoded_ref[min(len(encoded_ref) - 1, i)] + conv_weights.append(self.get_conv_weights(feat, i)) + + if not self.training: + self.embedding_weights, self.conv_weights, self.norm_weights \ + = embedding_weights, conv_weights, norm_weights + else: + # print('Reusing network weights.') + embedding_weights, conv_weights, norm_weights \ + = self.embedding_weights, self.conv_weights, self.norm_weights + + # Encode the target label to get the encoded features. + encoded_label = self.label_embedding(label, weights=( + embedding_weights if self.use_hyper_embed else None)) + + return x, encoded_label, conv_weights, norm_weights, \ + atn, atn_vis, ref_idx + + def encode_reference(self, ref_image, ref_label, label, k): + r"""Encode the reference image to get features for weight generation. + + Args: + ref_image ((NxK)x3xHxW tensor): Reference images. + ref_label ((NxK)xCxHxW tensor): Reference labels. + label (NxCxHxW tensor): Target label. + k (int): Number of reference images. + Returns: + (tuple): + - x (NxC2xH2xW2 tensor): Encoded features from reference images + for the main branch (as input to the decoder). + - encoded_ref (list of tensors): Encoded features from reference + images for the weight generation branch. + - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. + - atn_vis (1x1xH1xW1 tensor): Visualization for attention scores. + - ref_idx (Nx1 tensor): Index for which image to use from the + reference images. + """ + if self.concat_ref_label: + # Concat reference label map and image together for encoding. + concat_ref = torch.cat([ref_image, ref_label], dim=1) + x = self.ref_img_first(concat_ref) + elif self.mul_ref_label: + # Apply conv to both reference label and image, then multiply them + # together for encoding. + x = self.ref_img_first(ref_image) + x_label = self.ref_label_first(ref_label) + else: + x = self.ref_img_first(ref_image) + + # Attention map and the index of the most similar reference image. + atn = atn_vis = ref_idx = None + for i in range(self.num_downsamples): + x = getattr(self, 'ref_img_down_' + str(i))(x) + if self.mul_ref_label: + x_label = getattr(self, 'ref_label_down_' + str(i))(x_label) + + # Combine different reference images at a particular layer. + if k > 1 and i == self.num_downsample_atn - 1: + x, atn, atn_vis = self.attention_module(x, label, ref_label) + if self.mul_ref_label: + x_label, _, _ = self.attention_module(x_label, None, None, + atn) + + atn_sum = atn.view(label.shape[0], k, -1).sum(2) + ref_idx = torch.argmax(atn_sum, dim=1) + + # Get all corresponding layers in the encoder output for generating + # weights in corresponding layers. + encoded_image_ref = [x] + if self.mul_ref_label: + encoded_ref_label = [x_label] + + for i in reversed(range(self.num_downsamples)): + conv = getattr(self, 'ref_img_up_' + str(i))( + encoded_image_ref[-1]) + encoded_image_ref.append(conv) + if self.mul_ref_label: + conv_label = getattr(self, 'ref_label_up_' + str(i))( + encoded_ref_label[-1]) + encoded_ref_label.append(conv_label) + + if self.mul_ref_label: + encoded_ref = [] + for i in range(len(encoded_image_ref)): + conv, conv_label \ + = encoded_image_ref[i], encoded_ref_label[i] + b, c, h, w = conv.size() + conv_label = nn.Softmax(dim=1)(conv_label) + conv_prod = (conv.view(b, c, 1, h * w) * + conv_label.view(b, 1, c, + h * w)).sum(3, keepdim=True) + encoded_ref.append(conv_prod) + else: + encoded_ref = encoded_image_ref + encoded_ref = encoded_ref[::-1] + + return x, encoded_ref, atn, atn_vis, ref_idx + + def get_norm_weights(self, x, i): + r"""Adaptively generate weights for SPADE in layer i of generator. + + Args: + x (NxCxHxW tensor): Input features. + i (int): Layer index. + Returns: + (tuple): + - embedding_weights (list of tensors): Weights for the label + embedding network. + - norm_weights (list of tensors): Weights for the SPADE layers. + """ + if not self.mul_ref_label: + # Get fixed output size for fc layers. + x = nn.AdaptiveAvgPool2d((self.sh_fix, self.sw_fix))(x) + + in_ch = self.num_filters_each_layer[i] + out_ch = self.num_filters_each_layer[i + 1] + spade_ch = self.spade_in_channels[i] + eks, sks = self.embed_kernel_size, self.kernel_size + + b = x.size(0) + weight_reshaper = WeightReshaper() + x = weight_reshaper.reshape_embed_input(x) + + # Weights for the label embedding network. + embedding_weights = None + if self.use_hyper_embed: + fc_e = getattr(self, 'fc_spade_e_' + str(i))(x).view(b, -1) + if 'decoder' in self.embed_arch: + weight_shape = [in_ch, out_ch, eks, eks] + fc_e = fc_e[:, :-in_ch] + else: + weight_shape = [out_ch, in_ch, eks, eks] + embedding_weights = weight_reshaper.reshape_weight(fc_e, + weight_shape) + + # Weights for the 3 layers in SPADE module: conv_0, conv_1, + # and shortcut. + fc_0 = getattr(self, 'fc_spade_0_' + str(i))(x).view(b, -1) + fc_1 = getattr(self, 'fc_spade_1_' + str(i))(x).view(b, -1) + fc_s = getattr(self, 'fc_spade_s_' + str(i))(x).view(b, -1) + if self.conv_before_norm: + out_ch = in_ch + weight_0 = weight_reshaper.reshape_weight(fc_0, [out_ch * 2, spade_ch, + sks, sks]) + weight_1 = weight_reshaper.reshape_weight(fc_1, [in_ch * 2, spade_ch, + sks, sks]) + weight_s = weight_reshaper.reshape_weight(fc_s, [out_ch * 2, spade_ch, + sks, sks]) + norm_weights = [weight_0, weight_1, weight_s] + + return embedding_weights, norm_weights + + def get_conv_weights(self, x, i): + r"""Adaptively generate weights for layer i in main branch convolutions. + + Args: + x (NxCxHxW tensor): Input features. + i (int): Layer index. + Returns: + (tuple): + - conv_weights (list of tensors): Weights for the conv layers in + the main branch. + """ + if not self.mul_ref_label: + x = nn.AdaptiveAvgPool2d((self.sh_fix, self.sw_fix))(x) + in_ch = self.num_filters_each_layer[i] + out_ch = self.num_filters_each_layer[i + 1] + cks = self.conv_kernel_size + b = x.size()[0] + weight_reshaper = WeightReshaper() + x = weight_reshaper.reshape_embed_input(x) + + fc_0 = getattr(self, 'fc_conv_0_' + str(i))(x).view(b, -1) + fc_1 = getattr(self, 'fc_conv_1_' + str(i))(x).view(b, -1) + fc_s = getattr(self, 'fc_conv_s_' + str(i))(x).view(b, -1) + weight_0 = weight_reshaper.reshape_weight(fc_0, [in_ch, out_ch, + cks, cks]) + weight_1 = weight_reshaper.reshape_weight(fc_1, [in_ch, in_ch, + cks, cks]) + weight_s = weight_reshaper.reshape_weight(fc_s, [in_ch, out_ch, 1, 1]) + return [weight_0, weight_1, weight_s] + + def reset(self): + r"""Reset the network at the beginning of a sequence.""" + self.embedding_weights = self.conv_weights = self.norm_weights = None + + +class WeightReshaper(): + r"""Handles all weight reshape related tasks.""" + def reshape_weight(self, x, weight_shape): + r"""Reshape input x to the desired weight shape. + + Args: + x (tensor or list of tensors): Input features. + weight_shape (list of int): Desired shape of the weight. + Returns: + (tuple): + - weight (tensor): Network weights + - bias (tensor): Network bias. + """ + # If desired shape is a list, first divide x into the target list of + # features. + if type(weight_shape[0]) == list and type(x) != list: + x = self.split_weights(x, self.sum_mul(weight_shape)) + + if type(x) == list: + return [self.reshape_weight(xi, wi) + for xi, wi in zip(x, weight_shape)] + + # Get output shape, and divide x into either weight + bias or + # just weight. + weight_shape = [x.size(0)] + weight_shape + bias_size = weight_shape[1] + try: + weight = x[:, :-bias_size].view(weight_shape) + bias = x[:, -bias_size:] + except Exception: + weight = x.view(weight_shape) + bias = None + return [weight, bias] + + def split_weights(self, weight, sizes): + r"""When the desired shape is a list, first divide the input to each + corresponding weight shape in the list. + + Args: + weight (tensor): Input weight. + sizes (int or list of int): Target sizes. + Returns: + weight (list of tensors): Divided weights. + """ + if isinstance(sizes, list): + weights = [] + cur_size = 0 + for i in range(len(sizes)): + # For each target size in sizes, get the number of elements + # needed. + next_size = cur_size + self.sum(sizes[i]) + # Recursively divide the weights. + weights.append(self.split_weights( + weight[:, cur_size:next_size], sizes[i])) + cur_size = next_size + assert (next_size == weight.size(1)) + return weights + return weight + + def reshape_embed_input(self, x): + r"""Reshape input to be (B x C) X H X W. + + Args: + x (tensor or list of tensors): Input features. + Returns: + x (tensor or list of tensors): Reshaped features. + """ + if isinstance(x, list): + return [self.reshape_embed_input(xi) for xi in zip(x)] + b, c, _, _ = x.size() + x = x.view(b * c, -1) + return x + + def sum(self, x): + r"""Sum all elements recursively in a nested list. + + Args: + x (nested list of int): Input list of elements. + Returns: + out (int): Sum of all elements. + """ + if type(x) != list: + return x + return sum([self.sum(xi) for xi in x]) + + def sum_mul(self, x): + r"""Given a weight shape, compute the number of elements needed for + weight + bias. If input is a list of shapes, sum all the elements. + + Args: + x (list of int): Input list of elements. + Returns: + out (int or list of int): Summed number of elements. + """ + assert (type(x) == list) + if type(x[0]) != list: + return np.prod(x) + x[0] # x[0] accounts for bias. + return [self.sum_mul(xi) for xi in x] + + +class AttentionModule(nn.Module): + r"""Attention module constructor. + + Args: + atn_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file + conv_2d_block: Conv2DBlock constructor. + num_filters_each_layer (int): The number of filters in each layer. + """ + + def __init__(self, atn_cfg, data_cfg, conv_2d_block, + num_filters_each_layer): + super().__init__() + self.initial_few_shot_K = data_cfg.initial_few_shot_K + num_input_channels = data_cfg.num_input_channels + num_filters = getattr(atn_cfg, 'num_filters', 32) + + self.num_downsample_atn = getattr(atn_cfg, 'num_downsamples', 2) + self.atn_query_first = conv_2d_block(num_input_channels, num_filters) + self.atn_key_first = conv_2d_block(num_input_channels, num_filters) + for i in range(self.num_downsamples_atn): + f_in, f_out = num_filters_each_layer[i], \ + num_filters_each_layer[i + 1] + setattr(self, 'atn_key_%d' % i, + conv_2d_block(f_in, f_out, stride=2)) + setattr(self, 'atn_query_%d' % i, + conv_2d_block(f_in, f_out, stride=2)) + + def forward(self, in_features, label, ref_label, attention=None): + r"""Get the attention map to combine multiple image features in the + case of multiple reference images. + + Args: + in_features ((NxK)xC1xH1xW1 tensor): Input feaures. + label (NxC2xH2xW2 tensor): Target label. + ref_label (NxC2xH2xW2 tensor): Reference label. + attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. + Returns: + (tuple): + - out_features (NxC1xH1xW1 tensor): Attention-combined features. + - attention (Nx(KxH1xW1)x(H1xW1) tensor): Attention maps. + - atn_vis (1x1xH1xW1 tensor): Visualization for attention scores. + """ + b, c, h, w = in_features.size() + k = self.initial_few_shot_K + b = b // k + + if attention is None: + # Compute the attention map by encoding ref_label and label as + # key and query. The map represents how much energy for the k-th + # map at location (h_i, w_j) can contribute to the final map at + # location (h_i2, w_j2). + atn_key = self.attention_encode(ref_label, 'atn_key') + atn_query = self.attention_encode(label, 'atn_query') + + atn_key = atn_key.view(b, k, c, -1).permute( + 0, 1, 3, 2).contiguous().view(b, -1, c) # B X KHW X C + atn_query = atn_query.view(b, c, -1) # B X C X HW + energy = torch.bmm(atn_key, atn_query) # B X KHW X HW + attention = nn.Softmax(dim=1)(energy) + + # Combine the K features from different ref images into one by using + # the attention map. + in_features = in_features.view(b, k, c, h * w).permute( + 0, 2, 1, 3).contiguous().view(b, c, -1) # B X C X KHW + out_features = torch.bmm(in_features, attention).view(b, c, h, w) + + # Get a slice of the attention map for visualization. + atn_vis = attention.view(b, k, h * w, h * w).sum(2).view(b, k, h, w) + return out_features, attention, atn_vis[-1:, 0:1] + + def attention_encode(self, img, net_name): + r"""Encode the input image to get the attention map. + + Args: + img (NxCxHxW tensor): Input image. + net_name (str): Name for attention network. + Returns: + x (NxC2xH2xW2 tensor): Encoded feature. + """ + x = getattr(self, net_name + '_first')(img) + for i in range(self.num_downsample_atn): + x = getattr(self, net_name + '_' + str(i))(x) + return x + + +class FlowGenerator(nn.Module): + r"""flow generator constructor. + + Args: + flow_cfg (obj): Flow definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + num_frames (int): Number of input frames. + """ + + def __init__(self, flow_cfg, data_cfg, num_frames): + super().__init__() + num_input_channels = data_cfg.num_input_channels + if num_input_channels == 0: + num_input_channels = 1 + num_prev_img_channels = get_paired_input_image_channel_number(data_cfg) + num_downsamples = getattr(flow_cfg, 'num_downsamples', 3) + kernel_size = getattr(flow_cfg, 'kernel_size', 3) + padding = kernel_size // 2 + num_blocks = getattr(flow_cfg, 'num_blocks', 6) + num_filters = getattr(flow_cfg, 'num_filters', 32) + max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024) + num_filters_each_layer = [min(max_num_filters, num_filters * (2 ** i)) + for i in range(num_downsamples + 1)] + + self.flow_output_multiplier = getattr(flow_cfg, + 'flow_output_multiplier', 20) + self.sep_up_mask = getattr(flow_cfg, 'sep_up_mask', False) + activation_norm_type = getattr(flow_cfg, 'activation_norm_type', + 'sync_batch') + weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral') + + base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity='leakyrelu') + + num_input_channels = num_input_channels * num_frames + \ + num_prev_img_channels * (num_frames - 1) + # First layer. + down_flow = [base_conv_block(num_input_channels, num_filters)] + + # Downsamples. + for i in range(num_downsamples): + down_flow += [base_conv_block(num_filters_each_layer[i], + num_filters_each_layer[i + 1], + stride=2)] + + # Resnet blocks. + res_flow = [] + ch = num_filters_each_layer[num_downsamples] + for i in range(num_blocks): + res_flow += [ + Res2dBlock(ch, ch, kernel_size, padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + order='NACNAC')] + + # Upsamples. + up_flow = [] + for i in reversed(range(num_downsamples)): + up_flow += [nn.Upsample(scale_factor=2), + base_conv_block(num_filters_each_layer[i + 1], + num_filters_each_layer[i])] + + conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)] + conv_mask = [Conv2dBlock(num_filters, 1, kernel_size, padding=padding, + nonlinearity='sigmoid')] + + self.down_flow = nn.Sequential(*down_flow) + self.res_flow = nn.Sequential(*res_flow) + self.up_flow = nn.Sequential(*up_flow) + if self.sep_up_mask: + self.up_mask = nn.Sequential(*copy.deepcopy(up_flow)) + self.conv_flow = nn.Sequential(*conv_flow) + self.conv_mask = nn.Sequential(*conv_mask) + + def forward(self, label, ref_label, ref_image): + r"""Flow generator forward. + + Args: + label (4D tensor) : Input label tensor. + ref_label (4D tensor) : Reference label tensors. + ref_image (4D tensor) : Reference image tensors. + Returns: + (tuple): + - flow (4D tensor) : Generated flow map. + - mask (4D tensor) : Generated occlusion mask. + """ + label_concat = torch.cat([label, ref_label, ref_image], dim=1) + downsample = self.down_flow(label_concat) + res = self.res_flow(downsample) + flow_feat = self.up_flow(res) + flow = self.conv_flow(flow_feat) * self.flow_output_multiplier + + mask_feat = self.up_mask(res) if self.sep_up_mask else flow_feat + mask = self.conv_mask(mask_feat) + return flow, mask + + +class LabelEmbedder(nn.Module): + r"""Embed the input label map to get embedded features. + + Args: + emb_cfg (obj): Embed network configuration. + num_input_channels (int): Number of input channels. + num_hyper_layers (int): Number of hyper layers. + """ + + def __init__(self, emb_cfg, num_input_channels, num_hyper_layers=0): + super().__init__() + num_filters = getattr(emb_cfg, 'num_filters', 32) + max_num_filters = getattr(emb_cfg, 'max_num_filters', 1024) + self.arch = getattr(emb_cfg, 'arch', 'encoderdecoder') + self.num_downsamples = num_downsamples = \ + getattr(emb_cfg, 'num_downsamples', 5) + kernel_size = getattr(emb_cfg, 'kernel_size', 3) + weight_norm_type = getattr(emb_cfg, 'weight_norm_type', 'spectral') + activation_norm_type = getattr(emb_cfg, 'activation_norm_type', 'none') + + self.unet = 'unet' in self.arch + self.has_decoder = 'decoder' in self.arch or self.unet + self.num_hyper_layers = num_hyper_layers \ + if num_hyper_layers != -1 else num_downsamples + + base_conv_block = partial(HyperConv2dBlock, kernel_size=kernel_size, + padding=(kernel_size // 2), + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity='leakyrelu') + + ch = [min(max_num_filters, num_filters * (2 ** i)) + for i in range(num_downsamples + 1)] + + self.conv_first = base_conv_block(num_input_channels, num_filters, + activation_norm_type='none') + + # Downsample. + for i in range(num_downsamples): + is_hyper_conv = (i < num_hyper_layers) and not self.has_decoder + setattr(self, 'down_%d' % i, + base_conv_block(ch[i], ch[i + 1], stride=2, + is_hyper_conv=is_hyper_conv)) + + # Upsample. + if self.has_decoder: + self.upsample = nn.Upsample(scale_factor=2) + for i in reversed(range(num_downsamples)): + ch_i = ch[i + 1] * ( + 2 if self.unet and i != num_downsamples - 1 else 1) + setattr(self, 'up_%d' % i, + base_conv_block(ch_i, ch[i], + is_hyper_conv=(i < num_hyper_layers))) + + def forward(self, input, weights=None): + r"""Embedding network forward. + + Args: + input (NxCxHxW tensor): Network input. + weights (list of tensors): Conv weights if using hyper network. + Returns: + output (list of tensors): Network outputs at different layers. + """ + if input is None: + return None + output = [self.conv_first(input)] + + for i in range(self.num_downsamples): + layer = getattr(self, 'down_%d' % i) + # For hyper networks, the hyper layers are at the last few layers + # of decoder (if the network has a decoder). Otherwise, the hyper + # layers will be at the first few layers of the network. + if i >= self.num_hyper_layers or self.has_decoder: + conv = layer(output[-1]) + else: + conv = layer(output[-1], conv_weights=weights[i]) + # We will use outputs from different layers as input to different + # SPADE layers in the main branch. + output.append(conv) + + if not self.has_decoder: + return output + + # If the network has a decoder, will use outputs from the decoder + # layers instead of the encoding layers. + if not self.unet: + output = [output[-1]] + + for i in reversed(range(self.num_downsamples)): + input_i = output[-1] + if self.unet and i != self.num_downsamples - 1: + input_i = torch.cat([input_i, output[i + 1]], dim=1) + + input_i = self.upsample(input_i) + layer = getattr(self, 'up_%d' % i) + # The last few layers will be hyper layers if necessary. + if i >= self.num_hyper_layers: + conv = layer(input_i) + else: + conv = layer(input_i, conv_weights=weights[i]) + output.append(conv) + + if self.unet: + output = output[self.num_downsamples:] + return output[::-1] diff --git a/imaginaire/generators/funit.py b/imaginaire/generators/funit.py new file mode 100644 index 0000000000000000000000000000000000000000..6520166a4f906afe3c5cd2fda09a6fbc11502213 --- /dev/null +++ b/imaginaire/generators/funit.py @@ -0,0 +1,400 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from functools import partial +from types import SimpleNamespace + +import torch +from torch import nn + +from imaginaire.layers import \ + (Conv2dBlock, LinearBlock, Res2dBlock, UpRes2dBlock) + + +class Generator(nn.Module): + r"""Generator of the improved FUNIT baseline in the COCO-FUNIT paper. + """ + + def __init__(self, gen_cfg, data_cfg): + super().__init__() + self.generator = FUNITTranslator(**vars(gen_cfg)) + + def forward(self, data): + r"""In the FUNIT's forward pass, it generates a content embedding and + a style code from the content image, and a style code from the style + image. By mixing the content code and the style code from the content + image, we reconstruct the input image. By mixing the content code and + the style code from the style image, we have a translation output. + + Args: + data (dict): Training data at the current iteration. + """ + content_a = self.generator.content_encoder(data['images_content']) + style_a = self.generator.style_encoder(data['images_content']) + style_b = self.generator.style_encoder(data['images_style']) + images_trans = self.generator.decode(content_a, style_b) + images_recon = self.generator.decode(content_a, style_a) + + net_G_output = dict(images_trans=images_trans, + images_recon=images_recon) + return net_G_output + + def inference(self, data, keep_original_size=True): + r"""COCO-FUNIT inference. + + Args: + data (dict): Training data at the current iteration. + - images_content (tensor): Content images. + - images_style (tensor): Style images. + a2b (bool): If ``True``, translates images from domain A to B, + otherwise from B to A. + keep_original_size (bool): If ``True``, output image is resized + to the input content image size. + """ + content_a = self.generator.content_encoder(data['images_content']) + style_b = self.generator.style_encoder(data['images_style']) + output_images = self.generator.decode(content_a, style_b) + if keep_original_size: + height = data['original_h_w'][0][0] + width = data['original_h_w'][0][1] + # print('( H, W) = ( %d, %d)' % (height, width)) + output_images = torch.nn.functional.interpolate( + output_images, size=[height, width]) + file_names = data['key']['images_content'][0] + return output_images, file_names + + +class FUNITTranslator(nn.Module): + r""" + + Args: + num_filters (int): Base filter numbers. + num_filters_mlp (int): Base filter number in the MLP module. + style_dims (int): Dimension of the style code. + num_res_blocks (int): Number of residual blocks at the end of the + content encoder. + num_mlp_blocks (int): Number of layers in the MLP module. + num_downsamples_content (int): Number of times we reduce + resolution by 2x2 for the content image. + num_downsamples_style (int): Number of times we reduce + resolution by 2x2 for the style image. + num_image_channels (int): Number of input image channels. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + """ + + def __init__(self, + num_filters=64, + num_filters_mlp=256, + style_dims=64, + num_res_blocks=2, + num_mlp_blocks=3, + num_downsamples_style=4, + num_downsamples_content=2, + num_image_channels=3, + weight_norm_type='', + **kwargs): + super().__init__() + + self.style_encoder = StyleEncoder(num_downsamples_style, + num_image_channels, + num_filters, + style_dims, + 'reflect', + 'none', + weight_norm_type, + 'relu') + + self.content_encoder = ContentEncoder(num_downsamples_content, + num_res_blocks, + num_image_channels, + num_filters, + 'reflect', + 'instance', + weight_norm_type, + 'relu') + + self.decoder = Decoder(self.content_encoder.output_dim, + num_filters_mlp, + num_image_channels, + num_downsamples_content, + 'reflect', + weight_norm_type, + 'relu') + + self.mlp = MLP(style_dims, + num_filters_mlp, + num_filters_mlp, + num_mlp_blocks, + 'none', + 'relu') + + def forward(self, images): + r"""Reconstruct the input image by combining the computer content and + style code. + + Args: + images (tensor): Input image tensor. + """ + # reconstruct an image + content, style = self.encode(images) + images_recon = self.decode(content, style) + return images_recon + + def encode(self, images): + r"""Encoder images to get their content and style codes. + + Args: + images (tensor): Input image tensor. + """ + style = self.style_encoder(images) + content = self.content_encoder(images) + return content, style + + def decode(self, content, style): + r"""Generate images by combining their content and style codes. + + Args: + content (tensor): Content code tensor. + style (tensor): Style code tensor. + """ + style = self.mlp(style) + images = self.decoder(content, style) + return images + + +class Decoder(nn.Module): + r"""Improved FUNIT decoder. + + Args: + num_enc_output_channels (int): Number of content feature channels. + style_channels (int): Dimension of the style code. + num_image_channels (int): Number of image channels. + num_upsamples (int): How many times we are going to apply + upsample residual block. + """ + + def __init__(self, + num_enc_output_channels, + style_channels, + num_image_channels=3, + num_upsamples=4, + padding_type='reflect', + weight_norm_type='none', + nonlinearity='relu'): + super(Decoder, self).__init__() + adain_params = SimpleNamespace( + activation_norm_type='instance', + activation_norm_params=SimpleNamespace(affine=False), + cond_dims=style_channels) + + base_res_block = partial(Res2dBlock, + kernel_size=3, + padding=1, + padding_mode=padding_type, + nonlinearity=nonlinearity, + activation_norm_type='adaptive', + activation_norm_params=adain_params, + weight_norm_type=weight_norm_type, + learn_shortcut=False) + + base_up_res_block = partial(UpRes2dBlock, + kernel_size=5, + padding=2, + padding_mode=padding_type, + weight_norm_type=weight_norm_type, + activation_norm_type='adaptive', + activation_norm_params=adain_params, + skip_activation_norm='instance', + skip_nonlinearity=nonlinearity, + nonlinearity=nonlinearity, + hidden_channels_equal_out_channels=True, + learn_shortcut=True) + + dims = num_enc_output_channels + + # Residual blocks with AdaIN. + self.decoder = nn.ModuleList() + self.decoder += [base_res_block(dims, dims)] + self.decoder += [base_res_block(dims, dims)] + for _ in range(num_upsamples): + self.decoder += [base_up_res_block(dims, dims // 2)] + dims = dims // 2 + self.decoder += [Conv2dBlock(dims, + num_image_channels, + kernel_size=7, + stride=1, + padding=3, + padding_mode='reflect', + nonlinearity='tanh')] + + def forward(self, x, style): + r""" + + Args: + x (tensor): Content embedding of the content image. + style (tensor): Style embedding of the style image. + """ + for block in self.decoder: + if getattr(block, 'conditional', False): + x = block(x, style) + else: + x = block(x) + return x + + +class StyleEncoder(nn.Module): + r"""Improved FUNIT Style Encoder. This is basically the same as the + original FUNIT Style Encoder. + + Args: + num_downsamples (int): Number of times we reduce resolution by + 2x2. + image_channels (int): Number of input image channels. + num_filters (int): Base filter number. + style_channels (int): Style code dimension. + padding_mode (str): Padding mode. + activation_norm_type (str): Type of activation normalization. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + nonlinearity (str): Nonlinearity. + """ + + def __init__(self, + num_downsamples, + image_channels, + num_filters, + style_channels, + padding_mode, + activation_norm_type, + weight_norm_type, + nonlinearity): + super().__init__() + conv_params = dict(padding_mode=padding_mode, + activation_norm_type=activation_norm_type, + weight_norm_type=weight_norm_type, + nonlinearity=nonlinearity, + inplace_nonlinearity=True) + model = [] + model += [Conv2dBlock(image_channels, num_filters, 7, 1, 3, + **conv_params)] + for i in range(2): + model += [Conv2dBlock(num_filters, 2 * num_filters, 4, 2, 1, + **conv_params)] + num_filters *= 2 + for i in range(num_downsamples - 2): + model += [Conv2dBlock(num_filters, num_filters, 4, 2, 1, + **conv_params)] + model += [nn.AdaptiveAvgPool2d(1)] + model += [nn.Conv2d(num_filters, style_channels, 1, 1, 0)] + self.model = nn.Sequential(*model) + self.output_dim = num_filters + + def forward(self, x): + r""" + + Args: + x (tensor): Input image. + """ + return self.model(x) + + +class ContentEncoder(nn.Module): + r"""Improved FUNIT Content Encoder. This is basically the same as the + original FUNIT content encoder. + + Args: + num_downsamples (int): Number of times we reduce resolution by + 2x2. + num_res_blocks (int): Number of times we append residual block + after all the downsampling modules. + image_channels (int): Number of input image channels. + num_filters (int): Base filter number. + padding_mode (str): Padding mode + activation_norm_type (str): Type of activation normalization. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + nonlinearity (str): Nonlinearity. + """ + + def __init__(self, + num_downsamples, + num_res_blocks, + image_channels, + num_filters, + padding_mode, + activation_norm_type, + weight_norm_type, + nonlinearity): + super().__init__() + conv_params = dict(padding_mode=padding_mode, + activation_norm_type=activation_norm_type, + weight_norm_type=weight_norm_type, + nonlinearity=nonlinearity, + inplace_nonlinearity=True, + order='CNACNA') + model = [] + model += [Conv2dBlock(image_channels, num_filters, 7, 1, 3, + **conv_params)] + dims = num_filters + for i in range(num_downsamples): + model += [Conv2dBlock(dims, dims * 2, 4, 2, 1, **conv_params)] + dims *= 2 + + for _ in range(num_res_blocks): + model += [Res2dBlock(dims, dims, learn_shortcut=False, **conv_params)] + self.model = nn.Sequential(*model) + self.output_dim = dims + + def forward(self, x): + r""" + + Args: + x (tensor): Input image. + """ + return self.model(x) + + +class MLP(nn.Module): + r"""Improved FUNIT style decoder. + + Args: + input_dim (int): Input dimension (style code dimension). + output_dim (int): Output dimension (to be fed into the AdaIN + layer). + latent_dim (int): Latent dimension. + num_layers (int): Number of layers in the MLP. + activation_norm_type (str): Activation type. + nonlinearity (str): Nonlinearity type. + """ + + def __init__(self, + input_dim, + output_dim, + latent_dim, + num_layers, + activation_norm_type, + nonlinearity): + super().__init__() + model = [] + model += [LinearBlock(input_dim, latent_dim, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity)] + # changed from num_layers - 2 to num_layers - 3. + for i in range(num_layers - 3): + model += [LinearBlock(latent_dim, latent_dim, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity)] + model += [LinearBlock(latent_dim, output_dim, + activation_norm_type=activation_norm_type, + nonlinearity=nonlinearity)] + self.model = nn.Sequential(*model) + + def forward(self, x): + r""" + + Args: + x (tensor): Input tensor. + """ + return self.model(x.view(x.size(0), -1)) diff --git a/imaginaire/generators/gancraft.py b/imaginaire/generators/gancraft.py new file mode 100644 index 0000000000000000000000000000000000000000..94fc34bee88e31fcdcf48f715f4d17f3de5bc37b --- /dev/null +++ b/imaginaire/generators/gancraft.py @@ -0,0 +1,538 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os + +import cv2 +import imageio +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import imaginaire.model_utils.gancraft.camctl as camctl +import imaginaire.model_utils.gancraft.mc_utils as mc_utils +import imaginaire.model_utils.gancraft.voxlib as voxlib +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.generators.gancraft_base import Base3DGenerator, RenderMLP # noqa + + +class Generator(Base3DGenerator): + r"""GANcraft generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super(Generator, self).__init__(gen_cfg, data_cfg) + print('GANcraft generator initialization.') + + # Load voxels of the input world. + # The loaded voxel tensor has a shape of [X, Y, Z], dtype==torch.int32 + # 0 means empty (air). + print('[Generator] Loading voxel world: ', gen_cfg.voxel_path) + if gen_cfg.voxel_path.endswith('.npy'): + voxel_t = np.load(gen_cfg.voxel_path) + voxel_t = torch.from_numpy(voxel_t.astype(np.int32)) + else: + voxel_t = mc_utils.load_voxel_new(gen_cfg.voxel_path, shape=gen_cfg.voxel_shape) + print('[Generator] Loaded voxel world.') + self.voxel = mc_utils.McVoxel(voxel_t, preproc_ver=gen_cfg.voxel_preproc_ver) + blk_feats = torch.empty([self.voxel.nfilledvox, gen_cfg.blk_feat_dim], requires_grad=True) + self.blk_feats = nn.Parameter(blk_feats) # Feature per voxel corner. + + # Minecraft -> SPADE label translator. + self.label_trans = mc_utils.MCLabelTranslator() + self.num_reduced_labels = self.label_trans.get_num_reduced_lbls() + self.reduced_label_set = getattr(gen_cfg, 'reduced_label_set', False) + self.use_label_smooth = getattr(gen_cfg, 'use_label_smooth', False) + self.use_label_smooth_real = getattr(gen_cfg, 'use_label_smooth_real', self.use_label_smooth) + self.use_label_smooth_pgt = getattr(gen_cfg, 'use_label_smooth_pgt', False) + self.label_smooth_dia = getattr(gen_cfg, 'label_smooth_dia', 11) + + # Load MLP model. + self.render_net = globals()[gen_cfg.mlp_model]( + self.input_dim, viewdir_dim=self.input_dim_viewdir, style_dim=self.interm_style_dims, + mask_dim=self.num_reduced_labels, out_channels_s=1, out_channels_c=self.final_feat_dim, + **self.mlp_model_kwargs) + + # Camera sampler. + self.camera_sampler_type = getattr(gen_cfg, 'camera_sampler_type', "random") + assert self.camera_sampler_type in ['random', 'traditional'] + self.camera_min_entropy = getattr(gen_cfg, 'camera_min_entropy', -1) + self.camera_rej_avg_depth = getattr(gen_cfg, 'camera_rej_avg_depth', -1) + self.cam_res = gen_cfg.cam_res + self.crop_size = gen_cfg.crop_size + + print('Done with the GANcraft generator initialization.') + + def custom_init(self): + r"""Weight initialization of GANcraft components.""" + + self.blk_feats.data.uniform_(-1, 1) + + def init_func(m): + if hasattr(m, 'weight'): + nn.init.kaiming_normal_(m.weight.data, a=0.2, nonlinearity='leaky_relu') + m.weight.data *= 0.5 + if hasattr(m, 'bias') and m.bias is not None: + m.bias.data.fill_(0.0) + self.apply(init_func) + + def _get_batch(self, batch_size, device): + r"""Sample camera poses and perform ray-voxel intersection. + + Args: + batch_size (int): Expected batch size of the current batch + device (torch.device): Device on which the tensors should be stored + """ + with torch.no_grad(): + voxel_id_batch = [] + depth2_batch = [] + raydirs_batch = [] + cam_ori_t_batch = [] + for b in range(batch_size): + while True: # Rejection sampling. + # Sample camera pose. + if self.camera_sampler_type == 'random': + cam_res = self.cam_res + cam_ori_t, cam_dir_t, cam_up_t = camctl.rand_camera_pose_thridperson2(self.voxel) + # ~24mm fov horizontal. + cam_f = 0.5/np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1) + cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2] + cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad] + cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop) + elif self.camera_sampler_type == 'traditional': + cam_res = self.cam_res + cam_c = [(cam_res[0]-1)/2, (cam_res[1]-1)/2] + dice = torch.rand(1).item() + if dice > 0.5: + cam_ori_t, cam_dir_t, cam_up_t, cam_f = \ + camctl.rand_camera_pose_tour(self.voxel) + cam_f = cam_f * (cam_res[1]-1) + else: + cam_ori_t, cam_dir_t, cam_up_t = \ + camctl.rand_camera_pose_thridperson2(self.voxel) + # ~24mm fov horizontal. + cam_f = 0.5 / np.tan(np.deg2rad(73/2) * (np.random.rand(1)*0.5+0.5)) * (cam_res[1]-1) + + cam_res_crop = [self.crop_size[0] + self.pad, self.crop_size[1] + self.pad] + cam_c = mc_utils.rand_crop(cam_c, cam_res, cam_res_crop) + else: + raise NotImplementedError( + 'Unknown self.camera_sampler_type: {}'.format(self.camera_sampler_type)) + # Run ray-voxel intersection test + r"""Ray-voxel intersection CUDA kernel. + Note: voxel_id = 0 and depth2 = NaN if there is no intersection along the ray + + Args: + voxel_t (Y x 512 x 512 tensor, int32): Full 3D voxel of MC block IDs. + cam_ori_t (3 tensor): Camera origin. + cam_dir_t (3 tensor): Camera direction. + cam_up_t (3 tensor): Camera up vector. + cam_f (float): Camera focal length (in pixels). + cam_c (list of 2 floats [x, y]): Camera optical center. + img_dims (list of 2 ints [H, W]): Camera resolution. + max_samples (int): Maximum number of blocks intersected along the ray before stopping. + Returns: + voxel_id ( img_dims[0] x img_dims[1] x max_samples x 1 tensor): IDs of intersected tensors + along each ray + depth2 (2 x img_dims[0] x img_dims[1] x max_samples x 1 tensor): Depths of entrance and exit + points for each ray-voxel intersection. + raydirs ( img_dims[0] x img_dims[1] x 1 x 3 tensor): The direction of each ray. + + """ + voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective( + self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, cam_res_crop, + self.num_blocks_early_stop) + + if self.camera_rej_avg_depth > 0: + depth_map = depth2[0, :, :, 0, :] + avg_depth = torch.mean(depth_map[~torch.isnan(depth_map)]) + if avg_depth < self.camera_rej_avg_depth: + continue + + # Reject low entropy. + if self.camera_min_entropy > 0: + # Check entropy. + maskcnt = torch.bincount( + torch.flatten(voxel_id[:, :, 0, 0]), weights=None, minlength=680).float() / \ + (voxel_id.size(0)*voxel_id.size(1)) + maskentropy = -torch.sum(maskcnt * torch.log(maskcnt+1e-10)) + if maskentropy < self.camera_min_entropy: + continue + break + + voxel_id_batch.append(voxel_id) + depth2_batch.append(depth2) + raydirs_batch.append(raydirs) + cam_ori_t_batch.append(cam_ori_t) + voxel_id = torch.stack(voxel_id_batch, dim=0) + depth2 = torch.stack(depth2_batch, dim=0) + raydirs = torch.stack(raydirs_batch, dim=0) + cam_ori_t = torch.stack(cam_ori_t_batch, dim=0).to(device) + cam_poses = None + return voxel_id, depth2, raydirs, cam_ori_t, cam_poses + + def get_pseudo_gt(self, pseudo_gen, voxel_id, z=None, style_img=None, resize_512=True, deterministic=False): + r"""Evaluating img2img network to obtain pseudo-ground truth images. + + Args: + pseudo_gen (callable): Function converting mask to image using img2img network. + voxel_id (N x img_dims[0] x img_dims[1] x max_samples x 1 tensor): IDs of intersected tensors along + each ray. + z (N x C tensor): Optional style code passed to pseudo_gen. + style_img (N x 3 x H x W tensor): Optional style image passed to pseudo_gen. + resize_512 (bool): If True, evaluate pseudo_gen at 512x512 regardless of input resolution. + deterministic (bool): If True, disable stochastic label mapping. + """ + with torch.no_grad(): + mc_mask = voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long() + coco_mask = self.label_trans.mc2coco(mc_mask) - 1 + coco_mask[coco_mask < 0] = 183 + + if not deterministic: + # Stochastic mapping + dice = torch.rand(1).item() + if dice > 0.5 and dice < 0.9: + coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('clouds') + elif dice >= 0.9: + coco_mask[coco_mask == self.label_trans.gglbl2ggid('sky')] = self.label_trans.gglbl2ggid('fog') + dice = torch.rand(1).item() + if dice > 0.33 and dice < 0.66: + coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('sea') + elif dice >= 0.66: + coco_mask[coco_mask == self.label_trans.gglbl2ggid('water')] = self.label_trans.gglbl2ggid('river') + + fake_masks = torch.zeros([coco_mask.size(0), 185, coco_mask.size(2), coco_mask.size(3)], + dtype=torch.half, device=voxel_id.device) + fake_masks.scatter_(1, coco_mask, 1.0) + + if self.use_label_smooth_pgt: + fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia) + if self.pad > 0: + fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + + # Generate pseudo GT using GauGAN. + if resize_512: + fake_masks_512 = F.interpolate(fake_masks, size=[512, 512], mode='nearest') + else: + fake_masks_512 = fake_masks + pseudo_real_img = pseudo_gen(fake_masks_512, z=z, style_img=style_img) + + # NaN Inf Guard. NaN can occure on Volta GPUs. + nan_mask = torch.isnan(pseudo_real_img) + inf_mask = torch.isinf(pseudo_real_img) + pseudo_real_img[nan_mask | inf_mask] = 0.0 + if resize_512: + pseudo_real_img = F.interpolate( + pseudo_real_img, size=[fake_masks.size(2), fake_masks.size(3)], mode='area') + pseudo_real_img = torch.clamp(pseudo_real_img, -1, 1) + + return pseudo_real_img, fake_masks + + def sample_camera(self, data, pseudo_gen): + r"""Sample camera randomly and precompute everything used by both Gen and Dis. + + Args: + data (dict): + images (N x 3 x H x W tensor) : Real images + label (N x C2 x H x W tensor) : Segmentation map + pseudo_gen (callable): Function converting mask to image using img2img network. + Returns: + ret (dict): + voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray. + depth2 (N x 2 x H x W x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel + intersection. + raydirs (N x H x W x 1 x 3 tensor): The direction of each ray. + cam_ori_t (N x 3 tensor): Camera origins. + pseudo_real_img (N x 3 x H x W tensor): Pseudo-ground truth image. + real_masks (N x C3 x H x W tensor): One-hot segmentation map for real images, with translated labels. + fake_masks (N x C3 x H x W tensor): One-hot segmentation map for sampled camera views. + """ + device = torch.device('cuda') + batch_size = data['images'].size(0) + # ================ Assemble a batch ================== + # Requires: voxel_id, depth2, raydirs, cam_ori_t. + voxel_id, depth2, raydirs, cam_ori_t, _ = self._get_batch(batch_size, device) + ret = {'voxel_id': voxel_id, 'depth2': depth2, 'raydirs': raydirs, 'cam_ori_t': cam_ori_t} + + if pseudo_gen is not None: + pseudo_real_img, _ = self.get_pseudo_gt(pseudo_gen, voxel_id) + ret['pseudo_real_img'] = pseudo_real_img.float() + + # =============== Mask translation ================ + real_masks = data['label'] + if self.reduced_label_set: + # Translate fake mask (directly from mcid). + # convert unrecognized labels to 'dirt'. + # N C H W [1 1 80 80] + reduce_fake_mask = self.label_trans.mc2reduced( + voxel_id[:, :, :, 0, :].permute(0, 3, 1, 2).long(), ign2dirt=True) + reduce_fake_mask_onehot = torch.zeros([ + reduce_fake_mask.size(0), self.num_reduced_labels, reduce_fake_mask.size(2), reduce_fake_mask.size(3)], + dtype=torch.float, device=device) + reduce_fake_mask_onehot.scatter_(1, reduce_fake_mask, 1.0) + fake_masks = reduce_fake_mask_onehot + if self.pad != 0: + fake_masks = fake_masks[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + + # Translate real mask (data['label']), which is onehot. + real_masks_idx = torch.argmax(real_masks, dim=1, keepdim=True) + real_masks_idx[real_masks_idx > 182] = 182 + + reduced_real_mask = self.label_trans.coco2reduced(real_masks_idx) + reduced_real_mask_onehot = torch.zeros([ + reduced_real_mask.size(0), self.num_reduced_labels, reduced_real_mask.size(2), + reduced_real_mask.size(3)], dtype=torch.float, device=device) + reduced_real_mask_onehot.scatter_(1, reduced_real_mask, 1.0) + real_masks = reduced_real_mask_onehot + + # Mask smoothing. + if self.use_label_smooth: + fake_masks = mc_utils.segmask_smooth(fake_masks, kernel_size=self.label_smooth_dia) + if self.use_label_smooth_real: + real_masks = mc_utils.segmask_smooth(real_masks, kernel_size=self.label_smooth_dia) + + ret['real_masks'] = real_masks + ret['fake_masks'] = fake_masks + + return ret + + def forward(self, data, random_style=False): + r"""GANcraft Generator forward. + + Args: + data (dict): + images (N x 3 x H x W tensor) : Real images + voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray. + depth2 (N x 2 x H x W x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel + intersection. + raydirs (N x H x W x 1 x 3 tensor): The direction of each ray. + cam_ori_t (N x 3 tensor): Camera origins. + random_style (bool): Whether to sample a random style vector. + Returns: + output (dict): + fake_images (N x 3 x H x W tensor): fake images + mu (N x C1 tensor): mean vectors + logvar (N x C1 tensor): log-variance vectors + """ + device = torch.device('cuda') + batch_size = data['images'].size(0) + + # ================ Assemble a batch ================== + # Requires: voxel_id, depth2, raydirs, cam_ori_t. + voxel_id, depth2, raydirs, cam_ori_t = data['voxel_id'], data['depth2'], data['raydirs'], data['cam_ori_t'] + if 'pseudo_real_img' in data: + pseudo_real_img = data['pseudo_real_img'] + + z, mu, logvar = None, None, None + if random_style: + if self.style_dims > 0: + z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device) + else: + if self.style_encoder is None: + # ================ Get Style Code ================= + if self.style_dims > 0: + z = torch.randn(batch_size, self.style_dims, dtype=torch.float32, device=device) + else: + mu, logvar, z = self.style_encoder(pseudo_real_img) + + # ================ Network Forward ================ + # Forward StyleNet + if self.style_net is not None: + z = self.style_net(z) + + # Forward per-pixel net. + net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, nosky_mask, \ + sky_mask, sky_only_mask, new_idx = self._forward_perpix( + self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z) + + # Forward global net. + fake_images, fake_images_raw = self._forward_global(net_out, z) + if self.pad != 0: + fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + + # =============== Arrange Return Values ================ + output = {} + output['fake_images'] = fake_images + output['mu'] = mu + output['logvar'] = logvar + return output + + def inference(self, + output_dir, + camera_mode, + style_img_path=None, + seed=1, + pad=30, + num_samples=40, + num_blocks_early_stop=6, + sample_depth=3, + tile_size=128, + resolution_hw=[540, 960], + cam_ang=72, + cam_maxstep=10): + r"""Compute result images according to the provided camera trajectory and save the results in the specified + folder. The full image is evaluated in multiple tiles to save memory. + + Args: + output_dir (str): Where should the results be stored. + camera_mode (int): Which camera trajectory to use. + style_img_path (str): Path to the style-conditioning image. + seed (int): Random seed (controls style when style_image_path is not specified). + pad (int): Pixels to remove from the image tiles before stitching. Should be equal or larger than the + receptive field of the CNN to avoid border artifact. + num_samples (int): Number of samples per ray (different from training). + num_blocks_early_stop (int): Max number of intersected boxes per ray before stopping + (different from training). + sample_depth (float): Max distance traveled through boxes before stopping (different from training). + tile_size (int): Max size of a tile in pixels. + resolution_hw (list [H, W]): Resolution of the output image. + cam_ang (float): Horizontal FOV of the camera (may be adjusted by the camera controller). + cam_maxstep (int): Number of frames sampled from the camera trajectory. + """ + + def write_img(path, img, rgb_input=False): + img = ((img*0.5+0.5)*255).detach().cpu().numpy().astype(np.uint8) + img = img[0].transpose(1, 2, 0) + if rgb_input: + img = img[..., [2, 1, 0]] + cv2.imwrite(path, img, [cv2.IMWRITE_PNG_COMPRESSION, 4]) + return img[..., ::-1] + + def read_img(path): + img = cv2.imread(path).astype(np.float32)[..., [2, 1, 0]].transpose(2, 0, 1) / 255 + img = img * 2 - 1 + img = torch.from_numpy(img) + + print('Saving to', output_dir) + + # Use provided random seed. + device = torch.device('cuda') + rng_cuda = torch.Generator(device=device) + rng_cuda = rng_cuda.manual_seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + self.pad = pad + self.num_samples = num_samples + self.num_blocks_early_stop = num_blocks_early_stop + self.sample_depth = sample_depth + + self.coarse_deterministic_sampling = True + self.crop_size = resolution_hw + self.cam_res = [self.crop_size[0]+self.pad, self.crop_size[1]+self.pad] + self.use_label_smooth_pgt = False + + # Make output dirs. + gancraft_outputs_dir = os.path.join(output_dir, 'gancraft_outputs') + os.makedirs(gancraft_outputs_dir, exist_ok=True) + vis_masks_dir = os.path.join(output_dir, 'vis_masks') + os.makedirs(vis_masks_dir, exist_ok=True) + fout = imageio.get_writer(gancraft_outputs_dir + '.mp4', fps=1) + fout_cat = imageio.get_writer(gancraft_outputs_dir + '-vis_masks.mp4', fps=1) + + evalcamctl = camctl.EvalCameraController( + self.voxel, maxstep=cam_maxstep, pattern=camera_mode, cam_ang=cam_ang, + smooth_decay_multiplier=150/cam_maxstep) + + # Get output style. + if style_img_path is None: + z = torch.empty(1, self.style_dims, dtype=torch.float32, device=device) + z.normal_(generator=rng_cuda) + else: + style_img = read_img(style_img_path) + style_img = style_img.to(device).unsqueeze(0) + mu, logvar, z = self.style_encoder(style_img) + z = self.style_net(z) + + # Generate required output images. + for id, (cam_ori_t, cam_dir_t, cam_up_t, cam_f) in enumerate(evalcamctl): + print('Rendering frame', id) + cam_f = cam_f * (self.crop_size[1]-1) # So that the view is not depending on the padding + cam_c = [(self.cam_res[0]-1)/2, (self.cam_res[1]-1)/2] + + voxel_id, depth2, raydirs = voxlib.ray_voxel_intersection_perspective( + self.voxel.voxel_t, cam_ori_t, cam_dir_t, cam_up_t, cam_f, cam_c, self.cam_res, + self.num_blocks_early_stop) + + voxel_id = voxel_id.unsqueeze(0) + depth2 = depth2.unsqueeze(0) + raydirs = raydirs.unsqueeze(0) + cam_ori_t = cam_ori_t.unsqueeze(0).to(device) + + # Save 3D voxel rendering. + mc_rgb = self.label_trans.mc_color(voxel_id[0, :, :, 0, 0].cpu().numpy()) + # Diffused shading, co-located light. + first_intersection_depth = depth2[:, 0, :, :, 0, None, :] # [1, 542, 542, 1, 1]. + first_intersection_point = raydirs * first_intersection_depth + cam_ori_t[:, None, None, None, :] + fip_local_coords = torch.remainder(first_intersection_point, 1.0) + fip_wall_proximity = torch.minimum(fip_local_coords, 1.0-fip_local_coords) + fip_wall_orientation = torch.argmin(fip_wall_proximity, dim=-1, keepdim=False) + # 0: [1,0,0]; 1: [0,1,0]; 2: [0,0,1] + lut = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32, + device=fip_wall_orientation.device) + fip_normal = lut[fip_wall_orientation] # [1, 542, 542, 1, 3] + diffuse_shade = torch.abs(torch.sum(fip_normal * raydirs, dim=-1)) + + mc_rgb = (mc_rgb.astype(np.float) / 255) ** 2.2 + mc_rgb = mc_rgb * diffuse_shade[0, :, :, :].cpu().numpy() + mc_rgb = (mc_rgb ** (1/2.2)) * 255 + mc_rgb = mc_rgb.astype(np.uint8) + if self.pad > 0: + mc_rgb = mc_rgb[self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + cv2.imwrite(os.path.join(vis_masks_dir, '{:05d}.png'.format(id)), mc_rgb, [cv2.IMWRITE_PNG_COMPRESSION, 4]) + + # Tiled eval of GANcraft. + voxel_id_all = voxel_id + depth2_all = depth2 + raydirs_all = raydirs + + # Evaluate sky in advance to get a consistent sky in the semi-transparent region. + if self.sky_global_avgpool: + sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() + sky_raydirs_in = voxlib.positional_encoding( + sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1]) + skynet_out_c = self.sky_net(sky_raydirs_in, z) + sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True) + self.sky_avg = sky_avg + + num_strips_h = (self.cam_res[0]-self.pad+tile_size-1)//tile_size + num_strips_w = (self.cam_res[1]-self.pad+tile_size-1)//tile_size + + fake_images_chunks_v = [] + # For each horizontal strip. + for strip_id_h in range(num_strips_h): + strip_begin_h = strip_id_h * tile_size + strip_end_h = np.minimum(strip_id_h * tile_size + tile_size + self.pad, self.cam_res[0]) + # For each vertical strip. + fake_images_chunks_h = [] + for strip_id_w in range(num_strips_w): + strip_begin_w = strip_id_w * tile_size + strip_end_w = np.minimum(strip_id_w * tile_size + tile_size + self.pad, self.cam_res[1]) + + voxel_id = voxel_id_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :] + depth2 = depth2_all[:, :, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :] + raydirs = raydirs_all[:, strip_begin_h:strip_end_h, strip_begin_w:strip_end_w, :, :] + + net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \ + nosky_mask, sky_mask, sky_only_mask, new_idx = self._forward_perpix( + self.blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z) + fake_images, _ = self._forward_global(net_out, z) + + if self.pad != 0: + fake_images = fake_images[:, :, self.pad//2:-self.pad//2, self.pad//2:-self.pad//2] + fake_images_chunks_h.append(fake_images) + fake_images_h = torch.cat(fake_images_chunks_h, dim=-1) + fake_images_chunks_v.append(fake_images_h) + fake_images = torch.cat(fake_images_chunks_v, dim=-2) + rgb = write_img(os.path.join(gancraft_outputs_dir, + '{:05d}.png'.format(id)), fake_images, rgb_input=True) + fout.append_data(rgb) + fout_cat.append_data(np.concatenate((mc_rgb[..., ::-1], rgb), axis=1)) + fout.close() + fout_cat.close() diff --git a/imaginaire/generators/gancraft_base.py b/imaginaire/generators/gancraft_base.py new file mode 100644 index 0000000000000000000000000000000000000000..ef164b67053af5d228bda6c9aea75c261d7b114f --- /dev/null +++ b/imaginaire/generators/gancraft_base.py @@ -0,0 +1,603 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools +import re + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from imaginaire.layers import Conv2dBlock, LinearBlock +from imaginaire.model_utils.gancraft.layers import AffineMod, ModLinear +import imaginaire.model_utils.gancraft.mc_utils as mc_utils +import imaginaire.model_utils.gancraft.voxlib as voxlib +from imaginaire.utils.distributed import master_only_print as print + + +class RenderMLP(nn.Module): + r""" MLP with affine modulation.""" + + def __init__(self, in_channels, style_dim, viewdir_dim, mask_dim=680, + out_channels_s=1, out_channels_c=3, hidden_channels=256, + use_seg=True): + super(RenderMLP, self).__init__() + + self.use_seg = use_seg + if self.use_seg: + self.fc_m_a = nn.Linear(mask_dim, hidden_channels, bias=False) + + self.fc_viewdir = None + if viewdir_dim > 0: + self.fc_viewdir = nn.Linear(viewdir_dim, hidden_channels, bias=False) + + self.fc_1 = nn.Linear(in_channels, hidden_channels) + + self.fc_2 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + self.fc_3 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + self.fc_4 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + + self.fc_sigma = nn.Linear(hidden_channels, out_channels_s) + + if viewdir_dim > 0: + self.fc_5 = nn.Linear(hidden_channels, hidden_channels, bias=False) + self.mod_5 = AffineMod(hidden_channels, style_dim, mod_bias=True) + else: + self.fc_5 = ModLinear(hidden_channels, hidden_channels, style_dim, + bias=False, mod_bias=True, output_mode=True) + self.fc_6 = ModLinear(hidden_channels, hidden_channels, style_dim, bias=False, mod_bias=True, output_mode=True) + self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) + + self.act = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, x, raydir, z, m): + r""" Forward network + + Args: + x (N x H x W x M x in_channels tensor): Projected features. + raydir (N x H x W x 1 x viewdir_dim tensor): Ray directions. + z (N x style_dim tensor): Style codes. + m (N x H x W x M x mask_dim tensor): One-hot segmentation maps. + """ + b, h, w, n, _ = x.size() + z = z[:, None, None, None, :] + + f = self.fc_1(x) + if self.use_seg: + f = f + self.fc_m_a(m) + # Common MLP + f = self.act(f) + f = self.act(self.fc_2(f, z)) + f = self.act(self.fc_3(f, z)) + f = self.act(self.fc_4(f, z)) + + # Sigma MLP + sigma = self.fc_sigma(f) + + # Color MLP + if self.fc_viewdir is not None: + f = self.fc_5(f) + f = f + self.fc_viewdir(raydir) + f = self.act(self.mod_5(f, z)) + else: + f = self.act(self.fc_5(f, z)) + f = self.act(self.fc_6(f, z)) + c = self.fc_out_c(f) + return sigma, c + + +class StyleMLP(nn.Module): + r"""MLP converting style code to intermediate style representation.""" + + def __init__(self, style_dim, out_dim, hidden_channels=256, leaky_relu=True, num_layers=5, normalize_input=True, + output_act=True): + super(StyleMLP, self).__init__() + + self.normalize_input = normalize_input + self.output_act = output_act + fc_layers = [] + fc_layers.append(nn.Linear(style_dim, hidden_channels, bias=True)) + for i in range(num_layers-1): + fc_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=True)) + self.fc_layers = nn.ModuleList(fc_layers) + + self.fc_out = nn.Linear(hidden_channels, out_dim, bias=True) + + if leaky_relu: + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + self.act = functools.partial(F.relu, inplace=True) + + def forward(self, z): + r""" Forward network + + Args: + z (N x style_dim tensor): Style codes. + """ + if self.normalize_input: + z = F.normalize(z, p=2, dim=-1) + for fc_layer in self.fc_layers: + z = self.act(fc_layer(z)) + z = self.fc_out(z) + if self.output_act: + z = self.act(z) + return z + + +class SKYMLP(nn.Module): + r"""MLP converting ray directions to sky features.""" + + def __init__(self, in_channels, style_dim, out_channels_c=3, + hidden_channels=256, leaky_relu=True): + super(SKYMLP, self).__init__() + self.fc_z_a = nn.Linear(style_dim, hidden_channels, bias=False) + + self.fc1 = nn.Linear(in_channels, hidden_channels) + self.fc2 = nn.Linear(hidden_channels, hidden_channels) + self.fc3 = nn.Linear(hidden_channels, hidden_channels) + self.fc4 = nn.Linear(hidden_channels, hidden_channels) + self.fc5 = nn.Linear(hidden_channels, hidden_channels) + + self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) + + if leaky_relu: + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + self.act = functools.partial(F.relu, inplace=True) + + def forward(self, x, z): + r"""Forward network + + Args: + x (... x in_channels tensor): Ray direction embeddings. + z (... x style_dim tensor): Style codes. + """ + + z = self.fc_z_a(z) + while z.dim() < x.dim(): + z = z.unsqueeze(1) + + y = self.act(self.fc1(x) + z) + y = self.act(self.fc2(y)) + y = self.act(self.fc3(y)) + y = self.act(self.fc4(y)) + y = self.act(self.fc5(y)) + c = self.fc_out_c(y) + + return c + + +class RenderCNN(nn.Module): + r"""CNN converting intermediate feature map to final image.""" + + def __init__(self, in_channels, style_dim, hidden_channels=256, + leaky_relu=True): + super(RenderCNN, self).__init__() + self.fc_z_cond = nn.Linear(style_dim, 2 * 2 * hidden_channels) + + self.conv1 = nn.Conv2d(in_channels, hidden_channels, 1, stride=1, padding=0) + self.conv2a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) + self.conv2b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) + + self.conv3a = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1) + self.conv3b = nn.Conv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, bias=False) + + self.conv4a = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) + self.conv4b = nn.Conv2d(hidden_channels, hidden_channels, 1, stride=1, padding=0) + + self.conv4 = nn.Conv2d(hidden_channels, 3, 1, stride=1, padding=0) + + if leaky_relu: + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + self.act = functools.partial(F.relu, inplace=True) + + def modulate(self, x, w, b): + w = w[..., None, None] + b = b[..., None, None] + return x * (w+1) + b + + def forward(self, x, z): + r"""Forward network. + + Args: + x (N x in_channels x H x W tensor): Intermediate feature map + z (N x style_dim tensor): Style codes. + """ + z = self.fc_z_cond(z) + adapt = torch.chunk(z, 2 * 2, dim=-1) + + y = self.act(self.conv1(x)) + + y = y + self.conv2b(self.act(self.conv2a(y))) + y = self.act(self.modulate(y, adapt[0], adapt[1])) + + y = y + self.conv3b(self.act(self.conv3a(y))) + y = self.act(self.modulate(y, adapt[2], adapt[3])) + + y = y + self.conv4b(self.act(self.conv4a(y))) + y = self.act(y) + + y = self.conv4(y) + + return y + + +class StyleEncoder(nn.Module): + r"""Style Encoder constructor. + + Args: + style_enc_cfg (obj): Style encoder definition file. + """ + + def __init__(self, style_enc_cfg): + super(StyleEncoder, self).__init__() + input_image_channels = style_enc_cfg.input_image_channels + num_filters = style_enc_cfg.num_filters + kernel_size = style_enc_cfg.kernel_size + padding = int(np.ceil((kernel_size - 1.0) / 2)) + style_dims = style_enc_cfg.style_dims + weight_norm_type = style_enc_cfg.weight_norm_type + self.no_vae = getattr(style_enc_cfg, 'no_vae', False) + activation_norm_type = 'none' + nonlinearity = 'leakyrelu' + base_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=2, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + # inplace_nonlinearity=True, + nonlinearity=nonlinearity) + self.layer1 = base_conv2d_block(input_image_channels, num_filters) + self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) + self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) + self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) + self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) + self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) + self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims) + if not self.no_vae: + self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims) + + def forward(self, input_x): + r"""SPADE Style Encoder forward. + + Args: + input_x (N x 3 x H x W tensor): input images. + Returns: + mu (N x C tensor): Mean vectors. + logvar (N x C tensor): Log-variance vectors. + z (N x C tensor): Style code vectors. + """ + if input_x.size(2) != 256 or input_x.size(3) != 256: + input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') + x = self.layer1(input_x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + x = self.layer6(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + if not self.no_vae: + logvar = self.fc_var(x) + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + z = eps.mul(std) + mu + else: + z = mu + logvar = torch.zeros_like(mu) + return mu, logvar, z + + +class Base3DGenerator(nn.Module): + r"""Minecraft 3D generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super(Base3DGenerator, self).__init__() + print('Base3DGenerator initialization.') + + # ---------------------- Main Network ------------------------ + # Exclude some of the features from positional encoding + self.pe_no_pe_feat_dim = getattr(gen_cfg, 'pe_no_pe_feat_dim', 0) + + # blk_feat passes through PE + input_dim = (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim)*(gen_cfg.pe_lvl_feat*2) + self.pe_no_pe_feat_dim + if (gen_cfg.pe_incl_orig_feat): + input_dim += (gen_cfg.blk_feat_dim-self.pe_no_pe_feat_dim) + print('[Base3DGenerator] Expected input dimensions: ', input_dim) + self.input_dim = input_dim + + self.mlp_model_kwargs = gen_cfg.mlp_model_kwargs + self.pe_lvl_localcoords = getattr(gen_cfg, 'pe_lvl_localcoords', 0) + if self.pe_lvl_localcoords > 0: + self.mlp_model_kwargs['poscode_dim'] = self.pe_lvl_localcoords * 2 * 3 + + # Set pe_lvl_raydir=0 and pe_incl_orig_raydir=False to disable view direction input + input_dim_viewdir = 3*(gen_cfg.pe_lvl_raydir*2) + if (gen_cfg.pe_incl_orig_raydir): + input_dim_viewdir += 3 + print('[Base3DGenerator] Expected viewdir input dimensions: ', input_dim_viewdir) + self.input_dim_viewdir = input_dim_viewdir + + self.pe_params = [gen_cfg.pe_lvl_feat, gen_cfg.pe_incl_orig_feat, + gen_cfg.pe_lvl_raydir, gen_cfg.pe_incl_orig_raydir] + + # Style input dimension + style_dims = gen_cfg.style_dims + self.style_dims = style_dims + interm_style_dims = getattr(gen_cfg, 'interm_style_dims', style_dims) + self.interm_style_dims = interm_style_dims + # ---------------------- Style MLP -------------------------- + self.style_net = globals()[gen_cfg.stylenet_model]( + style_dims, interm_style_dims, **gen_cfg.stylenet_model_kwargs) + + # number of output channels for MLP (before blending) + final_feat_dim = getattr(gen_cfg, 'final_feat_dim', 16) + self.final_feat_dim = final_feat_dim + + # ----------------------- Sky Network ------------------------- + sky_input_dim_base = 3 + # Dedicated sky network input dimensions + sky_input_dim = sky_input_dim_base*(gen_cfg.pe_lvl_raydir_sky*2) + if (gen_cfg.pe_incl_orig_raydir_sky): + sky_input_dim += sky_input_dim_base + print('[Base3DGenerator] Expected sky input dimensions: ', sky_input_dim) + self.pe_params_sky = [gen_cfg.pe_lvl_raydir_sky, gen_cfg.pe_incl_orig_raydir_sky] + self.sky_net = SKYMLP(sky_input_dim, style_dim=interm_style_dims, out_channels_c=final_feat_dim) + + # ----------------------- Style Encoder ------------------------- + style_enc_cfg = getattr(gen_cfg, 'style_enc', None) + setattr(style_enc_cfg, 'input_image_channels', 3) + setattr(style_enc_cfg, 'style_dims', gen_cfg.style_dims) + self.style_encoder = StyleEncoder(style_enc_cfg) + + # ---------------------- Ray Caster ------------------------- + self.num_blocks_early_stop = gen_cfg.num_blocks_early_stop + self.num_samples = gen_cfg.num_samples + self.sample_depth = gen_cfg.sample_depth + self.coarse_deterministic_sampling = getattr(gen_cfg, 'coarse_deterministic_sampling', True) + self.sample_use_box_boundaries = getattr(gen_cfg, 'sample_use_box_boundaries', True) + + # ---------------------- Blender ------------------------- + self.raw_noise_std = getattr(gen_cfg, 'raw_noise_std', 0.0) + self.dists_scale = getattr(gen_cfg, 'dists_scale', 0.25) + self.clip_feat_map = getattr(gen_cfg, 'clip_feat_map', True) + self.keep_sky_out = getattr(gen_cfg, 'keep_sky_out', False) + self.keep_sky_out_avgpool = getattr(gen_cfg, 'keep_sky_out_avgpool', False) + keep_sky_out_learnbg = getattr(gen_cfg, 'keep_sky_out_learnbg', False) + self.sky_global_avgpool = getattr(gen_cfg, 'sky_global_avgpool', False) + if self.keep_sky_out: + self.sky_replace_color = None + if keep_sky_out_learnbg: + sky_replace_color = torch.zeros([final_feat_dim]) + sky_replace_color.requires_grad = True + self.sky_replace_color = torch.nn.Parameter(sky_replace_color) + # ---------------------- render_cnn ------------------------- + self.denoiser = RenderCNN(final_feat_dim, style_dim=interm_style_dims) + self.pad = gen_cfg.pad + + def get_param_groups(self, cfg_opt): + print('[Generator] get_param_groups') + + if hasattr(cfg_opt, 'ignore_parameters'): + print('[Generator::get_param_groups] [x]: ignored.') + optimize_parameters = [] + for k, x in self.named_parameters(): + match = False + for m in cfg_opt.ignore_parameters: + if re.match(m, k) is not None: + match = True + print(' [x]', k) + break + if match is False: + print(' [v]', k) + optimize_parameters.append(x) + else: + optimize_parameters = self.parameters() + + param_groups = [] + param_groups.append({'params': optimize_parameters}) + + if hasattr(cfg_opt, 'param_groups'): + optimized_param_names = [] + all_param_names = [k for k, v in self.named_parameters()] + param_groups = [] + for k, v in cfg_opt.param_groups.items(): + print('[Generator::get_param_groups] Adding param group from config:', k, v) + params = getattr(self, k) + named_parameters = [k] + if issubclass(type(params), nn.Module): + named_parameters = [k+'.'+pname for pname, _ in params.named_parameters()] + params = params.parameters() + param_groups.append({'params': params, **v}) + optimized_param_names.extend(named_parameters) + + print('[Generator::get_param_groups] UNOPTIMIZED PARAMETERS:\n ', + set(all_param_names) - set(optimized_param_names)) + + return param_groups + + def _forward_perpix_sub(self, blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot=None): + r"""Forwarding the MLP. + + Args: + blk_feats (K x C1 tensor): Sparse block features. + worldcoord2 (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. + raydirs_in (N x H x W x 1 x C2 tensor or None): ray direction embeddings. + z (N x C3 tensor): Intermediate style vectors. + mc_masks_onehot (N x H x W x L x C4): One-hot segmentation maps. + Returns: + net_out_s (N x H x W x L x 1 tensor): Opacities. + net_out_c (N x H x W x L x C5 tensor): Color embeddings. + """ + proj_feature = voxlib.sparse_trilinear_interp_worldcoord( + blk_feats, self.voxel.corner_t, worldcoord2, ign_zero=True) + + render_net_extra_kwargs = {} + if self.pe_lvl_localcoords > 0: + local_coords = torch.remainder(worldcoord2, 1.0) * 2.0 + # Scale to [0, 2], as the positional encoding function doesn't have internal x2 + local_coords[torch.isnan(local_coords)] = 0.0 + local_coords = local_coords.contiguous() + poscode = voxlib.positional_encoding(local_coords, self.pe_lvl_localcoords, -1, False) + render_net_extra_kwargs['poscode'] = poscode + + if self.pe_params[0] == 0 and self.pe_params[1] is True: # no PE shortcut, saves ~400MB + feature_in = proj_feature + else: + if self.pe_no_pe_feat_dim > 0: + feature_in = voxlib.positional_encoding( + proj_feature[..., :-self.pe_no_pe_feat_dim].contiguous(), self.pe_params[0], -1, self.pe_params[1]) + feature_in = torch.cat([feature_in, proj_feature[..., -self.pe_no_pe_feat_dim:]], dim=-1) + else: + feature_in = voxlib.positional_encoding( + proj_feature.contiguous(), self.pe_params[0], -1, self.pe_params[1]) + + net_out_s, net_out_c = self.render_net(feature_in, raydirs_in, z, mc_masks_onehot, **render_net_extra_kwargs) + + if self.raw_noise_std > 0.: + noise = torch.randn_like(net_out_s) * self.raw_noise_std + net_out_s = net_out_s + noise + + return net_out_s, net_out_c + + def _forward_perpix(self, blk_feats, voxel_id, depth2, raydirs, cam_ori_t, z): + r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features + + Args: + blk_feats (K x C1 tensor): Sparse block features. + voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels + depth2 (N x 2 x H x W x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection. + raydirs (N x H x W x 1 x 3 tensor): The direction of each ray. + cam_ori_t (N x 3 tensor): Camera origins. + z (N x C3 tensor): Intermediate style vectors. + """ + # Generate sky_mask; PE transform on ray direction. + with torch.no_grad(): + raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() + if self.pe_params[2] == 0 and self.pe_params[3] is True: + raydirs_in = raydirs_in + elif self.pe_params[2] == 0 and self.pe_params[3] is False: # Not using raydir at all + raydirs_in = None + else: + raydirs_in = voxlib.positional_encoding(raydirs_in, self.pe_params[2], -1, self.pe_params[3]) + + # sky_mask: when True, ray finally hits sky + sky_mask = voxel_id[:, :, :, [-1], :] == 0 + # sky_only_mask: when True, ray hits nothing but sky + sky_only_mask = voxel_id[:, :, :, [0], :] == 0 + + with torch.no_grad(): + # Random sample points along the ray + num_samples = self.num_samples + 1 + if self.sample_use_box_boundaries: + num_samples = self.num_samples - self.num_blocks_early_stop + + # 10 samples per ray + 4 intersections - 2 + rand_depth, new_dists, new_idx = mc_utils.sample_depth_batched( + depth2, num_samples, deterministic=self.coarse_deterministic_sampling, + use_box_boundaries=self.sample_use_box_boundaries, sample_depth=self.sample_depth) + + worldcoord2 = raydirs * rand_depth + cam_ori_t[:, None, None, None, :] + + # Generate per-sample segmentation label + voxel_id_reduced = self.label_trans.mc2reduced(voxel_id, ign2dirt=True) + mc_masks = torch.gather(voxel_id_reduced, -2, new_idx) # B 256 256 N 1 + mc_masks = mc_masks.long() + mc_masks_onehot = torch.zeros([mc_masks.size(0), mc_masks.size(1), mc_masks.size( + 2), mc_masks.size(3), self.num_reduced_labels], dtype=torch.float, device=voxel_id.device) + # mc_masks_onehot: [B H W Nlayer 680] + mc_masks_onehot.scatter_(-1, mc_masks, 1.0) + + net_out_s, net_out_c = self._forward_perpix_sub(blk_feats, worldcoord2, raydirs_in, z, mc_masks_onehot) + + # Handle sky + sky_raydirs_in = raydirs.expand(-1, -1, -1, 1, -1).contiguous() + sky_raydirs_in = voxlib.positional_encoding(sky_raydirs_in, self.pe_params_sky[0], -1, self.pe_params_sky[1]) + skynet_out_c = self.sky_net(sky_raydirs_in, z) + + # Blending + weights = mc_utils.volum_rendering_relu(net_out_s, new_dists * self.dists_scale, dim=-2) + + # If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero. + weights = weights * torch.logical_not(sky_only_mask).float() + total_weights_raw = torch.sum(weights, dim=-2, keepdim=True) # 256 256 1 1 + total_weights = total_weights_raw + + is_gnd = worldcoord2[..., [0]] <= 1.0 # Y X Z, [256, 256, 4, 3], nan < 1.0 == False + is_gnd = is_gnd.any(dim=-2, keepdim=True) + nosky_mask = torch.logical_or(torch.logical_not(sky_mask), is_gnd) + nosky_mask = nosky_mask.float() + + # Avoid sky leakage + sky_weight = 1.0-total_weights + if self.keep_sky_out: + # keep_sky_out_avgpool overrides sky_replace_color + if self.sky_replace_color is None or self.keep_sky_out_avgpool: + if self.keep_sky_out_avgpool: + if hasattr(self, 'sky_avg'): + sky_avg = self.sky_avg + else: + if self.sky_global_avgpool: + sky_avg = torch.mean(skynet_out_c, dim=[1, 2], keepdim=True) + else: + skynet_out_c_nchw = skynet_out_c.permute(0, 4, 1, 2, 3).squeeze(-1) + sky_avg = F.avg_pool2d(skynet_out_c_nchw, 31, stride=1, padding=15, count_include_pad=False) + sky_avg = sky_avg.permute(0, 2, 3, 1).unsqueeze(-2) + # print(sky_avg.shape) + skynet_out_c = skynet_out_c * (1.0-nosky_mask) + sky_avg*(nosky_mask) + else: + sky_weight = sky_weight * (1.0-nosky_mask) + else: + skynet_out_c = skynet_out_c * (1.0-nosky_mask) + self.sky_replace_color*(nosky_mask) + + if self.clip_feat_map is True: # intermediate feature before blending & CNN + rgbs = torch.clamp(net_out_c, -1, 1) + 1 + rgbs_sky = torch.clamp(skynet_out_c, -1, 1) + 1 + net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ + rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 + net_out = net_out.squeeze(-2) + net_out = net_out - 1 + elif self.clip_feat_map is False: + rgbs = net_out_c + rgbs_sky = skynet_out_c + net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ + rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 + net_out = net_out.squeeze(-2) + elif self.clip_feat_map == 'tanh': + rgbs = torch.tanh(net_out_c) + rgbs_sky = torch.tanh(skynet_out_c) + net_out = torch.sum(weights*rgbs, dim=-2, keepdim=True) + sky_weight * \ + rgbs_sky # 576, 768, 4, 3 -> 576, 768, 3 + net_out = net_out.squeeze(-2) + else: + raise NotImplementedError + + return net_out, new_dists, weights, total_weights_raw, rand_depth, net_out_s, net_out_c, skynet_out_c, \ + nosky_mask, sky_mask, sky_only_mask, new_idx + + def _forward_global(self, net_out, z): + r"""Forward the CNN + + Args: + net_out (N x C5 x H x W tensor): Intermediate feature maps. + z (N x C3 tensor): Intermediate style vectors. + + Returns: + fake_images (N x 3 x H x W tensor): Output image. + fake_images_raw (N x 3 x H x W tensor): Output image before TanH. + """ + fake_images = net_out.permute(0, 3, 1, 2) + fake_images_raw = self.denoiser(fake_images, z) + fake_images = torch.tanh(fake_images_raw) + + return fake_images, fake_images_raw diff --git a/imaginaire/generators/munit.py b/imaginaire/generators/munit.py new file mode 100644 index 0000000000000000000000000000000000000000..55cb066bb51e6f8e96d623d2224af06af017c086 --- /dev/null +++ b/imaginaire/generators/munit.py @@ -0,0 +1,465 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import warnings +from types import SimpleNamespace + +import torch +from torch import nn +from torch.nn import Upsample as NearestUpsample + +from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock +from imaginaire.generators.unit import ContentEncoder + + +class Generator(nn.Module): + r"""Improved MUNIT generator. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super().__init__() + self.autoencoder_a = AutoEncoder(**vars(gen_cfg)) + self.autoencoder_b = AutoEncoder(**vars(gen_cfg)) + + def forward(self, data, random_style=True, image_recon=True, + latent_recon=True, cycle_recon=True, within_latent_recon=False): + r"""In MUNIT's forward pass, it generates a content code and a style + code from images in both domain. It then performs a within-domain + reconstruction step and a cross-domain translation step. + In within-domain reconstruction, it reconstructs an image using the + content and style from the same image and optionally encodes the image + back to the latent space. + In cross-domain translation, it generates an translated image by mixing + the content and style from images in different domains, and optionally + encodes the image back to the latent space. + + Args: + data (dict): Training data at the current iteration. + - images_a (tensor): Images from domain A. + - images_b (tensor): Images from domain B. + random_style (bool): If ``True``, samples the style code from the + prior distribution, otherwise uses the style code encoded from + the input images in the other domain. + image_recon (bool): If ``True``, also returns reconstructed images. + latent_recon (bool): If ``True``, also returns reconstructed latent + code during cross-domain translation. + cycle_recon (bool): If ``True``, also returns cycle + reconstructed images. + within_latent_recon (bool): If ``True``, also returns reconstructed + latent code during within-domain reconstruction. + """ + + images_a = data['images_a'] + images_b = data['images_b'] + net_G_output = dict() + + # encode input images into content and style code + content_a, style_a = self.autoencoder_a.encode(images_a) + content_b, style_b = self.autoencoder_b.encode(images_b) + + # decode (within domain) + if image_recon: + images_aa = self.autoencoder_a.decode(content_a, style_a) + images_bb = self.autoencoder_b.decode(content_b, style_b) + net_G_output.update(dict(images_aa=images_aa, images_bb=images_bb)) + + # decode (cross domain) + if random_style: # use randomly sampled style code + style_a_rand = torch.randn_like(style_a) + style_b_rand = torch.randn_like(style_b) + else: # use style code encoded from the other domain + style_a_rand = style_a + style_b_rand = style_b + images_ba = self.autoencoder_a.decode(content_b, style_a_rand) + images_ab = self.autoencoder_b.decode(content_a, style_b_rand) + + # encode translated images into content and style code + if latent_recon or cycle_recon: + content_ba, style_ba = self.autoencoder_a.encode(images_ba) + content_ab, style_ab = self.autoencoder_b.encode(images_ab) + net_G_output.update(dict(content_ba=content_ba, style_ba=style_ba, + content_ab=content_ab, style_ab=style_ab)) + + # encode reconstructed images into content and style code + if image_recon and within_latent_recon: + content_aa, style_aa = self.autoencoder_a.encode(images_aa) + content_bb, style_bb = self.autoencoder_b.encode(images_bb) + net_G_output.update(dict(content_aa=content_aa, style_aa=style_aa, + content_bb=content_bb, style_bb=style_bb)) + + # cycle reconstruction + if cycle_recon: + images_aba = self.autoencoder_a.decode(content_ab, style_a) + images_bab = self.autoencoder_b.decode(content_ba, style_b) + net_G_output.update( + dict(images_aba=images_aba, images_bab=images_bab)) + + # required outputs + net_G_output.update(dict(content_a=content_a, content_b=content_b, + style_a=style_a, style_b=style_b, + style_a_rand=style_a_rand, + style_b_rand=style_b_rand, + images_ba=images_ba, images_ab=images_ab)) + + return net_G_output + + def inference(self, data, a2b=True, random_style=True): + r"""MUNIT inference. + + Args: + data (dict): Training data at the current iteration. + - images_a (tensor): Images from domain A. + - images_b (tensor): Images from domain B. + a2b (bool): If ``True``, translates images from domain A to B, + otherwise from B to A. + random_style (bool): If ``True``, samples the style code from the + prior distribution, otherwise uses the style code encoded from + the input images in the other domain. + """ + if a2b: + input_key = 'images_a' + content_encode = self.autoencoder_a.content_encoder + style_encode = self.autoencoder_b.style_encoder + decode = self.autoencoder_b.decode + else: + input_key = 'images_b' + content_encode = self.autoencoder_b.content_encoder + style_encode = self.autoencoder_a.style_encoder + decode = self.autoencoder_a.decode + + content_images = data[input_key] + content = content_encode(content_images) + if random_style: + style_channels = self.autoencoder_a.style_channels + style = torch.randn(content.size(0), style_channels, 1, 1, + device=torch.device('cuda')) + file_names = data['key'][input_key]['filename'] + else: + style_key = 'images_b' if a2b else 'images_a' + assert style_key in data.keys(), \ + "{} must be provided when 'random_style' " \ + "is set to False".format(style_key) + style_images = data[style_key] + style = style_encode(style_images) + file_names = \ + [content_name + '_style_' + style_name + for content_name, style_name in + zip(data['key'][input_key]['filename'], + data['key'][style_key]['filename'])] + + output_images = decode(content, style) + return output_images, file_names + + +class AutoEncoder(nn.Module): + r"""Improved MUNIT autoencoder. + + Args: + num_filters (int): Base filter numbers. + max_num_filters (int): Maximum number of filters in the encoder. + num_filters_mlp (int): Base filter number in the MLP module. + latent_dim (int): Dimension of the style code. + num_res_blocks (int): Number of residual blocks at the end of the + content encoder. + num_mlp_blocks (int): Number of layers in the MLP module. + num_downsamples_style (int): Number of times we reduce + resolution by 2x2 for the style image. + num_downsamples_content (int): Number of times we reduce + resolution by 2x2 for the content image. + num_image_channels (int): Number of input image channels. + content_norm_type (str): Type of activation normalization in the + content encoder. + style_norm_type (str): Type of activation normalization in the + style encoder. + decoder_norm_type (str): Type of activation normalization in the + decoder. + weight_norm_type (str): Type of weight normalization. + decoder_norm_params (obj): Parameters of activation normalization in the + decoder. If not ``None``, decoder_norm_params.__dict__ will be used + as keyword arguments when initializing activation normalization. + output_nonlinearity (str): Type of nonlinearity before final output, + ``'tanh'`` or ``'none'``. + pre_act (bool): If ``True``, uses pre-activation residual blocks. + apply_noise (bool): If ``True``, injects Gaussian noise in the decoder. + """ + + def __init__(self, + num_filters=64, + max_num_filters=256, + num_filters_mlp=256, + latent_dim=8, + num_res_blocks=4, + num_mlp_blocks=2, + num_downsamples_style=4, + num_downsamples_content=2, + num_image_channels=3, + content_norm_type='instance', + style_norm_type='', + decoder_norm_type='instance', + weight_norm_type='', + decoder_norm_params=SimpleNamespace(affine=False), + output_nonlinearity='', + pre_act=False, + apply_noise=False, + **kwargs): + super().__init__() + for key in kwargs: + if key != 'type': + warnings.warn( + "Generator argument '{}' is not used.".format(key)) + self.style_encoder = StyleEncoder(num_downsamples_style, + num_image_channels, + num_filters, + latent_dim, + 'reflect', + style_norm_type, + weight_norm_type, + 'relu') + self.content_encoder = ContentEncoder(num_downsamples_content, + num_res_blocks, + num_image_channels, + num_filters, + max_num_filters, + 'reflect', + content_norm_type, + weight_norm_type, + 'relu', + pre_act) + self.decoder = Decoder(num_downsamples_content, + num_res_blocks, + self.content_encoder.output_dim, + num_image_channels, + num_filters_mlp, + 'reflect', + decoder_norm_type, + decoder_norm_params, + weight_norm_type, + 'relu', + output_nonlinearity, + pre_act, + apply_noise) + self.mlp = MLP(latent_dim, + num_filters_mlp, + num_filters_mlp, + num_mlp_blocks, + 'none', + 'relu') + self.style_channels = latent_dim + + def forward(self, images): + r"""Reconstruct an image. + + Args: + images (Tensor): Input images. + Returns: + images_recon (Tensor): Reconstructed images. + """ + content, style = self.encode(images) + images_recon = self.decode(content, style) + return images_recon + + def encode(self, images): + r"""Encode an image to content and style code. + + Args: + images (Tensor): Input images. + Returns: + (tuple): + - content (Tensor): Content code. + - style (Tensor): Style code. + """ + style = self.style_encoder(images) + content = self.content_encoder(images) + return content, style + + def decode(self, content, style): + r"""Decode content and style code to an image. + + Args: + content (Tensor): Content code. + style (Tensor): Style code. + Returns: + images (Tensor): Output images. + """ + style = self.mlp(style) + images = self.decoder(content, style) + return images + + +class StyleEncoder(nn.Module): + r"""MUNIT style encoder. + + Args: + num_downsamples (int): Number of times we reduce + resolution by 2x2. + num_image_channels (int): Number of input image channels. + num_filters (int): Base filter numbers. + style_channels (int): Dimension of the style code. + padding_mode (string): Type of padding. + activation_norm_type (str): Type of activation normalization. + weight_norm_type (str): Type of weight normalization. + nonlinearity (str): Type of nonlinear activation function. + """ + + def __init__(self, num_downsamples, num_image_channels, num_filters, + style_channels, padding_mode, activation_norm_type, + weight_norm_type, nonlinearity): + super().__init__() + conv_params = dict(padding_mode=padding_mode, + activation_norm_type=activation_norm_type, + weight_norm_type=weight_norm_type, + nonlinearity=nonlinearity, + inplace_nonlinearity=True) + model = [] + model += [Conv2dBlock(num_image_channels, num_filters, 7, 1, 3, + **conv_params)] + for i in range(2): + model += [Conv2dBlock(num_filters, 2 * num_filters, 4, 2, 1, + **conv_params)] + num_filters *= 2 + for i in range(num_downsamples - 2): + model += [Conv2dBlock(num_filters, num_filters, 4, 2, 1, + **conv_params)] + model += [nn.AdaptiveAvgPool2d(1)] + model += [nn.Conv2d(num_filters, style_channels, 1, 1, 0)] + self.model = nn.Sequential(*model) + self.output_dim = num_filters + + def forward(self, x): + r""" + + Args: + x (tensor): Input image. + """ + return self.model(x) + + +class Decoder(nn.Module): + r"""Improved MUNIT decoder. The network consists of + + - $(num_res_blocks) residual blocks. + - $(num_upsamples) residual blocks or convolutional blocks + - output layer. + + Args: + num_upsamples (int): Number of times we increase resolution by 2x2. + num_res_blocks (int): Number of residual blocks. + num_filters (int): Base filter numbers. + num_image_channels (int): Number of input image channels. + style_channels (int): Dimension of the style code. + padding_mode (string): Type of padding. + activation_norm_type (str): Type of activation normalization. + activation_norm_params (obj): Parameters of activation normalization. + If not ``None``, decoder_norm_params.__dict__ will be used + as keyword arguments when initializing activation normalization. + weight_norm_type (str): Type of weight normalization. + nonlinearity (str): Type of nonlinear activation function. + output_nonlinearity (str): Type of nonlinearity before final output, + ``'tanh'`` or ``'none'``. + pre_act (bool): If ``True``, uses pre-activation residual blocks. + apply_noise (bool): If ``True``, injects Gaussian noise. + """ + + def __init__(self, + num_upsamples, + num_res_blocks, + num_filters, + num_image_channels, + style_channels, + padding_mode, + activation_norm_type, + activation_norm_params, + weight_norm_type, + nonlinearity, + output_nonlinearity, + pre_act=False, + apply_noise=False): + super().__init__() + adain_params = SimpleNamespace( + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + cond_dims=style_channels) + conv_params = dict(padding_mode=padding_mode, + nonlinearity=nonlinearity, + inplace_nonlinearity=True, + apply_noise=apply_noise, + weight_norm_type=weight_norm_type, + activation_norm_type='adaptive', + activation_norm_params=adain_params) + + # The order of operations in residual blocks. + order = 'pre_act' if pre_act else 'CNACNA' + + # Residual blocks with AdaIN. + self.decoder = nn.ModuleList() + for _ in range(num_res_blocks): + self.decoder += [Res2dBlock(num_filters, num_filters, + **conv_params, + order=order)] + + # Convolutional blocks with upsampling. + for i in range(num_upsamples): + self.decoder += [NearestUpsample(scale_factor=2)] + self.decoder += [Conv2dBlock(num_filters, num_filters // 2, + 5, 1, 2, **conv_params)] + num_filters //= 2 + self.decoder += [Conv2dBlock(num_filters, num_image_channels, 7, 1, 3, + nonlinearity=output_nonlinearity, + padding_mode=padding_mode)] + + def forward(self, x, style): + r""" + + Args: + x (tensor): Content embedding of the content image. + style (tensor): Style embedding of the style image. + """ + for block in self.decoder: + if getattr(block, 'conditional', False): + x = block(x, style) + else: + x = block(x) + return x + + +class MLP(nn.Module): + r"""The multi-layer perceptron (MLP) that maps Gaussian style code to a + feature vector that is given as the conditional input to AdaIN. + + Args: + input_dim (int): Number of channels in the input tensor. + output_dim (int): Number of channels in the output tensor. + latent_dim (int): Number of channels in the latent features. + num_layers (int): Number of layers in the MLP. + norm (str): Type of activation normalization. + nonlinearity (str): Type of nonlinear activation function. + """ + + def __init__(self, input_dim, output_dim, latent_dim, num_layers, + norm, nonlinearity): + super().__init__() + model = [] + model += [LinearBlock(input_dim, latent_dim, + activation_norm_type=norm, + nonlinearity=nonlinearity)] + for i in range(num_layers - 2): + model += [LinearBlock(latent_dim, latent_dim, + activation_norm_type=norm, + nonlinearity=nonlinearity)] + model += [LinearBlock(latent_dim, output_dim, + activation_norm_type=norm, + nonlinearity=nonlinearity)] + self.model = nn.Sequential(*model) + + def forward(self, x): + r""" + + Args: + x (tensor): Input image. + """ + return self.model(x.view(x.size(0), -1)) diff --git a/imaginaire/generators/pix2pixHD.py b/imaginaire/generators/pix2pixHD.py new file mode 100644 index 0000000000000000000000000000000000000000..cd2e36b31b2b045594d3dd1d7db7cb4ee336d6f8 --- /dev/null +++ b/imaginaire/generators/pix2pixHD.py @@ -0,0 +1,348 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import Upsample as NearestUpsample + +from imaginaire.layers import Conv2dBlock, Res2dBlock +from imaginaire.utils.data import (get_paired_input_image_channel_number, + get_paired_input_label_channel_number) +from imaginaire.utils.distributed import master_only_print as print + + +class Generator(nn.Module): + r"""Pix2pixHD coarse-to-fine generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super().__init__() + # pix2pixHD has a global generator. + global_gen_cfg = gen_cfg.global_generator + num_filters_global = getattr(global_gen_cfg, 'num_filters', 64) + # Optionally, it can have several local enhancers. They are useful + # for generating high resolution images. + local_gen_cfg = gen_cfg.local_enhancer + self.num_local_enhancers = num_local_enhancers = \ + getattr(local_gen_cfg, 'num_enhancers', 1) + # By default, pix2pixHD using instance normalization. + activation_norm_type = getattr(gen_cfg, 'activation_norm_type', + 'instance') + activation_norm_params = getattr(gen_cfg, 'activation_norm_params', + None) + weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '') + padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect') + base_conv_block = partial(Conv2dBlock, + padding_mode=padding_mode, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + nonlinearity='relu') + base_res_block = partial(Res2dBlock, + padding_mode=padding_mode, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + nonlinearity='relu', order='CNACN') + # Know what is the number of available segmentation labels. + num_input_channels = get_paired_input_label_channel_number(data_cfg) + self.concat_features = False + # Check whether label input contains specific type of data (e.g. + # instance_maps). + self.contain_instance_map = False + if data_cfg.input_labels[-1] == 'instance_maps': + self.contain_instance_map = True + # The feature encoder is only useful when the instance map is provided. + if hasattr(gen_cfg, 'enc') and self.contain_instance_map: + num_feat_channels = getattr(gen_cfg.enc, 'num_feat_channels', 0) + if num_feat_channels > 0: + num_input_channels += num_feat_channels + self.concat_features = True + self.encoder = Encoder(gen_cfg.enc, data_cfg) + + # Global generator model. + global_model = GlobalGenerator(global_gen_cfg, data_cfg, + num_input_channels, padding_mode, + base_conv_block, base_res_block) + if num_local_enhancers == 0: + self.global_model = global_model + else: + # Get rid of the last layer. + global_model = global_model.model + global_model = [global_model[i] + for i in range(len(global_model) - 1)] + # global_model = [global_model[i] + # for i in range(len(global_model) - 2)] + self.global_model = nn.Sequential(*global_model) + + # Local enhancer model. + for n in range(num_local_enhancers): + # num_filters = num_filters_global // (2 ** n) + num_filters = num_filters_global // (2 ** (n + 1)) + output_img = (n == num_local_enhancers - 1) + setattr(self, 'enhancer_%d' % n, + LocalEnhancer(local_gen_cfg, data_cfg, + num_input_channels, num_filters, + padding_mode, base_conv_block, + base_res_block, output_img)) + + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], + count_include_pad=False) + + def forward(self, data, random_style=False): + r"""Coarse-to-fine generator forward. + + Args: + data (dict) : Dictionary of input data. + random_style (bool): Always set to false for the pix2pixHD model. + Returns: + output (dict) : Dictionary of output data. + """ + label = data['label'] + + output = dict() + if self.concat_features: + features = self.encoder(data['images'], data['instance_maps']) + label = torch.cat([label, features], dim=1) + output['feature_maps'] = features + + # Create input pyramid. + input_downsampled = [label] + for i in range(self.num_local_enhancers): + input_downsampled.append(self.downsample(input_downsampled[-1])) + + # Output at coarsest level. + x = self.global_model(input_downsampled[-1]) + + # Coarse-to-fine: build up one layer at a time. + for n in range(self.num_local_enhancers): + input_n = input_downsampled[self.num_local_enhancers - n - 1] + enhancer = getattr(self, 'enhancer_%d' % n) + x = enhancer(x, input_n) + + output['fake_images'] = x + return output + + def load_pretrained_network(self, pretrained_dict): + r"""Load a pretrained network.""" + # print(pretrained_dict.keys()) + model_dict = self.state_dict() + print('Pretrained network has fewer layers; The following are ' + 'not initialized:') + + not_initialized = set() + for k, v in model_dict.items(): + kp = 'module.' + k.replace('global_model.', 'global_model.model.') + if kp in pretrained_dict and v.size() == pretrained_dict[kp].size(): + model_dict[k] = pretrained_dict[kp] + else: + not_initialized.add('.'.join(k.split('.')[:2])) + print(sorted(not_initialized)) + self.load_state_dict(model_dict) + + def inference(self, data, **kwargs): + r"""Generator inference. + + Args: + data (dict) : Dictionary of input data. + Returns: + fake_images (tensor): Output fake images. + file_names (str): Data file name. + """ + output = self.forward(data, **kwargs) + return output['fake_images'], data['key']['seg_maps'][0] + + +class LocalEnhancer(nn.Module): + r"""Local enhancer constructor. These are sub-networks that are useful + when aiming to produce high-resolution outputs. + + Args: + gen_cfg (obj): local generator definition part of the yaml config + file. + data_cfg (obj): Data definition part of the yaml config file. + num_input_channels (int): Number of segmentation labels. + num_filters (int): Number of filters for the first layer. + padding_mode (str): zero | reflect | ... + base_conv_block (obj): Conv block with preset attributes. + base_res_block (obj): Residual block with preset attributes. + output_img (bool): Output is image or feature map. + """ + + def __init__(self, gen_cfg, data_cfg, num_input_channels, num_filters, + padding_mode, base_conv_block, base_res_block, + output_img=False): + super(LocalEnhancer, self).__init__() + num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 3) + num_img_channels = get_paired_input_image_channel_number(data_cfg) + # Downsample. + model_downsample = \ + [base_conv_block(num_input_channels, num_filters, 7, padding=3), + base_conv_block(num_filters, num_filters * 2, 3, stride=2, + padding=1)] + # Residual blocks. + model_upsample = [] + for i in range(num_res_blocks): + model_upsample += [base_res_block(num_filters * 2, num_filters * 2, + 3, padding=1)] + # Upsample. + model_upsample += \ + [NearestUpsample(scale_factor=2), + base_conv_block(num_filters * 2, num_filters, 3, padding=1)] + + # Final convolution. + if output_img: + model_upsample += [Conv2dBlock(num_filters, num_img_channels, 7, + padding=3, padding_mode=padding_mode, + nonlinearity='tanh')] + + self.model_downsample = nn.Sequential(*model_downsample) + self.model_upsample = nn.Sequential(*model_upsample) + + def forward(self, output_coarse, input_fine): + r"""Local enhancer forward. + + Args: + output_coarse (4D tensor) : Coarse output from previous layer. + input_fine (4D tensor) : Fine input from current layer. + Returns: + output (4D tensor) : Refined output. + """ + output = self.model_upsample(self.model_downsample(input_fine) + output_coarse) + return output + + +class GlobalGenerator(nn.Module): + r"""Coarse generator constructor. This is the main generator in the + pix2pixHD architecture. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + num_input_channels (int): Number of segmentation labels. + padding_mode (str): zero | reflect | ... + base_conv_block (obj): Conv block with preset attributes. + base_res_block (obj): Residual block with preset attributes. + """ + + def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode, + base_conv_block, base_res_block): + super(GlobalGenerator, self).__init__() + num_img_channels = get_paired_input_image_channel_number(data_cfg) + num_filters = getattr(gen_cfg, 'num_filters', 64) + num_downsamples = getattr(gen_cfg, 'num_downsamples', 4) + num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9) + # First layer. + model = [base_conv_block(num_input_channels, num_filters, + kernel_size=7, padding=3)] + # Downsample. + for i in range(num_downsamples): + ch = num_filters * (2 ** i) + model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)] + # ResNet blocks. + ch = num_filters * (2 ** num_downsamples) + for i in range(num_res_blocks): + model += [base_res_block(ch, ch, 3, padding=1)] + # Upsample. + num_upsamples = num_downsamples + for i in reversed(range(num_upsamples)): + ch = num_filters * (2 ** i) + model += \ + [NearestUpsample(scale_factor=2), + base_conv_block(ch * 2, ch, 3, padding=1)] + model += [Conv2dBlock(num_filters, num_img_channels, 7, padding=3, + padding_mode=padding_mode, nonlinearity='tanh')] + self.model = nn.Sequential(*model) + + def forward(self, input): + r"""Coarse-to-fine generator forward. + + Args: + input (4D tensor) : Input semantic representations. + Returns: + output (4D tensor) : Synthesized image by generator. + """ + return self.model(input) + + +class Encoder(nn.Module): + r"""Encoder for getting region-wise features for style control. + + Args: + enc_cfg (obj): Encoder definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file + """ + + def __init__(self, enc_cfg, data_cfg): + super(Encoder, self).__init__() + label_nc = get_paired_input_label_channel_number(data_cfg) + feat_nc = enc_cfg.num_feat_channels + n_clusters = getattr(enc_cfg, 'num_clusters', 10) + for i in range(label_nc): + dummy_arr = np.zeros((n_clusters, feat_nc), dtype=np.float32) + self.register_buffer('cluster_%d' % i, + torch.tensor(dummy_arr, dtype=torch.float32)) + num_img_channels = get_paired_input_image_channel_number(data_cfg) + self.num_feat_channels = getattr(enc_cfg, 'num_feat_channels', 3) + num_filters = getattr(enc_cfg, 'num_filters', 64) + num_downsamples = getattr(enc_cfg, 'num_downsamples', 4) + weight_norm_type = getattr(enc_cfg, 'weight_norm_type', 'none') + activation_norm_type = getattr(enc_cfg, 'activation_norm_type', + 'instance') + padding_mode = getattr(enc_cfg, 'padding_mode', 'reflect') + base_conv_block = partial(Conv2dBlock, + padding_mode=padding_mode, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity='relu') + model = [base_conv_block(num_img_channels, num_filters, 7, padding=3)] + # Downsample. + for i in range(num_downsamples): + ch = num_filters * (2**i) + model += [base_conv_block(ch, ch * 2, 3, stride=2, padding=1)] + # Upsample. + for i in reversed(range(num_downsamples)): + ch = num_filters * (2 ** i) + model += [NearestUpsample(scale_factor=2), + base_conv_block(ch * 2, ch, 3, padding=1)] + + model += [Conv2dBlock(num_filters, self.num_feat_channels, 7, + padding=3, padding_mode=padding_mode, + nonlinearity='tanh')] + self.model = nn.Sequential(*model) + + def forward(self, input, instance_map): + r"""Extracting region-wise features + + Args: + input (4D tensor): Real RGB images. + instance_map (4D tensor): Instance label mask. + Returns: + outputs_mean (4D tensor): Instance-wise average-pooled + feature maps. + """ + outputs = self.model(input) + # Instance-wise average pooling. + outputs_mean = torch.zeros_like(outputs) + # Find all the unique labels in this batch. + inst_list = np.unique(instance_map.cpu().numpy().astype(int)) + for i in inst_list: + for b in range(input.size(0)): + # Find the pixels in this instance map have this instance label. + indices = (instance_map[b:b+1] == int(i)).nonzero() # n x 4 + # Scan through the feature channels. + for j in range(self.num_feat_channels): + output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, + indices[:, 2], indices[:, 3]] + mean_feat = torch.mean(output_ins).expand_as(output_ins) + outputs_mean[indices[:, 0] + b, indices[:, 1] + j, + indices[:, 2], indices[:, 3]] = mean_feat + return outputs_mean diff --git a/imaginaire/generators/spade.py b/imaginaire/generators/spade.py new file mode 100644 index 0000000000000000000000000000000000000000..dc69630304ccb2ce3fab707ca2e7de5f7aeec55a --- /dev/null +++ b/imaginaire/generators/spade.py @@ -0,0 +1,571 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools +import math +import types + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Upsample as NearestUpsample + +from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock +from imaginaire.utils.data import (get_crop_h_w, + get_paired_input_image_channel_number, + get_paired_input_label_channel_number) +from imaginaire.utils.distributed import master_only_print as print + + +class Generator(nn.Module): + r"""SPADE generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super(Generator, self).__init__() + print('SPADE generator initialization.') + # We assume the first datum is the ground truth image. + image_channels = getattr(gen_cfg, 'image_channels', None) + if image_channels is None: + image_channels = get_paired_input_image_channel_number(data_cfg) + num_labels = getattr(gen_cfg, 'num_labels', None) + if num_labels is None: + # Calculate number of channels in the input label when not specified. + num_labels = get_paired_input_label_channel_number(data_cfg) + crop_h, crop_w = get_crop_h_w(data_cfg.train.augmentations) + # Build the generator + out_image_small_side_size = crop_w if crop_w < crop_h else crop_h + num_filters = getattr(gen_cfg, 'num_filters', 128) + kernel_size = getattr(gen_cfg, 'kernel_size', 3) + weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral') + + cond_dims = 0 + # Check whether we use the style code. + style_dims = getattr(gen_cfg, 'style_dims', None) + self.style_dims = style_dims + if style_dims is not None: + print('\tStyle code dimensions: %d' % style_dims) + cond_dims += style_dims + self.use_style = True + else: + self.use_style = False + # Check whether we use the attribute code. + if hasattr(gen_cfg, 'attribute_dims'): + self.use_attribute = True + self.attribute_dims = gen_cfg.attribute_dims + cond_dims += gen_cfg.attribute_dims + else: + self.use_attribute = False + + if not self.use_style and not self.use_attribute: + self.use_style_encoder = False + else: + self.use_style_encoder = True + print('\tBase filter number: %d' % num_filters) + print('\tConvolution kernel size: %d' % kernel_size) + print('\tWeight norm type: %s' % weight_norm_type) + skip_activation_norm = \ + getattr(gen_cfg, 'skip_activation_norm', True) + activation_norm_params = getattr(gen_cfg, 'activation_norm_params', None) + if activation_norm_params is None: + activation_norm_params = types.SimpleNamespace() + if not hasattr(activation_norm_params, 'num_filters'): + setattr(activation_norm_params, 'num_filters', 128) + if not hasattr(activation_norm_params, 'kernel_size'): + setattr(activation_norm_params, 'kernel_size', 3) + if not hasattr(activation_norm_params, 'activation_norm_type'): + setattr(activation_norm_params, 'activation_norm_type', 'sync_batch') + if not hasattr(activation_norm_params, 'separate_projection'): + setattr(activation_norm_params, 'separate_projection', False) + if not hasattr(activation_norm_params, 'activation_norm_params'): + activation_norm_params.activation_norm_params = types.SimpleNamespace() + activation_norm_params.activation_norm_params.affine = True + setattr(activation_norm_params, 'cond_dims', num_labels) + if not hasattr(activation_norm_params, 'weight_norm_type'): + setattr(activation_norm_params, 'weight_norm_type', weight_norm_type) + global_adaptive_norm_type = getattr(gen_cfg, 'global_adaptive_norm_type', 'sync_batch') + use_posenc_in_input_layer = getattr(gen_cfg, 'use_posenc_in_input_layer', True) + output_multiplier = getattr(gen_cfg, 'output_multiplier', 1.0) + print(activation_norm_params) + self.spade_generator = SPADEGenerator(num_labels, + out_image_small_side_size, + image_channels, + num_filters, + kernel_size, + cond_dims, + activation_norm_params, + weight_norm_type, + global_adaptive_norm_type, + skip_activation_norm, + use_posenc_in_input_layer, + self.use_style_encoder, + output_multiplier) + if self.use_style: + # Build the encoder. + style_enc_cfg = getattr(gen_cfg, 'style_enc', None) + if style_enc_cfg is None: + style_enc_cfg = types.SimpleNamespace() + if not hasattr(style_enc_cfg, 'num_filters'): + setattr(style_enc_cfg, 'num_filters', 128) + if not hasattr(style_enc_cfg, 'kernel_size'): + setattr(style_enc_cfg, 'kernel_size', 3) + if not hasattr(style_enc_cfg, 'weight_norm_type'): + setattr(style_enc_cfg, 'weight_norm_type', weight_norm_type) + setattr(style_enc_cfg, 'input_image_channels', image_channels) + setattr(style_enc_cfg, 'style_dims', style_dims) + self.style_encoder = StyleEncoder(style_enc_cfg) + + self.z = None + print('Done with the SPADE generator initialization.') + + def forward(self, data, random_style=False): + r"""SPADE Generator forward. + + Args: + data (dict): + - images (N x C1 x H x W tensor) : Ground truth images + - label (N x C2 x H x W tensor) : Semantic representations + - z (N x style_dims tensor): Gaussian random noise + - random_style (bool): Whether to sample a random style vector. + Returns: + (dict): + - fake_images (N x 3 x H x W tensor): fake images + - mu (N x C1 tensor): mean vectors + - logvar (N x C1 tensor): log-variance vectors + """ + if self.use_style_encoder: + if random_style: + bs = data['label'].size(0) + z = torch.randn( + bs, self.style_dims, dtype=torch.float32).cuda() + if (data['label'].dtype == + data['label'].dtype == torch.float16): + z = z.half() + mu = None + logvar = None + else: + mu, logvar, z = self.style_encoder(data['images']) + if self.use_attribute: + data['z'] = torch.cat((z, data['attributes'].squeeze(1)), dim=1) + else: + data['z'] = z + output = self.spade_generator(data) + if self.use_style_encoder: + output['mu'] = mu + output['logvar'] = logvar + return output + + def inference(self, + data, + random_style=False, + use_fixed_random_style=False, + keep_original_size=False): + r"""Compute results images for a batch of input data and save the + results in the specified folder. + + Args: + data (dict): + - images (N x C1 x H x W tensor) : Ground truth images + - label (N x C2 x H x W tensor) : Semantic representations + - z (N x style_dims tensor): Gaussian random noise + random_style (bool): Whether to sample a random style vector. + use_fixed_random_style (bool): Sample random style once and use it + for all the remaining inference. + keep_original_size (bool): Keep original size of the input. + Returns: + (dict): + - fake_images (N x 3 x H x W tensor): fake images + - mu (N x C1 tensor): mean vectors + - logvar (N x C1 tensor): log-variance vectors + """ + self.eval() + self.spade_generator.eval() + + if self.use_style_encoder: + if random_style and self.use_style_encoder: + if self.z is None or not use_fixed_random_style: + bs = data['label'].size(0) + z = torch.randn( + bs, self.style_dims, dtype=torch.float32).to('cuda') + if (data['label'].dtype == + data['label'].dtype == + torch.float16): + z = z.half() + self.z = z + else: + z = self.z + else: + mu, logvar, z = self.style_encoder(data['images']) + data['z'] = z + + output = self.spade_generator(data) + output_images = output['fake_images'] + + if keep_original_size: + height = data['original_h_w'][0][0] + width = data['original_h_w'][0][1] + output_images = torch.nn.functional.interpolate( + output_images, size=[height, width]) + + for key in data['key'].keys(): + if 'segmaps' in key or 'seg_maps' in key: + file_names = data['key'][key][0] + break + for key in data['key'].keys(): + if 'edgemaps' in key or 'edge_maps' in key: + file_names = data['key'][key][0] + break + + return output_images, file_names + + +class SPADEGenerator(nn.Module): + r"""SPADE Image Generator constructor. + + Args: + num_labels (int): Number of different labels. + out_image_small_side_size (int): min(width, height) + image_channels (int): Num. of channels of the output image. + num_filters (int): Base filter numbers. + kernel_size (int): Convolution kernel size. + style_dims (int): Dimensions of the style code. + activation_norm_params (obj): Spatially adaptive normalization param. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + global_adaptive_norm_type (str): Type of normalization in SPADE. + skip_activation_norm (bool): If ``True``, applies activation norm to the + shortcut connection in residual blocks. + use_style_encoder (bool): Whether to use global adaptive norm + like conditional batch norm or adaptive instance norm. + output_multiplier (float): A positive number multiplied to the output + """ + + def __init__(self, + num_labels, + out_image_small_side_size, + image_channels, + num_filters, + kernel_size, + style_dims, + activation_norm_params, + weight_norm_type, + global_adaptive_norm_type, + skip_activation_norm, + use_posenc_in_input_layer, + use_style_encoder, + output_multiplier): + super(SPADEGenerator, self).__init__() + self.output_multiplier = output_multiplier + self.use_style_encoder = use_style_encoder + self.use_posenc_in_input_layer = use_posenc_in_input_layer + self.out_image_small_side_size = out_image_small_side_size + self.num_filters = num_filters + padding = int(np.ceil((kernel_size - 1.0) / 2)) + nonlinearity = 'leakyrelu' + activation_norm_type = 'spatially_adaptive' + base_res2d_block = \ + functools.partial(Res2dBlock, + kernel_size=kernel_size, + padding=padding, + bias=[True, True, False], + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + skip_activation_norm=skip_activation_norm, + nonlinearity=nonlinearity, + order='NACNAC') + if self.use_style_encoder: + self.fc_0 = LinearBlock(style_dims, 2 * style_dims, + weight_norm_type=weight_norm_type, + nonlinearity='relu', + order='CAN') + self.fc_1 = LinearBlock(2 * style_dims, 2 * style_dims, + weight_norm_type=weight_norm_type, + nonlinearity='relu', + order='CAN') + + adaptive_norm_params = types.SimpleNamespace() + if not hasattr(adaptive_norm_params, 'cond_dims'): + setattr(adaptive_norm_params, 'cond_dims', 2 * style_dims) + if not hasattr(adaptive_norm_params, 'activation_norm_type'): + setattr(adaptive_norm_params, 'activation_norm_type', global_adaptive_norm_type) + if not hasattr(adaptive_norm_params, 'weight_norm_type'): + setattr(adaptive_norm_params, 'weight_norm_type', activation_norm_params.weight_norm_type) + if not hasattr(adaptive_norm_params, 'separate_projection'): + setattr(adaptive_norm_params, 'separate_projection', activation_norm_params.separate_projection) + adaptive_norm_params.activation_norm_params = types.SimpleNamespace() + setattr(adaptive_norm_params.activation_norm_params, 'affine', + activation_norm_params.activation_norm_params.affine) + base_cbn2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=1, + padding=padding, + bias=True, + weight_norm_type=weight_norm_type, + activation_norm_type='adaptive', + activation_norm_params=adaptive_norm_params, + nonlinearity=nonlinearity, + order='NAC') + else: + base_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=1, + padding=padding, + bias=True, + weight_norm_type=weight_norm_type, + nonlinearity=nonlinearity, + order='NAC') + in_num_labels = num_labels + in_num_labels += 2 if self.use_posenc_in_input_layer else 0 + self.head_0 = Conv2dBlock(in_num_labels, 8 * num_filters, + kernel_size=kernel_size, stride=1, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type='none', + nonlinearity=nonlinearity) + if self.use_style_encoder: + self.cbn_head_0 = base_cbn2d_block( + 8 * num_filters, 16 * num_filters) + else: + self.conv_head_0 = base_conv2d_block( + 8 * num_filters, 16 * num_filters) + self.head_1 = base_res2d_block(16 * num_filters, 16 * num_filters) + self.head_2 = base_res2d_block(16 * num_filters, 16 * num_filters) + + self.up_0a = base_res2d_block(16 * num_filters, 8 * num_filters) + if self.use_style_encoder: + self.cbn_up_0a = base_cbn2d_block( + 8 * num_filters, 8 * num_filters) + else: + self.conv_up_0a = base_conv2d_block( + 8 * num_filters, 8 * num_filters) + self.up_0b = base_res2d_block(8 * num_filters, 8 * num_filters) + + self.up_1a = base_res2d_block(8 * num_filters, 4 * num_filters) + if self.use_style_encoder: + self.cbn_up_1a = base_cbn2d_block( + 4 * num_filters, 4 * num_filters) + else: + self.conv_up_1a = base_conv2d_block( + 4 * num_filters, 4 * num_filters) + self.up_1b = base_res2d_block(4 * num_filters, 4 * num_filters) + self.up_2a = base_res2d_block(4 * num_filters, 4 * num_filters) + if self.use_style_encoder: + self.cbn_up_2a = base_cbn2d_block( + 4 * num_filters, 4 * num_filters) + else: + self.conv_up_2a = base_conv2d_block( + 4 * num_filters, 4 * num_filters) + self.up_2b = base_res2d_block(4 * num_filters, 2 * num_filters) + self.conv_img256 = Conv2dBlock(2 * num_filters, image_channels, + 5, stride=1, padding=2, + weight_norm_type=weight_norm_type, + activation_norm_type='none', + nonlinearity=nonlinearity, + order='ANC') + self.base = 16 + if self.out_image_small_side_size == 512: + self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters) + self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters) + self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels, + 5, stride=1, padding=2, + weight_norm_type=weight_norm_type, + activation_norm_type='none', + nonlinearity=nonlinearity, + order='ANC') + self.base = 32 + if self.out_image_small_side_size == 1024: + self.up_3a = base_res2d_block(2 * num_filters, 1 * num_filters) + self.up_3b = base_res2d_block(1 * num_filters, 1 * num_filters) + self.conv_img512 = Conv2dBlock(1 * num_filters, image_channels, + 5, stride=1, padding=2, + weight_norm_type=weight_norm_type, + activation_norm_type='none', + nonlinearity=nonlinearity, + order='ANC') + self.up_4a = base_res2d_block(num_filters, num_filters // 2) + self.up_4b = base_res2d_block(num_filters // 2, num_filters // 2) + self.conv_img1024 = Conv2dBlock(num_filters // 2, image_channels, + 5, stride=1, padding=2, + weight_norm_type=weight_norm_type, + activation_norm_type='none', + nonlinearity=nonlinearity, + order='ANC') + self.nearest_upsample4x = NearestUpsample(scale_factor=4, mode='nearest') + self.base = 64 + if self.out_image_small_side_size != 256 and self.out_image_small_side_size != 512 \ + and self.out_image_small_side_size != 1024: + raise ValueError('Generation image size (%d, %d) not supported' % + (self.out_image_small_side_size, + self.out_image_small_side_size)) + self.nearest_upsample2x = NearestUpsample(scale_factor=2, mode='nearest') + + xv, yv = torch.meshgrid( + [torch.arange(-1, 1.1, 2. / 15), torch.arange(-1, 1.1, 2. / 15)]) + self.xy = torch.cat((xv.unsqueeze(0), yv.unsqueeze(0)), 0).unsqueeze(0) + self.xy = self.xy.cuda() + + def forward(self, data): + r"""SPADE Generator forward. + + Args: + data (dict): + - data (N x C1 x H x W tensor) : Ground truth images. + - label (N x C2 x H x W tensor) : Semantic representations. + - z (N x style_dims tensor): Gaussian random noise. + Returns: + output (dict): + - fake_images (N x 3 x H x W tensor): Fake images. + """ + seg = data['label'] + + if self.use_style_encoder: + z = data['z'] + z = self.fc_0(z) + z = self.fc_1(z) + + # The code piece below makes sure that the input size is always 16x16 + sy = math.floor(seg.size()[2] * 1.0 / self.base) + sx = math.floor(seg.size()[3] * 1.0 / self.base) + + in_seg = F.interpolate(seg, size=[sy, sx], mode='nearest') + if self.use_posenc_in_input_layer: + in_xy = F.interpolate(self.xy, size=[sy, sx], mode='bicubic') + in_seg_xy = torch.cat( + (in_seg, in_xy.expand(in_seg.size()[0], 2, sy, sx)), 1) + else: + in_seg_xy = in_seg + # 16x16 + x = self.head_0(in_seg_xy) + if self.use_style_encoder: + x = self.cbn_head_0(x, z) + else: + x = self.conv_head_0(x) + x = self.head_1(x, seg) + x = self.head_2(x, seg) + x = self.nearest_upsample2x(x) + # 32x32 + x = self.up_0a(x, seg) + if self.use_style_encoder: + x = self.cbn_up_0a(x, z) + else: + x = self.conv_up_0a(x) + x = self.up_0b(x, seg) + x = self.nearest_upsample2x(x) + # 64x64 + x = self.up_1a(x, seg) + if self.use_style_encoder: + x = self.cbn_up_1a(x, z) + else: + x = self.conv_up_1a(x) + x = self.up_1b(x, seg) + x = self.nearest_upsample2x(x) + # 128x128 + x = self.up_2a(x, seg) + if self.use_style_encoder: + x = self.cbn_up_2a(x, z) + else: + x = self.conv_up_2a(x) + x = self.up_2b(x, seg) + x = self.nearest_upsample2x(x) + # 256x256 + if self.out_image_small_side_size == 256: + x256 = self.conv_img256(x) + x = torch.tanh(self.output_multiplier * x256) + # 512x512 + elif self.out_image_small_side_size == 512: + x256 = self.conv_img256(x) + x256 = self.nearest_upsample2x(x256) + x = self.up_3a(x, seg) + x = self.up_3b(x, seg) + x = self.nearest_upsample2x(x) + x512 = self.conv_img512(x) + x = torch.tanh(self.output_multiplier * (x256 + x512)) + # 1024x1024 + elif self.out_image_small_side_size == 1024: + x256 = self.conv_img256(x) + x256 = self.nearest_upsample4x(x256) + x = self.up_3a(x, seg) + x = self.up_3b(x, seg) + x = self.nearest_upsample2x(x) + x512 = self.conv_img512(x) + x512 = self.nearest_upsample2x(x512) + x = self.up_4a(x, seg) + x = self.up_4b(x, seg) + x = self.nearest_upsample2x(x) + x1024 = self.conv_img1024(x) + x = torch.tanh(self.output_multiplier * (x256 + x512 + x1024)) + output = dict() + output['fake_images'] = x + return output + + +class StyleEncoder(nn.Module): + r"""Style Encode constructor. + + Args: + style_enc_cfg (obj): Style encoder definition file. + """ + + def __init__(self, style_enc_cfg): + super(StyleEncoder, self).__init__() + input_image_channels = style_enc_cfg.input_image_channels + num_filters = style_enc_cfg.num_filters + kernel_size = style_enc_cfg.kernel_size + padding = int(np.ceil((kernel_size - 1.0) / 2)) + style_dims = style_enc_cfg.style_dims + weight_norm_type = style_enc_cfg.weight_norm_type + activation_norm_type = 'none' + nonlinearity = 'leakyrelu' + base_conv2d_block = \ + functools.partial(Conv2dBlock, + kernel_size=kernel_size, + stride=2, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + # inplace_nonlinearity=True, + nonlinearity=nonlinearity) + self.layer1 = base_conv2d_block(input_image_channels, num_filters) + self.layer2 = base_conv2d_block(num_filters * 1, num_filters * 2) + self.layer3 = base_conv2d_block(num_filters * 2, num_filters * 4) + self.layer4 = base_conv2d_block(num_filters * 4, num_filters * 8) + self.layer5 = base_conv2d_block(num_filters * 8, num_filters * 8) + self.layer6 = base_conv2d_block(num_filters * 8, num_filters * 8) + self.fc_mu = LinearBlock(num_filters * 8 * 4 * 4, style_dims) + self.fc_var = LinearBlock(num_filters * 8 * 4 * 4, style_dims) + + def forward(self, input_x): + r"""SPADE Style Encoder forward. + + Args: + input_x (N x 3 x H x W tensor): input images. + Returns: + (tuple): + - mu (N x C tensor): Mean vectors. + - logvar (N x C tensor): Log-variance vectors. + - z (N x C tensor): Style code vectors. + """ + if input_x.size(2) != 256 or input_x.size(3) != 256: + input_x = F.interpolate(input_x, size=(256, 256), mode='bilinear') + x = self.layer1(input_x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.layer5(x) + x = self.layer6(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + z = eps.mul(std) + mu + return mu, logvar, z diff --git a/imaginaire/generators/unit.py b/imaginaire/generators/unit.py new file mode 100644 index 0000000000000000000000000000000000000000..c09f1b050d59de4940d49c7336ddd2f5928c55c6 --- /dev/null +++ b/imaginaire/generators/unit.py @@ -0,0 +1,312 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import warnings + +from torch import nn +from torch.nn import Upsample as NearestUpsample + +from imaginaire.layers import Conv2dBlock, Res2dBlock + + +class Generator(nn.Module): + r"""Improved UNIT generator. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super().__init__() + self.autoencoder_a = AutoEncoder(**vars(gen_cfg)) + self.autoencoder_b = AutoEncoder(**vars(gen_cfg)) + + def forward(self, data, image_recon=True, cycle_recon=True): + r"""UNIT forward function""" + images_a = data['images_a'] + images_b = data['images_b'] + net_G_output = dict() + + # encode input images into latent code + content_a = self.autoencoder_a.content_encoder(images_a) + content_b = self.autoencoder_b.content_encoder(images_b) + + # decode (within domain) + if image_recon: + images_aa = self.autoencoder_a.decoder(content_a) + images_bb = self.autoencoder_b.decoder(content_b) + net_G_output.update(dict(images_aa=images_aa, images_bb=images_bb)) + + # decode (cross domain) + images_ba = self.autoencoder_a.decoder(content_b) + images_ab = self.autoencoder_b.decoder(content_a) + + # cycle reconstruction + if cycle_recon: + content_ba = self.autoencoder_a.content_encoder(images_ba) + content_ab = self.autoencoder_b.content_encoder(images_ab) + images_aba = self.autoencoder_a.decoder(content_ab) + images_bab = self.autoencoder_b.decoder(content_ba) + net_G_output.update( + dict(content_ba=content_ba, content_ab=content_ab, + images_aba=images_aba, images_bab=images_bab)) + + # required outputs + net_G_output.update(dict(content_a=content_a, content_b=content_b, + images_ba=images_ba, images_ab=images_ab)) + + return net_G_output + + def inference(self, data, a2b=True): + r"""UNIT inference. + + Args: + data (dict): Training data at the current iteration. + - images_a (tensor): Images from domain A. + - images_b (tensor): Images from domain B. + a2b (bool): If ``True``, translates images from domain A to B, + otherwise from B to A. + """ + if a2b: + input_key = 'images_a' + content_encode = self.autoencoder_a.content_encoder + decode = self.autoencoder_b.decoder + else: + input_key = 'images_b' + content_encode = self.autoencoder_b.content_encoder + decode = self.autoencoder_a.decoder + + content_images = data[input_key] + content = content_encode(content_images) + output_images = decode(content) + filename = '%s/%s' % ( + data['key'][input_key]['sequence_name'][0], + data['key'][input_key]['filename'][0]) + filenames = [filename] + return output_images, filenames + + +class AutoEncoder(nn.Module): + r"""Improved UNIT autoencoder. + + Args: + num_filters (int): Base filter numbers. + max_num_filters (int): Maximum number of filters in the encoder. + num_res_blocks (int): Number of residual blocks at the end of the + content encoder. + num_downsamples_content (int): Number of times we reduce + resolution by 2x2 for the content image. + num_image_channels (int): Number of input image channels. + content_norm_type (str): Type of activation normalization in the + content encoder. + decoder_norm_type (str): Type of activation normalization in the + decoder. + weight_norm_type (str): Type of weight normalization. + output_nonlinearity (str): Type of nonlinearity before final output, + ``'tanh'`` or ``'none'``. + pre_act (bool): If ``True``, uses pre-activation residual blocks. + apply_noise (bool): If ``True``, injects Gaussian noise in the decoder. + """ + + def __init__(self, + num_filters=64, + max_num_filters=256, + num_res_blocks=4, + num_downsamples_content=2, + num_image_channels=3, + content_norm_type='instance', + decoder_norm_type='instance', + weight_norm_type='', + output_nonlinearity='', + pre_act=False, + apply_noise=False, + **kwargs): + super().__init__() + for key in kwargs: + if key != 'type': + warnings.warn( + "Generator argument '{}' is not used.".format(key)) + self.content_encoder = ContentEncoder(num_downsamples_content, + num_res_blocks, + num_image_channels, + num_filters, + max_num_filters, + 'reflect', + content_norm_type, + weight_norm_type, + 'relu', + pre_act) + self.decoder = Decoder(num_downsamples_content, + num_res_blocks, + self.content_encoder.output_dim, + num_image_channels, + 'reflect', + decoder_norm_type, + weight_norm_type, + 'relu', + output_nonlinearity, + pre_act, + apply_noise) + + def forward(self, images): + r"""Reconstruct an image. + + Args: + images (Tensor): Input images. + Returns: + images_recon (Tensor): Reconstructed images. + """ + content = self.content_encoder(images) + images_recon = self.decoder(content) + return images_recon + + +class ContentEncoder(nn.Module): + r"""Improved UNIT encoder. The network consists of: + + - input layers + - $(num_downsamples) convolutional blocks + - $(num_res_blocks) residual blocks. + - output layer. + + Args: + num_downsamples (int): Number of times we reduce + resolution by 2x2. + num_res_blocks (int): Number of residual blocks at the end of the + content encoder. + num_image_channels (int): Number of input image channels. + num_filters (int): Base filter numbers. + max_num_filters (int): Maximum number of filters in the encoder. + padding_mode (string): Type of padding. + activation_norm_type (str): Type of activation normalization. + weight_norm_type (str): Type of weight normalization. + nonlinearity (str): Type of nonlinear activation function. + pre_act (bool): If ``True``, uses pre-activation residual blocks. + """ + + def __init__(self, + num_downsamples, + num_res_blocks, + num_image_channels, + num_filters, + max_num_filters, + padding_mode, + activation_norm_type, + weight_norm_type, + nonlinearity, + pre_act=False): + super().__init__() + conv_params = dict(padding_mode=padding_mode, + activation_norm_type=activation_norm_type, + weight_norm_type=weight_norm_type, + nonlinearity=nonlinearity) + # Whether or not it is safe to use inplace nonlinear activation. + if not pre_act or (activation_norm_type != '' and + activation_norm_type != 'none'): + conv_params['inplace_nonlinearity'] = True + + # The order of operations in residual blocks. + order = 'pre_act' if pre_act else 'CNACNA' + + model = [] + model += [Conv2dBlock(num_image_channels, num_filters, 7, 1, 3, + **conv_params)] + + # Downsampling blocks. + for i in range(num_downsamples): + num_filters_prev = num_filters + num_filters = min(num_filters * 2, max_num_filters) + model += [Conv2dBlock(num_filters_prev, num_filters, 4, 2, 1, + **conv_params)] + + # Residual blocks. + for _ in range(num_res_blocks): + model += [Res2dBlock(num_filters, num_filters, + **conv_params, + order=order)] + self.model = nn.Sequential(*model) + self.output_dim = num_filters + + def forward(self, x): + r""" + + Args: + x (tensor): Input image. + """ + return self.model(x) + + +class Decoder(nn.Module): + r"""Improved UNIT decoder. The network consists of: + + - $(num_res_blocks) residual blocks. + - $(num_upsamples) residual blocks or convolutional blocks + - output layer. + + Args: + num_upsamples (int): Number of times we increase resolution by 2x2. + num_res_blocks (int): Number of residual blocks. + num_filters (int): Base filter numbers. + num_image_channels (int): Number of input image channels. + padding_mode (string): Type of padding. + activation_norm_type (str): Type of activation normalization. + weight_norm_type (str): Type of weight normalization. + nonlinearity (str): Type of nonlinear activation function. + output_nonlinearity (str): Type of nonlinearity before final output, + ``'tanh'`` or ``'none'``. + pre_act (bool): If ``True``, uses pre-activation residual blocks. + apply_noise (bool): If ``True``, injects Gaussian noise. + """ + + def __init__(self, + num_upsamples, + num_res_blocks, + num_filters, + num_image_channels, + padding_mode, + activation_norm_type, + weight_norm_type, + nonlinearity, + output_nonlinearity, + pre_act=False, + apply_noise=False): + super().__init__() + + conv_params = dict(padding_mode=padding_mode, + nonlinearity=nonlinearity, + inplace_nonlinearity=True, + apply_noise=apply_noise, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type) + + # The order of operations in residual blocks. + order = 'pre_act' if pre_act else 'CNACNA' + + # Residual blocks. + self.decoder = nn.ModuleList() + for _ in range(num_res_blocks): + self.decoder += [Res2dBlock(num_filters, num_filters, + **conv_params, + order=order)] + + # Convolutional blocks with upsampling. + for i in range(num_upsamples): + self.decoder += [NearestUpsample(scale_factor=2)] + self.decoder += [Conv2dBlock(num_filters, num_filters // 2, + 5, 1, 2, **conv_params)] + num_filters //= 2 + self.decoder += [Conv2dBlock(num_filters, num_image_channels, 7, 1, 3, + nonlinearity=output_nonlinearity, + padding_mode=padding_mode)] + + def forward(self, x): + r""" + + Args: + x (tensor): Content embedding of the content image. + """ + for block in self.decoder: + x = block(x) + return x diff --git a/imaginaire/generators/vid2vid.py b/imaginaire/generators/vid2vid.py new file mode 100644 index 0000000000000000000000000000000000000000..78262debc9d2ea6106a0816b53c4c73c5a5a6053 --- /dev/null +++ b/imaginaire/generators/vid2vid.py @@ -0,0 +1,481 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from imaginaire.generators.fs_vid2vid import LabelEmbedder +from imaginaire.layers import Conv2dBlock, LinearBlock, Res2dBlock +from imaginaire.model_utils.fs_vid2vid import (extract_valid_pose_labels, + resample) +from imaginaire.utils.data import (get_paired_input_image_channel_number, + get_paired_input_label_channel_number) +from imaginaire.utils.init_weight import weights_init + + +class BaseNetwork(nn.Module): + r"""vid2vid generator.""" + + def __init__(self): + super(BaseNetwork, self).__init__() + + def get_num_filters(self, num_downsamples): + r"""Get the number of filters at current layer. + + Args: + num_downsamples (int) : How many downsamples at current layer. + Returns: + output (int) : Number of filters. + """ + return min(self.max_num_filters, + self.num_filters * (2 ** num_downsamples)) + + +class Generator(BaseNetwork): + r"""vid2vid generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, gen_cfg, data_cfg): + super().__init__() + self.gen_cfg = gen_cfg + self.data_cfg = data_cfg + self.num_frames_G = data_cfg.num_frames_G + # Number of residual blocks in generator. + self.num_layers = num_layers = getattr(gen_cfg, 'num_layers', 7) + # Number of downsamplings for previous frame. + self.num_downsamples_img = getattr(gen_cfg, 'num_downsamples_img', 4) + # Number of filters in the first layer. + self.num_filters = num_filters = getattr(gen_cfg, 'num_filters', 32) + self.max_num_filters = getattr(gen_cfg, 'max_num_filters', 1024) + self.kernel_size = kernel_size = getattr(gen_cfg, 'kernel_size', 3) + padding = kernel_size // 2 + + # For pose dataset. + self.is_pose_data = hasattr(data_cfg, 'for_pose_dataset') + if self.is_pose_data: + pose_cfg = data_cfg.for_pose_dataset + self.pose_type = getattr(pose_cfg, 'pose_type', 'both') + self.remove_face_labels = getattr(pose_cfg, 'remove_face_labels', + False) + + # Input data params. + num_input_channels = get_paired_input_label_channel_number(data_cfg) + num_img_channels = get_paired_input_image_channel_number(data_cfg) + aug_cfg = data_cfg.val.augmentations + if hasattr(aug_cfg, 'center_crop_h_w'): + crop_h_w = aug_cfg.center_crop_h_w + elif hasattr(aug_cfg, 'resize_h_w'): + crop_h_w = aug_cfg.resize_h_w + else: + raise ValueError('Need to specify output size.') + crop_h, crop_w = crop_h_w.split(',') + crop_h, crop_w = int(crop_h), int(crop_w) + # Spatial size at the bottle neck of generator. + self.sh = crop_h // (2 ** num_layers) + self.sw = crop_w // (2 ** num_layers) + + # Noise vector dimension. + self.z_dim = getattr(gen_cfg, 'style_dims', 256) + self.use_segmap_as_input = \ + getattr(gen_cfg, 'use_segmap_as_input', False) + + # Label / image embedding network. + self.emb_cfg = emb_cfg = getattr(gen_cfg, 'embed', None) + self.use_embed = getattr(emb_cfg, 'use_embed', 'True') + self.num_downsamples_embed = getattr(emb_cfg, 'num_downsamples', 5) + if self.use_embed: + self.label_embedding = LabelEmbedder(emb_cfg, num_input_channels) + + # Flow network. + self.flow_cfg = flow_cfg = gen_cfg.flow + # Use SPADE to combine warped and hallucinated frames instead of + # linear combination. + self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True) + # Number of layers to perform multi-spade combine. + self.num_multi_spade_layers = getattr(flow_cfg.multi_spade_combine, + 'num_layers', 3) + # At beginning of training, only train an image generator. + self.temporal_initialized = False + # Whether to output hallucinated frame (when training temporal network) + # for additional loss. + self.generate_raw_output = False + + # Image generation network. + weight_norm_type = getattr(gen_cfg, 'weight_norm_type', 'spectral') + activation_norm_type = gen_cfg.activation_norm_type + activation_norm_params = gen_cfg.activation_norm_params + if self.use_embed and \ + not hasattr(activation_norm_params, 'num_filters'): + activation_norm_params.num_filters = 0 + nonlinearity = 'leakyrelu' + + self.base_res_block = base_res_block = partial( + Res2dBlock, kernel_size=kernel_size, padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + nonlinearity=nonlinearity, order='NACNAC') + + # Upsampling residual blocks. + for i in range(num_layers, -1, -1): + activation_norm_params.cond_dims = self.get_cond_dims(i) + activation_norm_params.partial = self.get_partial( + i) if hasattr(self, 'get_partial') else False + layer = base_res_block(self.get_num_filters(i + 1), + self.get_num_filters(i)) + setattr(self, 'up_%d' % i, layer) + + # Final conv layer. + self.conv_img = Conv2dBlock(num_filters, num_img_channels, + kernel_size, padding=padding, + nonlinearity=nonlinearity, order='AC') + + num_filters = min(self.max_num_filters, + num_filters * (2 ** (self.num_layers + 1))) + if self.use_segmap_as_input: + self.fc = Conv2dBlock(num_input_channels, num_filters, + kernel_size=3, padding=1) + else: + self.fc = LinearBlock(self.z_dim, num_filters * self.sh * self.sw) + + # Misc. + self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.upsample = partial(F.interpolate, scale_factor=2) + self.init_temporal_network() + + def forward(self, data): + r"""vid2vid generator forward. + + Args: + data (dict) : Dictionary of input data. + Returns: + output (dict) : Dictionary of output data. + """ + label = data['label'] + label_prev, img_prev = data['prev_labels'], data['prev_images'] + is_first_frame = img_prev is None + z = getattr(data, 'z', None) + bs, _, h, w = label.size() + + if self.is_pose_data: + label, label_prev = extract_valid_pose_labels( + [label, label_prev], self.pose_type, self.remove_face_labels) + + # Get SPADE conditional maps by embedding current label input. + cond_maps_now = self.get_cond_maps(label, self.label_embedding) + + # Input to the generator will either be noise/segmentation map (for + # first frame) or encoded previous frame (for subsequent frames). + if is_first_frame: + # First frame in the sequence, start from scratch. + if self.use_segmap_as_input: + x_img = F.interpolate(label, size=(self.sh, self.sw)) + x_img = self.fc(x_img) + else: + if z is None: + z = torch.randn(bs, self.z_dim, dtype=label.dtype, + device=label.get_device()).fill_(0) + x_img = self.fc(z).view(bs, -1, self.sh, self.sw) + + # Upsampling layers. + for i in range(self.num_layers, self.num_downsamples_img, -1): + j = min(self.num_downsamples_embed, i) + x_img = getattr(self, 'up_' + str(i))(x_img, *cond_maps_now[j]) + x_img = self.upsample(x_img) + else: + # Not the first frame, will encode the previous frame and feed to + # the generator. + x_img = self.down_first(img_prev[:, -1]) + + # Get label embedding for the previous frame. + cond_maps_prev = self.get_cond_maps(label_prev[:, -1], + self.label_embedding) + + # Downsampling layers. + for i in range(self.num_downsamples_img + 1): + j = min(self.num_downsamples_embed, i) + x_img = getattr(self, 'down_' + str(i))(x_img, + *cond_maps_prev[j]) + if i != self.num_downsamples_img: + x_img = self.downsample(x_img) + + # Resnet blocks. + j = min(self.num_downsamples_embed, self.num_downsamples_img + 1) + for i in range(self.num_res_blocks): + cond_maps = cond_maps_prev[j] if i < self.num_res_blocks // 2 \ + else cond_maps_now[j] + x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps) + + flow = mask = img_warp = None + + num_frames_G = self.num_frames_G + # Whether to warp the previous frame or not. + warp_prev = self.temporal_initialized and not is_first_frame and \ + label_prev.shape[1] == num_frames_G - 1 + if warp_prev: + # Estimate flow & mask. + label_concat = torch.cat([label_prev.view(bs, -1, h, w), + label], dim=1) + img_prev_concat = img_prev.view(bs, -1, h, w) + flow, mask = self.flow_network_temp(label_concat, img_prev_concat) + img_warp = resample(img_prev[:, -1], flow) + if self.spade_combine: + # if using SPADE combine, integrate the warped image (and + # occlusion mask) into conditional inputs for SPADE. + img_embed = torch.cat([img_warp, mask], dim=1) + cond_maps_img = self.get_cond_maps(img_embed, + self.img_prev_embedding) + x_raw_img = None + + # Main image generation branch. + for i in range(self.num_downsamples_img, -1, -1): + # Get SPADE conditional inputs. + j = min(i, self.num_downsamples_embed) + cond_maps = cond_maps_now[j] + + # For raw output generation. + if self.generate_raw_output: + if i >= self.num_multi_spade_layers - 1: + x_raw_img = x_img + if i < self.num_multi_spade_layers: + x_raw_img = self.one_up_conv_layer(x_raw_img, cond_maps, i) + + # For final output. + if warp_prev and i < self.num_multi_spade_layers: + cond_maps += cond_maps_img[j] + x_img = self.one_up_conv_layer(x_img, cond_maps, i) + + # Final conv layer. + img_final = torch.tanh(self.conv_img(x_img)) + + img_raw = None + if self.spade_combine and self.generate_raw_output: + img_raw = torch.tanh(self.conv_img(x_raw_img)) + if warp_prev and not self.spade_combine: + img_raw = img_final + img_final = img_final * mask + img_warp * (1 - mask) + + output = dict() + output['fake_images'] = img_final + output['fake_flow_maps'] = flow + output['fake_occlusion_masks'] = mask + output['fake_raw_images'] = img_raw + output['warped_images'] = img_warp + return output + + def one_up_conv_layer(self, x, encoded_label, i): + r"""One residual block layer in the main branch. + + Args: + x (4D tensor) : Current feature map. + encoded_label (list of tensors) : Encoded input label maps. + i (int) : Layer index. + Returns: + x (4D tensor) : Output feature map. + """ + layer = getattr(self, 'up_' + str(i)) + x = layer(x, *encoded_label) + if i != 0: + x = self.upsample(x) + return x + + def init_temporal_network(self, cfg_init=None): + r"""When starting training multiple frames, initialize the + downsampling network and flow network. + + Args: + cfg_init (dict) : Weight initialization config. + """ + # Number of image downsamplings for the previous frame. + num_downsamples_img = self.num_downsamples_img + # Number of residual blocks for the previous frame. + self.num_res_blocks = int( + np.ceil((self.num_layers - num_downsamples_img) / 2.0) * 2) + + # First conv layer. + num_img_channels = get_paired_input_image_channel_number(self.data_cfg) + self.down_first = \ + Conv2dBlock(num_img_channels, + self.num_filters, self.kernel_size, + padding=self.kernel_size // 2) + if cfg_init is not None: + self.down_first.apply(weights_init(cfg_init.type, cfg_init.gain)) + + # Downsampling residual blocks. + activation_norm_params = self.gen_cfg.activation_norm_params + for i in range(num_downsamples_img + 1): + activation_norm_params.cond_dims = self.get_cond_dims(i) + layer = self.base_res_block(self.get_num_filters(i), + self.get_num_filters(i + 1)) + if cfg_init is not None: + layer.apply(weights_init(cfg_init.type, cfg_init.gain)) + setattr(self, 'down_%d' % i, layer) + + # Additional residual blocks. + res_ch = self.get_num_filters(num_downsamples_img + 1) + activation_norm_params.cond_dims = \ + self.get_cond_dims(num_downsamples_img + 1) + for i in range(self.num_res_blocks): + layer = self.base_res_block(res_ch, res_ch) + if cfg_init is not None: + layer.apply(weights_init(cfg_init.type, cfg_init.gain)) + setattr(self, 'res_%d' % i, layer) + + # Flow network. + flow_cfg = self.flow_cfg + self.temporal_initialized = True + self.generate_raw_output = getattr(flow_cfg, 'generate_raw_output', + False) and self.spade_combine + self.flow_network_temp = FlowGenerator(flow_cfg, self.data_cfg) + if cfg_init is not None: + self.flow_network_temp.apply(weights_init(cfg_init.type, + cfg_init.gain)) + + self.spade_combine = getattr(flow_cfg, 'multi_spade_combine', True) + if self.spade_combine: + emb_cfg = flow_cfg.multi_spade_combine.embed + num_img_channels = get_paired_input_image_channel_number( + self.data_cfg) + self.img_prev_embedding = LabelEmbedder(emb_cfg, + num_img_channels + 1) + if cfg_init is not None: + self.img_prev_embedding.apply(weights_init(cfg_init.type, + cfg_init.gain)) + + def get_cond_dims(self, num_downs=0): + r"""Get the dimensions of conditional inputs. + + Args: + num_downs (int) : How many downsamples at current layer. + Returns: + ch (list) : List of dimensions. + """ + if not self.use_embed: + ch = [self.num_input_channels] + else: + num_filters = getattr(self.emb_cfg, 'num_filters', 32) + num_downs = min(num_downs, self.num_downsamples_embed) + ch = [min(self.max_num_filters, num_filters * (2 ** num_downs))] + if (num_downs < self.num_multi_spade_layers): + ch = ch * 2 + return ch + + def get_cond_maps(self, label, embedder): + r"""Get the conditional inputs. + + Args: + label (4D tensor) : Input label tensor. + embedder (obj) : Embedding network. + Returns: + cond_maps (list) : List of conditional inputs. + """ + if not self.use_embed: + return [label] * (self.num_layers + 1) + embedded_label = embedder(label) + cond_maps = [embedded_label] + cond_maps = [[m[i] for m in cond_maps] for i in + range(len(cond_maps[0]))] + return cond_maps + + +class FlowGenerator(BaseNetwork): + r"""Flow generator constructor. + + Args: + flow_cfg (obj): Flow definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file. + """ + + def __init__(self, flow_cfg, data_cfg): + super().__init__() + num_input_channels = get_paired_input_label_channel_number(data_cfg) + num_prev_img_channels = get_paired_input_image_channel_number(data_cfg) + num_frames = data_cfg.num_frames_G # Num. of input frames. + + self.num_filters = num_filters = getattr(flow_cfg, 'num_filters', 32) + self.max_num_filters = getattr(flow_cfg, 'max_num_filters', 1024) + num_downsamples = getattr(flow_cfg, 'num_downsamples', 5) + kernel_size = getattr(flow_cfg, 'kernel_size', 3) + padding = kernel_size // 2 + self.num_res_blocks = getattr(flow_cfg, 'num_res_blocks', 6) + # Multiplier on the flow output. + self.flow_output_multiplier = getattr(flow_cfg, + 'flow_output_multiplier', 20) + + activation_norm_type = getattr(flow_cfg, 'activation_norm_type', + 'sync_batch') + weight_norm_type = getattr(flow_cfg, 'weight_norm_type', 'spectral') + + base_conv_block = partial(Conv2dBlock, kernel_size=kernel_size, + padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + nonlinearity='leakyrelu') + + # Will downsample the labels and prev frames separately, then combine. + down_lbl = [base_conv_block(num_input_channels * num_frames, + num_filters)] + down_img = [base_conv_block(num_prev_img_channels * (num_frames - 1), + num_filters)] + for i in range(num_downsamples): + down_lbl += [base_conv_block(self.get_num_filters(i), + self.get_num_filters(i + 1), + stride=2)] + down_img += [base_conv_block(self.get_num_filters(i), + self.get_num_filters(i + 1), + stride=2)] + + # Resnet blocks. + res_flow = [] + ch = self.get_num_filters(num_downsamples) + for i in range(self.num_res_blocks): + res_flow += [ + Res2dBlock(ch, ch, kernel_size, padding=padding, + weight_norm_type=weight_norm_type, + activation_norm_type=activation_norm_type, + order='CNACN')] + + # Upsample. + up_flow = [] + for i in reversed(range(num_downsamples)): + up_flow += [nn.Upsample(scale_factor=2), + base_conv_block(self.get_num_filters(i + 1), + self.get_num_filters(i))] + + conv_flow = [Conv2dBlock(num_filters, 2, kernel_size, padding=padding)] + conv_mask = [Conv2dBlock(num_filters, 1, kernel_size, padding=padding, + nonlinearity='sigmoid')] + + self.down_lbl = nn.Sequential(*down_lbl) + self.down_img = nn.Sequential(*down_img) + self.res_flow = nn.Sequential(*res_flow) + self.up_flow = nn.Sequential(*up_flow) + self.conv_flow = nn.Sequential(*conv_flow) + self.conv_mask = nn.Sequential(*conv_mask) + + def forward(self, label, img_prev): + r"""Flow generator forward. + + Args: + label (4D tensor) : Input label tensor. + img_prev (4D tensor) : Previously generated image tensors. + Returns: + (tuple): + - flow (4D tensor) : Generated flow map. + - mask (4D tensor) : Generated occlusion mask. + """ + downsample = self.down_lbl(label) + self.down_img(img_prev) + res = self.res_flow(downsample) + flow_feat = self.up_flow(res) + flow = self.conv_flow(flow_feat) * self.flow_output_multiplier + mask = self.conv_mask(flow_feat) + return flow, mask diff --git a/imaginaire/generators/wc_vid2vid.py b/imaginaire/generators/wc_vid2vid.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8139a3c05140a567ef202eccd1d9209ba681b4 --- /dev/null +++ b/imaginaire/generators/wc_vid2vid.py @@ -0,0 +1,354 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import numpy as np +import torch +import torch.nn.functional as F +from torchvision import transforms + +from imaginaire.config import Config +from imaginaire.generators.vid2vid import Generator as Vid2VidGenerator +from imaginaire.model_utils.fs_vid2vid import resample +from imaginaire.model_utils.wc_vid2vid.render import SplatRenderer +from imaginaire.utils.trainer import (get_model_optimizer_and_scheduler, + get_trainer) +from imaginaire.utils.visualization import tensor2im + + +class Generator(Vid2VidGenerator): + r"""world consistent vid2vid generator constructor. + + Args: + gen_cfg (obj): Generator definition part of the yaml config file. + data_cfg (obj): Data definition part of the yaml config file + """ + + def __init__(self, gen_cfg, data_cfg): + # Guidance options. + self.guidance_cfg = gen_cfg.guidance + self.guidance_only_with_flow = getattr( + self.guidance_cfg, 'only_with_flow', False) + self.guidance_partial_conv = getattr( + self.guidance_cfg, 'partial_conv', False) + + # Splatter for guidance. + self.renderer = SplatRenderer() + self.reset_renderer() + + # Single image model. + self.single_image_model = None + + # Initialize the rest same as vid2vid. + super().__init__(gen_cfg, data_cfg) + + def _init_single_image_model(self, load_weights=True): + r"""Load single image model, if any.""" + if self.single_image_model is None and \ + hasattr(self.gen_cfg, 'single_image_model'): + print('Using single image model...') + single_image_cfg = Config(self.gen_cfg.single_image_model.config) + + # Init model. + net_G, net_D, opt_G, opt_D, sch_G, sch_D = \ + get_model_optimizer_and_scheduler(single_image_cfg) + + # Init trainer and load checkpoint. + trainer = get_trainer(single_image_cfg, net_G, net_D, + opt_G, opt_D, + sch_G, sch_D, + None, None) + if load_weights: + print('Loading single image model checkpoint') + single_image_ckpt = self.gen_cfg.single_image_model.checkpoint + trainer.load_checkpoint(single_image_cfg, single_image_ckpt) + print('Loaded single image model checkpoint') + + self.single_image_model = net_G.module + self.single_image_model_z = None + + def reset_renderer(self, is_flipped_input=False): + r"""Reset the renderer. + Args: + is_flipped_input (bool): Is the input sequence left-right flipped? + """ + self.renderer.reset() + self.is_flipped_input = is_flipped_input + self.renderer_num_forwards = 0 + self.single_image_model_z = None + + def renderer_update_point_cloud(self, image, point_info): + r"""Update the renderer's color dictionary.""" + if point_info is None or len(point_info) == 0: + return + # print('Updating the renderer.') + _, _, h, w = image.size() + + # Renderer expects (h, w, c) [0-255] RGB image. + if isinstance(image, torch.Tensor): + image = tensor2im(image.detach())[0] + + # Flip this image to correspond to SfM camera pose. + if self.is_flipped_input: + image = np.fliplr(image).copy() + + self.renderer.update_point_cloud(image, point_info) + self.renderer_num_forwards += 1 + + def get_guidance_images_and_masks(self, unprojection): + r"""Do stuff.""" + + resolution = 'w1024xh512' + point_info = unprojection[resolution] + + w, h = resolution.split('x') + w, h = int(w[1:]), int(h[1:]) + + # This returns guidance image in [0-255] RGB. + # We will convert it into Tensor repr. below. + guidance_image, guidance_mask = self.renderer.render_image( + point_info, w, h, return_mask=True) + + # If mask is None, there is no guidance. + # print(np.sum(guidance_mask), guidance_mask.size) + # if np.sum(guidance_mask) == 0: + # return None, point_info + + # Flip guidance image and guidance mask if needed. + if self.is_flipped_input: + guidance_image = np.fliplr(guidance_image).copy() + guidance_mask = np.fliplr(guidance_mask).copy() + + # Go from (h, w, c) to (1, c, h, w). + # Convert guidance image to Tensor. + guidance_image = (transforms.ToTensor()(guidance_image) - 0.5) * 2 + guidance_mask = transforms.ToTensor()(guidance_mask) + guidance = torch.cat((guidance_image, guidance_mask), dim=0) + guidance = guidance.unsqueeze(0).cuda() + + # Save guidance at all resolutions. + guidance_images_and_masks = guidance + + return guidance_images_and_masks, point_info + + def forward(self, data): + r"""vid2vid generator forward. + Args: + data (dict) : Dictionary of input data. + Returns: + output (dict) : Dictionary of output data. + """ + self._init_single_image_model() + + label = data['label'] + unprojection = data['unprojection'] + label_prev, img_prev = data['prev_labels'], data['prev_images'] + is_first_frame = img_prev is None + z = getattr(data, 'z', None) + bs, _, h, w = label.size() + + # Whether to warp the previous frame or not. + flow = mask = img_warp = None + warp_prev = self.temporal_initialized and not is_first_frame and \ + label_prev.shape[1] == self.num_frames_G - 1 + + # Get guidance images and masks. + guidance_images_and_masks, point_info = None, None + if unprojection is not None: + guidance_images_and_masks, point_info = \ + self.get_guidance_images_and_masks(unprojection) + + # Get SPADE conditional maps by embedding current label input. + cond_maps_now = self.get_cond_maps(label, self.label_embedding) + + # Use single image model, if flow features are not available. + # Guidance features are used whenever flow features are available. + if self.single_image_model is not None and not warp_prev: + # Get z vector for single image model. + if self.single_image_model_z is None: + bs = data['label'].size(0) + z = torch.randn(bs, self.single_image_model.style_dims, + dtype=torch.float32).cuda() + if data['label'].dtype == torch.float16: + z = z.half() + self.single_image_model_z = z + + # Get output image. + data['z'] = self.single_image_model_z + self.single_image_model.eval() + with torch.no_grad(): + output = self.single_image_model.spade_generator(data) + img_final = output['fake_images'].detach() + fake_images_source = 'pretrained' + else: + # Input to the generator will either be noise/segmentation map (for + # first frame) or encoded previous frame (for subsequent frames). + if is_first_frame: + # First frame in the sequence, start from scratch. + if self.use_segmap_as_input: + x_img = F.interpolate(label, size=(self.sh, self.sw)) + x_img = self.fc(x_img) + else: + if z is None: + z = torch.randn(bs, self.z_dim, dtype=label.dtype, + device=label.get_device()).fill_(0) + x_img = self.fc(z).view(bs, -1, self.sh, self.sw) + + # Upsampling layers. + for i in range(self.num_layers, self.num_downsamples_img, -1): + j = min(self.num_downsamples_embed, i) + x_img = getattr(self, 'up_' + str(i) + )(x_img, *cond_maps_now[j]) + x_img = self.upsample(x_img) + else: + # Not the first frame, will encode the previous frame and feed + # to the generator. + x_img = self.down_first(img_prev[:, -1]) + + # Get label embedding for the previous frame. + cond_maps_prev = self.get_cond_maps(label_prev[:, -1], + self.label_embedding) + + # Downsampling layers. + for i in range(self.num_downsamples_img + 1): + j = min(self.num_downsamples_embed, i) + x_img = getattr(self, 'down_' + str(i))(x_img, + *cond_maps_prev[j]) + if i != self.num_downsamples_img: + x_img = self.downsample(x_img) + + # Resnet blocks. + j = min(self.num_downsamples_embed, + self.num_downsamples_img + 1) + for i in range(self.num_res_blocks): + cond_maps = cond_maps_prev[j] if \ + i < self.num_res_blocks // 2 else cond_maps_now[j] + x_img = getattr(self, 'res_' + str(i))(x_img, *cond_maps) + + # Optical flow warped image features. + if warp_prev: + # Estimate flow & mask. + label_concat = torch.cat([label_prev.view(bs, -1, h, w), + label], dim=1) + img_prev_concat = img_prev.view(bs, -1, h, w) + flow, mask = self.flow_network_temp( + label_concat, img_prev_concat) + img_warp = resample(img_prev[:, -1], flow) + if self.spade_combine: + # if using SPADE combine, integrate the warped image (and + # occlusion mask) into conditional inputs for SPADE. + img_embed = torch.cat([img_warp, mask], dim=1) + cond_maps_img = self.get_cond_maps(img_embed, + self.img_prev_embedding) + x_raw_img = None + + # Main image generation branch. + for i in range(self.num_downsamples_img, -1, -1): + # Get SPADE conditional inputs. + j = min(i, self.num_downsamples_embed) + cond_maps = cond_maps_now[j] + + # For raw output generation. + if self.generate_raw_output: + if i >= self.num_multi_spade_layers - 1: + x_raw_img = x_img + if i < self.num_multi_spade_layers: + x_raw_img = self.one_up_conv_layer( + x_raw_img, cond_maps, i) + + # Add flow and guidance features. + if warp_prev: + if i < self.num_multi_spade_layers: + # Add flow. + cond_maps += cond_maps_img[j] + # Add guidance. + if guidance_images_and_masks is not None: + cond_maps += [guidance_images_and_masks] + elif not self.guidance_only_with_flow: + # Add guidance if it is to be applied to every layer. + if guidance_images_and_masks is not None: + cond_maps += [guidance_images_and_masks] + + x_img = self.one_up_conv_layer(x_img, cond_maps, i) + + # Final conv layer. + img_final = torch.tanh(self.conv_img(x_img)) + fake_images_source = 'in_training' + + # Update the point cloud color dict of renderer. + self.renderer_update_point_cloud(img_final, point_info) + + output = dict() + output['fake_images'] = img_final + output['fake_flow_maps'] = flow + output['fake_occlusion_masks'] = mask + output['fake_raw_images'] = None + output['warped_images'] = img_warp + output['guidance_images_and_masks'] = guidance_images_and_masks + output['fake_images_source'] = fake_images_source + return output + + def get_cond_dims(self, num_downs=0): + r"""Get the dimensions of conditional inputs. + Args: + num_downs (int) : How many downsamples at current layer. + Returns: + ch (list) : List of dimensions. + """ + if not self.use_embed: + ch = [self.num_input_channels] + else: + num_filters = getattr(self.emb_cfg, 'num_filters', 32) + num_downs = min(num_downs, self.num_downsamples_embed) + ch = [min(self.max_num_filters, num_filters * (2 ** num_downs))] + if (num_downs < self.num_multi_spade_layers): + ch = ch * 2 + # Also add guidance (RGB + mask = 4 channels, or 3 if partial). + if self.guidance_partial_conv: + ch.append(3) + else: + ch.append(4) + elif not self.guidance_only_with_flow: + if self.guidance_partial_conv: + ch.append(3) + else: + ch.append(4) + return ch + + def get_partial(self, num_downs=0): + r"""Get if convs should be partial or not. + Args: + num_downs (int) : How many downsamples at current layer. + Returns: + partial (list) : List of boolean partial or not. + """ + partial = [False] + if (num_downs < self.num_multi_spade_layers): + partial = partial * 2 + # Also add guidance (RGB + mask = 4 channels, or 3 if partial). + if self.guidance_partial_conv: + partial.append(True) + else: + partial.append(False) + elif not self.guidance_only_with_flow: + if self.guidance_partial_conv: + partial.append(True) + else: + partial.append(False) + return partial + + def get_cond_maps(self, label, embedder): + r"""Get the conditional inputs. + Args: + label (4D tensor) : Input label tensor. + embedder (obj) : Embedding network. + Returns: + cond_maps (list) : List of conditional inputs. + """ + if not self.use_embed: + return [label] * (self.num_layers + 1) + embedded_label = embedder(label) + cond_maps = [embedded_label] + cond_maps = [[m[i] for m in cond_maps] for i in + range(len(cond_maps[0]))] + return cond_maps diff --git a/imaginaire/layers/__init__.py b/imaginaire/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e3f93c154678b630de93ab3d6b199204c4fd8fb --- /dev/null +++ b/imaginaire/layers/__init__.py @@ -0,0 +1,27 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .conv import LinearBlock, Conv1dBlock, Conv2dBlock, Conv3dBlock, \ + HyperConv2dBlock, MultiOutConv2dBlock, \ + PartialConv2dBlock, PartialConv3dBlock +from .residual import ResLinearBlock, Res1dBlock, Res2dBlock, Res3dBlock, \ + HyperRes2dBlock, MultiOutRes2dBlock, UpRes2dBlock, DownRes2dBlock, \ + PartialRes2dBlock, PartialRes3dBlock +from .non_local import NonLocal2dBlock + +__all__ = ['Conv1dBlock', 'Conv2dBlock', 'Conv3dBlock', 'LinearBlock', + 'HyperConv2dBlock', 'MultiOutConv2dBlock', + 'PartialConv2dBlock', 'PartialConv3dBlock', + 'Res1dBlock', 'Res2dBlock', 'Res3dBlock', + 'UpRes2dBlock', 'DownRes2dBlock', + 'ResLinearBlock', 'HyperRes2dBlock', 'MultiOutRes2dBlock', + 'PartialRes2dBlock', 'PartialRes3dBlock', + 'NonLocal2dBlock'] + +try: + from .repvgg import RepVGG1dBlock, RepVGG2dBlock, RepVGG3dBlock + from .attn import MultiheadAttention + __all__.extend(['RepVGG1dBlock', 'RepVGG2dBlock', 'RepVGG3dBlock']) +except: # noqa + pass diff --git a/imaginaire/layers/__pycache__/__init__.cpython-38.pyc b/imaginaire/layers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..528061b6276196a6fdf6c254450dfc5f9f9f6a89 Binary files /dev/null and b/imaginaire/layers/__pycache__/__init__.cpython-38.pyc differ diff --git a/imaginaire/layers/__pycache__/activation_norm.cpython-38.pyc b/imaginaire/layers/__pycache__/activation_norm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..754bbd93d5840bc721eaaede366f15c5903d286b Binary files /dev/null and b/imaginaire/layers/__pycache__/activation_norm.cpython-38.pyc differ diff --git a/imaginaire/layers/__pycache__/conv.cpython-38.pyc b/imaginaire/layers/__pycache__/conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da61b20f0c5a954d150297691572cfcae2b75bee Binary files /dev/null and b/imaginaire/layers/__pycache__/conv.cpython-38.pyc differ diff --git a/imaginaire/layers/__pycache__/misc.cpython-38.pyc b/imaginaire/layers/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0cf730aee80241b8f869fd876e4e43102a25790 Binary files /dev/null and b/imaginaire/layers/__pycache__/misc.cpython-38.pyc differ diff --git a/imaginaire/layers/__pycache__/non_local.cpython-38.pyc b/imaginaire/layers/__pycache__/non_local.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9333890d786503983cdaf4be1fb7458fbc215ea0 Binary files /dev/null and b/imaginaire/layers/__pycache__/non_local.cpython-38.pyc differ diff --git a/imaginaire/layers/__pycache__/nonlinearity.cpython-38.pyc b/imaginaire/layers/__pycache__/nonlinearity.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a3a57d74453c76df61eaed8c0a4294479611dc2 Binary files /dev/null and b/imaginaire/layers/__pycache__/nonlinearity.cpython-38.pyc differ diff --git a/imaginaire/layers/__pycache__/residual.cpython-38.pyc b/imaginaire/layers/__pycache__/residual.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a64ea7db7e342783ae70a6348e6278f3fc030fc5 Binary files /dev/null and b/imaginaire/layers/__pycache__/residual.cpython-38.pyc differ diff --git a/imaginaire/layers/__pycache__/weight_norm.cpython-38.pyc b/imaginaire/layers/__pycache__/weight_norm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4fffd5e3452b5dbedacfeada43c7c969a7026b7 Binary files /dev/null and b/imaginaire/layers/__pycache__/weight_norm.cpython-38.pyc differ diff --git a/imaginaire/layers/activation_norm.py b/imaginaire/layers/activation_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..709a928eed76f456c6284f575ba47f2dff581b39 --- /dev/null +++ b/imaginaire/layers/activation_norm.py @@ -0,0 +1,629 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# flake8: noqa E722 +from types import SimpleNamespace + +import torch + +try: + from torch.nn import SyncBatchNorm +except ImportError: + from torch.nn import BatchNorm2d as SyncBatchNorm +from torch import nn +from torch.nn import functional as F +from .conv import LinearBlock, Conv2dBlock, HyperConv2d, PartialConv2dBlock +from .misc import PartialSequential, ApplyNoise + + +class AdaptiveNorm(nn.Module): + r"""Adaptive normalization layer. The layer first normalizes the input, then + performs an affine transformation using parameters computed from the + conditional inputs. + + Args: + num_features (int): Number of channels in the input tensor. + cond_dims (int): Number of channels in the conditional inputs. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'``, or ``'weight_demod'``. + projection (bool): If ``True``, project the conditional input to gamma + and beta using a fully connected layer, otherwise directly use + the conditional input as gamma and beta. + projection_bias (bool) If ``True``, use bias in the fully connected + projection layer. + separate_projection (bool): If ``True``, we will use two different + layers for gamma and beta. Otherwise, we will use one layer. It + matters only if you apply any weight norms to this layer. + input_dim (int): Number of dimensions of the input tensor. + activation_norm_type (str): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + """ + + def __init__(self, num_features, cond_dims, weight_norm_type='', + projection=True, + projection_bias=True, + separate_projection=False, + input_dim=2, + activation_norm_type='instance', + activation_norm_params=None, + apply_noise=False, + add_bias=True, + input_scale=1.0, + init_gain=1.0): + super().__init__() + if activation_norm_params is None: + activation_norm_params = SimpleNamespace(affine=False) + self.norm = get_activation_norm_layer(num_features, + activation_norm_type, + input_dim, + **vars(activation_norm_params)) + if apply_noise: + self.noise_layer = ApplyNoise() + else: + self.noise_layer = None + + if projection: + if separate_projection: + self.fc_gamma = \ + LinearBlock(cond_dims, num_features, + weight_norm_type=weight_norm_type, + bias=projection_bias) + self.fc_beta = \ + LinearBlock(cond_dims, num_features, + weight_norm_type=weight_norm_type, + bias=projection_bias) + else: + self.fc = LinearBlock(cond_dims, num_features * 2, + weight_norm_type=weight_norm_type, + bias=projection_bias) + + self.projection = projection + self.separate_projection = separate_projection + self.input_scale = input_scale + self.add_bias = add_bias + self.conditional = True + self.init_gain = init_gain + + def forward(self, x, y, noise=None, **_kwargs): + r"""Adaptive Normalization forward. + + Args: + x (N x C1 x * tensor): Input tensor. + y (N x C2 tensor): Conditional information. + Returns: + out (N x C1 x * tensor): Output tensor. + """ + y = y * self.input_scale + if self.projection: + if self.separate_projection: + gamma = self.fc_gamma(y) + beta = self.fc_beta(y) + for _ in range(x.dim() - gamma.dim()): + gamma = gamma.unsqueeze(-1) + beta = beta.unsqueeze(-1) + else: + y = self.fc(y) + for _ in range(x.dim() - y.dim()): + y = y.unsqueeze(-1) + gamma, beta = y.chunk(2, 1) + else: + for _ in range(x.dim() - y.dim()): + y = y.unsqueeze(-1) + gamma, beta = y.chunk(2, 1) + if self.norm is not None: + x = self.norm(x) + if self.noise_layer is not None: + x = self.noise_layer(x, noise=noise) + if self.add_bias: + x = torch.addcmul(beta, x, 1 + gamma) + return x + else: + return x * (1 + gamma), beta.squeeze(3).squeeze(2) + + +class SpatiallyAdaptiveNorm(nn.Module): + r"""Spatially Adaptive Normalization (SPADE) initialization. + + Args: + num_features (int) : Number of channels in the input tensor. + cond_dims (int or list of int) : List of numbers of channels + in the input. + num_filters (int): Number of filters in SPADE. + kernel_size (int): Kernel size of the convolutional filters in + the SPADE layer. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + separate_projection (bool): If ``True``, we will use two different + layers for gamma and beta. Otherwise, we will use one layer. It + matters only if you apply any weight norms to this layer. + activation_norm_type (str): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + """ + + def __init__(self, + num_features, + cond_dims, + num_filters=128, + kernel_size=3, + weight_norm_type='', + separate_projection=False, + activation_norm_type='sync_batch', + activation_norm_params=None, + bias_only=False, + partial=False, + interpolation='nearest'): + super().__init__() + if activation_norm_params is None: + activation_norm_params = SimpleNamespace(affine=False) + padding = kernel_size // 2 + self.separate_projection = separate_projection + self.mlps = nn.ModuleList() + self.gammas = nn.ModuleList() + self.betas = nn.ModuleList() + self.bias_only = bias_only + self.interpolation = interpolation + + # Make cond_dims a list. + if type(cond_dims) != list: + cond_dims = [cond_dims] + + # Make num_filters a list. + if not isinstance(num_filters, list): + num_filters = [num_filters] * len(cond_dims) + else: + assert len(num_filters) >= len(cond_dims) + + # Make partial a list. + if not isinstance(partial, list): + partial = [partial] * len(cond_dims) + else: + assert len(partial) >= len(cond_dims) + + for i, cond_dim in enumerate(cond_dims): + mlp = [] + conv_block = PartialConv2dBlock if partial[i] else Conv2dBlock + sequential = PartialSequential if partial[i] else nn.Sequential + + if num_filters[i] > 0: + mlp += [conv_block(cond_dim, + num_filters[i], + kernel_size, + padding=padding, + weight_norm_type=weight_norm_type, + nonlinearity='relu')] + mlp_ch = cond_dim if num_filters[i] == 0 else num_filters[i] + + if self.separate_projection: + if partial[i]: + raise NotImplementedError( + 'Separate projection not yet implemented for ' + + 'partial conv') + self.mlps.append(nn.Sequential(*mlp)) + self.gammas.append( + conv_block(mlp_ch, num_features, + kernel_size, + padding=padding, + weight_norm_type=weight_norm_type)) + self.betas.append( + conv_block(mlp_ch, num_features, + kernel_size, + padding=padding, + weight_norm_type=weight_norm_type)) + else: + mlp += [conv_block(mlp_ch, num_features * 2, kernel_size, + padding=padding, + weight_norm_type=weight_norm_type)] + self.mlps.append(sequential(*mlp)) + + self.norm = get_activation_norm_layer(num_features, + activation_norm_type, + 2, + **vars(activation_norm_params)) + self.conditional = True + + def forward(self, x, *cond_inputs, **_kwargs): + r"""Spatially Adaptive Normalization (SPADE) forward. + + Args: + x (N x C1 x H x W tensor) : Input tensor. + cond_inputs (list of tensors) : Conditional maps for SPADE. + Returns: + output (4D tensor) : Output tensor. + """ + output = self.norm(x) if self.norm is not None else x + for i in range(len(cond_inputs)): + if cond_inputs[i] is None: + continue + label_map = F.interpolate(cond_inputs[i], size=x.size()[2:], mode=self.interpolation) + if self.separate_projection: + hidden = self.mlps[i](label_map) + gamma = self.gammas[i](hidden) + beta = self.betas[i](hidden) + else: + affine_params = self.mlps[i](label_map) + gamma, beta = affine_params.chunk(2, dim=1) + if self.bias_only: + output = output + beta + else: + output = output * (1 + gamma) + beta + return output + + +class DualAdaptiveNorm(nn.Module): + def __init__(self, + num_features, + cond_dims, + projection_bias=True, + weight_norm_type='', + activation_norm_type='instance', + activation_norm_params=None, + apply_noise=False, + bias_only=False, + init_gain=1.0, + fc_scale=None, + is_spatial=None): + super().__init__() + if activation_norm_params is None: + activation_norm_params = SimpleNamespace(affine=False) + self.mlps = nn.ModuleList() + self.gammas = nn.ModuleList() + self.betas = nn.ModuleList() + self.bias_only = bias_only + + # Make cond_dims a list. + if type(cond_dims) != list: + cond_dims = [cond_dims] + + if is_spatial is None: + is_spatial = [False for _ in range(len(cond_dims))] + self.is_spatial = is_spatial + + for cond_dim, this_is_spatial in zip(cond_dims, is_spatial): + kwargs = dict(weight_norm_type=weight_norm_type, + bias=projection_bias, + init_gain=init_gain, + output_scale=fc_scale) + if this_is_spatial: + self.gammas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs)) + self.betas.append(Conv2dBlock(cond_dim, num_features, 1, 1, 0, **kwargs)) + else: + self.gammas.append(LinearBlock(cond_dim, num_features, **kwargs)) + self.betas.append(LinearBlock(cond_dim, num_features, **kwargs)) + + self.norm = get_activation_norm_layer(num_features, + activation_norm_type, + 2, + **vars(activation_norm_params)) + self.conditional = True + + def forward(self, x, *cond_inputs, **_kwargs): + assert len(cond_inputs) == len(self.gammas) + output = self.norm(x) if self.norm is not None else x + for cond, gamma_layer, beta_layer in zip(cond_inputs, self.gammas, self.betas): + if cond is None: + continue + gamma = gamma_layer(cond) + beta = beta_layer(cond) + if cond.dim() == 4 and gamma.shape != x.shape: + gamma = F.interpolate(gamma, size=x.size()[2:], mode='bilinear') + beta = F.interpolate(beta, size=x.size()[2:], mode='bilinear') + elif cond.dim() == 2: + gamma = gamma[:, :, None, None] + beta = beta[:, :, None, None] + if self.bias_only: + output = output + beta + else: + output = output * (1 + gamma) + beta + return output + + +class HyperSpatiallyAdaptiveNorm(nn.Module): + r"""Spatially Adaptive Normalization (SPADE) initialization. + + Args: + num_features (int) : Number of channels in the input tensor. + cond_dims (int or list of int) : List of numbers of channels + in the conditional input. + num_filters (int): Number of filters in SPADE. + kernel_size (int): Kernel size of the convolutional filters in + the SPADE layer. + weight_norm_type (str): Type of weight normalization. + ``'none'``, ``'spectral'``, or ``'weight'``. + activation_norm_type (str): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``. + is_hyper (bool): Whether to use hyper SPADE. + """ + + def __init__(self, num_features, cond_dims, + num_filters=0, kernel_size=3, + weight_norm_type='', + activation_norm_type='sync_batch', is_hyper=True): + super().__init__() + padding = kernel_size // 2 + self.mlps = nn.ModuleList() + if type(cond_dims) != list: + cond_dims = [cond_dims] + + for i, cond_dim in enumerate(cond_dims): + mlp = [] + if not is_hyper or (i != 0): + if num_filters > 0: + mlp += [Conv2dBlock(cond_dim, num_filters, kernel_size, + padding=padding, + weight_norm_type=weight_norm_type, + nonlinearity='relu')] + mlp_ch = cond_dim if num_filters == 0 else num_filters + mlp += [Conv2dBlock(mlp_ch, num_features * 2, kernel_size, + padding=padding, + weight_norm_type=weight_norm_type)] + mlp = nn.Sequential(*mlp) + else: + if num_filters > 0: + raise ValueError('Multi hyper layer not supported yet.') + mlp = HyperConv2d(padding=padding) + self.mlps.append(mlp) + + self.norm = get_activation_norm_layer(num_features, + activation_norm_type, + 2, + affine=False) + + self.conditional = True + + def forward(self, x, *cond_inputs, + norm_weights=(None, None), **_kwargs): + r"""Spatially Adaptive Normalization (SPADE) forward. + + Args: + x (4D tensor) : Input tensor. + cond_inputs (list of tensors) : Conditional maps for SPADE. + norm_weights (5D tensor or list of tensors): conv weights or + [weights, biases]. + Returns: + output (4D tensor) : Output tensor. + """ + output = self.norm(x) + for i in range(len(cond_inputs)): + if cond_inputs[i] is None: + continue + if type(cond_inputs[i]) == list: + cond_input, mask = cond_inputs[i] + mask = F.interpolate(mask, size=x.size()[2:], mode='bilinear', align_corners=False) + else: + cond_input = cond_inputs[i] + mask = None + label_map = F.interpolate(cond_input, size=x.size()[2:]) + if norm_weights is None or norm_weights[0] is None or i != 0: + affine_params = self.mlps[i](label_map) + else: + affine_params = self.mlps[i](label_map, + conv_weights=norm_weights) + gamma, beta = affine_params.chunk(2, dim=1) + if mask is not None: + gamma = gamma * (1 - mask) + beta = beta * (1 - mask) + output = output * (1 + gamma) + beta + return output + + +class LayerNorm2d(nn.Module): + r"""Layer Normalization as introduced in + https://arxiv.org/abs/1607.06450. + This is the usual way to apply layer normalization in CNNs. + Note that unlike the pytorch implementation which applies per-element + scale and bias, here it applies per-channel scale and bias, similar to + batch/instance normalization. + + Args: + num_features (int): Number of channels in the input tensor. + eps (float, optional, default=1e-5): a value added to the + denominator for numerical stability. + affine (bool, optional, default=False): If ``True``, performs + affine transformation after normalization. + """ + + def __init__(self, num_features, eps=1e-5, channel_only=False, affine=True): + super(LayerNorm2d, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.channel_only = channel_only + + if self.affine: + self.gamma = nn.Parameter(torch.Tensor(num_features).fill_(1.0)) + self.beta = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + r""" + + Args: + x (tensor): Input tensor. + """ + shape = [-1] + [1] * (x.dim() - 1) + if self.channel_only: + mean = x.mean(1, keepdim=True) + std = x.std(1, keepdim=True) + else: + mean = x.view(x.size(0), -1).mean(1).view(*shape) + std = x.view(x.size(0), -1).std(1).view(*shape) + + x = (x - mean) / (std + self.eps) + + if self.affine: + shape = [1, -1] + [1] * (x.dim() - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ScaleNorm(nn.Module): + r"""Scale normalization: + "Transformers without Tears: Improving the Normalization of Self-Attention" + Modified from: + https://github.com/tnq177/transformers_without_tears + """ + + def __init__(self, dim=-1, learned_scale=True, eps=1e-5): + super().__init__() + # scale = num_features ** 0.5 + if learned_scale: + self.scale = nn.Parameter(torch.tensor(1.)) + else: + self.scale = 1. + # self.num_features = num_features + self.dim = dim + self.eps = eps + self.learned_scale = learned_scale + + def forward(self, x): + # noinspection PyArgumentList + scale = self.scale * torch.rsqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps) + return x * scale + + def extra_repr(self): + s = 'learned_scale={learned_scale}' + return s.format(**self.__dict__) + + +class PixelNorm(ScaleNorm): + def __init__(self, learned_scale=False, eps=1e-5, **_kwargs): + super().__init__(1, learned_scale, eps) + + +class SplitMeanStd(nn.Module): + def __init__(self, num_features, eps=1e-5, **kwargs): + super().__init__() + self.num_features = num_features + self.eps = eps + self.multiple_outputs = True + + def forward(self, x): + b, c, h, w = x.size() + mean = x.view(b, c, -1).mean(-1)[:, :, None, None] + var = x.view(b, c, -1).var(-1)[:, :, None, None] + std = torch.sqrt(var + self.eps) + + # x = (x - mean) / std + return x, torch.cat((mean, std), dim=1) + + +class ScaleNorm(nn.Module): + r"""Scale normalization: + "Transformers without Tears: Improving the Normalization of Self-Attention" + Modified from: + https://github.com/tnq177/transformers_without_tears + """ + + def __init__(self, dim=-1, learned_scale=True, eps=1e-5): + super().__init__() + # scale = num_features ** 0.5 + if learned_scale: + self.scale = nn.Parameter(torch.tensor(1.)) + else: + self.scale = 1. + # self.num_features = num_features + self.dim = dim + self.eps = eps + self.learned_scale = learned_scale + + def forward(self, x): + # noinspection PyArgumentList + scale = self.scale * torch.rsqrt( + torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps) + return x * scale + + def extra_repr(self): + s = 'learned_scale={learned_scale}' + return s.format(**self.__dict__) + + +class PixelLayerNorm(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.norm = nn.LayerNorm(*args, **kwargs) + + def forward(self, x): + if x.dim() == 4: + b, c, h, w = x.shape + return self.norm(x.permute(0, 2, 3, 1).view(-1, c)).view(b, h, w, c).permute(0, 3, 1, 2) + else: + return self.norm(x) + + +def get_activation_norm_layer(num_features, norm_type, input_dim, **norm_params): + r"""Return an activation normalization layer. + + Args: + num_features (int): Number of feature channels. + norm_type (str): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + input_dim (int): Number of input dimensions. + norm_params: Arbitrary keyword arguments that will be used to + initialize the activation normalization. + """ + input_dim = max(input_dim, 1) # Norm1d works with both 0d and 1d inputs + + if norm_type == 'none' or norm_type == '': + norm_layer = None + elif norm_type == 'batch': + norm = getattr(nn, 'BatchNorm%dd' % input_dim) + norm_layer = norm(num_features, **norm_params) + elif norm_type == 'instance': + affine = norm_params.pop('affine', True) # Use affine=True by default + norm = getattr(nn, 'InstanceNorm%dd' % input_dim) + norm_layer = norm(num_features, affine=affine, **norm_params) + elif norm_type == 'sync_batch': + norm_layer = SyncBatchNorm(num_features, **norm_params) + elif norm_type == 'layer': + norm_layer = nn.LayerNorm(num_features, **norm_params) + elif norm_type == 'layer_2d': + norm_layer = LayerNorm2d(num_features, **norm_params) + elif norm_type == 'pixel_layer': + elementwise_affine = norm_params.pop('affine', True) # Use affine=True by default + norm_layer = PixelLayerNorm(num_features, elementwise_affine=elementwise_affine, **norm_params) + elif norm_type == 'scale': + norm_layer = ScaleNorm(**norm_params) + elif norm_type == 'pixel': + norm_layer = PixelNorm(**norm_params) + import imaginaire.config + if imaginaire.config.USE_JIT: + norm_layer = torch.jit.script(norm_layer) + elif norm_type == 'group': + num_groups = norm_params.pop('num_groups', 4) + norm_layer = nn.GroupNorm(num_channels=num_features, num_groups=num_groups, **norm_params) + elif norm_type == 'adaptive': + norm_layer = AdaptiveNorm(num_features, **norm_params) + elif norm_type == 'dual_adaptive': + norm_layer = DualAdaptiveNorm(num_features, **norm_params) + elif norm_type == 'spatially_adaptive': + if input_dim != 2: + raise ValueError('Spatially adaptive normalization layers ' + 'only supports 2D input') + norm_layer = SpatiallyAdaptiveNorm(num_features, **norm_params) + elif norm_type == 'hyper_spatially_adaptive': + if input_dim != 2: + raise ValueError('Spatially adaptive normalization layers ' + 'only supports 2D input') + norm_layer = HyperSpatiallyAdaptiveNorm(num_features, **norm_params) + else: + raise ValueError('Activation norm layer %s ' + 'is not recognized' % norm_type) + return norm_layer diff --git a/imaginaire/layers/conv.py b/imaginaire/layers/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..499fc0442b77e3183225c3529a4e3590dab0bc57 --- /dev/null +++ b/imaginaire/layers/conv.py @@ -0,0 +1,1377 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import warnings +from types import SimpleNamespace + +import torch +from torch import nn +from torch.nn import functional as F + +from .misc import ApplyNoise +from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur + + +class _BaseConvBlock(nn.Module): + r"""An abstract wrapper class that wraps a torch convolution or linear layer + with normalization and nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity, + inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale, + init_gain): + super().__init__() + from .nonlinearity import get_nonlinearity_layer + from .weight_norm import get_weight_norm_layer + from .activation_norm import get_activation_norm_layer + self.weight_norm_type = weight_norm_type + self.stride = stride + self.clamp = clamp + self.init_gain = init_gain + + # Nonlinearity layer. + if 'fused' in nonlinearity: + # Fusing nonlinearity with bias. + lr_mul = getattr(weight_norm_params, 'lr_mul', 1) + conv_before_nonlinearity = order.find('C') < order.find('A') + if conv_before_nonlinearity: + assert bias is True + bias = False + channel = out_channels if conv_before_nonlinearity else in_channels + nonlinearity_layer = get_nonlinearity_layer( + nonlinearity, inplace=inplace_nonlinearity, + num_channels=channel, lr_mul=lr_mul) + else: + nonlinearity_layer = get_nonlinearity_layer( + nonlinearity, inplace=inplace_nonlinearity) + + # Noise injection layer. + if apply_noise: + order = order.replace('C', 'CG') + noise_layer = ApplyNoise() + else: + noise_layer = None + + # Convolutional layer. + if blur: + assert blur_kernel is not None + if stride == 2: + # Blur - Conv - Noise - Activate + p = (len(blur_kernel) - 2) + (kernel_size - 1) + pad0, pad1 = (p + 1) // 2, p // 2 + padding = 0 + blur_layer = Blur( + blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode + ) + order = order.replace('C', 'BC') + elif stride == 0.5: + # Conv - Blur - Noise - Activate + padding = 0 + p = (len(blur_kernel) - 2) - (kernel_size - 1) + pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1 + blur_layer = Blur( + blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode + ) + order = order.replace('C', 'CB') + elif stride == 1: + # No blur for now + blur_layer = nn.Identity() + else: + raise NotImplementedError + else: + blur_layer = nn.Identity() + + if weight_norm_params is None: + weight_norm_params = SimpleNamespace() + weight_norm = get_weight_norm_layer( + weight_norm_type, **vars(weight_norm_params)) + conv_layer = weight_norm(self._get_conv_layer( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, input_dim)) + + # Normalization layer. + conv_before_norm = order.find('C') < order.find('N') + norm_channels = out_channels if conv_before_norm else in_channels + if activation_norm_params is None: + activation_norm_params = SimpleNamespace() + activation_norm_layer = get_activation_norm_layer( + norm_channels, + activation_norm_type, + input_dim, + **vars(activation_norm_params)) + + # Mapping from operation names to layers. + mappings = {'C': {'conv': conv_layer}, + 'N': {'norm': activation_norm_layer}, + 'A': {'nonlinearity': nonlinearity_layer}} + mappings.update({'B': {'blur': blur_layer}}) + mappings.update({'G': {'noise': noise_layer}}) + + # All layers in order. + self.layers = nn.ModuleDict() + for op in order: + if list(mappings[op].values())[0] is not None: + self.layers.update(mappings[op]) + + # Whether this block expects conditional inputs. + self.conditional = \ + getattr(conv_layer, 'conditional', False) or \ + getattr(activation_norm_layer, 'conditional', False) + + # Scale the output by a learnable scaler parameter. + if output_scale is not None: + self.output_scale = nn.Parameter(torch.tensor(output_scale)) + else: + self.register_parameter("output_scale", None) + + def forward(self, x, *cond_inputs, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + kw_cond_inputs (dict) : Keyword conditional inputs. + """ + for key, layer in self.layers.items(): + if getattr(layer, 'conditional', False): + # Layers that require conditional inputs. + x = layer(x, *cond_inputs, **kw_cond_inputs) + else: + x = layer(x) + if self.clamp is not None and isinstance(layer, nn.Conv2d): + x.clamp_(max=self.clamp) + if key == 'conv': + if self.output_scale is not None: + x = x * self.output_scale + return x + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + # Returns the convolutional layer. + if input_dim == 0: + layer = nn.Linear(in_channels, out_channels, bias) + else: + if stride < 1: # Fractionally-strided convolution. + padding_mode = 'zeros' + assert padding == 0 + layer_type = getattr(nn, f'ConvTranspose{input_dim}d') + stride = round(1 / stride) + else: + layer_type = getattr(nn, f'Conv{input_dim}d') + layer = layer_type( + in_channels, out_channels, kernel_size, stride, padding, + dilation=dilation, groups=groups, bias=bias, + padding_mode=padding_mode + ) + + return layer + + def __repr__(self): + main_str = self._get_name() + '(' + child_lines = [] + for name, layer in self.layers.items(): + mod_str = repr(layer) + if name == 'conv' and self.weight_norm_type != 'none' and \ + self.weight_norm_type != '': + mod_str = mod_str[:-1] + \ + ', weight_norm={}'.format(self.weight_norm_type) + ')' + if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1: + mod_str = mod_str[:-1] + \ + ', lr_mul={}'.format(layer.base_lr_mul) + ')' + mod_str = self._addindent(mod_str, 2) + child_lines.append(mod_str) + if len(child_lines) == 1: + main_str += child_lines[0] + else: + main_str += '\n ' + '\n '.join(child_lines) + '\n' + + main_str += ')' + return main_str + + @staticmethod + def _addindent(s_, numSpaces): + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + +class ModulatedConv2dBlock(_BaseConvBlock): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=True, blur=True, order='CNA', demodulate=True, + eps=True, style_dim=None, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0): + self.eps = eps + self.demodulate = demodulate + assert style_dim is not None + + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, blur, + order, 2, clamp, blur_kernel, output_scale, init_gain) + self.modulation = LinearBlock(style_dim, in_channels, + weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params) + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + assert input_dim == 2 + layer = ModulatedConv2d( + in_channels, out_channels, kernel_size, stride, padding, + dilation, groups, bias, padding_mode, self.demodulate, self.eps) + return layer + + def forward(self, x, *cond_inputs, **kw_cond_inputs): + for layer in self.layers.values(): + if getattr(layer, 'conditional', False): + # Layers that require conditional inputs. + assert len(cond_inputs) == 1 + style = cond_inputs[0] + x = layer( + x, self.modulation(style), **kw_cond_inputs + ) + else: + x = layer(x) + if self.clamp is not None and isinstance(layer, ModulatedConv2d): + x.clamp_(max=self.clamp) + return x + + def __repr__(self): + main_str = self._get_name() + '(' + child_lines = [] + for name, layer in self.layers.items(): + mod_str = repr(layer) + if name == 'conv' and self.weight_norm_type != 'none' and \ + self.weight_norm_type != '': + mod_str = mod_str[:-1] + \ + ', weight_norm={}'.format(self.weight_norm_type) + \ + ', demodulate={}'.format(self.demodulate) + ')' + mod_str = self._addindent(mod_str, 2) + child_lines.append(mod_str) + child_lines.append( + self._addindent('Modulation(' + repr(self.modulation) + ')', 2) + ) + if len(child_lines) == 1: + main_str += child_lines[0] + else: + main_str += '\n ' + '\n '.join(child_lines) + '\n' + + main_str += ')' + return main_str + + +class ModulatedConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, + dilation, groups, bias, padding_mode, demodulate=True, + eps=1e-8): + # in_channels, out_channels, kernel_size, stride, padding, + # dilation, groups, bias, padding_mode + assert dilation == 1 and groups == 1 + + super().__init__() + + self.eps = eps + self.kernel_size = kernel_size + self.in_channels = in_channels + self.out_channels = out_channels + self.padding = padding + self.stride = stride + self.padding_mode = padding_mode + # kernel_size // 2 + # assert self.padding == padding + + self.weight = nn.Parameter( + torch.randn(out_channels, in_channels, kernel_size, kernel_size) + ) + + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + # noinspection PyTypeChecker + self.register_parameter('bias', None) + + # self.modulation = LinearBlock(style_dim, in_channels, + # weight_norm_type=weight_norm_type) + self.demodulate = demodulate + self.conditional = True + + def forward(self, x, style, **_kwargs): + batch, in_channel, height, width = x.shape + + # style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + # We assume the modulation layer is outside this module. + style = style.view(batch, 1, in_channel, 1, 1) + weight = self.weight.unsqueeze(0) * style + + if self.demodulate: + demod = torch.rsqrt( + weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(batch, self.out_channels, 1, 1, 1) + + weight = weight.view( + batch * self.out_channels, + in_channel, self.kernel_size, self.kernel_size + ) + if self.bias is not None: + bias = self.bias.repeat(batch) + else: + bias = self.bias + + x = x.view(1, batch * in_channel, height, width) + + if self.padding_mode != 'zeros': + x = F.pad(x, self._reversed_padding_repeated_twice, + mode=self.padding_mode) + padding = (0, 0) + else: + padding = self.padding + + if self.stride == 0.5: + weight = weight.view( + batch, self.out_channels, in_channel, + self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channels, + self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d( + x, weight, bias, padding=padding, stride=2, groups=batch + ) + + elif self.stride == 2: + out = F.conv2d( + x, weight, bias, padding=padding, stride=2, groups=batch + ) + + else: + out = F.conv2d(x, weight, bias, padding=padding, groups=batch) + + _, _, height, width = out.shape + out = out.view(batch, self.out_channels, height, width) + + return out + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.bias is None: + s += ', bias=False' + if self.padding_mode != 'zeros': + s += ', padding_mode={padding_mode}' + return s.format(**self.__dict__) + + +class LinearBlock(_BaseConvBlock): + r"""A Wrapper class that wraps ``torch.nn.Linear`` with normalization and + nonlinearity. + + Args: + in_features (int): Number of channels in the input tensor. + out_features (int): Number of channels in the output tensor. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, add + Gaussian noise with learnable magnitude after the + fully-connected layer. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: fully-connected, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_features, out_features, bias=True, + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, + init_gain=1.0, **_kwargs): + if bool(_kwargs): + warnings.warn(f"Unused keyword arguments {_kwargs}") + super().__init__(in_features, out_features, None, None, + None, None, None, bias, + None, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, + False, order, 0, clamp, blur_kernel, output_scale, + init_gain) + + +class EmbeddingBlock(_BaseConvBlock): + def __init__(self, in_features, out_features, bias=True, + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, order='CNA', clamp=None, output_scale=None, + init_gain=1.0, **_kwargs): + if bool(_kwargs): + warnings.warn(f"Unused keyword arguments {_kwargs}") + super().__init__(in_features, out_features, None, None, + None, None, None, bias, + None, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, + False, order, 0, clamp, None, output_scale, + init_gain) + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + assert input_dim == 0 + return nn.Embedding(in_channels, out_channels) + + +class Embedding2dBlock(_BaseConvBlock): + def __init__(self, in_features, out_features, bias=True, + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, order='CNA', clamp=None, output_scale=None, + init_gain=1.0, **_kwargs): + if bool(_kwargs): + warnings.warn(f"Unused keyword arguments {_kwargs}") + super().__init__(in_features, out_features, None, None, + None, None, None, bias, + None, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, + False, order, 0, clamp, None, output_scale, + init_gain) + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + assert input_dim == 0 + return Embedding2d(in_channels, out_channels) + + +class Conv1dBlock(_BaseConvBlock): + r"""A Wrapper class that wraps ``torch.nn.Conv1d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, blur=False, order='CNA', clamp=None, output_scale=None, init_gain=1.0, **_kwargs): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, + blur, order, 1, clamp, None, output_scale, init_gain) + + +class Conv2dBlock(_BaseConvBlock): + r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), + output_scale=None, init_gain=1.0): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + apply_noise, blur, order, 2, clamp, blur_kernel, output_scale, init_gain) + + +class Conv3dBlock(_BaseConvBlock): + r"""A Wrapper class that wraps ``torch.nn.Conv3d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, blur=False, order='CNA', clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, + init_gain=1.0): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + apply_noise, blur, order, 3, clamp, blur_kernel, output_scale, init_gain) + + +class _BaseHyperConvBlock(_BaseConvBlock): + r"""An abstract wrapper class that wraps a hyper convolutional layer + with normalization and nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, + padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, blur, + is_hyper_conv, is_hyper_norm, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1), + output_scale=None, init_gain=1.0): + self.is_hyper_conv = is_hyper_conv + if is_hyper_conv: + weight_norm_type = 'none' + if is_hyper_norm: + activation_norm_type = 'hyper_' + activation_norm_type + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, blur, + order, input_dim, clamp, blur_kernel, output_scale, init_gain) + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + if input_dim == 0: + raise ValueError('HyperLinearBlock is not supported.') + else: + name = 'HyperConv' if self.is_hyper_conv else 'nn.Conv' + layer_type = eval(name + '%dd' % input_dim) + layer = layer_type( + in_channels, out_channels, kernel_size, stride, padding, + dilation, groups, bias, padding_mode) + return layer + + +class HyperConv2dBlock(_BaseHyperConvBlock): + r"""A Wrapper class that wraps ``HyperConv2d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + is_hyper_conv (bool, optional, default=False): If ``True``, use + ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``. + is_hyper_norm (bool, optional, default=False): If ``True``, use + hyper normalizations. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + is_hyper_conv=False, is_hyper_norm=False, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, blur=False, order='CNA', clamp=None): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, blur, + is_hyper_conv, is_hyper_norm, order, 2, clamp) + + +class HyperConv2d(nn.Module): + r"""Hyper Conv2d initialization. + + Args: + in_channels (int): Dummy parameter. + out_channels (int): Dummy parameter. + kernel_size (int or tuple): Dummy parameter. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. Default: 1 + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + padding_mode (string, optional, default='zeros'): + ``'zeros'``, ``'reflect'``, ``'replicate'`` + or ``'circular'``. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): If ``True``, + adds a learnable bias to the output. + """ + + def __init__(self, in_channels=0, out_channels=0, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros'): + super().__init__() + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.use_bias = bias + self.padding_mode = padding_mode + self.conditional = True + + def forward(self, x, *args, conv_weights=(None, None), **kwargs): + r"""Hyper Conv2d forward. Convolve x using the provided weight and bias. + + Args: + x (N x C x H x W tensor): Input tensor. + conv_weights (N x C2 x C1 x k x k tensor or list of tensors): + Convolution weights or [weight, bias]. + Returns: + y (N x C2 x H x W tensor): Output tensor. + """ + if conv_weights is None: + conv_weight, conv_bias = None, None + elif isinstance(conv_weights, torch.Tensor): + conv_weight, conv_bias = conv_weights, None + else: + conv_weight, conv_bias = conv_weights + + if conv_weight is None: + return x + if conv_bias is None: + if self.use_bias: + raise ValueError('bias not provided but set to true during ' + 'initialization') + conv_bias = [None] * x.size(0) + if self.padding_mode != 'zeros': + x = F.pad(x, [self.padding] * 4, mode=self.padding_mode) + padding = 0 + else: + padding = self.padding + + y = None + # noinspection PyArgumentList + for i in range(x.size(0)): + if self.stride >= 1: + yi = F.conv2d(x[i: i + 1], + weight=conv_weight[i], bias=conv_bias[i], + stride=self.stride, padding=padding, + dilation=self.dilation, groups=self.groups) + else: + yi = F.conv_transpose2d(x[i: i + 1], weight=conv_weight[i], + bias=conv_bias[i], padding=self.padding, + stride=int(1 / self.stride), + dilation=self.dilation, + output_padding=self.padding, + groups=self.groups) + y = torch.cat([y, yi]) if y is not None else yi + return y + + +class _BasePartialConvBlock(_BaseConvBlock): + r"""An abstract wrapper class that wraps a partial convolutional layer + with normalization and nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + multi_channel, return_mask, + apply_noise, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1), output_scale=None, init_gain=1.0): + self.multi_channel = multi_channel + self.return_mask = return_mask + self.partial_conv = True + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, apply_noise, + False, order, input_dim, clamp, blur_kernel, output_scale, init_gain) + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + if input_dim == 2: + layer_type = PartialConv2d + elif input_dim == 3: + layer_type = PartialConv3d + else: + raise ValueError('Partial conv only supports 2D and 3D conv now.') + layer = layer_type( + in_channels, out_channels, kernel_size, stride, padding, + dilation, groups, bias, padding_mode, + multi_channel=self.multi_channel, return_mask=self.return_mask) + return layer + + def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + mask_in (tensor, optional, default=``None``) If not ``None``, + it masks the valid input region. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + (tuple): + - x (tensor): Output tensor. + - mask_out (tensor, optional): Masks the valid output region. + """ + mask_out = None + for layer in self.layers.values(): + if getattr(layer, 'conditional', False): + x = layer(x, *cond_inputs, **kw_cond_inputs) + elif getattr(layer, 'partial_conv', False): + x = layer(x, mask_in=mask_in, **kw_cond_inputs) + if type(x) == tuple: + x, mask_out = x + else: + x = layer(x) + + if mask_out is not None: + return x, mask_out + return x + + +class PartialConv2dBlock(_BasePartialConvBlock): + r"""A Wrapper class that wraps ``PartialConv2d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + multi_channel (bool, optional, default=False): If ``True``, use + different masks for different channels. + return_mask (bool, optional, default=True): If ``True``, the + forward call also returns a new mask. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + multi_channel=False, return_mask=True, + apply_noise=False, order='CNA', clamp=None): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + multi_channel, return_mask, apply_noise, order, 2, + clamp) + + +class PartialConv3dBlock(_BasePartialConvBlock): + r"""A Wrapper class that wraps ``PartialConv3d`` with normalization and + nonlinearity. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + multi_channel (bool, optional, default=False): If ``True``, use + different masks for different channels. + return_mask (bool, optional, default=True): If ``True``, the + forward call also returns a new mask. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + multi_channel=False, return_mask=True, + apply_noise=False, order='CNA', clamp=None): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + multi_channel, return_mask, apply_noise, order, 3, + clamp) + + +class _MultiOutBaseConvBlock(_BaseConvBlock): + r"""An abstract wrapper class that wraps a hyper convolutional layer with + normalization and nonlinearity. It can return multiple outputs, if some + layers in the block return more than one output. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, activation_norm_type, activation_norm_params, nonlinearity, + inplace_nonlinearity, apply_noise, blur, order, input_dim, clamp=None, blur_kernel=(1, 3, 3, 1), + output_scale=None, init_gain=1.0): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + apply_noise, blur, order, input_dim, clamp, blur_kernel, output_scale, init_gain) + self.multiple_outputs = True + + def forward(self, x, *cond_inputs, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + (tuple): + - x (tensor): Main output tensor. + - other_outputs (list of tensors): Other output tensors. + """ + other_outputs = [] + for layer in self.layers.values(): + if getattr(layer, 'conditional', False): + x = layer(x, *cond_inputs, **kw_cond_inputs) + if getattr(layer, 'multiple_outputs', False): + x, other_output = layer(x) + other_outputs.append(other_output) + else: + x = layer(x) + return (x, *other_outputs) + + +class MultiOutConv2dBlock(_MultiOutBaseConvBlock): + r"""A Wrapper class that wraps ``torch.nn.Conv2d`` with normalization and + nonlinearity. It can return multiple outputs, if some layers in the block + return more than one output. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or float or tuple, optional, default=1): + Stride of the convolution. + padding (int or tuple, optional, default=0): + Zero-padding added to both sides of the input. + dilation (int or tuple, optional, default=1): + Spacing between kernel elements. + groups (int, optional, default=1): Number of blocked connections + from input channels to output channels. + bias (bool, optional, default=True): + If ``True``, adds a learnable bias to the output. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layer. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + order (str, optional, default='CNA'): Order of operations. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + For example, a block initialized with ``order='CNA'`` will + do convolution first, then normalization, then nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + nonlinearity='none', inplace_nonlinearity=False, + apply_noise=False, blur=False, order='CNA', clamp=None): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + apply_noise, blur, order, 2, clamp) + + +############################################################################### +# BSD 3-Clause License +# +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Author & Contact: Guilin Liu (guilinl@nvidia.com) +############################################################################### +class PartialConv2d(nn.Conv2d): + r"""Partial 2D convolution in + "Image inpainting for irregular holes using partial convolutions." + Liu et al., ECCV 2018 + """ + + def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs): + # whether the mask is multi-channel or not + self.multi_channel = multi_channel + self.return_mask = return_mask + super(PartialConv2d, self).__init__(*args, **kwargs) + + if self.multi_channel: + self.weight_maskUpdater = torch.ones(self.out_channels, + self.in_channels, + self.kernel_size[0], + self.kernel_size[1]) + else: + self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], + self.kernel_size[1]) + + shape = self.weight_maskUpdater.shape + self.slide_winsize = shape[1] * shape[2] * shape[3] + + self.last_size = (None, None, None, None) + self.update_mask = None + self.mask_ratio = None + self.partial_conv = True + + def forward(self, x, mask_in=None): + r""" + + Args: + x (tensor): Input tensor. + mask_in (tensor, optional, default=``None``) If not ``None``, + it masks the valid input region. + """ + assert len(x.shape) == 4 + if mask_in is not None or self.last_size != tuple(x.shape): + self.last_size = tuple(x.shape) + + with torch.no_grad(): + if self.weight_maskUpdater.type() != x.type(): + self.weight_maskUpdater = self.weight_maskUpdater.to(x) + + if mask_in is None: + # If mask is not provided, create a mask. + if self.multi_channel: + mask = torch.ones(x.data.shape[0], + x.data.shape[1], + x.data.shape[2], + x.data.shape[3]).to(x) + else: + mask = torch.ones(1, 1, x.data.shape[2], + x.data.shape[3]).to(x) + else: + mask = mask_in + + self.update_mask = F.conv2d(mask, self.weight_maskUpdater, + bias=None, stride=self.stride, + padding=self.padding, + dilation=self.dilation, groups=1) + + # For mixed precision training, eps from 1e-8 to 1e-6. + eps = 1e-6 + self.mask_ratio = self.slide_winsize / (self.update_mask + eps) + self.update_mask = torch.clamp(self.update_mask, 0, 1) + self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) + + raw_out = super(PartialConv2d, self).forward( + torch.mul(x, mask) if mask_in is not None else x) + + if self.bias is not None: + bias_view = self.bias.view(1, self.out_channels, 1, 1) + output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view + output = torch.mul(output, self.update_mask) + else: + output = torch.mul(raw_out, self.mask_ratio) + + if self.return_mask: + return output, self.update_mask + else: + return output + + +class PartialConv3d(nn.Conv3d): + r"""Partial 3D convolution in + "Image inpainting for irregular holes using partial convolutions." + Liu et al., ECCV 2018 + """ + + def __init__(self, *args, multi_channel=False, return_mask=True, **kwargs): + # whether the mask is multi-channel or not + self.multi_channel = multi_channel + self.return_mask = return_mask + super(PartialConv3d, self).__init__(*args, **kwargs) + + if self.multi_channel: + self.weight_maskUpdater = \ + torch.ones(self.out_channels, self.in_channels, + self.kernel_size[0], self.kernel_size[1], + self.kernel_size[2]) + else: + self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], + self.kernel_size[1], + self.kernel_size[2]) + self.weight_maskUpdater = self.weight_maskUpdater.to('cuda') + + shape = self.weight_maskUpdater.shape + self.slide_winsize = shape[1] * shape[2] * shape[3] * shape[4] + self.partial_conv = True + + def forward(self, x, mask_in=None): + r""" + + Args: + x (tensor): Input tensor. + mask_in (tensor, optional, default=``None``) If not ``None``, it + masks the valid input region. + """ + assert len(x.shape) == 5 + + with torch.no_grad(): + mask = mask_in + update_mask = F.conv3d(mask, self.weight_maskUpdater, bias=None, + stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=1) + + mask_ratio = self.slide_winsize / (update_mask + 1e-8) + update_mask = torch.clamp(update_mask, 0, 1) + mask_ratio = torch.mul(mask_ratio, update_mask) + + raw_out = super(PartialConv3d, self).forward(torch.mul(x, mask_in)) + + if self.bias is not None: + bias_view = self.bias.view(1, self.out_channels, 1, 1, 1) + output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view + if mask_in is not None: + output = torch.mul(output, update_mask) + else: + output = torch.mul(raw_out, mask_ratio) + + if self.return_mask: + return output, update_mask + else: + return output + + +class Embedding2d(nn.Embedding): + def __init__(self, in_channels, out_channels): + super().__init__(in_channels, out_channels) + + def forward(self, x): + return F.embedding( + x.squeeze(1).long(), self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse).permute(0, 3, 1, 2).contiguous() diff --git a/imaginaire/layers/misc.py b/imaginaire/layers/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7731bd2fa855939a2866b211c33eb3ccce00c480 --- /dev/null +++ b/imaginaire/layers/misc.py @@ -0,0 +1,61 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch import nn + + +class ApplyNoise(nn.Module): + r"""Add Gaussian noise to the input tensor.""" + + def __init__(self): + super().__init__() + # scale of the noise + self.scale = nn.Parameter(torch.zeros(1)) + self.conditional = True + + def forward(self, x, *_args, noise=None, **_kwargs): + r""" + + Args: + x (tensor): Input tensor. + noise (tensor, optional, default=``None``) : Noise tensor to be + added to the input. + """ + if noise is None: + sz = x.size() + noise = x.new_empty(sz[0], 1, *sz[2:]).normal_() + + return x + self.scale * noise + + +class PartialSequential(nn.Sequential): + r"""Sequential block for partial convolutions.""" + def __init__(self, *modules): + super(PartialSequential, self).__init__(*modules) + + def forward(self, x): + r""" + + Args: + x (tensor): Input tensor. + """ + act = x[:, :-1] + mask = x[:, -1].unsqueeze(1) + for module in self: + act, mask = module(act, mask_in=mask) + return act + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + if isinstance(size, int): + h, w = size, size + else: + h, w = size + self.input = nn.Parameter(torch.randn(1, channel, h, w)) + + def forward(self): + return self.input diff --git a/imaginaire/layers/non_local.py b/imaginaire/layers/non_local.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1a8b36d668377ef4c9c4d897cd3b55fe0363e7 --- /dev/null +++ b/imaginaire/layers/non_local.py @@ -0,0 +1,88 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from functools import partial + +import torch +import torch.nn as nn + +from imaginaire.layers import Conv2dBlock + + +class NonLocal2dBlock(nn.Module): + r"""Self attention Layer + + Args: + in_channels (int): Number of channels in the input tensor. + scale (bool, optional, default=True): If ``True``, scale the + output by a learnable parameter. + clamp (bool, optional, default=``False``): If ``True``, clamp the + scaling parameter to (-1, 1). + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, weight_norm_params.__dict__ will be used as + keyword arguments when initializing weight normalization. + bias (bool, optional, default=True): If ``True``, adds bias in the + convolutional blocks. + """ + + def __init__(self, + in_channels, + scale=True, + clamp=False, + weight_norm_type='none', + weight_norm_params=None, + bias=True): + super(NonLocal2dBlock, self).__init__() + self.clamp = clamp + self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0 + self.in_channels = in_channels + base_conv2d_block = partial(Conv2dBlock, + kernel_size=1, + stride=1, + padding=0, + weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params, + bias=bias) + self.theta = base_conv2d_block(in_channels, in_channels // 8) + self.phi = base_conv2d_block(in_channels, in_channels // 8) + self.g = base_conv2d_block(in_channels, in_channels // 2) + self.out_conv = base_conv2d_block(in_channels // 2, in_channels) + self.softmax = nn.Softmax(dim=-1) + self.max_pool = nn.MaxPool2d(2) + + def forward(self, x): + r""" + + Args: + x (tensor) : input feature maps (B X C X W X H) + Returns: + (tuple): + - out (tensor) : self attention value + input feature + - attention (tensor): B x N x N (N is Width*Height) + """ + n, c, h, w = x.size() + theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1) + + phi = self.phi(x) + phi = self.max_pool(phi).view(n, -1, h * w // 4) + + energy = torch.bmm(theta, phi) + attention = self.softmax(energy) + + g = self.g(x) + g = self.max_pool(g).view(n, -1, h * w // 4) + + out = torch.bmm(g, attention.permute(0, 2, 1)) + out = out.view(n, c // 2, h, w) + out = self.out_conv(out) + + if self.clamp: + out = self.gamma.clamp(-1, 1) * out + x + else: + out = self.gamma * out + x + return out diff --git a/imaginaire/layers/nonlinearity.py b/imaginaire/layers/nonlinearity.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc172c74323e707e5a19f94e466c1bf0dae4418 --- /dev/null +++ b/imaginaire/layers/nonlinearity.py @@ -0,0 +1,65 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch import nn +import torch.nn.functional as F + +from imaginaire.third_party.bias_act.bias_act import FusedNonlinearity + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2, scale=2 ** 0.5, inplace=False): + super().__init__() + + self.negative_slope = negative_slope + self.scale = scale + self.inplace = inplace + + def forward(self, x): + return F.leaky_relu(x, self.negative_slope, inplace=self.inplace) * self.scale + # return _fused_scaled_leakyrelu(x, self.negative_slope, self.inplace, self.scale) + + +# @torch.jit.script +# def _fused_scaled_leakyrelu(x: torch.Tensor, negative_slope: float, inplace: bool, scale: float): +# return F.leaky_relu(x, negative_slope, inplace=inplace) * scale + + +def get_nonlinearity_layer(nonlinearity_type, inplace, **kwargs): + r"""Return a nonlinearity layer. + + Args: + nonlinearity_type (str): + Type of nonlinear activation function. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace (bool): If ``True``, set ``inplace=True`` when initializing + the nonlinearity layer. + """ + if nonlinearity_type.startswith('fused'): + nonlinearity = FusedNonlinearity(nonlinearity=nonlinearity_type[6:], **kwargs) + elif nonlinearity_type == 'relu': + nonlinearity = nn.ReLU(inplace=inplace) + elif nonlinearity_type == 'leakyrelu': + nonlinearity = nn.LeakyReLU(0.2, inplace=inplace) + elif nonlinearity_type == 'scaled_leakyrelu': + nonlinearity = ScaledLeakyReLU(0.2, inplace=inplace) + import imaginaire.config + if imaginaire.config.USE_JIT: + nonlinearity = torch.jit.script(nonlinearity) + elif nonlinearity_type == 'prelu': + nonlinearity = nn.PReLU() + elif nonlinearity_type == 'tanh': + nonlinearity = nn.Tanh() + elif nonlinearity_type == 'sigmoid': + nonlinearity = nn.Sigmoid() + elif nonlinearity_type.startswith('softmax'): + dim = nonlinearity_type.split(',')[1] if ',' in nonlinearity_type else 1 + nonlinearity = nn.Softmax(dim=int(dim)) + elif nonlinearity_type == 'none' or nonlinearity_type == '': + nonlinearity = None + else: + raise ValueError('Nonlinearity %s is not recognized' % nonlinearity_type) + return nonlinearity diff --git a/imaginaire/layers/residual.py b/imaginaire/layers/residual.py new file mode 100644 index 0000000000000000000000000000000000000000..5e1bda4dd30922f694302803b7d606af7f3c0c21 --- /dev/null +++ b/imaginaire/layers/residual.py @@ -0,0 +1,1411 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools + +import torch +from torch import nn +from torch.nn import Upsample as NearestUpsample +from torch.utils.checkpoint import checkpoint + +from .conv import (Conv1dBlock, Conv2dBlock, Conv3dBlock, HyperConv2dBlock, + LinearBlock, MultiOutConv2dBlock, PartialConv2dBlock, + PartialConv3dBlock, ModulatedConv2dBlock) +from imaginaire.third_party.upfirdn2d.upfirdn2d import BlurUpsample + + +class _BaseResBlock(nn.Module): + r"""An abstract class for residual blocks. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, + order, block, learn_shortcut, clamp, output_scale, + skip_block=None, blur=False, upsample_first=True, skip_weight_norm=True): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.output_scale = output_scale + self.upsample_first = upsample_first + self.stride = stride + self.blur = blur + if skip_block is None: + skip_block = block + + if order == 'pre_act': + order = 'NACNAC' + if isinstance(bias, bool): + # The bias for conv_block_0, conv_block_1, and conv_block_s. + biases = [bias, bias, bias] + elif isinstance(bias, list): + if len(bias) == 3: + biases = bias + else: + raise ValueError('Bias list must be 3.') + else: + raise ValueError('Bias must be either an integer or s list.') + if learn_shortcut is None: + self.learn_shortcut = (in_channels != out_channels) + else: + self.learn_shortcut = learn_shortcut + if len(order) > 6 or len(order) < 5: + raise ValueError('order must be either 5 or 6 characters') + if hidden_channels_equal_out_channels: + hidden_channels = out_channels + else: + hidden_channels = min(in_channels, out_channels) + + # Parameters. + residual_params = {} + shortcut_params = {} + base_params = dict(dilation=dilation, + groups=groups, + padding_mode=padding_mode, + clamp=clamp) + residual_params.update(base_params) + residual_params.update( + dict(activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params, + padding=padding, + apply_noise=apply_noise)) + shortcut_params.update(base_params) + shortcut_params.update(dict(kernel_size=1)) + if skip_activation_norm: + shortcut_params.update( + dict(activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + apply_noise=False)) + if skip_weight_norm: + shortcut_params.update( + dict(weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params)) + + # Residual branch. + if order.find('A') < order.find('C') and \ + (activation_norm_type == '' or activation_norm_type == 'none'): + # Nonlinearity is the first operation in the residual path. + # In-place nonlinearity will modify the input variable and cause + # backward error. + first_inplace = False + else: + first_inplace = inplace_nonlinearity + + (first_stride, second_stride, shortcut_stride, + first_blur, second_blur, shortcut_blur) = self._get_stride_blur() + self.conv_block_0 = block( + in_channels, hidden_channels, + kernel_size=kernel_size, + bias=biases[0], + nonlinearity=nonlinearity, + order=order[0:3], + inplace_nonlinearity=first_inplace, + stride=first_stride, + blur=first_blur, + **residual_params + ) + self.conv_block_1 = block( + hidden_channels, out_channels, + kernel_size=kernel_size, + bias=biases[1], + nonlinearity=nonlinearity, + order=order[3:], + inplace_nonlinearity=inplace_nonlinearity, + stride=second_stride, + blur=second_blur, + **residual_params + ) + + # Shortcut branch. + if self.learn_shortcut: + if skip_nonlinearity: + skip_nonlinearity_type = nonlinearity + else: + skip_nonlinearity_type = '' + self.conv_block_s = skip_block(in_channels, out_channels, + bias=biases[2], + nonlinearity=skip_nonlinearity_type, + order=order[0:3], + stride=shortcut_stride, + blur=shortcut_blur, + **shortcut_params) + elif in_channels < out_channels: + if skip_nonlinearity: + skip_nonlinearity_type = nonlinearity + else: + skip_nonlinearity_type = '' + self.conv_block_s = skip_block(in_channels, + out_channels - in_channels, + bias=biases[2], + nonlinearity=skip_nonlinearity_type, + order=order[0:3], + stride=shortcut_stride, + blur=shortcut_blur, + **shortcut_params) + + # Whether this block expects conditional inputs. + self.conditional = \ + getattr(self.conv_block_0, 'conditional', False) or \ + getattr(self.conv_block_1, 'conditional', False) + + def _get_stride_blur(self): + if self.stride > 1: + # Downsampling. + first_stride, second_stride = 1, self.stride + first_blur, second_blur = False, self.blur + shortcut_stride = self.stride + shortcut_blur = self.blur + self.upsample = None + elif self.stride < 1: + # Upsampling. + first_stride, second_stride = self.stride, 1 + first_blur, second_blur = self.blur, False + shortcut_blur = False + shortcut_stride = 1 + if self.blur: + # The shortcut branch uses blur_upsample + stride-1 conv + self.upsample = BlurUpsample() + else: + shortcut_stride = self.stride + self.upsample = nn.Upsample(scale_factor=2) + else: + first_stride = second_stride = 1 + first_blur = second_blur = False + shortcut_stride = 1 + shortcut_blur = False + self.upsample = None + return (first_stride, second_stride, shortcut_stride, + first_blur, second_blur, shortcut_blur) + + def conv_blocks( + self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs + ): + r"""Returns the output of the residual branch. + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + dx (tensor): Output tensor. + """ + if separate_cond: + dx = self.conv_block_0(x, cond_inputs[0], + **kw_cond_inputs.get('kwargs_0', {})) + dx = self.conv_block_1(dx, cond_inputs[1], + **kw_cond_inputs.get('kwargs_1', {})) + else: + dx = self.conv_block_0(x, *cond_inputs, **kw_cond_inputs) + dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs) + return dx + + def forward(self, x, *cond_inputs, do_checkpoint=False, separate_cond=False, + **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + do_checkpoint (bool, optional, default=``False``) If ``True``, + trade compute for memory by checkpointing the model. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + output (tensor): Output tensor. + """ + if do_checkpoint: + dx = checkpoint(self.conv_blocks, x, *cond_inputs, + separate_cond=separate_cond, **kw_cond_inputs) + else: + dx = self.conv_blocks(x, *cond_inputs, + separate_cond=separate_cond, **kw_cond_inputs) + + if self.upsample_first and self.upsample is not None: + x = self.upsample(x) + if self.learn_shortcut: + if separate_cond: + x_shortcut = self.conv_block_s( + x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {}) + ) + else: + x_shortcut = self.conv_block_s( + x, *cond_inputs, **kw_cond_inputs + ) + elif self.in_channels < self.out_channels: + if separate_cond: + x_shortcut_pad = self.conv_block_s( + x, cond_inputs[2], **kw_cond_inputs.get('kwargs_2', {}) + ) + else: + x_shortcut_pad = self.conv_block_s( + x, *cond_inputs, **kw_cond_inputs + ) + x_shortcut = torch.cat((x, x_shortcut_pad), dim=1) + elif self.in_channels > self.out_channels: + x_shortcut = x[:, :self.out_channels, :, :] + else: + x_shortcut = x + if not self.upsample_first and self.upsample is not None: + x_shortcut = self.upsample(x_shortcut) + + output = x_shortcut + dx + return self.output_scale * output + + def extra_repr(self): + s = 'output_scale={output_scale}' + return s.format(**self.__dict__) + + +class ModulatedRes2dBlock(_BaseResBlock): + def __init__(self, in_channels, out_channels, style_dim, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=True, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, output_scale=1, + demodulate=True, eps=1e-8): + block = functools.partial(ModulatedConv2dBlock, + style_dim=style_dim, + demodulate=demodulate, eps=eps) + skip_block = Conv2dBlock + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale, skip_block=skip_block) + + def conv_blocks(self, x, *cond_inputs, **kw_cond_inputs): + assert len(list(cond_inputs)) == 2 + dx = self.conv_block_0(x, cond_inputs[0], **kw_cond_inputs) + dx = self.conv_block_1(dx, cond_inputs[1], **kw_cond_inputs) + return dx + + +class ResLinearBlock(_BaseResBlock): + r"""Residual block with full-connected layers. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, add + Gaussian noise with learnable magnitude after the + fully-connected layer. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: fully-connected, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, bias=True, + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, None, 1, None, None, + None, bias, None, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, LinearBlock, + learn_shortcut, clamp, output_scale) + + +class Res1dBlock(_BaseResBlock): + r"""Residual block for 1D input. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, Conv1dBlock, + learn_shortcut, clamp, output_scale) + + +class Res2dBlock(_BaseResBlock): + r"""Residual block for 2D input. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + skip_weight_norm=True, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1, blur=False, upsample_first=True): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, Conv2dBlock, + learn_shortcut, clamp, output_scale, blur=blur, + upsample_first=upsample_first, + skip_weight_norm=skip_weight_norm) + + +class Res3dBlock(_BaseResBlock): + r"""Residual block for 3D input. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, Conv3dBlock, + learn_shortcut, clamp, output_scale) + + +class _BaseHyperResBlock(_BaseResBlock): + r"""An abstract class for hyper residual blocks. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, + order, is_hyper_conv, is_hyper_norm, block, learn_shortcut, + clamp=None, output_scale=1): + block = functools.partial(block, + is_hyper_conv=is_hyper_conv, + is_hyper_norm=is_hyper_norm) + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale) + + def forward(self, x, *cond_inputs, conv_weights=(None,) * 3, + norm_weights=(None,) * 3, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + conv_weights (list of tensors): Convolution weights for + three convolutional layers respectively. + norm_weights (list of tensors): Normalization weights for + three convolutional layers respectively. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + output (tensor): Output tensor. + """ + dx = self.conv_block_0(x, *cond_inputs, conv_weights=conv_weights[0], + norm_weights=norm_weights[0]) + dx = self.conv_block_1(dx, *cond_inputs, conv_weights=conv_weights[1], + norm_weights=norm_weights[1]) + if self.learn_shortcut: + x_shortcut = self.conv_block_s(x, *cond_inputs, + conv_weights=conv_weights[2], + norm_weights=norm_weights[2]) + else: + x_shortcut = x + output = x_shortcut + dx + return self.output_scale * output + + +class HyperRes2dBlock(_BaseHyperResBlock): + r"""Hyper residual block for 2D input. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + is_hyper_conv (bool, optional, default=False): If ``True``, use + ``HyperConv2d``, otherwise use ``torch.nn.Conv2d``. + is_hyper_norm (bool, optional, default=False): If ``True``, use + hyper normalizations. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='', weight_norm_params=None, + activation_norm_type='', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', is_hyper_conv=False, is_hyper_norm=False, + learn_shortcut=None, clamp=None, output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, + order, is_hyper_conv, is_hyper_norm, + HyperConv2dBlock, learn_shortcut, clamp, output_scale) + + +class _BaseDownResBlock(_BaseResBlock): + r"""An abstract class for residual blocks with downsampling. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, + apply_noise, hidden_channels_equal_out_channels, + order, block, pooling, down_factor, learn_shortcut, + clamp=None, output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale) + self.pooling = pooling(down_factor) + + def forward(self, x, *cond_inputs): + r""" + + Args: + x (tensor) : Input tensor. + cond_inputs (list of tensors) : conditional input. + Returns: + output (tensor) : Output tensor. + """ + dx = self.conv_block_0(x, *cond_inputs) + dx = self.conv_block_1(dx, *cond_inputs) + dx = self.pooling(dx) + if self.learn_shortcut: + x_shortcut = self.conv_block_s(x, *cond_inputs) + else: + x_shortcut = x + x_shortcut = self.pooling(x_shortcut) + output = x_shortcut + dx + return self.output_scale * output + + +class DownRes2dBlock(_BaseDownResBlock): + r"""Residual block for 2D input with downsampling. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + pooling (class, optional, default=nn.AvgPool2d): Pytorch pooling + layer to be used. + down_factor (int, optional, default=2): Downsampling factor. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', pooling=nn.AvgPool2d, down_factor=2, + learn_shortcut=None, clamp=None, output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, + order, Conv2dBlock, pooling, + down_factor, learn_shortcut, clamp, output_scale) + + +class _BaseUpResBlock(_BaseResBlock): + r"""An abstract class for residual blocks with upsampling. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, + apply_noise, hidden_channels_equal_out_channels, + order, block, upsample, up_factor, learn_shortcut, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale) + self.order = order + self.upsample = upsample(scale_factor=up_factor) + + def _get_stride_blur(self): + # Upsampling. + first_stride, second_stride = self.stride, 1 + first_blur, second_blur = self.blur, False + shortcut_blur = False + shortcut_stride = 1 + # if self.upsample == 'blur_deconv': + + if self.blur: + # The shortcut branch uses blur_upsample + stride-1 conv + self.upsample = BlurUpsample() + else: + shortcut_stride = self.stride + self.upsample = nn.Upsample(scale_factor=2) + + return (first_stride, second_stride, shortcut_stride, + first_blur, second_blur, shortcut_blur) + + def forward(self, x, *cond_inputs): + r"""Implementation of the up residual block forward function. + If the order is 'NAC' for the first residual block, we will first + do the activation norm and nonlinearity, in the original resolution. + We will then upsample the activation map to a higher resolution. We + then do the convolution. + It is is other orders, then we first do the whole processing and + then upsample. + + Args: + x (tensor) : Input tensor. + cond_inputs (list of tensors) : Conditional input. + Returns: + output (tensor) : Output tensor. + """ + # In this particular upsample residual block operation, we first + # upsample the skip connection. + if self.learn_shortcut: + x_shortcut = self.upsample(x) + x_shortcut = self.conv_block_s(x_shortcut, *cond_inputs) + else: + x_shortcut = self.upsample(x) + + if self.order[0:3] == 'NAC': + for ix, layer in enumerate(self.conv_block_0.layers.values()): + if getattr(layer, 'conditional', False): + x = layer(x, *cond_inputs) + else: + x = layer(x) + if ix == 1: + x = self.upsample(x) + else: + x = self.conv_block_0(x, *cond_inputs) + x = self.upsample(x) + x = self.conv_block_1(x, *cond_inputs) + + output = x_shortcut + x + return self.output_scale * output + + +class UpRes2dBlock(_BaseUpResBlock): + r"""Residual block for 2D input with downsampling. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + upsample (class, optional, default=NearestUpsample): PPytorch + upsampling layer to be used. + up_factor (int, optional, default=2): Upsampling factor. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', upsample=NearestUpsample, up_factor=2, + learn_shortcut=None, clamp=None, output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, + apply_noise, hidden_channels_equal_out_channels, + order, Conv2dBlock, + upsample, up_factor, learn_shortcut, clamp, + output_scale) + + +class _BasePartialResBlock(_BaseResBlock): + r"""An abstract class for residual blocks with partial convolution. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, + multi_channel, return_mask, + apply_noise, hidden_channels_equal_out_channels, + order, block, learn_shortcut, clamp=None, output_scale=1): + block = functools.partial(block, + multi_channel=multi_channel, + return_mask=return_mask) + self.partial_conv = True + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale) + + def forward(self, x, *cond_inputs, mask_in=None, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + mask_in (tensor, optional, default=``None``) If not ``None``, + it masks the valid input region. + kw_cond_inputs (dict) : Keyword conditional inputs. + Returns: + (tuple): + - output (tensor): Output tensor. + - mask_out (tensor, optional): Masks the valid output region. + """ + if self.conv_block_0.layers.conv.return_mask: + dx, mask_out = self.conv_block_0(x, *cond_inputs, + mask_in=mask_in, **kw_cond_inputs) + dx, mask_out = self.conv_block_1(dx, *cond_inputs, + mask_in=mask_out, **kw_cond_inputs) + else: + dx = self.conv_block_0(x, *cond_inputs, + mask_in=mask_in, **kw_cond_inputs) + dx = self.conv_block_1(dx, *cond_inputs, + mask_in=mask_in, **kw_cond_inputs) + mask_out = None + + if self.learn_shortcut: + x_shortcut = self.conv_block_s(x, mask_in=mask_in, *cond_inputs, + **kw_cond_inputs) + if type(x_shortcut) == tuple: + x_shortcut, _ = x_shortcut + else: + x_shortcut = x + output = x_shortcut + dx + + if mask_out is not None: + return output, mask_out + return self.output_scale * output + + +class PartialRes2dBlock(_BasePartialResBlock): + r"""Residual block for 2D input with partial convolution. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + multi_channel=False, return_mask=True, + apply_noise=False, + hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, + padding_mode, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, multi_channel, return_mask, + apply_noise, hidden_channels_equal_out_channels, + order, PartialConv2dBlock, learn_shortcut, clamp, + output_scale) + + +class PartialRes3dBlock(_BasePartialResBlock): + r"""Residual block for 3D input with partial convolution. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + multi_channel=False, return_mask=True, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1): + super().__init__(in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, + padding_mode, weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, multi_channel, + return_mask, apply_noise, + hidden_channels_equal_out_channels, + order, PartialConv3dBlock, learn_shortcut, clamp, + output_scale) + + +class _BaseMultiOutResBlock(_BaseResBlock): + r"""An abstract class for residual blocks that can returns multiple outputs. + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, + apply_noise, hidden_channels_equal_out_channels, + order, block, learn_shortcut, clamp=None, output_scale=1, + blur=False, upsample_first=True): + self.multiple_outputs = True + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, block, + learn_shortcut, clamp, output_scale, blur=blur, + upsample_first=upsample_first) + + def forward(self, x, *cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + Returns: + (tuple): + - output (tensor): Output tensor. + - aux_outputs_0 (tensor): Auxiliary output of the first block. + - aux_outputs_1 (tensor): Auxiliary output of the second block. + """ + dx, aux_outputs_0 = self.conv_block_0(x, *cond_inputs) + dx, aux_outputs_1 = self.conv_block_1(dx, *cond_inputs) + if self.learn_shortcut: + # We are not using the auxiliary outputs of self.conv_block_s. + x_shortcut, _ = self.conv_block_s(x, *cond_inputs) + else: + x_shortcut = x + output = x_shortcut + dx + return self.output_scale * output, aux_outputs_0, aux_outputs_1 + + +class MultiOutRes2dBlock(_BaseMultiOutResBlock): + r"""Residual block for 2D input. It can return multiple outputs, if some + layers in the block return more than one output. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=None, clamp=None, + output_scale=1, blur=False, upsample_first=True): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, + MultiOutConv2dBlock, learn_shortcut, clamp, + output_scale, blur=blur, upsample_first=upsample_first) diff --git a/imaginaire/layers/residual_deep.py b/imaginaire/layers/residual_deep.py new file mode 100644 index 0000000000000000000000000000000000000000..b0bbcd497f4689bed4faf20e8e47c0fc4e282812 --- /dev/null +++ b/imaginaire/layers/residual_deep.py @@ -0,0 +1,346 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from imaginaire.third_party.upfirdn2d import BlurDownsample, BlurUpsample +from .conv import Conv2dBlock + + +class _BaseDeepResBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, + nonlinearity, inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, + order, block, learn_shortcut, output_scale, skip_block=None, + blur=True, border_free=True, resample_first=True, + skip_weight_norm=True, hidden_channel_ratio=4): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.output_scale = output_scale + self.resample_first = resample_first + self.stride = stride + self.blur = blur + self.border_free = border_free + assert not border_free + if skip_block is None: + skip_block = block + + if order == 'pre_act': + order = 'NACNAC' + if isinstance(bias, bool): + # The bias for conv_block_0, conv_block_1, and conv_block_s. + biases = [bias, bias, bias] + elif isinstance(bias, list): + if len(bias) == 3: + biases = bias + else: + raise ValueError('Bias list must be 3.') + else: + raise ValueError('Bias must be either an integer or s list.') + self.learn_shortcut = learn_shortcut + if len(order) > 6 or len(order) < 5: + raise ValueError('order must be either 5 or 6 characters') + hidden_channels = in_channels // hidden_channel_ratio + + # Parameters. + residual_params = {} + shortcut_params = {} + base_params = dict(dilation=dilation, + groups=groups, + padding_mode=padding_mode) + residual_params.update(base_params) + residual_params.update( + dict(activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params, + apply_noise=apply_noise) + ) + shortcut_params.update(base_params) + shortcut_params.update(dict(kernel_size=1)) + if skip_activation_norm: + shortcut_params.update( + dict(activation_norm_type=activation_norm_type, + activation_norm_params=activation_norm_params, + apply_noise=False)) + if skip_weight_norm: + shortcut_params.update( + dict(weight_norm_type=weight_norm_type, + weight_norm_params=weight_norm_params)) + + # Residual branch. + if order.find('A') < order.find('C') and \ + (activation_norm_type == '' or activation_norm_type == 'none'): + # Nonlinearity is the first operation in the residual path. + # In-place nonlinearity will modify the input variable and cause + # backward error. + first_inplace = False + else: + first_inplace = inplace_nonlinearity + + (first_stride, second_stride, shortcut_stride, + first_blur, second_blur, shortcut_blur) = self._get_stride_blur() + + self.conv_block_1x1_in = block( + in_channels, hidden_channels, + 1, 1, 0, + bias=biases[0], + nonlinearity=nonlinearity, + order=order[0:3], + inplace_nonlinearity=first_inplace, + **residual_params + ) + + self.conv_block_0 = block( + hidden_channels, hidden_channels, + kernel_size=2 if self.border_free and first_stride < 1 else + kernel_size, + padding=padding, + bias=biases[0], + nonlinearity=nonlinearity, + order=order[0:3], + inplace_nonlinearity=inplace_nonlinearity, + stride=first_stride, + blur=first_blur, + **residual_params + ) + self.conv_block_1 = block( + hidden_channels, hidden_channels, + kernel_size=kernel_size, + padding=padding, + bias=biases[1], + nonlinearity=nonlinearity, + order=order[3:], + inplace_nonlinearity=inplace_nonlinearity, + stride=second_stride, + blur=second_blur, + **residual_params + ) + + self.conv_block_1x1_out = block( + hidden_channels, out_channels, + 1, 1, 0, + bias=biases[1], + nonlinearity=nonlinearity, + order=order[0:3], + inplace_nonlinearity=inplace_nonlinearity, + **residual_params + ) + + # Shortcut branch. + if self.learn_shortcut: + if skip_nonlinearity: + skip_nonlinearity_type = nonlinearity + else: + skip_nonlinearity_type = '' + self.conv_block_s = skip_block(in_channels, out_channels, + bias=biases[2], + nonlinearity=skip_nonlinearity_type, + order=order[0:3], + stride=shortcut_stride, + blur=shortcut_blur, + **shortcut_params) + elif in_channels < out_channels: + if skip_nonlinearity: + skip_nonlinearity_type = nonlinearity + else: + skip_nonlinearity_type = '' + self.conv_block_s = skip_block(in_channels, + out_channels - in_channels, + bias=biases[2], + nonlinearity=skip_nonlinearity_type, + order=order[0:3], + stride=shortcut_stride, + blur=shortcut_blur, + **shortcut_params) + + # Whether this block expects conditional inputs. + self.conditional = \ + getattr(self.conv_block_0, 'conditional', False) or \ + getattr(self.conv_block_1, 'conditional', False) or \ + getattr(self.conv_block_1x1_in, 'conditional', False) or \ + getattr(self.conv_block_1x1_out, 'conditional', False) + + def _get_stride_blur(self): + if self.stride > 1: + # Downsampling. + first_stride, second_stride = 1, self.stride + first_blur, second_blur = False, self.blur + shortcut_blur = False + shortcut_stride = 1 + if self.blur: + # The shortcut branch uses blur_downsample + stride-1 conv + if self.border_free: + self.resample = nn.AvgPool2d(2) + else: + self.resample = BlurDownsample() + else: + shortcut_stride = self.stride + self.resample = nn.AvgPool2d(2) + elif self.stride < 1: + # Upsampling. + first_stride, second_stride = self.stride, 1 + first_blur, second_blur = self.blur, False + shortcut_blur = False + shortcut_stride = 1 + if self.blur: + # The shortcut branch uses blur_upsample + stride-1 conv + if self.border_free: + self.resample = nn.Upsample(scale_factor=2, + mode='bilinear') + else: + self.resample = BlurUpsample() + else: + shortcut_stride = self.stride + self.resample = nn.Upsample(scale_factor=2) + else: + first_stride = second_stride = 1 + first_blur = second_blur = False + shortcut_stride = 1 + shortcut_blur = False + self.resample = None + return (first_stride, second_stride, shortcut_stride, + first_blur, second_blur, shortcut_blur) + + def conv_blocks( + self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs + ): + if separate_cond: + assert len(list(cond_inputs)) == 4 + dx = self.conv_block_1x1_in(x, cond_inputs[0], + **kw_cond_inputs.get('kwargs_0', {})) + dx = self.conv_block_0(dx, cond_inputs[1], + **kw_cond_inputs.get('kwargs_1', {})) + dx = self.conv_block_1(dx, cond_inputs[2], + **kw_cond_inputs.get('kwargs_2', {})) + dx = self.conv_block_1x1_out(dx, cond_inputs[3], + **kw_cond_inputs.get('kwargs_3', {})) + else: + dx = self.conv_block_1x1_in(x, *cond_inputs, **kw_cond_inputs) + dx = self.conv_block_0(dx, *cond_inputs, **kw_cond_inputs) + dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs) + dx = self.conv_block_1x1_out(dx, *cond_inputs, **kw_cond_inputs) + return dx + + def forward(self, x, *cond_inputs, do_checkpoint=False, **kw_cond_inputs): + if do_checkpoint: + dx = checkpoint(self.conv_blocks, x, *cond_inputs, **kw_cond_inputs) + else: + dx = self.conv_blocks(x, *cond_inputs, **kw_cond_inputs) + + if self.resample_first and self.resample is not None: + x = self.resample(x) + if self.learn_shortcut: + x_shortcut = self.conv_block_s( + x, *cond_inputs, **kw_cond_inputs + ) + elif self.in_channels < self.out_channels: + x_shortcut_pad = self.conv_block_s( + x, *cond_inputs, **kw_cond_inputs + ) + x_shortcut = torch.cat((x, x_shortcut_pad), dim=1) + elif self.in_channels > self.out_channels: + x_shortcut = x[:, :self.out_channels, :, :] + else: + x_shortcut = x + if not self.resample_first and self.resample is not None: + x_shortcut = self.resample(x_shortcut) + + output = x_shortcut + dx + return self.output_scale * output + + def extra_repr(self): + s = 'output_scale={output_scale}' + return s.format(**self.__dict__) + + +class DeepRes2dBlock(_BaseDeepResBlock): + r"""Residual block for 2D input. + + Args: + in_channels (int) : Number of channels in the input tensor. + out_channels (int) : Number of channels in the output tensor. + kernel_size (int, optional, default=3): Kernel size for the + convolutional filters in the residual link. + padding (int, optional, default=1): Padding size. + dilation (int, optional, default=1): Dilation factor. + groups (int, optional, default=1): Number of convolutional/linear + groups. + padding_mode (string, optional, default='zeros'): Type of padding: + ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + weight_norm_type (str, optional, default='none'): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + weight_norm_params (obj, optional, default=None): + Parameters of weight normalization. + If not ``None``, ``weight_norm_params.__dict__`` will be used as + keyword arguments when initializing weight normalization. + activation_norm_type (str, optional, default='none'): + Type of activation normalization. + ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``, + ``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``, + ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``. + activation_norm_params (obj, optional, default=None): + Parameters of activation normalization. + If not ``None``, ``activation_norm_params.__dict__`` will be used as + keyword arguments when initializing activation normalization. + skip_activation_norm (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies activation norm to the + learned shortcut connection. + skip_nonlinearity (bool, optional, default=True): If ``True`` and + ``learn_shortcut`` is also ``True``, applies nonlinearity to the + learned shortcut connection. + nonlinearity (str, optional, default='none'): + Type of nonlinear activation function in the residual link. + ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, + ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. + inplace_nonlinearity (bool, optional, default=False): If ``True``, + set ``inplace=True`` when initializing the nonlinearity layers. + apply_noise (bool, optional, default=False): If ``True``, adds + Gaussian noise with learnable magnitude to the convolution output. + hidden_channels_equal_out_channels (bool, optional, default=False): + If ``True``, set the hidden channel number to be equal to the + output channel number. If ``False``, the hidden channel number + equals to the smaller of the input channel number and the + output channel number. + order (str, optional, default='CNACNA'): Order of operations + in the residual link. + ``'C'``: convolution, + ``'N'``: normalization, + ``'A'``: nonlinear activation. + learn_shortcut (bool, optional, default=False): If ``True``, always use + a convolutional shortcut instead of an identity one, otherwise only + use a convolutional one if input and output have different number of + channels. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding=1, dilation=1, groups=1, bias=True, + padding_mode='zeros', + weight_norm_type='none', weight_norm_params=None, + activation_norm_type='none', activation_norm_params=None, + skip_activation_norm=True, skip_nonlinearity=False, + skip_weight_norm=True, + nonlinearity='leakyrelu', inplace_nonlinearity=False, + apply_noise=False, hidden_channels_equal_out_channels=False, + order='CNACNA', learn_shortcut=False, output_scale=1, + blur=True, resample_first=True, border_free=False): + super().__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + skip_activation_norm, skip_nonlinearity, nonlinearity, + inplace_nonlinearity, apply_noise, + hidden_channels_equal_out_channels, order, Conv2dBlock, + learn_shortcut, output_scale, blur=blur, + resample_first=resample_first, border_free=border_free, + skip_weight_norm=skip_weight_norm) diff --git a/imaginaire/layers/vit.py b/imaginaire/layers/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..abd0d039d715444efc3c5a1e9889330bfa5b4c4f --- /dev/null +++ b/imaginaire/layers/vit.py @@ -0,0 +1,204 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from types import SimpleNamespace + +import torch +from torch import nn + +from .misc import ApplyNoise +from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur + + +class ViT2dBlock(nn.Module): + r"""An abstract wrapper class that wraps a torch convolution or linear layer + with normalization and nonlinearity. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + weight_norm_type, weight_norm_params, + activation_norm_type, activation_norm_params, + nonlinearity, inplace_nonlinearity, + apply_noise, blur, order, input_dim, clamp, + blur_kernel=(1, 3, 3, 1), output_scale=None, + init_gain=1.0): + super().__init__() + from .nonlinearity import get_nonlinearity_layer + from .weight_norm import get_weight_norm_layer + from .activation_norm import get_activation_norm_layer + self.weight_norm_type = weight_norm_type + self.stride = stride + self.clamp = clamp + self.init_gain = init_gain + + # Nonlinearity layer. + if 'fused' in nonlinearity: + # Fusing nonlinearity with bias. + lr_mul = getattr(weight_norm_params, 'lr_mul', 1) + conv_before_nonlinearity = order.find('C') < order.find('A') + if conv_before_nonlinearity: + assert bias + bias = False + channel = out_channels if conv_before_nonlinearity else in_channels + nonlinearity_layer = get_nonlinearity_layer( + nonlinearity, inplace=inplace_nonlinearity, + num_channels=channel, lr_mul=lr_mul) + else: + nonlinearity_layer = get_nonlinearity_layer( + nonlinearity, inplace=inplace_nonlinearity) + + # Noise injection layer. + if apply_noise: + order = order.replace('C', 'CG') + noise_layer = ApplyNoise() + else: + noise_layer = None + + # Convolutional layer. + if blur: + if stride == 2: + # Blur - Conv - Noise - Activate + p = (len(blur_kernel) - 2) + (kernel_size - 1) + pad0, pad1 = (p + 1) // 2, p // 2 + padding = 0 + blur_layer = Blur( + blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode + ) + order = order.replace('C', 'BC') + elif stride == 0.5: + # Conv - Blur - Noise - Activate + padding = 0 + p = (len(blur_kernel) - 2) - (kernel_size - 1) + pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1 + blur_layer = Blur( + blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode + ) + order = order.replace('C', 'CB') + elif stride == 1: + # No blur for now + blur_layer = nn.Identity() + else: + raise NotImplementedError + else: + blur_layer = nn.Identity() + + if weight_norm_params is None: + weight_norm_params = SimpleNamespace() + weight_norm = get_weight_norm_layer( + weight_norm_type, **vars(weight_norm_params)) + conv_layer = weight_norm(self._get_conv_layer( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode, input_dim)) + + # Normalization layer. + conv_before_norm = order.find('C') < order.find('N') + norm_channels = out_channels if conv_before_norm else in_channels + if activation_norm_params is None: + activation_norm_params = SimpleNamespace() + activation_norm_layer = get_activation_norm_layer( + norm_channels, + activation_norm_type, + input_dim, + **vars(activation_norm_params)) + + # Mapping from operation names to layers. + mappings = {'C': {'conv': conv_layer}, + 'N': {'norm': activation_norm_layer}, + 'A': {'nonlinearity': nonlinearity_layer}} + mappings.update({'B': {'blur': blur_layer}}) + mappings.update({'G': {'noise': noise_layer}}) + + # All layers in order. + self.layers = nn.ModuleDict() + for op in order: + if list(mappings[op].values())[0] is not None: + self.layers.update(mappings[op]) + + # Whether this block expects conditional inputs. + self.conditional = \ + getattr(conv_layer, 'conditional', False) or \ + getattr(activation_norm_layer, 'conditional', False) + + if output_scale is not None: + self.output_scale = nn.Parameter(torch.tensor(output_scale)) + else: + self.register_parameter("output_scale", None) + + def forward(self, x, *cond_inputs, **kw_cond_inputs): + r""" + + Args: + x (tensor): Input tensor. + cond_inputs (list of tensors) : Conditional input tensors. + kw_cond_inputs (dict) : Keyword conditional inputs. + """ + for key, layer in self.layers.items(): + if getattr(layer, 'conditional', False): + # Layers that require conditional inputs. + x = layer(x, *cond_inputs, **kw_cond_inputs) + else: + x = layer(x) + if self.clamp is not None and isinstance(layer, nn.Conv2d): + x.clamp_(max=self.clamp) + if key == 'conv': + if self.output_scale is not None: + x = x * self.output_scale + return x + + def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode, + input_dim): + # Returns the convolutional layer. + if input_dim == 0: + layer = nn.Linear(in_channels, out_channels, bias) + else: + if stride < 1: # Fractionally-strided convolution. + padding_mode = 'zeros' + assert padding == 0 + layer_type = getattr(nn, f'ConvTranspose{input_dim}d') + stride = round(1 / stride) + else: + layer_type = getattr(nn, f'Conv{input_dim}d') + layer = layer_type( + in_channels, out_channels, kernel_size, stride, padding, + dilation=dilation, groups=groups, bias=bias, + padding_mode=padding_mode + ) + + return layer + + def __repr__(self): + main_str = self._get_name() + '(' + child_lines = [] + for name, layer in self.layers.items(): + mod_str = repr(layer) + if name == 'conv' and self.weight_norm_type != 'none' and \ + self.weight_norm_type != '': + mod_str = mod_str[:-1] + \ + ', weight_norm={}'.format(self.weight_norm_type) + ')' + if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1: + mod_str = mod_str[:-1] + \ + ', lr_mul={}'.format(layer.base_lr_mul) + ')' + mod_str = self._addindent(mod_str, 2) + child_lines.append(mod_str) + if len(child_lines) == 1: + main_str += child_lines[0] + else: + main_str += '\n ' + '\n '.join(child_lines) + '\n' + + main_str += ')' + return main_str + + @staticmethod + def _addindent(s_, numSpaces): + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s diff --git a/imaginaire/layers/weight_norm.py b/imaginaire/layers/weight_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..e15ca2d21cea70062fa24ffdcd5adab51c8dcb25 --- /dev/null +++ b/imaginaire/layers/weight_norm.py @@ -0,0 +1,267 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import collections +import functools + +import torch +from torch import nn +from torch.nn.utils import spectral_norm, weight_norm +from torch.nn.utils.spectral_norm import SpectralNorm, \ + SpectralNormStateDictHook, SpectralNormLoadStateDictPreHook + +from .conv import LinearBlock + + +class WeightDemodulation(nn.Module): + r"""Weight demodulation in + "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. + + Args: + conv (torch.nn.Modules): Convolutional layer. + cond_dims (int): The number of channels in the conditional input. + eps (float, optional, default=1e-8): a value added to the + denominator for numerical stability. + adaptive_bias (bool, optional, default=False): If ``True``, adaptively + predicts bias from the conditional input. + demod (bool, optional, default=False): If ``True``, performs + weight demodulation. + """ + + def __init__(self, conv, cond_dims, eps=1e-8, + adaptive_bias=False, demod=True): + super().__init__() + self.conv = conv + self.adaptive_bias = adaptive_bias + if adaptive_bias: + self.conv.register_parameter('bias', None) + self.fc_beta = LinearBlock(cond_dims, self.conv.out_channels) + self.fc_gamma = LinearBlock(cond_dims, self.conv.in_channels) + self.eps = eps + self.demod = demod + self.conditional = True + + def forward(self, x, y, **_kwargs): + r"""Weight demodulation forward""" + b, c, h, w = x.size() + self.conv.groups = b + gamma = self.fc_gamma(y) + gamma = gamma[:, None, :, None, None] + weight = self.conv.weight[None, :, :, :, :] * gamma + + if self.demod: + d = torch.rsqrt( + (weight ** 2).sum( + dim=(2, 3, 4), keepdim=True) + self.eps) + weight = weight * d + + x = x.reshape(1, -1, h, w) + _, _, *ws = weight.shape + weight = weight.reshape(b * self.conv.out_channels, *ws) + x = self.conv._conv_forward(x, weight) + + x = x.reshape(-1, self.conv.out_channels, h, w) + if self.adaptive_bias: + x += self.fc_beta(y)[:, :, None, None] + return x + + +def weight_demod( + conv, cond_dims=256, eps=1e-8, adaptive_bias=False, demod=True): + r"""Weight demodulation.""" + return WeightDemodulation(conv, cond_dims, eps, adaptive_bias, demod) + + +class ScaledLR(object): + def __init__(self, weight_name, bias_name): + self.weight_name = weight_name + self.bias_name = bias_name + + def compute_weight(self, module): + weight = getattr(module, self.weight_name + '_ori') + return weight * module.weight_scale + + def compute_bias(self, module): + bias = getattr(module, self.bias_name + '_ori') + if bias is not None: + return bias * module.bias_scale + else: + return None + + @staticmethod + def apply(module, weight_name, bias_name, lr_mul, equalized): + assert weight_name == 'weight' + assert bias_name == 'bias' + fn = ScaledLR(weight_name, bias_name) + module.register_forward_pre_hook(fn) + + if hasattr(module, bias_name): + # module.bias is a parameter (can be None). + bias = getattr(module, bias_name) + delattr(module, bias_name) + module.register_parameter(bias_name + '_ori', bias) + else: + # module.bias does not exist. + bias = None + setattr(module, bias_name + '_ori', bias) + if bias is not None: + setattr(module, bias_name, bias.data) + else: + setattr(module, bias_name, None) + module.register_buffer('bias_scale', torch.tensor(lr_mul)) + + if hasattr(module, weight_name + '_orig'): + # The module has been wrapped with spectral normalization. + # We only want to keep a single weight parameter. + weight = getattr(module, weight_name + '_orig') + delattr(module, weight_name + '_orig') + module.register_parameter(weight_name + '_ori', weight) + setattr(module, weight_name + '_orig', weight.data) + # Put this hook before the spectral norm hook. + module._forward_pre_hooks = collections.OrderedDict( + reversed(list(module._forward_pre_hooks.items())) + ) + module.use_sn = True + else: + weight = getattr(module, weight_name) + delattr(module, weight_name) + module.register_parameter(weight_name + '_ori', weight) + setattr(module, weight_name, weight.data) + module.use_sn = False + + # assert weight.dim() == 4 or weight.dim() == 2 + if equalized: + fan_in = weight.data.size(1) * weight.data[0][0].numel() + # Theoretically, the gain should be sqrt(2) instead of 1. + # The official StyleGAN2 uses 1 for some reason. + module.register_buffer( + 'weight_scale', torch.tensor(lr_mul * ((1 / fan_in) ** 0.5)) + ) + else: + module.register_buffer('weight_scale', torch.tensor(lr_mul)) + + module.lr_mul = module.weight_scale + module.base_lr_mul = lr_mul + + return fn + + def remove(self, module): + with torch.no_grad(): + weight = self.compute_weight(module) + delattr(module, self.weight_name + '_ori') + + if module.use_sn: + setattr(module, self.weight_name + '_orig', weight.detach()) + else: + delattr(module, self.weight_name) + module.register_parameter(self.weight_name, + torch.nn.Parameter(weight.detach())) + + with torch.no_grad(): + bias = self.compute_bias(module) + delattr(module, self.bias_name) + delattr(module, self.bias_name + '_ori') + if bias is not None: + module.register_parameter(self.bias_name, + torch.nn.Parameter(bias.detach())) + else: + module.register_parameter(self.bias_name, None) + + module.lr_mul = 1.0 + module.base_lr_mul = 1.0 + + def __call__(self, module, input): + weight = self.compute_weight(module) + if module.use_sn: + # The following spectral norm hook will compute the SN of + # "module.weight_orig" and store the normalized weight in + # "module.weight". + setattr(module, self.weight_name + '_orig', weight) + else: + setattr(module, self.weight_name, weight) + bias = self.compute_bias(module) + setattr(module, self.bias_name, bias) + + +def remove_weight_norms(module, weight_name='weight', bias_name='bias'): + if hasattr(module, 'weight_ori') or hasattr(module, 'weight_orig'): + for k in list(module._forward_pre_hooks.keys()): + hook = module._forward_pre_hooks[k] + if (isinstance(hook, ScaledLR) or isinstance(hook, SpectralNorm)): + hook.remove(module) + del module._forward_pre_hooks[k] + + for k, hook in module._state_dict_hooks.items(): + if isinstance(hook, SpectralNormStateDictHook) and \ + hook.fn.name == weight_name: + del module._state_dict_hooks[k] + break + + for k, hook in module._load_state_dict_pre_hooks.items(): + if isinstance(hook, SpectralNormLoadStateDictPreHook) and \ + hook.fn.name == weight_name: + del module._load_state_dict_pre_hooks[k] + break + + return module + + +def remove_equalized_lr(module, weight_name='weight', bias_name='bias'): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, ScaledLR) and hook.weight_name == weight_name: + hook.remove(module) + del module._forward_pre_hooks[k] + break + else: + raise ValueError("Equalized learning rate not found") + + return module + + +def scaled_lr( + module, weight_name='weight', bias_name='bias', lr_mul=1., + equalized=False, +): + ScaledLR.apply(module, weight_name, bias_name, lr_mul, equalized) + return module + + +def get_weight_norm_layer(norm_type, **norm_params): + r"""Return weight normalization. + + Args: + norm_type (str): + Type of weight normalization. + ``'none'``, ``'spectral'``, ``'weight'`` + or ``'weight_demod'``. + norm_params: Arbitrary keyword arguments that will be used to + initialize the weight normalization. + """ + if norm_type == 'none' or norm_type == '': # no normalization + return lambda x: x + elif norm_type == 'spectral': # spectral normalization + return functools.partial(spectral_norm, **norm_params) + elif norm_type == 'weight': # weight normalization + return functools.partial(weight_norm, **norm_params) + elif norm_type == 'weight_demod': # weight demodulation + return functools.partial(weight_demod, **norm_params) + elif norm_type == 'equalized_lr': # equalized learning rate + return functools.partial(scaled_lr, equalized=True, **norm_params) + elif norm_type == 'scaled_lr': # equalized learning rate + return functools.partial(scaled_lr, **norm_params) + elif norm_type == 'equalized_lr_spectral': + lr_mul = norm_params.pop('lr_mul', 1.0) + return lambda x: functools.partial( + scaled_lr, equalized=True, lr_mul=lr_mul)( + functools.partial(spectral_norm, **norm_params)(x) + ) + elif norm_type == 'scaled_lr_spectral': + lr_mul = norm_params.pop('lr_mul', 1.0) + return lambda x: functools.partial( + scaled_lr, lr_mul=lr_mul)( + functools.partial(spectral_norm, **norm_params)(x) + ) + else: + raise ValueError( + 'Weight norm layer %s is not recognized' % norm_type) diff --git a/imaginaire/losses/TVloss.py b/imaginaire/losses/TVloss.py new file mode 100644 index 0000000000000000000000000000000000000000..158b12519e90c580df1c71658c86b47ce8e71a6e --- /dev/null +++ b/imaginaire/losses/TVloss.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn + + +class TV_loss(nn.Module): + def __init__(self): + super().__init__() + + def forward(self,input): + B,D1,D2,D3 = input.size() + tv_d1 = torch.pow(input[:,1:,:,:]-input[:,:-1,:,:], 2).sum() + tv_d2 = torch.pow(input[:,:,1:,:]-input[:,:,:-1,:], 2).sum() + tv_d3 = torch.pow(input[:,:,:,1:]-input[:,:,:,:-1], 2).sum() + return (tv_d1+tv_d2+tv_d3)/(B*D1*D2*D3) + diff --git a/imaginaire/losses/__init__.py b/imaginaire/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b7d8e381cd8fe0771dd2ce2478af0986a9be87 --- /dev/null +++ b/imaginaire/losses/__init__.py @@ -0,0 +1,22 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .gan import GANLoss +from .perceptual import PerceptualLoss +from .feature_matching import FeatureMatchingLoss +from .kl import GaussianKLLoss +from .flow import MaskedL1Loss, FlowLoss +from .dict import DictLoss +from .weighted_mse import WeightedMSELoss +from .TVloss import TV_loss + +__all__ = ['GANLoss', 'PerceptualLoss', 'FeatureMatchingLoss', 'GaussianKLLoss', + 'MaskedL1Loss', 'FlowLoss', 'DictLoss', + 'WeightedMSELoss','TV_loss'] + +try: + from .gradient_penalty import GradientPenaltyLoss + __all__.extend(['GradientPenaltyLoss']) +except: # noqa + pass diff --git a/imaginaire/losses/__pycache__/TVloss.cpython-38.pyc b/imaginaire/losses/__pycache__/TVloss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e9cc54635ab39fc68ce1bf7fa42f3880f15ca82 Binary files /dev/null and b/imaginaire/losses/__pycache__/TVloss.cpython-38.pyc differ diff --git a/imaginaire/losses/__pycache__/__init__.cpython-38.pyc b/imaginaire/losses/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a283a2113362dc366af1fa2376349246babbd54b Binary files /dev/null and b/imaginaire/losses/__pycache__/__init__.cpython-38.pyc differ diff --git a/imaginaire/losses/__pycache__/dict.cpython-38.pyc b/imaginaire/losses/__pycache__/dict.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..899655408ca317238e65846264a0fcbadd04a4d1 Binary files /dev/null and b/imaginaire/losses/__pycache__/dict.cpython-38.pyc differ diff --git a/imaginaire/losses/__pycache__/feature_matching.cpython-38.pyc b/imaginaire/losses/__pycache__/feature_matching.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbeb6414e6aeb71b2ec64402751478cb786495c4 Binary files /dev/null and b/imaginaire/losses/__pycache__/feature_matching.cpython-38.pyc differ diff --git a/imaginaire/losses/__pycache__/flow.cpython-38.pyc b/imaginaire/losses/__pycache__/flow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1043723e126ca15751d7651e0e848b69ea7df80e Binary files /dev/null and b/imaginaire/losses/__pycache__/flow.cpython-38.pyc differ diff --git a/imaginaire/losses/__pycache__/gan.cpython-38.pyc b/imaginaire/losses/__pycache__/gan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fa848211b872fdf205c9c30b31d19bb2f3aae0e Binary files /dev/null and b/imaginaire/losses/__pycache__/gan.cpython-38.pyc differ diff --git a/imaginaire/losses/__pycache__/info_nce.cpython-38.pyc b/imaginaire/losses/__pycache__/info_nce.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e56e1ed11ac004f1aa741f38c930f033457feca Binary files /dev/null and b/imaginaire/losses/__pycache__/info_nce.cpython-38.pyc differ diff --git a/imaginaire/losses/__pycache__/kl.cpython-38.pyc b/imaginaire/losses/__pycache__/kl.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..140c611c8fda1667885013dfd8f35225b00bd8fb Binary files /dev/null and b/imaginaire/losses/__pycache__/kl.cpython-38.pyc differ diff --git a/imaginaire/losses/__pycache__/perceptual.cpython-38.pyc b/imaginaire/losses/__pycache__/perceptual.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aa1b81d9bc18105b7f3c3940b4b49831f19a22b Binary files /dev/null and b/imaginaire/losses/__pycache__/perceptual.cpython-38.pyc differ diff --git a/imaginaire/losses/__pycache__/weighted_mse.cpython-38.pyc b/imaginaire/losses/__pycache__/weighted_mse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e0a751c997154fbaec8d66e1527da42ff45bb73 Binary files /dev/null and b/imaginaire/losses/__pycache__/weighted_mse.cpython-38.pyc differ diff --git a/imaginaire/losses/dict.py b/imaginaire/losses/dict.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0e5ee9516013036b92c564fa349755fc9fc1c9 --- /dev/null +++ b/imaginaire/losses/dict.py @@ -0,0 +1,36 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch.nn as nn + + +class DictLoss(nn.Module): + def __init__(self, criterion='l1'): + super(DictLoss, self).__init__() + if criterion == 'l1': + self.criterion = nn.L1Loss() + elif criterion == 'l2' or criterion == 'mse': + self.criterion = nn.MSELoss() + else: + raise ValueError('Criterion %s is not recognized' % criterion) + + def forward(self, fake, real): + """Return the target vector for the l1/l2 loss computation. + + Args: + fake (dict, list or tuple): Discriminator features of fake images. + real (dict, list or tuple): Discriminator features of real images. + Returns: + loss (tensor): Loss value. + """ + loss = 0 + if type(fake) == dict: + for key in fake.keys(): + loss += self.criterion(fake[key], real[key].detach()) + elif type(fake) == list or type(fake) == tuple: + for f, r in zip(fake, real): + loss += self.criterion(f, r.detach()) + else: + loss += self.criterion(fake, real.detach()) + return loss diff --git a/imaginaire/losses/feature_matching.py b/imaginaire/losses/feature_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..f70034b3c3afba5a914261b55cf0abeab832391c --- /dev/null +++ b/imaginaire/losses/feature_matching.py @@ -0,0 +1,38 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch.nn as nn + + +class FeatureMatchingLoss(nn.Module): + r"""Compute feature matching loss""" + def __init__(self, criterion='l1'): + super(FeatureMatchingLoss, self).__init__() + if criterion == 'l1': + self.criterion = nn.L1Loss() + elif criterion == 'l2' or criterion == 'mse': + self.criterion = nn.MSELoss() + else: + raise ValueError('Criterion %s is not recognized' % criterion) + + def forward(self, fake_features, real_features): + r"""Return the target vector for the binary cross entropy loss + computation. + + Args: + fake_features (list of lists): Discriminator features of fake images. + real_features (list of lists): Discriminator features of real images. + + Returns: + (tensor): Loss value. + """ + num_d = len(fake_features) + dis_weight = 1.0 / num_d + loss = fake_features[0][0].new_tensor(0) + for i in range(num_d): + for j in range(len(fake_features[i])): + tmp_loss = self.criterion(fake_features[i][j], + real_features[i][j].detach()) + loss += dis_weight * tmp_loss + return loss diff --git a/imaginaire/losses/flow.py b/imaginaire/losses/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..3464d949bba8cf148f7bc5c8fa092995f66c663a --- /dev/null +++ b/imaginaire/losses/flow.py @@ -0,0 +1,313 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# flake8: noqa +import importlib +import warnings + +import torch +import torch.nn as nn + +from imaginaire.model_utils.fs_vid2vid import (get_face_mask, get_fg_mask, + get_part_mask, pick_image, + resample) + + +class MaskedL1Loss(nn.Module): + r"""Masked L1 loss constructor.""" + + def __init__(self, normalize_over_valid=False): + super(MaskedL1Loss, self).__init__() + self.criterion = nn.L1Loss() + self.normalize_over_valid = normalize_over_valid + + def forward(self, input, target, mask): + r"""Masked L1 loss computation. + + Args: + input (tensor): Input tensor. + target (tensor): Target tensor. + mask (tensor): Mask to be applied to the output loss. + + Returns: + (tensor): Loss value. + """ + mask = mask.expand_as(input) + loss = self.criterion(input * mask, target * mask) + if self.normalize_over_valid: + # The loss has been averaged over all pixels. + # Only average over regions which are valid. + loss = loss * torch.numel(mask) / (torch.sum(mask) + 1e-6) + return loss + + +class FlowLoss(nn.Module): + r"""Flow loss constructor. + + Args: + cfg (obj): Configuration. + """ + + def __init__(self, cfg): + super(FlowLoss, self).__init__() + self.cfg = cfg + self.data_cfg = cfg.data + self.criterion = nn.L1Loss() + self.criterionMasked = MaskedL1Loss() + flow_module = importlib.import_module(cfg.flow_network.type) + self.flowNet = flow_module.FlowNet(pretrained=True) + self.warp_ref = getattr(cfg.gen.flow, 'warp_ref', False) + self.pose_cfg = pose_cfg = getattr(cfg.data, 'for_pose_dataset', None) + self.for_pose_dataset = pose_cfg is not None + self.has_fg = getattr(cfg.data, 'has_foreground', False) + + def forward(self, data, net_G_output, current_epoch): + r"""Compute losses on the output flow and occlusion mask. + + Args: + data (dict): Input data. + net_G_output (dict): Generator output. + current_epoch (int): Current training epoch number. + Returns: + (dict): + - loss_flow_L1 (tensor): L1 loss compared to ground truth flow. + - loss_flow_warp (tensor): L1 loss between the warped image and the + target image when using the flow to warp. + - loss_mask (tensor): Loss for the occlusion mask. + """ + tgt_label, tgt_image = data['label'], data['image'] + + fake_image = net_G_output['fake_images'] + warped_images = net_G_output['warped_images'] + flow = net_G_output['fake_flow_maps'] + occ_mask = net_G_output['fake_occlusion_masks'] + + if self.warp_ref: + # Pick the most similar reference image to warp. + ref_labels, ref_images = data['ref_labels'], data['ref_images'] + ref_idx = net_G_output['ref_idx'] + ref_label, ref_image = pick_image([ref_labels, ref_images], ref_idx) + else: + ref_label = ref_image = None + + # Compute the ground truth flows and confidence maps. + flow_gt_prev = flow_gt_ref = conf_gt_prev = conf_gt_ref = None + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if self.warp_ref: + # Compute GT for warping reference -> target. + if self.for_pose_dataset: + # Use DensePose maps to compute flows for pose dataset. + flow_gt_ref, conf_gt_ref = self.flowNet(tgt_label[:, :3], + ref_label[:, :3]) + else: + # Use RGB images for other datasets. + flow_gt_ref, conf_gt_ref = self.flowNet(tgt_image, + ref_image) + + if current_epoch >= self.cfg.single_frame_epoch and \ + data['real_prev_image'] is not None: + # Compute GT for warping previous -> target. + tgt_image_prev = data['real_prev_image'] + flow_gt_prev, conf_gt_prev = self.flowNet(tgt_image, + tgt_image_prev) + + flow_gt = [flow_gt_ref, flow_gt_prev] + flow_conf_gt = [conf_gt_ref, conf_gt_prev] + # Get the foreground masks. + fg_mask, ref_fg_mask = get_fg_mask([tgt_label, ref_label], self.has_fg) + + # Compute losses for flow maps and masks. + loss_flow_L1, loss_flow_warp, body_mask_diff = \ + self.compute_flow_losses(flow, warped_images, tgt_image, flow_gt, + flow_conf_gt, fg_mask, tgt_label, + ref_label) + + loss_mask = self.compute_mask_losses( + occ_mask, fake_image, warped_images, tgt_label, tgt_image, + fg_mask, ref_fg_mask, body_mask_diff) + + return loss_flow_L1, loss_flow_warp, loss_mask + + def compute_flow_losses(self, flow, warped_images, tgt_image, flow_gt, + flow_conf_gt, fg_mask, tgt_label, ref_label): + r"""Compute losses on the generated flow maps. + + Args: + flow (tensor or list of tensors): Generated flow maps. + warped_images (tensor or list of tensors): Warped images using the + flow maps. + tgt_image (tensor): Target image for the warped image. + flow_gt (tensor or list of tensors): Ground truth flow maps. + flow_conf_gt (tensor or list of tensors): Confidence for the ground + truth flow maps. + fg_mask (tensor): Foreground mask for the target image. + tgt_label (tensor): Target label map. + ref_label (tensor): Reference label map. + Returns: + (dict): + - loss_flow_L1 (tensor): L1 loss compared to ground truth flow. + - loss_flow_warp (tensor): L1 loss between the warped image and the + target image when using the flow to warp. + - body_mask_diff (tensor): Difference between warped body part map + and target body part map. Used for pose dataset only. + """ + loss_flow_L1 = torch.tensor(0., device=torch.device('cuda')) + loss_flow_warp = torch.tensor(0., device=torch.device('cuda')) + if isinstance(flow, list): + # Compute flow losses for both warping reference -> target and + # previous -> target. + for i in range(len(flow)): + loss_flow_L1_i, loss_flow_warp_i = \ + self.compute_flow_loss(flow[i], warped_images[i], tgt_image, + flow_gt[i], flow_conf_gt[i], fg_mask) + loss_flow_L1 += loss_flow_L1_i + loss_flow_warp += loss_flow_warp_i + else: + # Compute loss for warping either reference or previous images. + loss_flow_L1, loss_flow_warp = \ + self.compute_flow_loss(flow, warped_images, tgt_image, + flow_gt[-1], flow_conf_gt[-1], fg_mask) + + # For pose dataset only. + body_mask_diff = None + if self.warp_ref: + if self.for_pose_dataset: + # Warped reference body part map should be similar to target + # body part map. + body_mask = get_part_mask(tgt_label[:, 2]) + ref_body_mask = get_part_mask(ref_label[:, 2]) + warped_ref_body_mask = resample(ref_body_mask, flow[0]) + loss_flow_warp += self.criterion(warped_ref_body_mask, + body_mask) + body_mask_diff = torch.sum( + abs(warped_ref_body_mask - body_mask), dim=1, keepdim=True) + + if self.has_fg: + # Warped reference foreground map should be similar to target + # foreground map. + fg_mask, ref_fg_mask = \ + get_fg_mask([tgt_label, ref_label], True) + warped_ref_fg_mask = resample(ref_fg_mask, flow[0]) + loss_flow_warp += self.criterion(warped_ref_fg_mask, fg_mask) + + return loss_flow_L1, loss_flow_warp, body_mask_diff + + def compute_flow_loss(self, flow, warped_image, tgt_image, flow_gt, + flow_conf_gt, fg_mask): + r"""Compute losses on the generated flow map. + + Args: + flow (tensor): Generated flow map. + warped_image (tensor): Warped image using the flow map. + tgt_image (tensor): Target image for the warped image. + flow_gt (tensor): Ground truth flow map. + flow_conf_gt (tensor): Confidence for the ground truth flow map. + fg_mask (tensor): Foreground mask for the target image. + Returns: + (dict): + - loss_flow_L1 (tensor): L1 loss compared to ground truth flow. + - loss_flow_warp (tensor): L1 loss between the warped image and + the target image when using the flow to warp. + """ + loss_flow_L1 = torch.tensor(0., device=torch.device('cuda')) + loss_flow_warp = torch.tensor(0., device=torch.device('cuda')) + if flow is not None and flow_gt is not None: + # L1 loss compared to flow ground truth. + loss_flow_L1 = self.criterionMasked(flow, flow_gt, + flow_conf_gt * fg_mask) + if warped_image is not None: + # L1 loss between warped image and target image. + loss_flow_warp = self.criterion(warped_image, tgt_image) + return loss_flow_L1, loss_flow_warp + + def compute_mask_losses(self, occ_mask, fake_image, warped_image, + tgt_label, tgt_image, fg_mask, ref_fg_mask, + body_mask_diff): + r"""Compute losses on the generated occlusion masks. + + Args: + occ_mask (tensor or list of tensors): Generated occlusion masks. + fake_image (tensor): Generated image. + warped_image (tensor or list of tensors): Warped images using the + flow maps. + tgt_label (tensor): Target label map. + tgt_image (tensor): Target image for the warped image. + fg_mask (tensor): Foreground mask for the target image. + ref_fg_mask (tensor): Foreground mask for the reference image. + body_mask_diff (tensor): Difference between warped body part map + and target body part map. Used for pose dataset only. + Returns: + (tensor): Loss for the mask. + """ + loss_mask = torch.tensor(0., device=torch.device('cuda')) + + if isinstance(occ_mask, list): + # Compute occlusion mask losses for both warping reference -> target + # and previous -> target. + for i in range(len(occ_mask)): + loss_mask += self.compute_mask_loss(occ_mask[i], + warped_image[i], + tgt_image) + else: + # Compute loss for warping either reference or previous images. + loss_mask += self.compute_mask_loss(occ_mask, warped_image, + tgt_image) + + if self.warp_ref: + ref_occ_mask = occ_mask[0] + dummy0 = torch.zeros_like(ref_occ_mask) + dummy1 = torch.ones_like(ref_occ_mask) + if self.for_pose_dataset: + # Enforce output to use more warped reference image for + # face region. + face_mask = get_face_mask(tgt_label[:, 2]).unsqueeze(1) + AvgPool = torch.nn.AvgPool2d(15, padding=7, stride=1) + face_mask = AvgPool(face_mask) + loss_mask += self.criterionMasked(ref_occ_mask, dummy0, + face_mask) + loss_mask += self.criterionMasked(fake_image, warped_image[0], + face_mask) + # Enforce output to use more hallucinated image for discrepancy + # regions of body part masks between warped reference and + # target image. + loss_mask += self.criterionMasked(ref_occ_mask, dummy1, + body_mask_diff) + + if self.has_fg: + # Enforce output to use more hallucinated image for discrepancy + # regions of foreground masks between reference and target + # image. + fg_mask_diff = ((ref_fg_mask - fg_mask) > 0).float() + loss_mask += self.criterionMasked(ref_occ_mask, dummy1, + fg_mask_diff) + return loss_mask + + def compute_mask_loss(self, occ_mask, warped_image, tgt_image): + r"""Compute losses on the generated occlusion mask. + + Args: + occ_mask (tensor): Generated occlusion mask. + warped_image (tensor): Warped image using the flow map. + tgt_image (tensor): Target image for the warped image. + Returns: + (tensor): Loss for the mask. + """ + loss_mask = torch.tensor(0., device=torch.device('cuda')) + if occ_mask is not None: + dummy0 = torch.zeros_like(occ_mask) + dummy1 = torch.ones_like(occ_mask) + + # Compute the confidence map based on L1 distance between warped + # and GT image. + img_diff = torch.sum(abs(warped_image - tgt_image), dim=1, + keepdim=True) + conf = torch.clamp(1 - img_diff, 0, 1) + + # Force mask value to be small if warped image is similar to GT, + # and vice versa. + loss_mask = self.criterionMasked(occ_mask, dummy0, conf) + loss_mask += self.criterionMasked(occ_mask, dummy1, 1 - conf) + + return loss_mask diff --git a/imaginaire/losses/gan.py b/imaginaire/losses/gan.py new file mode 100644 index 0000000000000000000000000000000000000000..aaa9c30dd51887b25b439b90fa728e94fe2b03a9 --- /dev/null +++ b/imaginaire/losses/gan.py @@ -0,0 +1,173 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from imaginaire.utils.distributed import master_only_print as print + + +@torch.jit.script +def fuse_math_min_mean_pos(x): + r"""Fuse operation min mean for hinge loss computation of positive + samples""" + minval = torch.min(x - 1, x * 0) + loss = -torch.mean(minval) + return loss + + +@torch.jit.script +def fuse_math_min_mean_neg(x): + r"""Fuse operation min mean for hinge loss computation of negative + samples""" + minval = torch.min(-x - 1, x * 0) + loss = -torch.mean(minval) + return loss + + +class GANLoss(nn.Module): + r"""GAN loss constructor. + + Args: + gan_mode (str): Type of GAN loss. ``'hinge'``, ``'least_square'``, + ``'non_saturated'``, ``'wasserstein'``. + target_real_label (float): The desired output label for real images. + target_fake_label (float): The desired output label for fake images. + decay_k (float): The decay factor per epoch for top-k training. + min_k (float): The minimum percentage of samples to select. + separate_topk (bool): If ``True``, selects top-k for each sample + separately, otherwise selects top-k among all samples. + """ + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, + decay_k=1., min_k=1., separate_topk=False): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_tensor = None + self.fake_label_tensor = None + self.gan_mode = gan_mode + self.decay_k = decay_k + self.min_k = min_k + self.separate_topk = separate_topk + self.register_buffer('k', torch.tensor(1.0)) + print('GAN mode: %s' % gan_mode) + + def forward(self, dis_output, t_real, dis_update=True, reduce=True): + r"""GAN loss computation. + + Args: + dis_output (tensor or list of tensors): Discriminator outputs. + t_real (bool): If ``True``, uses the real label as target, otherwise uses the fake label as target. + dis_update (bool): If ``True``, the loss will be used to update the discriminator, otherwise the generator. + reduce (bool): If ``True``, when a list of discriminator outputs are provided, it will return the average + of all losses, otherwise it will return a list of losses. + Returns: + loss (tensor): Loss value. + """ + if isinstance(dis_output, list): + # For multi-scale discriminators. + # In this implementation, the loss is first averaged for each scale + # (batch size and number of locations) then averaged across scales, + # so that the gradient is not dominated by the discriminator that + # has the most output values (highest resolution). + losses = [] + for dis_output_i in dis_output: + assert isinstance(dis_output_i, torch.Tensor) + losses.append(self.loss(dis_output_i, t_real, dis_update)) + if reduce: + return torch.mean(torch.stack(losses)) + else: + return losses + else: + return self.loss(dis_output, t_real, dis_update) + + def loss(self, dis_output, t_real, dis_update=True): + r"""GAN loss computation. + + Args: + dis_output (tensor): Discriminator outputs. + t_real (bool): If ``True``, uses the real label as target, otherwise + uses the fake label as target. + dis_update (bool): Updating the discriminator or the generator. + Returns: + loss (tensor): Loss value. + """ + if not dis_update: + assert t_real, \ + "The target should be real when updating the generator." + + if not dis_update and self.k < 1: + r""" + Use top-k training: + "Top-k Training of GANs: Improving GAN Performance by Throwing + Away Bad Samples" + Here, each sample may have multiple discriminator output values + (patch discriminator). We could either select top-k for each sample + separately (when ``self.separate_topk=True``), or collect values + from all samples and then select top-k (default, when + ``self.separate_topk=False``). + """ + if self.separate_topk: + dis_output = dis_output.view(dis_output.size(0), -1) + else: + dis_output = dis_output.view(-1) + k = math.ceil(self.k * dis_output.size(-1)) + dis_output, _ = torch.topk(dis_output, k) + + if self.gan_mode == 'non_saturated': + target_tensor = self.get_target_tensor(dis_output, t_real) + loss = F.binary_cross_entropy_with_logits(dis_output, + target_tensor) + elif self.gan_mode == 'least_square': + target_tensor = self.get_target_tensor(dis_output, t_real) + loss = 0.5 * F.mse_loss(dis_output, target_tensor) + elif self.gan_mode == 'hinge': + if dis_update: + if t_real: + loss = fuse_math_min_mean_pos(dis_output) + else: + loss = fuse_math_min_mean_neg(dis_output) + else: + loss = -torch.mean(dis_output) + elif self.gan_mode == 'wasserstein': + if t_real: + loss = -torch.mean(dis_output) + else: + loss = torch.mean(dis_output) + elif self.gan_mode == 'softplus': + target_tensor = self.get_target_tensor(dis_output, t_real) + loss = F.binary_cross_entropy_with_logits(dis_output, + target_tensor) + else: + raise ValueError('Unexpected gan_mode {}'.format(self.gan_mode)) + return loss + + def get_target_tensor(self, dis_output, t_real): + r"""Return the target vector for the binary cross entropy loss + computation. + + Args: + dis_output (tensor): Discriminator outputs. + t_real (bool): If ``True``, uses the real label as target, otherwise + uses the fake label as target. + Returns: + target (tensor): Target tensor vector. + """ + if t_real: + if self.real_label_tensor is None: + self.real_label_tensor = dis_output.new_tensor(self.real_label) + return self.real_label_tensor.expand_as(dis_output) + else: + if self.fake_label_tensor is None: + self.fake_label_tensor = dis_output.new_tensor(self.fake_label) + return self.fake_label_tensor.expand_as(dis_output) + + def topk_anneal(self): + r"""Anneal k after each epoch.""" + if self.decay_k < 1: + # noinspection PyAttributeOutsideInit + self.k.fill_(max(self.decay_k * self.k, self.min_k)) + print("Top-k training: update k to {}.".format(self.k)) diff --git a/imaginaire/losses/info_nce.py b/imaginaire/losses/info_nce.py new file mode 100644 index 0000000000000000000000000000000000000000..8033e828f0b99d12d6e8f8b71811982d0ab568f6 --- /dev/null +++ b/imaginaire/losses/info_nce.py @@ -0,0 +1,87 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from imaginaire.utils.distributed import get_world_size, get_rank, \ + dist_all_reduce_tensor + + +class GatherLayer(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] + dist.all_gather(output, input) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + input, = ctx.saved_tensors + grad_out = torch.zeros_like(input) + all_grads = torch.stack(grads) + all_grads = dist_all_reduce_tensor(all_grads, reduce='sum') + grad_out[:] = all_grads[get_rank()] + return grad_out + + +class InfoNCELoss(nn.Module): + def __init__(self, + temperature=0.07, + gather_distributed=True, + learn_temperature=True, + single_direction=False, + flatten=True): + super(InfoNCELoss, self).__init__() + self.logit_scale = nn.Parameter(torch.tensor([math.log(1/temperature)])) + self.logit_scale.requires_grad = learn_temperature + self.gather_distributed = gather_distributed + self.single_direction = single_direction + self.flatten = flatten + + def forward(self, features_a, features_b, gather_distributed=None, eps=1e-8): + if gather_distributed is None: + gather_distributed = self.gather_distributed + + if features_a is None or features_b is None: + return torch.tensor(0, device='cuda'), torch.tensor(0, device='cuda') + + bs_a, bs_b = features_a.size(0), features_b.size(0) + if self.flatten: + features_a, features_b = features_a.reshape(bs_a, -1), features_b.reshape(bs_b, -1) + else: + features_a = features_a.reshape(bs_a, features_a.size(1), -1).mean(-1) + features_b = features_b.reshape(bs_b, features_b.size(1), -1).mean(-1) + + # Temperature clipping. + self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 4.6052) + + # normalized features + features_a = features_a / (features_a.norm(dim=1, keepdim=True) + eps) + features_b = features_b / (features_b.norm(dim=1, keepdim=True) + eps) + + loss_a = self._forward_single_direction(features_a, features_b, gather_distributed) + if self.single_direction: + return loss_a + else: + loss_b = self._forward_single_direction(features_b, features_a, gather_distributed) + return loss_a + loss_b + + def _forward_single_direction( + self, features_a, features_b, gather_distributed): + bs_a = features_a.shape[0] + logit_scale = self.logit_scale.exp() + if get_world_size() > 1 and gather_distributed: + gather_features_b = torch.cat(GatherLayer.apply(features_b)) + gather_labels_a = torch.arange(bs_a, device='cuda') + get_rank() * bs_a + logits_a = logit_scale * features_a @ gather_features_b.t() + else: + gather_labels_a = torch.arange(bs_a, device='cuda') + logits_a = logit_scale * features_a @ features_b.t() + loss_a = F.cross_entropy(logits_a, gather_labels_a) + return loss_a diff --git a/imaginaire/losses/kl.py b/imaginaire/losses/kl.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9a2da14b06ccfef143312acd20ccd3784bdb34 --- /dev/null +++ b/imaginaire/losses/kl.py @@ -0,0 +1,22 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn as nn + +class GaussianKLLoss(nn.Module): + r"""Compute KL loss in VAE for Gaussian distributions""" + def __init__(self): + super(GaussianKLLoss, self).__init__() + + def forward(self, mu, logvar=None): + r"""Compute loss + + Args: + mu (tensor): mean + logvar (tensor): logarithm of variance + """ + if logvar is None: + logvar = torch.zeros_like(mu) + return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) diff --git a/imaginaire/losses/perceptual.py b/imaginaire/losses/perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..424656fa09b65333e4fa28cae2de7114de69ebfa --- /dev/null +++ b/imaginaire/losses/perceptual.py @@ -0,0 +1,395 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn.functional as F +import torchvision +from torch import nn, distributed as dist + +from imaginaire.losses.info_nce import InfoNCELoss +from imaginaire.utils.distributed import master_only_print as print, \ + is_local_master +from imaginaire.utils.misc import apply_imagenet_normalization, to_float + + +class PerceptualLoss(nn.Module): + r"""Perceptual loss initialization. + + Args: + network (str) : The name of the loss network: 'vgg16' | 'vgg19'. + layers (str or list of str) : The layers used to compute the loss. + weights (float or list of float : The loss weights of each layer. + criterion (str): The type of distance function: 'l1' | 'l2'. + resize (bool) : If ``True``, resize the input images to 224x224. + resize_mode (str): Algorithm used for resizing. + num_scales (int): The loss will be evaluated at original size and + this many times downsampled sizes. + per_sample_weight (bool): Output loss for individual samples in the + batch instead of mean loss. + """ + + def __init__(self, network='vgg19', layers='relu_4_1', weights=None, + criterion='l1', resize=False, resize_mode='bilinear', + num_scales=1, per_sample_weight=False, + info_nce_temperature=0.07, + info_nce_gather_distributed=True, + info_nce_learn_temperature=True, + info_nce_flatten=True): + super().__init__() + if isinstance(layers, str): + layers = [layers] + if weights is None: + weights = [1.] * len(layers) + elif isinstance(layers, float) or isinstance(layers, int): + weights = [weights] + + if dist.is_initialized() and not is_local_master(): + # Make sure only the first process in distributed training downloads + # the model, and the others will use the cache + # noinspection PyUnresolvedReferences + torch.distributed.barrier() + + assert len(layers) == len(weights), \ + 'The number of layers (%s) must be equal to ' \ + 'the number of weights (%s).' % (len(layers), len(weights)) + if network == 'vgg19': + self.model = _vgg19(layers) + elif network == 'vgg16': + self.model = _vgg16(layers) + elif network == 'alexnet': + self.model = _alexnet(layers) + elif network == 'inception_v3': + self.model = _inception_v3(layers) + elif network == 'resnet50': + self.model = _resnet50(layers) + elif network == 'robust_resnet50': + self.model = _robust_resnet50(layers) + elif network == 'vgg_face_dag': + self.model = _vgg_face_dag(layers) + else: + raise ValueError('Network %s is not recognized' % network) + + if dist.is_initialized() and is_local_master(): + # Make sure only the first process in distributed training downloads + # the model, and the others will use the cache + # noinspection PyUnresolvedReferences + torch.distributed.barrier() + + self.num_scales = num_scales + self.layers = layers + self.weights = weights + reduction = 'mean' if not per_sample_weight else 'none' + if criterion == 'l1': + self.criterion = nn.L1Loss(reduction=reduction) + elif criterion == 'l2' or criterion == 'mse': + self.criterion = nn.MSELoss(reduction=reduction) + elif criterion == 'info_nce': + self.criterion = InfoNCELoss( + temperature=info_nce_temperature, + gather_distributed=info_nce_gather_distributed, + learn_temperature=info_nce_learn_temperature, + flatten=info_nce_flatten, + single_direction=True + ) + else: + raise ValueError('Criterion %s is not recognized' % criterion) + self.resize = resize + self.resize_mode = resize_mode + print('Perceptual loss:') + print('\tMode: {}'.format(network)) + + def forward(self, inp, target, per_sample_weights=None): + r"""Perceptual loss forward. + + Args: + inp (4D tensor) : Input tensor. + target (4D tensor) : Ground truth tensor, same shape as the input. + per_sample_weight (bool): Output loss for individual samples in the + batch instead of mean loss. + Returns: + (scalar tensor) : The perceptual loss. + """ + if not torch.is_autocast_enabled(): + inp, target = to_float([inp, target]) + + # Perceptual loss should operate in eval mode by default. + self.model.eval() + inp, target = apply_imagenet_normalization(inp), apply_imagenet_normalization(target) + if self.resize: + inp = F.interpolate(inp, mode=self.resize_mode, size=(224, 224), align_corners=False) + target = F.interpolate(target, mode=self.resize_mode, size=(224, 224), align_corners=False) + + # Evaluate perceptual loss at each scale. + loss = 0 + for scale in range(self.num_scales): + input_features, target_features = self.model(inp), self.model(target) + + for layer, weight in zip(self.layers, self.weights): + # Example per-layer VGG19 loss values after applying + # [0.03125, 0.0625, 0.125, 0.25, 1.0] weighting. + # relu_1_1, 0.014698 + # relu_2_1, 0.085817 + # relu_3_1, 0.349977 + # relu_4_1, 0.544188 + # relu_5_1, 0.906261 + # print('%s, %f' % ( + # layer, + # weight * self.criterion( + # input_features[layer], + # target_features[ + # layer].detach()).item())) + l_tmp = self.criterion(input_features[layer], target_features[layer].detach()) + if per_sample_weights is not None: + l_tmp = l_tmp.mean(1).mean(1).mean(1) + loss += weight * l_tmp + # Downsample the input and target. + if scale != self.num_scales - 1: + inp = F.interpolate( + inp, mode=self.resize_mode, scale_factor=0.5, + align_corners=False, recompute_scale_factor=True) + target = F.interpolate( + target, mode=self.resize_mode, scale_factor=0.5, + align_corners=False, recompute_scale_factor=True) + + return loss.float() + + +class _PerceptualNetwork(nn.Module): + r"""The network that extracts features to compute the perceptual loss. + + Args: + network (nn.Sequential) : The network that extracts features. + layer_name_mapping (dict) : The dictionary that + maps a layer's index to its name. + layers (list of str): The list of layer names that we are using. + """ + + def __init__(self, network, layer_name_mapping, layers): + super().__init__() + assert isinstance(network, nn.Sequential), \ + 'The network needs to be of type "nn.Sequential".' + self.network = network + self.layer_name_mapping = layer_name_mapping + self.layers = layers + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + r"""Extract perceptual features.""" + output = {} + for i, layer in enumerate(self.network): + x = layer(x) + layer_name = self.layer_name_mapping.get(i, None) + if layer_name in self.layers: + # If the current layer is used by the perceptual loss. + output[layer_name] = x + return output + + +def _vgg19(layers): + r"""Get vgg19 layers""" + vgg = torchvision.models.vgg19(pretrained=True) + # network = vgg.features + network = torch.nn.Sequential(*(list(vgg.features) + [vgg.avgpool] + [nn.Flatten()] + list(vgg.classifier))) + layer_name_mapping = {1: 'relu_1_1', + 3: 'relu_1_2', + 6: 'relu_2_1', + 8: 'relu_2_2', + 11: 'relu_3_1', + 13: 'relu_3_2', + 15: 'relu_3_3', + 17: 'relu_3_4', + 20: 'relu_4_1', + 22: 'relu_4_2', + 24: 'relu_4_3', + 26: 'relu_4_4', + 29: 'relu_5_1', + 31: 'relu_5_2', + 33: 'relu_5_3', + 35: 'relu_5_4', + 36: 'pool_5', + 42: 'fc_2'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _vgg16(layers): + r"""Get vgg16 layers""" + network = torchvision.models.vgg16(pretrained=True).features + layer_name_mapping = {1: 'relu_1_1', + 3: 'relu_1_2', + 6: 'relu_2_1', + 8: 'relu_2_2', + 11: 'relu_3_1', + 13: 'relu_3_2', + 15: 'relu_3_3', + 18: 'relu_4_1', + 20: 'relu_4_2', + 22: 'relu_4_3', + 25: 'relu_5_1'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _alexnet(layers): + r"""Get alexnet layers""" + network = torchvision.models.alexnet(pretrained=True).features + layer_name_mapping = {0: 'conv_1', + 1: 'relu_1', + 3: 'conv_2', + 4: 'relu_2', + 6: 'conv_3', + 7: 'relu_3', + 8: 'conv_4', + 9: 'relu_4', + 10: 'conv_5', + 11: 'relu_5'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _inception_v3(layers): + r"""Get inception v3 layers""" + inception = torchvision.models.inception_v3(pretrained=True) + network = nn.Sequential(inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2), + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2), + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1))) + layer_name_mapping = {3: 'pool_1', + 6: 'pool_2', + 14: 'mixed_6e', + 18: 'pool_3'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _resnet50(layers): + r"""Get resnet50 layers""" + resnet50 = torchvision.models.resnet50(pretrained=True) + network = nn.Sequential(resnet50.conv1, + resnet50.bn1, + resnet50.relu, + resnet50.maxpool, + resnet50.layer1, + resnet50.layer2, + resnet50.layer3, + resnet50.layer4, + resnet50.avgpool) + layer_name_mapping = {4: 'layer_1', + 5: 'layer_2', + 6: 'layer_3', + 7: 'layer_4'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _robust_resnet50(layers): + r"""Get robust resnet50 layers""" + resnet50 = torchvision.models.resnet50(pretrained=False) + state_dict = torch.utils.model_zoo.load_url( + 'http://andrewilyas.com/ImageNet.pt') + new_state_dict = {} + for k, v in state_dict['model'].items(): + if k.startswith('module.model.'): + new_state_dict[k[13:]] = v + resnet50.load_state_dict(new_state_dict) + network = nn.Sequential(resnet50.conv1, + resnet50.bn1, + resnet50.relu, + resnet50.maxpool, + resnet50.layer1, + resnet50.layer2, + resnet50.layer3, + resnet50.layer4, + resnet50.avgpool) + layer_name_mapping = {4: 'layer_1', + 5: 'layer_2', + 6: 'layer_3', + 7: 'layer_4'} + return _PerceptualNetwork(network, layer_name_mapping, layers) + + +def _vgg_face_dag(layers): + network = torchvision.models.vgg16(num_classes=2622) + state_dict = torch.utils.model_zoo.load_url( + 'http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/' + 'vgg_face_dag.pth') + feature_layer_name_mapping = { + 0: 'conv1_1', + 2: 'conv1_2', + 5: 'conv2_1', + 7: 'conv2_2', + 10: 'conv3_1', + 12: 'conv3_2', + 14: 'conv3_3', + 17: 'conv4_1', + 19: 'conv4_2', + 21: 'conv4_3', + 24: 'conv5_1', + 26: 'conv5_2', + 28: 'conv5_3'} + new_state_dict = {} + for k, v in feature_layer_name_mapping.items(): + new_state_dict['features.' + str(k) + '.weight'] = \ + state_dict[v + '.weight'] + new_state_dict['features.' + str(k) + '.bias'] = \ + state_dict[v + '.bias'] + + classifier_layer_name_mapping = { + 0: 'fc6', + 3: 'fc7', + 6: 'fc8'} + for k, v in classifier_layer_name_mapping.items(): + new_state_dict['classifier.' + str(k) + '.weight'] = \ + state_dict[v + '.weight'] + new_state_dict['classifier.' + str(k) + '.bias'] = \ + state_dict[v + '.bias'] + + network.load_state_dict(new_state_dict) + + class Flatten(nn.Module): + def forward(self, x): + return x.view(x.shape[0], -1) + + layer_name_mapping = { + 0: 'conv_1_1', + 1: 'relu_1_1', + 2: 'conv_1_2', + 5: 'conv_2_1', # 1/2 + 6: 'relu_2_1', + 7: 'conv_2_2', + 10: 'conv_3_1', # 1/4 + 11: 'relu_3_1', + 12: 'conv_3_2', + 14: 'conv_3_3', + 17: 'conv_4_1', # 1/8 + 18: 'relu_4_1', + 19: 'conv_4_2', + 21: 'conv_4_3', + 24: 'conv_5_1', # 1/16 + 25: 'relu_5_1', + 26: 'conv_5_2', + 28: 'conv_5_3', + 33: 'fc6', + 36: 'fc7', + 39: 'fc8' + } + seq_layers = [] + for feature in network.features: + seq_layers += [feature] + seq_layers += [network.avgpool, Flatten()] + for classifier in network.classifier: + seq_layers += [classifier] + network = nn.Sequential(*seq_layers) + return _PerceptualNetwork(network, layer_name_mapping, layers) diff --git a/imaginaire/losses/weighted_mse.py b/imaginaire/losses/weighted_mse.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e49989a5c3ee8576dcf4dea8a98a16c1911cc9 --- /dev/null +++ b/imaginaire/losses/weighted_mse.py @@ -0,0 +1,28 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn as nn + + +class WeightedMSELoss(nn.Module): + r"""Compute Weighted MSE loss""" + def __init__(self, reduction='mean'): + super(WeightedMSELoss, self).__init__() + self.reduction = reduction + + def forward(self, input, target, weight): + r"""Return weighted MSE Loss. + Args: + input (tensor): + target (tensor): + weight (tensor): + Returns: + (tensor): Loss value. + """ + if self.reduction == 'mean': + loss = torch.mean(weight * (input - target) ** 2) + else: + loss = torch.sum(weight * (input - target) ** 2) + return loss diff --git a/imaginaire/model_utils/__init__.py b/imaginaire/model_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/model_utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/model_utils/__pycache__/__init__.cpython-38.pyc b/imaginaire/model_utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c45aa8fec0f4180fe7efd93161a67e545e3fb71 Binary files /dev/null and b/imaginaire/model_utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/imaginaire/model_utils/__pycache__/fs_vid2vid.cpython-38.pyc b/imaginaire/model_utils/__pycache__/fs_vid2vid.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa52fd3e68a006d480e174ee47d5066eaf79b98f Binary files /dev/null and b/imaginaire/model_utils/__pycache__/fs_vid2vid.cpython-38.pyc differ diff --git a/imaginaire/model_utils/fs_vid2vid.py b/imaginaire/model_utils/fs_vid2vid.py new file mode 100644 index 0000000000000000000000000000000000000000..b52faf73d3f37221e9d0f089d1f4a85d5378877d --- /dev/null +++ b/imaginaire/model_utils/fs_vid2vid.py @@ -0,0 +1,865 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +"""Utils for the few shot vid2vid model.""" +import random +import numpy as np +import torch +import torch.nn.functional as F + + +def resample(image, flow): + r"""Resamples an image using the provided flow. + + Args: + image (NxCxHxW tensor) : Image to resample. + flow (Nx2xHxW tensor) : Optical flow to resample the image. + Returns: + output (NxCxHxW tensor) : Resampled image. + """ + assert flow.shape[1] == 2 + b, c, h, w = image.size() + grid = get_grid(b, (h, w)) + flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), + flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) + final_grid = (grid + flow).permute(0, 2, 3, 1) + try: + output = F.grid_sample(image, final_grid, mode='bilinear', + padding_mode='border', align_corners=True) + except Exception: + output = F.grid_sample(image, final_grid, mode='bilinear', + padding_mode='border') + return output + + +def get_grid(batchsize, size, minval=-1.0, maxval=1.0): + r"""Get a grid ranging [-1, 1] of 2D/3D coordinates. + + Args: + batchsize (int) : Batch size. + size (tuple) : (height, width) or (depth, height, width). + minval (float) : minimum value in returned grid. + maxval (float) : maximum value in returned grid. + Returns: + t_grid (4D tensor) : Grid of coordinates. + """ + if len(size) == 2: + rows, cols = size + elif len(size) == 3: + deps, rows, cols = size + else: + raise ValueError('Dimension can only be 2 or 3.') + x = torch.linspace(minval, maxval, cols) + x = x.view(1, 1, 1, cols) + x = x.expand(batchsize, 1, rows, cols) + + y = torch.linspace(minval, maxval, rows) + y = y.view(1, 1, rows, 1) + y = y.expand(batchsize, 1, rows, cols) + + t_grid = torch.cat([x, y], dim=1) + + if len(size) == 3: + z = torch.linspace(minval, maxval, deps) + z = z.view(1, 1, deps, 1, 1) + z = z.expand(batchsize, 1, deps, rows, cols) + + t_grid = t_grid.unsqueeze(2).expand(batchsize, 2, deps, rows, cols) + t_grid = torch.cat([t_grid, z], dim=1) + + t_grid.requires_grad = False + return t_grid.to('cuda') + + +def pick_image(images, idx): + r"""Pick the image among images according to idx. + + Args: + images (B x N x C x H x W tensor or list of tensors) : N images. + idx (B tensor) : indices to select. + Returns: + image (B x C x H x W) : Selected images. + """ + if type(images) == list: + return [pick_image(r, idx) for r in images] + if idx is None: + return images[:, 0] + elif type(idx) == int: + return images[:, idx] + idx = idx.long().view(-1, 1, 1, 1, 1) + image = images.gather(1, idx.expand_as(images)[:, 0:1])[:, 0] + return image + + +def crop_face_from_data(cfg, is_inference, data): + r"""Crop the face regions in input data and resize to the target size. + This is for training face datasets. + + Args: + cfg (obj): Data configuration. + is_inference (bool): Is doing inference or not. + data (dict): Input data. + Returns: + data (dict): Cropped data. + """ + label = data['label'] if 'label' in data else None + image = data['images'] + landmarks = data['landmarks-dlib68_xy'] + ref_labels = data['few_shot_label'] if 'few_shot_label' in data else None + ref_images = data['few_shot_images'] + ref_landmarks = data['few_shot_landmarks-dlib68_xy'] + img_size = image.shape[-2:] + h, w = cfg.output_h_w.split(',') + h, w = int(h), int(w) + + # When doing inference, need to sync common attributes like crop coodinates + # between different workers, so all workers crop the same region. + if 'common_attr' in data and 'crop_coords' in data['common_attr']: + # Has been computed before, reusing the previous one. + crop_coords, ref_crop_coords = data['common_attr']['crop_coords'] + else: + # Is the first frame, need to compute the bbox. + ref_crop_coords, scale = get_face_bbox_for_data( + ref_landmarks[0], img_size, None, is_inference) + crop_coords, _ = get_face_bbox_for_data( + landmarks[0], img_size, scale, is_inference) + + # Crop the images according to the bbox and resize them to target size. + label, image = crop_and_resize([label, image], crop_coords, (h, w)) + ref_labels, ref_images = crop_and_resize([ref_labels, ref_images], + ref_crop_coords, (h, w)) + + data['images'], data['few_shot_images'] = image, ref_images + if label is not None: + data['label'], data['few_shot_label'] = label, ref_labels + if is_inference: + if 'common_attr' not in data: + data['common_attr'] = dict() + data['common_attr']['crop_coords'] = crop_coords, ref_crop_coords + return data + + +def get_face_bbox_for_data(keypoints, orig_img_size, scale, is_inference): + r"""Get the bbox coordinates for face region. + + Args: + keypoints (Nx2 tensor): Facial landmarks. + orig_img_size (int tuple): Height and width of the input image size. + scale (float): When training, randomly scale the crop size for + augmentation. + is_inference (bool): Is doing inference or not. + Returns: + crop_coords (list of int): bbox for face region. + scale (float): Also returns scale to ensure reference and target frames + are croppped using the same scale. + """ + min_y, max_y = int(keypoints[:, 1].min()), int(keypoints[:, 1].max()) + min_x, max_x = int(keypoints[:, 0].min()), int(keypoints[:, 0].max()) + x_cen, y_cen = (min_x + max_x) // 2, (min_y + max_y) // 2 + H, W = orig_img_size + w = h = (max_x - min_x) + if not is_inference: + # During training, randomly jitter the cropping position by offset + # amount for augmentation. + offset_max = 0.2 + offset = [np.random.uniform(-offset_max, offset_max), + np.random.uniform(-offset_max, offset_max)] + # Also augment the crop size. + if scale is None: + scale_max = 0.2 + scale = [np.random.uniform(1 - scale_max, 1 + scale_max), + np.random.uniform(1 - scale_max, 1 + scale_max)] + w *= scale[0] + h *= scale[1] + x_cen += int(offset[0] * w) + y_cen += int(offset[1] * h) + + # Get the cropping coordinates. + x_cen = max(w, min(W - w, x_cen)) + y_cen = max(h * 1.25, min(H - h * 0.75, y_cen)) + + min_x = x_cen - w + min_y = y_cen - h * 1.25 + max_x = min_x + w * 2 + max_y = min_y + h * 2 + + crop_coords = [min_y, max_y, min_x, max_x] + return [int(x) for x in crop_coords], scale + + +def crop_person_from_data(cfg, is_inference, data): + r"""Crop the person regions in data and resize to the target size. + This is for training full body datasets. + + Args: + cfg (obj): Data configuration. + is_inference (bool): Is doing inference or not. + data (dict): Input data. + Returns: + data (dict): Cropped data. + """ + label = data['label'] + image = data['images'] + use_few_shot = 'few_shot_label' in data + if use_few_shot: + ref_labels = data['few_shot_label'] + ref_images = data['few_shot_images'] + + img_size = image.shape[-2:] + output_h, output_w = cfg.output_h_w.split(',') + output_h, output_w = int(output_h), int(output_w) + output_aspect_ratio = output_w / output_h + + if 'human_instance_maps' in data: + # Remove other people in the DensePose map except for the current + # target. + label = remove_other_ppl(label, data['human_instance_maps']) + if use_few_shot: + ref_labels = remove_other_ppl(ref_labels, + data['few_shot_human_instance_maps']) + + # Randomly jitter the crop position by offset amount for augmentation. + offset = ref_offset = None + if not is_inference: + offset = np.random.randn(2) * 0.05 + offset = np.minimum(1, np.maximum(-1, offset)) + ref_offset = np.random.randn(2) * 0.02 + ref_offset = np.minimum(1, np.maximum(-1, ref_offset)) + + # Randomly scale the crop size for augmentation. + # Final cropped size = person height * scale. + scale = ref_scale = 1.5 + if not is_inference: + scale = min(2, max(1, scale + np.random.randn() * 0.05)) + ref_scale = min(2, max(1, ref_scale + np.random.randn() * 0.02)) + + # When doing inference, need to sync common attributes like crop coodinates + # between different workers, so all workers crop the same region. + if 'common_attr' in data: + # Has been computed before, reusing the previous one. + crop_coords, ref_crop_coords = data['common_attr']['crop_coords'] + else: + # Is the first frame, need to compute the bbox. + crop_coords = get_person_bbox_for_data(label, img_size, scale, + output_aspect_ratio, offset) + if use_few_shot: + ref_crop_coords = get_person_bbox_for_data( + ref_labels, img_size, ref_scale, + output_aspect_ratio, ref_offset) + else: + ref_crop_coords = None + + # Crop the images according to the bbox and resize them to target size. + label = crop_and_resize(label, crop_coords, (output_h, output_w), 'nearest') + image = crop_and_resize(image, crop_coords, (output_h, output_w)) + if use_few_shot: + ref_labels = crop_and_resize(ref_labels, ref_crop_coords, + (output_h, output_w), 'nearest') + ref_images = crop_and_resize(ref_images, ref_crop_coords, + (output_h, output_w)) + + data['label'], data['images'] = label, image + if use_few_shot: + data['few_shot_label'], data['few_shot_images'] = ref_labels, ref_images + if 'human_instance_maps' in data: + del data['human_instance_maps'] + if 'few_shot_human_instance_maps' in data: + del data['few_shot_human_instance_maps'] + if is_inference: + data['common_attr'] = dict() + data['common_attr']['crop_coords'] = crop_coords, ref_crop_coords + + return data + + +def get_person_bbox_for_data(pose_map, orig_img_size, scale=1.5, + crop_aspect_ratio=1, offset=None): + r"""Get the bbox (pixel coordinates) to crop for person body region. + + Args: + pose_map (NxCxHxW tensor): Input pose map. + orig_img_size (int tuple): Height and width of the input image size. + scale (float): When training, randomly scale the crop size for + augmentation. + crop_aspect_ratio (float): Output aspect ratio, + offset (list of float): Offset for crop position. + Returns: + crop_coords (list of int): bbox for body region. + """ + H, W = orig_img_size + assert pose_map.dim() == 4 + nonzero_indices = (pose_map[:, :3] > 0).nonzero(as_tuple=False) + if nonzero_indices.size(0) == 0: + bw = int(H * crop_aspect_ratio // 2) + return [0, H, W // 2 - bw, W // 2 + bw] + + y_indices, x_indices = nonzero_indices[:, 2], nonzero_indices[:, 3] + y_min, y_max = y_indices.min().item(), y_indices.max().item() + x_min, x_max = x_indices.min().item(), x_indices.max().item() + y_cen = int(y_min + y_max) // 2 + x_cen = int(x_min + x_max) // 2 + y_len = y_max - y_min + x_len = x_max - x_min + + # bh, bw: half of height / width of final cropped size. + bh = int(min(H, max(H // 2, y_len * scale))) // 2 + bh = max(bh, int(x_len * scale / crop_aspect_ratio) // 2) + bw = int(bh * crop_aspect_ratio) + + # Randomly offset the cropped position for augmentation. + if offset is not None: + x_cen += int(offset[0] * bw) + y_cen += int(offset[1] * bh) + x_cen = max(bw, min(W - bw, x_cen)) + y_cen = max(bh, min(H - bh, y_cen)) + + return [(y_cen - bh), (y_cen + bh), (x_cen - bw), (x_cen + bw)] + + +def crop_and_resize(img, coords, size=None, method='bilinear'): + r"""Crop the image using the given coordinates and resize to target size. + + Args: + img (tensor or list of tensors): Input image. + coords (list of int): Pixel coordinates to crop. + size (list of int): Output size. + method (str): Interpolation method. + Returns: + img (tensor or list of tensors): Output image. + """ + if isinstance(img, list): + return [crop_and_resize(x, coords, size, method) for x in img] + if img is None: + return None + min_y, max_y, min_x, max_x = coords + + img = img[:, :, min_y:max_y, min_x:max_x] + if size is not None: + if method == 'nearest': + img = F.interpolate(img, size=size, mode=method) + else: + img = F.interpolate(img, size=size, mode=method, + align_corners=False) + return img + + +def remove_other_ppl(labels, densemasks): + r"""Remove other people in the label map except for the current target + by looking at the id in the densemask map. + + Args: + labels (NxCxHxW tensor): Input labels. + densemasks (Nx1xHxW tensor): Densemask maps. + Returns: + labels (NxCxHxW tensor): Output labels. + """ + densemasks = densemasks[:, 0:1] * 255 + for idx in range(labels.shape[0]): + label, densemask = labels[idx], densemasks[idx] + # Get OpenPose and find the person id in Densemask that has the most + # overlap with the person in OpenPose result. + openpose = label[3:] + valid = (openpose[0] > 0) | (openpose[1] > 0) | (openpose[2] > 0) + dp_valid = densemask[valid.unsqueeze(0)] + if dp_valid.shape[0]: + ind = np.bincount(dp_valid).argmax() + # Remove all other people that have different indices. + label = label * (densemask == ind).float() + labels[idx] = label + return labels + + +def select_object(data, obj_indices=None): + r"""Select the object/person in the dict according to the object index. + Currently it's used to select the target person in OpenPose dict. + + Args: + data (dict): Input data. + obj_indices (list of int): Indices for the objects to select. + Returns: + data (dict): Output data. + """ + op_keys = ['poses-openpose', 'captions-clip'] + for op_key in op_keys: + if op_key in data: + for i in range(len(data[op_key])): + # data[op_key] is a list of dicts for different frames. + # people = data[op_key][i]['people'] + people = data[op_key][i] + # "people" is a list of people dicts found by OpenPose. We will + # use the obj_index to get the target person from the list, and + # write it back to the dict. + # data[op_key][i]['people'] = [people[obj_indices[i]]] + if obj_indices is not None: + data[op_key][i] = people[obj_indices[i]] + else: + if op_key == 'poses-openpose': + data[op_key][i] = people[0] + else: + idx = random.randint(0, len(people) - 1) + data[op_key][i] = people[idx] + return data + + +def concat_frames(prev, now, n_frames): + r"""Concat previous and current frames and only keep the latest $(n_frames). + If concatenated frames are longer than $(n_frames), drop the oldest one. + + Args: + prev (NxTxCxHxW tensor): Tensor for previous frames. + now (NxCxHxW tensor): Tensor for current frame. + n_frames (int): Max number of frames to store. + Returns: + result (NxTxCxHxW tensor): Updated tensor. + """ + now = now.unsqueeze(1) + if prev is None: + return now + if prev.shape[1] == n_frames: + prev = prev[:, 1:] + return torch.cat([prev, now], dim=1) + + +def combine_fg_mask(fg_mask, ref_fg_mask, has_fg): + r"""Get the union of target and reference foreground masks. + Args: + fg_mask (tensor): Foreground mask for target image. + ref_fg_mask (tensor): Foreground mask for reference image. + has_fg (bool): Whether the image can be classified into fg/bg. + Returns: + output (tensor or int): Combined foreground mask. + """ + return ((fg_mask > 0) | (ref_fg_mask > 0)).float() if has_fg else 1 + + +def get_fg_mask(densepose_map, has_fg): + r"""Obtain the foreground mask for pose sequences, which only includes + the human. This is done by looking at the body part map from DensePose. + + Args: + densepose_map (NxCxHxW tensor): DensePose map. + has_fg (bool): Whether data has foreground or not. + Returns: + mask (Nx1xHxW tensor): fg mask. + """ + if type(densepose_map) == list: + return [get_fg_mask(label, has_fg) for label in densepose_map] + if not has_fg or densepose_map is None: + return 1 + if densepose_map.dim() == 5: + densepose_map = densepose_map[:, 0] + # Get the body part map from DensePose. + mask = densepose_map[:, 2:3] + + # Make the mask slightly larger. + mask = torch.nn.MaxPool2d(15, padding=7, stride=1)(mask) + mask = (mask > -1).float() + return mask + + +def get_part_mask(densepose_map): + r"""Obtain mask of different body parts of humans. This is done by + looking at the body part map from DensePose. + + Args: + densepose_map (NxCxHxW tensor): DensePose map. + Returns: + mask (NxKxHxW tensor): Body part mask, where K is the number of parts. + """ + # Groups of body parts. Each group contains IDs of body part labels in + # DensePose. The 9 groups here are: background, torso, hands, feet, + # upper legs, lower legs, upper arms, lower arms, head. + part_groups = [[0], [1, 2], [3, 4], [5, 6], [7, 9, 8, 10], [11, 13, 12, 14], + [15, 17, 16, 18], [19, 21, 20, 22], [23, 24]] + n_parts = len(part_groups) + + need_reshape = densepose_map.dim() == 4 + if need_reshape: + bo, t, h, w = densepose_map.size() + densepose_map = densepose_map.view(-1, h, w) + b, h, w = densepose_map.size() + part_map = (densepose_map / 2 + 0.5) * 24 + assert (part_map >= 0).all() and (part_map < 25).all() + + mask = torch.cuda.ByteTensor(b, n_parts, h, w).fill_(0) + for i in range(n_parts): + for j in part_groups[i]: + # Account for numerical errors. + mask[:, i] = mask[:, i] | ( + (part_map > j - 0.1) & (part_map < j + 0.1)).byte() + if need_reshape: + mask = mask.view(bo, t, -1, h, w) + return mask.float() + + +def get_face_mask(densepose_map): + r"""Obtain mask of faces. + Args: + densepose_map (3D or 4D tensor): DensePose map. + Returns: + mask (3D or 4D tensor): Face mask. + """ + need_reshape = densepose_map.dim() == 4 + if need_reshape: + bo, t, h, w = densepose_map.size() + densepose_map = densepose_map.view(-1, h, w) + + b, h, w = densepose_map.size() + part_map = (densepose_map / 2 + 0.5) * 24 + assert (part_map >= 0).all() and (part_map < 25).all() + if densepose_map.is_cuda: + mask = torch.cuda.ByteTensor(b, h, w).fill_(0) + else: + mask = torch.ByteTensor(b, h, w).fill_(0) + for j in [23, 24]: + mask = mask | ((part_map > j - 0.1) & (part_map < j + 0.1)).byte() + if need_reshape: + mask = mask.view(bo, t, h, w) + return mask.float() + + +def extract_valid_pose_labels(pose_map, pose_type, remove_face_labels, + do_remove=True): + r"""Remove some labels (e.g. face regions) in the pose map if necessary. + + Args: + pose_map (3D, 4D or 5D tensor): Input pose map. + pose_type (str): 'both' or 'open'. + remove_face_labels (bool): Whether to remove labels for the face region. + do_remove (bool): Do remove face labels. + Returns: + pose_map (3D, 4D or 5D tensor): Output pose map. + """ + if pose_map is None: + return pose_map + if type(pose_map) == list: + return [extract_valid_pose_labels(p, pose_type, remove_face_labels, + do_remove) for p in pose_map] + + orig_dim = pose_map.dim() + assert (orig_dim >= 3 and orig_dim <= 5) + if orig_dim == 3: + pose_map = pose_map.unsqueeze(0).unsqueeze(0) + elif orig_dim == 4: + pose_map = pose_map.unsqueeze(0) + + if pose_type == 'open': + # If input is only openpose, remove densepose part. + pose_map = pose_map[:, :, 3:] + + elif remove_face_labels and do_remove: + # Remove face part for densepose input. + densepose, openpose = pose_map[:, :, :3], pose_map[:, :, 3:] + face_mask = get_face_mask(pose_map[:, :, 2]).unsqueeze(2) + pose_map = torch.cat([densepose * (1 - face_mask) - face_mask, + openpose], dim=2) + + if orig_dim == 3: + pose_map = pose_map[0, 0] + elif orig_dim == 4: + pose_map = pose_map[0] + return pose_map + + +def normalize_faces(keypoints, ref_keypoints, + dist_scale_x=None, dist_scale_y=None): + r"""Normalize face keypoints w.r.t. the reference face keypoints. + + Args: + keypoints (Kx2 numpy array): target facial keypoints. + ref_keypoints (Kx2 numpy array): reference facial keypoints. + Returns: + keypoints (Kx2 numpy array): normalized facial keypoints. + """ + if keypoints.shape[0] == 68: + central_keypoints = [8] + add_upper_face = False + part_list = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], + [5, 11], [6, 10], [7, 9, 8], + [17, 26], [18, 25], [19, 24], [20, 23], [21, 22], + [27], [28], [29], [30], [31, 35], [32, 34], [33], + [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46], + [48, 54], [49, 53], [50, 52], [51], [55, 59], [56, 58], + [57], + [60, 64], [61, 63], [62], [65, 67], [66] + ] + if add_upper_face: + part_list += [[68, 82], [69, 81], [70, 80], [71, 79], [72, 78], + [73, 77], [74, 76, 75]] + elif keypoints.shape[0] == 126: + central_keypoints = [16] + part_list = [[i] for i in range(126)] + else: + raise ValueError('Input keypoints type not supported.') + + face_cen = np.mean(keypoints[central_keypoints, :], axis=0) + ref_face_cen = np.mean(ref_keypoints[central_keypoints, :], axis=0) + + def get_mean_dists(pts, face_cen): + r"""Get the mean xy distances of keypoints wrt face center.""" + mean_dists_x, mean_dists_y = [], [] + pts_cen = np.mean(pts, axis=0) + for p, pt in enumerate(pts): + mean_dists_x.append(np.linalg.norm(pt - pts_cen)) + mean_dists_y.append(np.linalg.norm(pts_cen - face_cen)) + mean_dist_x = sum(mean_dists_x) / len(mean_dists_x) + 1e-3 + mean_dist_y = sum(mean_dists_y) / len(mean_dists_y) + 1e-3 + return mean_dist_x, mean_dist_y + + if dist_scale_x is None: + dist_scale_x, dist_scale_y = [None] * len(part_list), \ + [None] * len(part_list) + + for i, pts_idx in enumerate(part_list): + pts = keypoints[pts_idx] + if dist_scale_x[i] is None: + ref_pts = ref_keypoints[pts_idx] + mean_dist_x, mean_dist_y = get_mean_dists(pts, face_cen) + ref_dist_x, ref_dist_y = get_mean_dists(ref_pts, ref_face_cen) + + dist_scale_x[i] = ref_dist_x / mean_dist_x + dist_scale_y[i] = ref_dist_y / mean_dist_y + + pts_cen = np.mean(pts, axis=0) + pts = (pts - pts_cen) * dist_scale_x[i] + \ + (pts_cen - face_cen) * dist_scale_y[i] + face_cen + keypoints[pts_idx] = pts + return keypoints, [dist_scale_x, dist_scale_y] + + +def crop_face_from_output(data_cfg, image, input_label, crop_smaller=0): + r"""Crop out the face region of the image (and resize if necessary to feed + into generator/discriminator). + + Args: + data_cfg (obj): Data configuration. + image (NxC1xHxW tensor or list of tensors): Image to crop. + input_label (NxC2xHxW tensor): Input label map. + crop_smaller (int): Number of pixels to crop slightly smaller region. + Returns: + output (NxC1xHxW tensor or list of tensors): Cropped image. + """ + if type(image) == list: + return [crop_face_from_output(data_cfg, im, input_label, crop_smaller) + for im in image] + + output = None + face_size = image.shape[-2] // 32 * 8 + for i in range(input_label.size(0)): + ys, ye, xs, xe = get_face_bbox_for_output(data_cfg, + input_label[i:i + 1], + crop_smaller=crop_smaller) + output_i = F.interpolate(image[i:i + 1, -3:, ys:ye, xs:xe], + size=(face_size, face_size), mode='bilinear', + align_corners=True) + # output_i = image[i:i + 1, -3:, ys:ye, xs:xe] + output = torch.cat([output, output_i]) if i != 0 else output_i + return output + + +def get_face_bbox_for_output(data_cfg, pose, crop_smaller=0): + r"""Get pixel coordinates of the face bounding box. + + Args: + data_cfg (obj): Data configuration. + pose (NxCxHxW tensor): Pose label map. + crop_smaller (int): Number of pixels to crop slightly smaller region. + Returns: + output (list of int): Face bbox. + """ + if pose.dim() == 3: + pose = pose.unsqueeze(0) + elif pose.dim() == 5: + pose = pose[-1, -1:] + _, _, h, w = pose.size() + + use_openpose = 'pose_maps-densepose' not in data_cfg.input_labels + if use_openpose: # Use openpose face keypoints to identify face region. + for input_type in data_cfg.input_types: + if 'poses-openpose' in input_type: + num_ch = input_type['poses-openpose'].num_channels + if num_ch > 3: + face = (pose[:, -1] > 0).nonzero(as_tuple=False) + else: + raise ValueError('Not implemented yet.') + else: # Use densepose labels. + face = (pose[:, 2] > 0.9).nonzero(as_tuple=False) + + ylen = xlen = h // 32 * 8 + if face.size(0): + y, x = face[:, 1], face[:, 2] + ys, ye = y.min().item(), y.max().item() + xs, xe = x.min().item(), x.max().item() + if use_openpose: + xc, yc = (xs + xe) // 2, (ys * 3 + ye * 2) // 5 + ylen = int((xe - xs) * 2.5) + else: + xc, yc = (xs + xe) // 2, (ys + ye) // 2 + ylen = int((ye - ys) * 1.25) + ylen = xlen = min(w, max(32, ylen)) + yc = max(ylen // 2, min(h - 1 - ylen // 2, yc)) + xc = max(xlen // 2, min(w - 1 - xlen // 2, xc)) + else: + yc = h // 4 + xc = w // 2 + + ys, ye = yc - ylen // 2, yc + ylen // 2 + xs, xe = xc - xlen // 2, xc + xlen // 2 + if crop_smaller != 0: # Crop slightly smaller region inside face. + ys += crop_smaller + xs += crop_smaller + ye -= crop_smaller + xe -= crop_smaller + return [ys, ye, xs, xe] + + +def crop_hand_from_output(data_cfg, image, input_label): + r"""Crop out the hand region of the image. + + Args: + data_cfg (obj): Data configuration. + image (NxC1xHxW tensor or list of tensors): Image to crop. + input_label (NxC2xHxW tensor): Input label map. + Returns: + output (NxC1xHxW tensor or list of tensors): Cropped image. + """ + if type(image) == list: + return [crop_hand_from_output(data_cfg, im, input_label) + for im in image] + + output = None + for i in range(input_label.size(0)): + coords = get_hand_bbox_for_output(data_cfg, input_label[i:i + 1]) + if coords: + for coord in coords: + ys, ye, xs, xe = coord + output_i = image[i:i + 1, -3:, ys:ye, xs:xe] + output = torch.cat([output, output_i]) \ + if output is not None else output_i + return output + + +def get_hand_bbox_for_output(data_cfg, pose): + r"""Get coordinates of the hand bounding box. + + Args: + data_cfg (obj): Data configuration. + pose (NxCxHxW tensor): Pose label map. + Returns: + output (list of int): Hand bbox. + """ + if pose.dim() == 3: + pose = pose.unsqueeze(0) + elif pose.dim() == 5: + pose = pose[-1, -1:] + _, _, h, w = pose.size() + ylen = xlen = h // 64 * 8 + + coords = [] + colors = [[0.95, 0.5, 0.95], [0.95, 0.95, 0.5]] + for i, color in enumerate(colors): + if pose.shape[1] > 6: # Using one-hot encoding for openpose. + idx = -3 if i == 0 else -2 + hand = (pose[:, idx] == 1).nonzero(as_tuple=False) + else: + raise ValueError('Not implemented yet.') + if hand.size(0): + y, x = hand[:, 1], hand[:, 2] + ys, ye, xs, xe = y.min().item(), y.max().item(), \ + x.min().item(), x.max().item() + xc, yc = (xs + xe) // 2, (ys + ye) // 2 + yc = max(ylen // 2, min(h - 1 - ylen // 2, yc)) + xc = max(xlen // 2, min(w - 1 - xlen // 2, xc)) + ys, ye, xs, xe = yc - ylen // 2, yc + ylen // 2, \ + xc - xlen // 2, xc + xlen // 2 + coords.append([ys, ye, xs, xe]) + return coords + + +def pre_process_densepose(pose_cfg, pose_map, is_infer=False): + r"""Pre-process the DensePose part of input label map. + + Args: + pose_cfg (obj): Pose data configuration. + pose_map (NxCxHxW tensor): Pose label map. + is_infer (bool): Is doing inference. + Returns: + pose_map (NxCxHxW tensor): Processed pose label map. + """ + part_map = pose_map[:, :, 2] * 255 # should be within [0-24] + assert (part_map >= 0).all() and (part_map < 25).all() + + # Randomly drop some body part during training. + if not is_infer: + random_drop_prob = getattr(pose_cfg, 'random_drop_prob', 0) + else: + random_drop_prob = 0 + if random_drop_prob > 0: + densepose_map = pose_map[:, :, :3] + for part_id in range(1, 25): + if (random.random() < random_drop_prob): + part_mask = abs(part_map - part_id) < 0.1 + densepose_map[part_mask.unsqueeze(2).expand_as( + densepose_map)] = 0 + pose_map[:, :, :3] = densepose_map + + # Renormalize the DensePose channel from [0, 24] to [0, 255]. + pose_map[:, :, 2] = pose_map[:, :, 2] * (255 / 24) + # Normalize from [0, 1] to [-1, 1]. + pose_map = pose_map * 2 - 1 + return pose_map + + +def random_roll(tensors): + r"""Randomly roll the input tensors along x and y dimensions. Also randomly + flip the tensors. + + Args: + tensors (list of 4D tensors): Input tensors. + Returns: + output (list of 4D tensors): Rolled tensors. + """ + h, w = tensors[0].shape[2:] + ny = np.random.choice([np.random.randint(h//16), + h-np.random.randint(h//16)]) + nx = np.random.choice([np.random.randint(w//16), + w-np.random.randint(w//16)]) + flip = np.random.rand() > 0.5 + return [roll(t, ny, nx, flip) for t in tensors] + + +def roll(t, ny, nx, flip=False): + r"""Roll and flip the tensor by specified amounts. + + Args: + t (4D tensor): Input tensor. + ny (int): Amount to roll along y dimension. + nx (int): Amount to roll along x dimension. + flip (bool): Whether to flip input. + Returns: + t (4D tensor): Output tensor. + """ + t = torch.cat([t[:, :, -ny:], t[:, :, :-ny]], dim=2) + t = torch.cat([t[:, :, :, -nx:], t[:, :, :, :-nx]], dim=3) + if flip: + t = torch.flip(t, dims=[3]) + return t + + +def detach(output): + r"""Detach tensors in the dict. + + Args: + output (dict): Output dict. + Returns: + output (dict): Detached output dict. + """ + if type(output) == dict: + new_dict = dict() + for k, v in output.items(): + new_dict[k] = detach(v) + return new_dict + elif type(output) == torch.Tensor: + return output.detach() + return output diff --git a/imaginaire/model_utils/gancraft/camctl.py b/imaginaire/model_utils/gancraft/camctl.py new file mode 100644 index 0000000000000000000000000000000000000000..26e4ab674b7f6b44d484d06c661267ed7ce69d56 --- /dev/null +++ b/imaginaire/model_utils/gancraft/camctl.py @@ -0,0 +1,640 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import numpy as np +import torch + + +class EvalCameraController: + def __init__(self, voxel, maxstep=128, pattern=0, cam_ang=73, smooth_decay_multiplier=1.0): + self.voxel = voxel + self.maxstep = maxstep + self.camera_poses = [] # ori, dir, up, f + circle = torch.linspace(0, 2*np.pi, steps=maxstep) + size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2 + # Shrink the circle a bit. + shift = size * 0.2 + size = size * 0.8 + + if pattern == 0: + height_history = [] + # Calculate smooth height. + for i in range(maxstep): + farpoint = torch.tensor([ + 70, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0])) + + # Filtfilt + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + farpoint = torch.tensor([ + 70, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = height_history[i] + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + elif pattern == 1: + zoom = torch.linspace(1.0, 0.25, steps=maxstep) + height_history = [] + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + + height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0])) + + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = height_history[i] + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)*zoom[i]) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + elif pattern == 2: + move = torch.linspace(1.0, 0.2, steps=maxstep) + height_history = [] + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0])) + + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = height_history[i] + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + elif pattern == 3: + move = torch.linspace(0.75, 0.2, steps=maxstep) + height_history = [] + for i in range(maxstep): + farpoint = torch.tensor([ + 70, + torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0])) + + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + farpoint = torch.tensor([ + 70, + torch.sin(-circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(-circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = height_history[i] + + nearpoint = torch.tensor([ + 60, + torch.sin(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(-circle[i]-0.4*np.pi)*size*0.9*move[i] + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + elif pattern == 4: + move = torch.linspace(1.0, 0.5, steps=maxstep) + height_history = [] + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + height_history.append(self._get_height(farpoint[1], farpoint[2], farpoint[0])) + + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = height_history[i] + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + # look outward + elif pattern == 5: + move = torch.linspace(1.0, 0.5, steps=maxstep) + height_history = [] + for i in range(maxstep): + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + height_history.append(self._get_height(nearpoint[1], nearpoint[2], nearpoint[0])) + + height_history = self.filtfilt(height_history, decay=0.2*smooth_decay_multiplier) + + for i in range(maxstep): + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + nearpoint[0] = height_history[i] + + farpoint = torch.tensor([ + 60, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + cam_ori = self.voxel.world2local(nearpoint) + cam_dir = self.voxel.world2local(farpoint - nearpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(cam_ang/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + # Rise + elif pattern == 6: + shift = 0 + lift = torch.linspace(0.0, 200.0, steps=maxstep) + zoom = torch.linspace(0.8, 1.6, steps=maxstep) + for i in range(maxstep): + farpoint = torch.tensor([ + 80+lift[i], + torch.sin(circle[i]/4)*size*0.2 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]/4)*size*0.2 + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0]) + + nearpoint = torch.tensor([ + 65, + torch.sin(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]/4+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + # 45deg + elif pattern == 7: + rad = torch.tensor([np.deg2rad(45).astype(np.float32)]) + size = 1536 + for i in range(maxstep): + farpoint = torch.tensor([ + 61+size, + torch.sin(rad)*size + voxel.voxel_t.size(1)/2, + torch.cos(rad)*size + voxel.voxel_t.size(2)/2]) + + nearpoint = torch.tensor([ + 61, + voxel.voxel_t.size(1)/2, + voxel.voxel_t.size(2)/2]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(19.5/2)) # about 50mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + def _get_height(self, loc0, loc1, minheight): + loc0 = int(loc0) + loc1 = int(loc1) + height = minheight + for dx in range(-3, 4): + for dy in range(-3, 4): + if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \ + (loc1+dy) >= self.voxel.heightmap.shape[1]: + height = max(height, minheight) + else: + height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2) + return height + + def filtfilt(self, height_history, decay=0.2): + # Filtfilt + height_history2 = [] + maxstep = len(height_history) + prev_height = height_history[0] + for i in range(maxstep): + prev_height = prev_height - decay + if prev_height < height_history[i]: + prev_height = height_history[i] + height_history2.append(prev_height) + prev_height = height_history[-1] + for i in range(maxstep-1, -1, -1): + prev_height = prev_height - decay + if prev_height < height_history[i]: + prev_height = height_history[i] + height_history2[i] = max(prev_height, height_history2[i]) + return height_history2 + + def __len__(self): + return len(self.camera_poses) + + def __getitem__(self, idx): + return self.camera_poses[idx] + + +class TourCameraController: + def __init__(self, voxel, maxstep=128): + self.voxel = voxel + self.maxstep = maxstep + self.camera_poses = [] # ori, dir, up, f + circle = torch.linspace(0, 2*np.pi, steps=maxstep//4) + size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2 + # Shrink the circle a bit + shift = size * 0.2 + size = size * 0.8 + + for i in range(maxstep//4): + farpoint = torch.tensor([ + 70, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0]) + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.5 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + zoom = torch.linspace(1.0, 0.25, steps=maxstep//4) + for i in range(maxstep//4): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0]) + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]-0.3*np.pi)*size*0.3 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + move = torch.linspace(1.0, 0.2, steps=maxstep//4) + for i in range(maxstep//4): + farpoint = torch.tensor([ + 90, + torch.sin(circle[i])*size*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*move[i] + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0]) + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.3*move[i] + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + lift = torch.linspace(0.0, 200.0, steps=maxstep//4) + zoom = torch.linspace(0.6, 1.2, steps=maxstep//4) + for i in range(maxstep//4): + farpoint = torch.tensor([ + 80+lift[i], + torch.sin(circle[i])*size*0.2 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i])*size*0.2 + voxel.voxel_t.size(2)/2 + shift]) + + farpoint[0] = self._get_height(farpoint[1], farpoint[2], farpoint[0]) + + nearpoint = torch.tensor([ + 60, + torch.sin(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(1)/2 + shift, + torch.cos(circle[i]+0.5*np.pi)*size*0.1 + voxel.voxel_t.size(2)/2 + shift]) + cam_ori = self.voxel.world2local(farpoint) + cam_dir = self.voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = self.voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)*zoom[i]) # about 24mm fov + + self.camera_poses.append((cam_ori, cam_dir, cam_up, cam_f)) + + def _get_height(self, loc0, loc1, minheight): + loc0 = int(loc0) + loc1 = int(loc1) + height = minheight + for dx in range(-3, 4): + for dy in range(-3, 4): + if (loc0+dx) < 0 or (loc0+dx) >= self.voxel.heightmap.shape[0] or (loc1+dy) < 0 or \ + (loc1+dy) >= self.voxel.heightmap.shape[1]: + height = max(height, minheight) + else: + height = max(height, self.voxel.heightmap[loc0+dx, loc1+dy] + 2) + return height + + def __len__(self): + return len(self.camera_poses) + + def __getitem__(self, idx): + return self.camera_poses[idx] + + +def rand_camera_pose_birdseye(voxel, border=128): + r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up + Assuming [Y X Z] coordinate. Y is negative gravity direction. + The camera pose is converted into the voxel coordinate system so that it can be used directly for rendering + 1. Uniformly sample a point on the upper hemisphere of a unit sphere, as cam_ori. + 2. Set cam_dir to be from cam_ori to the origin + 3. cam_up is always pointing towards sky + 4. move cam_ori to random place according to voxel size + """ + cam_dir = torch.randn(3, dtype=torch.float32) + cam_dir = cam_dir / torch.sqrt(torch.sum(cam_dir*cam_dir)) + cam_dir[0] = -torch.abs(cam_dir[0]) + cam_up = torch.tensor([1, 0, 0], dtype=torch.float32) + + # generate camera lookat target + r = np.random.rand(2) + r[0] *= voxel.voxel_t.size(1)-border-border + r[1] *= voxel.voxel_t.size(2)-border-border + r = r + border + y = voxel.heightmap[int(r[0]+0.5), int(r[1]+0.5)] + (np.random.rand(1)-0.5) * 5 + cam_target = torch.tensor([y, r[0], r[1]], dtype=torch.float32) + cam_ori = cam_target - cam_dir * (np.random.rand(1).item() * 100) + cam_ori[0] = max(voxel.heightmap[int(cam_ori[1]+0.5), int(cam_ori[2]+0.5)]+2, cam_ori[0]) + # Translate to voxel coordinate + cam_ori = voxel.world2local(cam_ori) + cam_dir = voxel.world2local(cam_dir, is_vec=True) + cam_up = voxel.world2local(cam_up, is_vec=True) + + return cam_ori, cam_dir, cam_up + + +def get_neighbor_height(heightmap, loc0, loc1, minheight, neighbor_size=7): + loc0 = int(loc0) + loc1 = int(loc1) + height = 0 + for dx in range(-neighbor_size//2, neighbor_size//2+1): + for dy in range(-neighbor_size//2, neighbor_size//2+1): + if (loc0+dx) < 0 or (loc0+dx) >= heightmap.shape[0] or (loc1+dy) < 0 or (loc1+dy) >= heightmap.shape[1]: + height = max(height, minheight) + else: + height = max(minheight, heightmap[loc0+dx, loc1+dy] + 2) + return height + + +def rand_camera_pose_firstperson(voxel, border=128): + r"""Generating random camera pose in the upper hemisphere, in the format of origin-direction-up + """ + r = np.random.rand(5) + r[0] *= voxel.voxel_t.size(1)-border-border + r[1] *= voxel.voxel_t.size(2)-border-border + r[0] = r[0] + border + r[1] = r[1] + border + + y = get_neighbor_height(voxel.heightmap, r[0], r[1], 0) + np.random.rand(1) * 15 + + cam_ori = torch.tensor([y, r[0], r[1]], dtype=torch.float32) + + rand_ang_h = r[2] * 2 * np.pi + cam_target = torch.tensor([0, cam_ori[1]+np.sin(rand_ang_h)*border*r[4], cam_ori[2] + + np.cos(rand_ang_h)*border*r[4]], dtype=torch.float32) + cam_target[0] = get_neighbor_height(voxel.heightmap, cam_target[1], + cam_target[2], 0, neighbor_size=1) - 2 + r[3] * 10 + + cam_dir = cam_target - cam_ori + + cam_up = torch.tensor([1, 0, 0], dtype=torch.float32) + + cam_ori = voxel.world2local(cam_ori) + cam_dir = voxel.world2local(cam_dir, is_vec=True) + cam_up = voxel.world2local(cam_up, is_vec=True) + + return cam_ori, cam_dir, cam_up + + +def rand_camera_pose_thridperson(voxel, border=96): + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) + r[1] *= voxel.voxel_t.size(2) + rand_height = 60 + torch.rand(1) * 40 + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5) + farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) - border - border + r[1] *= voxel.voxel_t.size(2) - border - border + r[0] = r[0] + border + r[1] = r[1] + border + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5 + nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + cam_ori = voxel.world2local(farpoint) + cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = voxel.world2local(torch.tensor([1, 0, 0], dtype=torch.float32), is_vec=True) + + return cam_ori, cam_dir, cam_up + + +def rand_camera_pose_thridperson2(voxel, border=48): + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) - border - border + r[1] *= voxel.voxel_t.size(2) - border - border + r[0] = r[0] + border + r[1] = r[1] + border + rand_height = 60 + torch.rand(1) * 40 + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=5) + farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) - border - border + r[1] *= voxel.voxel_t.size(2) - border - border + r[0] = r[0] + border + r[1] = r[1] + border + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=1) - 5 + nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + # Random Up vector (tilt a little bit) + # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176 + up = torch.randn(3) * 0.02 + up[0] = 1.0 + up = up / up.norm(p=2) + cam_ori = voxel.world2local(farpoint) + cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = voxel.world2local(up, is_vec=True) + + return cam_ori, cam_dir, cam_up + + +def rand_camera_pose_thridperson3(voxel, border=64): + r"""Attempting to solve the camera too close to wall problem and the lack of aerial poses.""" + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) - border - border + r[1] *= voxel.voxel_t.size(2) - border - border + r[0] = r[0] + border + r[1] = r[1] + border + rand_height = 60 + torch.rand(1) * 40 + if torch.rand(1) > 0.8: + rand_height = 60 + torch.rand(1) * 60 + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], rand_height, neighbor_size=7) + farpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + r = torch.rand(2) + r[0] *= voxel.voxel_t.size(1) - border - border + r[1] *= voxel.voxel_t.size(2) - border - border + r[0] = r[0] + border + r[1] = r[1] + border + rand_height = get_neighbor_height(voxel.heightmap, r[0], r[1], 65, neighbor_size=3) - 5 + nearpoint = torch.tensor([rand_height, r[0], r[1]], dtype=torch.float32) + + # Random Up vector (tilt a little bit) + # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176 + up = torch.randn(3) * 0.02 + up[0] = 1.0 + up = up / up.norm(p=2) + # print(up) + cam_ori = voxel.world2local(farpoint) + cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = voxel.world2local(up, is_vec=True) + + return cam_ori, cam_dir, cam_up + + +def rand_camera_pose_tour(voxel): + size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2 + center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2] + + rnd = torch.rand(8) + + rnd_deg = torch.rand(1) * 2 * np.pi + far_radius = rnd[0]*0.8+0.2 + far_height = rnd[1]*30 + 60 + farpoint = torch.tensor([ + far_height, + torch.sin(rnd_deg)*size*far_radius + center[0], + torch.cos(rnd_deg)*size*far_radius + center[1]]) + + farpoint[0] = get_neighbor_height(voxel.heightmap, farpoint[1], farpoint[2], farpoint[0], neighbor_size=7) + + near_radius = far_radius * rnd[2] + near_shift_rad = np.pi*(rnd[3]-0.5) + near_height = 60 + rnd[4] * 10 + nearpoint = torch.tensor([ + near_height, + torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0], + torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]]) + + # Random Up vector (tilt a little bit) + # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176 + up = torch.randn(3) * 0.02 + up[0] = 1.0 + up = up / up.norm(p=2) + cam_ori = voxel.world2local(farpoint) + cam_dir = voxel.world2local(nearpoint - farpoint, is_vec=True) + cam_up = voxel.world2local(up, is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25)) # about 24mm fov + + return cam_ori, cam_dir, cam_up, cam_f + +# Look from center to outward + + +def rand_camera_pose_insideout(voxel): + size = min(voxel.voxel_t.size(1), voxel.voxel_t.size(2)) / 2 + center = [voxel.voxel_t.size(1)/2, voxel.voxel_t.size(2)/2] + + rnd = torch.rand(8) + + rnd_deg = torch.rand(1) * 2 * np.pi + far_radius = rnd[0]*0.8+0.2 + far_height = rnd[1]*10 + 60 + farpoint = torch.tensor([ + far_height, + torch.sin(rnd_deg)*size*far_radius + center[0], + torch.cos(rnd_deg)*size*far_radius + center[1]]) + + near_radius = far_radius * rnd[2] + near_shift_rad = np.pi*(rnd[3]-0.5) + near_height = 60 + rnd[4] * 30 + nearpoint = torch.tensor([ + near_height, + torch.sin(rnd_deg+near_shift_rad)*size*near_radius + center[0], + torch.cos(rnd_deg+near_shift_rad)*size*near_radius + center[1]]) + + nearpoint[0] = get_neighbor_height(voxel.heightmap, nearpoint[1], nearpoint[2], nearpoint[0], neighbor_size=7) + + # Random Up vector (tilt a little bit) + # up = torch.randn(3) * 0.05 # cutoff +-0.1, Tan(10deg) = 0.176 + up = torch.randn(3) * 0.02 + up[0] = 1.0 + up = up / up.norm(p=2) + cam_ori = voxel.world2local(nearpoint) + cam_dir = voxel.world2local(farpoint-nearpoint, is_vec=True) + cam_up = voxel.world2local(up, is_vec=True) + cam_f = 0.5/np.tan(np.deg2rad(73/2)*(rnd[5]*0.75+0.25)) # about 24mm fov + + return cam_ori, cam_dir, cam_up, cam_f diff --git a/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv b/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv new file mode 100644 index 0000000000000000000000000000000000000000..ba061b7a5bda98899c8b1f653d5f204fccfa38e4 --- /dev/null +++ b/imaginaire/model_utils/gancraft/gaugan_lbl2col.csv @@ -0,0 +1,182 @@ +person,#00AC0D +bicycle,#012F47 +car,#0275B8 +motorcycle,#03C098 +airplane,#04434F +bus,#05FB29 +train,#06C312 +truck,#076728 +boat,#0809B6 +traffic-light,#09D3CF +fire-hydrant,#0A150B +street-sign,#0BF2A6 +stop-sign,#0C246F +parking-meter,#0D575D +bench,#0E46F9 +bird,#0FD881 +cat,#1058DF +dog,#118C76 +horse,#123A2C +sheep,#13C1D8 +cow,#14E67D +elephant,#152718 +bear,#165743 +zebra,#17AED2 +giraffe,#1858EF +hat,#195103 +backpack,#1AA5EA +umbrella,#1B19CC +shoe,#1C4DE6 +eye-glasses,#1D4823 +handbag,#1E09D6 +tie,#1F94FE +suitcase,#2073BD +frisbee,#21D0C5 +skis,#22F3D7 +snowboard,#23C52B +sports-ball,#24FE20 +kite,#254F0B +baseball-bat,#26AF68 +baseball-glove,#27C0D4 +skateboard,#28528A +surfboard,#2963B6 +tennis-racket,#2AD8EB +bottle,#2BB1A5 +plate,#2CF37D +wine-glass,#2D1D9C +cup,#2E936F +fork,#2F93E8 +knife,#308E02 +spoon,#31A71B +bowl,#3220D3 +banana,#33C1D9 +apple,#340997 +sandwich,#35B935 +orange,#367F33 +broccoli,#3720AE +carrot,#381F94 +hot-dog,#39CAB5 +pizza,#3AF41D +donut,#3B9743 +cake,#3CA323 +chair,#3DFE27 +couch,#3ECB89 +potted-plant,#3F7249 +bed,#40B729 +mirror,#411C97 +dining-table,#422283 +window,#43802E +desk,#4480DA +toilet,#45A4B2 +door,#46356C +tv,#478503 +laptop,#48261F +mouse,#49E809 +remote,#4AF48A +keyboard,#4B111B +cell-phone,#4C4FAD +microwave,#4D84C7 +oven,#4E69A7 +toaster,#4F2A3D +sink,#50BA55 +refrigerator,#511F61 +blender,#52782C +book,#530122 +clock,#5441A2 +vase,#55E758 +scissors,#56A921 +teddy-bear,#573985 +hair-drier,#5823E8 +toothbrush,#5966FF +hair-brush,#5A7724 +banner,#5B0B00 +blanket,#5CAECB +branch,#5D5222 +bridge,#5E5BC5 +building-other,#5F807E +bush,#606E32 +cabinet,#6163FE +cage,#623550 +cardboard,#638CBE +carpet,#647988 +ceiling-other,#65AABD +ceiling-tile,#665481 +cloth,#67CBD1 +clothes,#684470 +clouds,#696969 +counter,#6AC478 +cupboard,#6B2F5B +curtain,#6C7FA8 +desk-stuff,#6DF474 +dirt,#6E6E28 +door-stuff,#6FCCB0 +fence,#706419 +floor-marble,#71B443 +floor-other,#72E867 +floor-stone,#734EFC +floor-tile,#748F23 +floor-wood,#759472 +flower,#760000 +fog,#77BA1D +food-other,#7817F1 +fruit,#79CF21 +furniture-other,#7A8D92 +grass,#7BC800 +gravel,#7C32C8 +ground-other,#7D3054 +hill,#7EC864 +house,#7F4502 +leaves,#80A945 +light,#81A365 +mat,#82C08C +metal,#835F2C +mirror-stuff,#84C575 +moss,#855EFD +mountain,#869664 +mud,#87716F +napkin,#88B25B +net,#892455 +paper,#8AA2A7 +pavement,#8B3027 +pillow,#8C5DCB +plant,#8DE61E +plastic,#8E629E +platform,#8F2A91 +playingfield,#90CDC6 +railing,#9170C7 +railroad,#92E712 +river,#9364C8 +road,#946E28 +rock,#956432 +roof,#9600B1 +rug,#978A29 +salad,#98725D +sand,#999900 +sea,#9AC6DA +shelf,#9B7FC9 +sky,#9CEEDD +skyscraper,#9DBBF2 +snow,#9E9EAA +solid-other,#9F79DB +stairs,#A06249 +stone,#A1A164 +straw,#A2A3EB +structural,#A3DED1 +table,#A47B69 +tent,#A5C3BA +textile-other,#A65280 +towel,#A7AED6 +tree,#A8C832 +vegetable,#A99410 +wall-brick,#AAD16A +wall-concrete,#AB32A4 +wall-other,#AC9B5E +wall-panel,#AD0E18 +wall-stone,#AE2974 +wall-tile,#AF3ABF +wall-wood,#B0C1C3 +water,#B1C8FF +waterdrops,#B20A88 +window-blind,#B356B8 +window-other,#B42B5B +wood,#B57B00 diff --git a/imaginaire/model_utils/gancraft/gaugan_reduction.csv b/imaginaire/model_utils/gancraft/gaugan_reduction.csv new file mode 100644 index 0000000000000000000000000000000000000000..a49ad38c2f8bddd04a12a08a09b1bbdab5944fc2 --- /dev/null +++ b/imaginaire/model_utils/gancraft/gaugan_reduction.csv @@ -0,0 +1,182 @@ +person,ignore +bicycle,ignore +car,ignore +motorcycle,ignore +airplane,ignore +bus,ignore +train,ignore +truck,ignore +boat,ignore +traffic-light,ignore +fire-hydrant,ignore +street-sign,ignore +stop-sign,ignore +parking-meter,ignore +bench,ignore +bird,ignore +cat,ignore +dog,ignore +horse,ignore +sheep,ignore +cow,ignore +elephant,ignore +bear,ignore +zebra,ignore +giraffe,ignore +hat,ignore +backpack,ignore +umbrella,ignore +shoe,ignore +eye-glasses,ignore +handbag,ignore +tie,ignore +suitcase,ignore +frisbee,ignore +skis,ignore +snowboard,ignore +sports-ball,ignore +kite,ignore +baseball-bat,ignore +baseball-glove,ignore +skateboard,ignore +surfboard,ignore +tennis-racket,ignore +bottle,ignore +plate,ignore +wine-glass,ignore +cup,ignore +fork,ignore +knife,ignore +spoon,ignore +bowl,ignore +banana,ignore +apple,ignore +sandwich,ignore +orange,ignore +broccoli,ignore +carrot,ignore +hot-dog,ignore +pizza,ignore +donut,ignore +cake,ignore +chair,ignore +couch,ignore +potted-plant,ignore +bed,ignore +mirror,ignore +dining-table,ignore +window,ignore +desk,ignore +toilet,ignore +door,ignore +tv,ignore +laptop,ignore +mouse,ignore +remote,ignore +keyboard,ignore +cell-phone,ignore +microwave,ignore +oven,ignore +toaster,ignore +sink,ignore +refrigerator,ignore +blender,ignore +book,ignore +clock,ignore +vase,ignore +scissors,ignore +teddy-bear,ignore +hair-drier,ignore +toothbrush,ignore +hair-brush,ignore +banner,ignore +blanket,ignore +branch,tree +bridge,ignore +building-other,ignore +bush,tree +cabinet,ignore +cage,ignore +cardboard,ignore +carpet,ignore +ceiling-other,ignore +ceiling-tile,ignore +cloth,ignore +clothes,ignore +clouds,sky +counter,ignore +cupboard,ignore +curtain,ignore +desk-stuff,ignore +dirt,dirt +door-stuff,ignore +fence,ignore +floor-marble,ignore +floor-other,ignore +floor-stone,ignore +floor-tile,ignore +floor-wood,ignore +flower,flower +fog,sky +food-other,ignore +fruit,ignore +furniture-other,ignore +grass,grass +gravel,gravel +ground-other,ignore +hill,grass +house,ignore +leaves,tree +light,ignore +mat,ignore +metal,ignore +mirror-stuff,ignore +moss,grass +mountain,grass +mud,dirt +napkin,ignore +net,ignore +paper,ignore +pavement,ignore +pillow,ignore +plant,flower +plastic,ignore +platform,ignore +playingfield,ignore +railing,ignore +railroad,ignore +river,water +road,ignore +rock,rock +roof,ignore +rug,ignore +salad,ignore +sand,sand +sea,water +shelf,ignore +sky,sky +skyscraper,ignore +snow,snow +solid-other,ignore +stairs,ignore +stone,stone +straw,grass +structural,ignore +table,ignore +tent,ignore +textile-other,ignore +towel,ignore +tree,tree +vegetable,ignore +wall-brick,ignore +wall-concrete,ignore +wall-other,ignore +wall-panel,ignore +wall-stone,ignore +wall-tile,ignore +wall-wood,ignore +water,water +waterdrops,ignore +window-blind,ignore +window-other,ignore +wood,ignore diff --git a/imaginaire/model_utils/gancraft/id2name_gg.csv b/imaginaire/model_utils/gancraft/id2name_gg.csv new file mode 100644 index 0000000000000000000000000000000000000000..bb52afe4132cdae36494c08dab6ac4982f572386 --- /dev/null +++ b/imaginaire/model_utils/gancraft/id2name_gg.csv @@ -0,0 +1,680 @@ +0,air,0,sky +1,stone,7368816,stone +2,granite,7368816,rock +3,polished_granite,7368816,rock +4,diorite,7368816,rock +5,polished_diorite,7368816,rock +6,andesite,7368816,rock +7,polished_andesite,7368816,rock +8,grass_block,8368696,grass +9,dirt,9923917,dirt +10,coarse_dirt,9923917,dirt +11,podzol,9923917,dirt +12,cobblestone,7368816,stone +13,oak_planks,9402184,wood +14,spruce_planks,9402184,wood +15,birch_planks,9402184,wood +16,jungle_planks,9402184,wood +17,acacia_planks,9402184,wood +18,dark_oak_planks,9402184,wood +19,oak_sapling,31744,plant +20,spruce_sapling,31744,plant +21,birch_sapling,31744,plant +22,jungle_sapling,31744,plant +23,acacia_sapling,31744,plant +24,dark_oak_sapling,31744,plant +25,bedrock,7368816,rock +26,water,4210943,water +27,lava,16711680, +28,sand,16247203,sand +29,red_sand,16247203,sand +30,gravel,16247203,gravel +31,gold_ore,7368816,rock +32,iron_ore,7368816,rock +33,coal_ore,7368816,rock +34,oak_log,9402184,tree +35,spruce_log,9402184,tree +36,birch_log,9402184,tree +37,jungle_log,9402184,tree +38,acacia_log,9402184,tree +39,dark_oak_log,9402184,tree +40,stripped_spruce_log,9402184,wood +41,stripped_birch_log,9402184,wood +42,stripped_jungle_log,9402184,wood +43,stripped_acacia_log,9402184,wood +44,stripped_dark_oak_log,9402184,wood +45,stripped_oak_log,9402184,wood +46,oak_wood,9402184,wood +47,spruce_wood,9402184,wood +48,birch_wood,9402184,wood +49,jungle_wood,9402184,wood +50,acacia_wood,9402184,wood +51,dark_oak_wood,9402184,wood +52,stripped_oak_wood,9402184,wood +53,stripped_spruce_wood,9402184,wood +54,stripped_birch_wood,9402184,wood +55,stripped_jungle_wood,9402184,wood +56,stripped_acacia_wood,9402184,wood +57,stripped_dark_oak_wood,9402184,wood +58,oak_leaves,31744,tree +59,spruce_leaves,31744,tree +60,birch_leaves,31744,tree +61,jungle_leaves,31744,tree +62,acacia_leaves,31744,tree +63,dark_oak_leaves,31744,tree +64,sponge,15066419, +65,wet_sponge,15066419, +66,glass,0, +67,lapis_ore,7368816, +68,lapis_block,10987431, +69,dispenser,7368816, +70,sandstone,7368816,sand +71,chiseled_sandstone,7368816,sand +72,cut_sandstone,7368816,sand +73,note_block,9402184, +74,white_bed,13092807, +75,orange_bed,13092807, +76,magenta_bed,13092807, +77,light_blue_bed,13092807, +78,yellow_bed,13092807, +79,lime_bed,13092807, +80,pink_bed,13092807, +81,gray_bed,13092807, +82,light_gray_bed,13092807, +83,cyan_bed,13092807, +84,purple_bed,13092807, +85,blue_bed,13092807, +86,brown_bed,13092807, +87,green_bed,13092807, +88,red_bed,13092807, +89,black_bed,13092807, +90,powered_rail,0, +91,detector_rail,0, +92,sticky_piston,7368816, +93,cobweb,13092807, +94,grass,31744,grass +95,fern,31744,grass +96,dead_bush,31744,grass +97,seagrass,4210943,water +98,tall_seagrass,4210943,water +99,piston,7368816, +100,piston_head,7368816, +101,white_wool,13092807, +102,orange_wool,13092807, +103,magenta_wool,13092807, +104,light_blue_wool,13092807, +105,yellow_wool,13092807, +106,lime_wool,13092807, +107,pink_wool,13092807, +108,gray_wool,13092807, +109,light_gray_wool,13092807, +110,cyan_wool,13092807, +111,purple_wool,13092807, +112,blue_wool,13092807, +113,brown_wool,13092807, +114,green_wool,13092807, +115,red_wool,13092807, +116,black_wool,13092807, +117,moving_piston,7368816, +118,dandelion,31744,flower +119,poppy,31744,flower +120,blue_orchid,31744,flower +121,allium,31744,flower +122,azure_bluet,31744,flower +123,red_tulip,31744,flower +124,orange_tulip,31744,flower +125,white_tulip,31744,flower +126,pink_tulip,31744,flower +127,oxeye_daisy,31744,flower +128,cornflower,31744,flower +129,wither_rose,31744,flower +130,lily_of_the_valley,31744,flower +131,brown_mushroom,31744,flower +132,red_mushroom,31744,flower +133,gold_block,10987431, +134,iron_block,10987431, +135,bricks,7368816, +136,tnt,16711680, +137,bookshelf,9402184, +138,mossy_cobblestone,7368816, +139,obsidian,7368816, +140,torch,0, +141,wall_torch,0, +142,fire,0, +143,spawner,7368816, +144,oak_stairs,9402184, +145,chest,9402184, +146,redstone_wire,0, +147,diamond_ore,7368816, +148,diamond_block,10987431, +149,crafting_table,9402184, +150,wheat,31744, +151,farmland,9923917, +152,furnace,7368816, +153,oak_sign,9402184, +154,spruce_sign,9402184, +155,birch_sign,9402184, +156,acacia_sign,9402184, +157,jungle_sign,9402184, +158,dark_oak_sign,9402184, +159,oak_door,9402184, +160,ladder,0, +161,rail,0, +162,cobblestone_stairs,7368816, +163,oak_wall_sign,9402184, +164,spruce_wall_sign,9402184, +165,birch_wall_sign,9402184, +166,acacia_wall_sign,9402184, +167,jungle_wall_sign,9402184, +168,dark_oak_wall_sign,9402184, +169,lever,0, +170,stone_pressure_plate,7368816, +171,iron_door,10987431, +172,oak_pressure_plate,9402184, +173,spruce_pressure_plate,9402184, +174,birch_pressure_plate,9402184, +175,jungle_pressure_plate,9402184, +176,acacia_pressure_plate,9402184, +177,dark_oak_pressure_plate,9402184, +178,redstone_ore,7368816, +179,redstone_torch,0, +180,redstone_wall_torch,0, +181,stone_button,0, +182,snow,16777215,snow +183,ice,10526975,snow +184,snow_block,16777215,snow +185,cactus,31744,plant +186,clay,10791096, +187,sugar_cane,31744,plant +188,jukebox,9402184, +189,oak_fence,9402184, +190,pumpkin,31744, +191,netherrack,7368816, +192,soul_sand,16247203, +193,glowstone,0, +194,nether_portal,0, +195,carved_pumpkin,31744, +196,jack_o_lantern,31744, +197,cake,0, +198,repeater,0, +199,white_stained_glass,0, +200,orange_stained_glass,0, +201,magenta_stained_glass,0, +202,light_blue_stained_glass,0, +203,yellow_stained_glass,0, +204,lime_stained_glass,0, +205,pink_stained_glass,0, +206,gray_stained_glass,0, +207,light_gray_stained_glass,0, +208,cyan_stained_glass,0, +209,purple_stained_glass,0, +210,blue_stained_glass,0, +211,brown_stained_glass,0, +212,green_stained_glass,0, +213,red_stained_glass,0, +214,black_stained_glass,0, +215,oak_trapdoor,9402184, +216,spruce_trapdoor,9402184, +217,birch_trapdoor,9402184, +218,jungle_trapdoor,9402184, +219,acacia_trapdoor,9402184, +220,dark_oak_trapdoor,9402184, +221,stone_bricks,7368816, +222,mossy_stone_bricks,7368816, +223,cracked_stone_bricks,7368816, +224,chiseled_stone_bricks,7368816, +225,infested_stone,10791096, +226,infested_cobblestone,10791096, +227,infested_stone_bricks,10791096, +228,infested_mossy_stone_bricks,10791096, +229,infested_cracked_stone_bricks,10791096, +230,infested_chiseled_stone_bricks,10791096, +231,brown_mushroom_block,9402184,tree +232,red_mushroom_block,9402184,tree +233,mushroom_stem,9402184,tree +234,iron_bars,10987431, +235,glass_pane,0, +236,melon,31744, +237,attached_pumpkin_stem,31744, +238,attached_melon_stem,31744, +239,pumpkin_stem,31744, +240,melon_stem,31744, +241,vine,31744,plant +242,oak_fence_gate,9402184, +243,brick_stairs,7368816, +244,stone_brick_stairs,7368816, +245,mycelium,8368696, +246,lily_pad,31744,grass +247,nether_bricks,7368816, +248,nether_brick_fence,7368816, +249,nether_brick_stairs,7368816, +250,nether_wart,31744, +251,enchanting_table,7368816, +252,brewing_stand,10987431, +253,cauldron,10987431, +254,end_portal,0, +255,end_portal_frame,7368816, +256,end_stone,7368816, +257,dragon_egg,31744, +258,redstone_lamp,0, +259,cocoa,31744, +260,sandstone_stairs,7368816, +261,emerald_ore,7368816, +262,ender_chest,7368816, +263,tripwire_hook,0, +264,tripwire,0, +265,emerald_block,10987431, +266,spruce_stairs,9402184, +267,birch_stairs,9402184, +268,jungle_stairs,9402184, +269,command_block,10987431, +270,beacon,0, +271,cobblestone_wall,7368816, +272,mossy_cobblestone_wall,7368816, +273,flower_pot,0, +274,potted_oak_sapling,0, +275,potted_spruce_sapling,0, +276,potted_birch_sapling,0, +277,potted_jungle_sapling,0, +278,potted_acacia_sapling,0, +279,potted_dark_oak_sapling,0, +280,potted_fern,0, +281,potted_dandelion,0, +282,potted_poppy,0, +283,potted_blue_orchid,0, +284,potted_allium,0, +285,potted_azure_bluet,0, +286,potted_red_tulip,0, +287,potted_orange_tulip,0, +288,potted_white_tulip,0, +289,potted_pink_tulip,0, +290,potted_oxeye_daisy,0, +291,potted_cornflower,0, +292,potted_lily_of_the_valley,0, +293,potted_wither_rose,0, +294,potted_red_mushroom,0, +295,potted_brown_mushroom,0, +296,potted_dead_bush,0, +297,potted_cactus,0, +298,carrots,31744, +299,potatoes,31744, +300,oak_button,0, +301,spruce_button,0, +302,birch_button,0, +303,jungle_button,0, +304,acacia_button,0, +305,dark_oak_button,0, +306,skeleton_skull,0, +307,skeleton_wall_skull,0, +308,wither_skeleton_skull,0, +309,wither_skeleton_wall_skull,0, +310,zombie_head,0, +311,zombie_wall_head,0, +312,player_head,0, +313,player_wall_head,0, +314,creeper_head,0, +315,creeper_wall_head,0, +316,dragon_head,0, +317,dragon_wall_head,0, +318,anvil,10987431, +319,chipped_anvil,10987431, +320,damaged_anvil,10987431, +321,trapped_chest,9402184, +322,light_weighted_pressure_plate,10987431, +323,heavy_weighted_pressure_plate,10987431, +324,comparator,0, +325,daylight_detector,9402184, +326,redstone_block,10987431, +327,nether_quartz_ore,7368816, +328,hopper,10987431, +329,quartz_block,7368816, +330,chiseled_quartz_block,7368816, +331,quartz_pillar,7368816, +332,quartz_stairs,7368816, +333,activator_rail,0, +334,dropper,7368816, +335,white_terracotta,7368816, +336,orange_terracotta,7368816, +337,magenta_terracotta,7368816, +338,light_blue_terracotta,7368816, +339,yellow_terracotta,7368816, +340,lime_terracotta,7368816, +341,pink_terracotta,7368816, +342,gray_terracotta,7368816, +343,light_gray_terracotta,7368816, +344,cyan_terracotta,7368816, +345,purple_terracotta,7368816, +346,blue_terracotta,7368816, +347,brown_terracotta,7368816, +348,green_terracotta,7368816, +349,red_terracotta,7368816, +350,black_terracotta,7368816, +351,white_stained_glass_pane,0, +352,orange_stained_glass_pane,0, +353,magenta_stained_glass_pane,0, +354,light_blue_stained_glass_pane,0, +355,yellow_stained_glass_pane,0, +356,lime_stained_glass_pane,0, +357,pink_stained_glass_pane,0, +358,gray_stained_glass_pane,0, +359,light_gray_stained_glass_pane,0, +360,cyan_stained_glass_pane,0, +361,purple_stained_glass_pane,0, +362,blue_stained_glass_pane,0, +363,brown_stained_glass_pane,0, +364,green_stained_glass_pane,0, +365,red_stained_glass_pane,0, +366,black_stained_glass_pane,0, +367,acacia_stairs,9402184, +368,dark_oak_stairs,9402184, +369,slime_block,10791096, +370,barrier,0, +371,iron_trapdoor,10987431, +372,prismarine,7368816, +373,prismarine_bricks,7368816, +374,dark_prismarine,7368816, +375,prismarine_stairs,7368816, +376,prismarine_brick_stairs,7368816, +377,dark_prismarine_stairs,7368816, +378,prismarine_slab,7368816, +379,prismarine_brick_slab,7368816, +380,dark_prismarine_slab,7368816, +381,sea_lantern,0, +382,hay_block,8368696, +383,white_carpet,13092807, +384,orange_carpet,13092807, +385,magenta_carpet,13092807, +386,light_blue_carpet,13092807, +387,yellow_carpet,13092807, +388,lime_carpet,13092807, +389,pink_carpet,13092807, +390,gray_carpet,13092807, +391,light_gray_carpet,13092807, +392,cyan_carpet,13092807, +393,purple_carpet,13092807, +394,blue_carpet,13092807, +395,brown_carpet,13092807, +396,green_carpet,13092807, +397,red_carpet,13092807, +398,black_carpet,13092807, +399,terracotta,7368816, +400,coal_block,7368816, +401,packed_ice,10526975, +402,sunflower,31744,flower +403,lilac,31744,flower +404,rose_bush,31744,flower +405,peony,31744,flower +406,tall_grass,31744,plant +407,large_fern,31744,plant +408,white_banner,9402184, +409,orange_banner,9402184, +410,magenta_banner,9402184, +411,light_blue_banner,9402184, +412,yellow_banner,9402184, +413,lime_banner,9402184, +414,pink_banner,9402184, +415,gray_banner,9402184, +416,light_gray_banner,9402184, +417,cyan_banner,9402184, +418,purple_banner,9402184, +419,blue_banner,9402184, +420,brown_banner,9402184, +421,green_banner,9402184, +422,red_banner,9402184, +423,black_banner,9402184, +424,white_wall_banner,9402184, +425,orange_wall_banner,9402184, +426,magenta_wall_banner,9402184, +427,light_blue_wall_banner,9402184, +428,yellow_wall_banner,9402184, +429,lime_wall_banner,9402184, +430,pink_wall_banner,9402184, +431,gray_wall_banner,9402184, +432,light_gray_wall_banner,9402184, +433,cyan_wall_banner,9402184, +434,purple_wall_banner,9402184, +435,blue_wall_banner,9402184, +436,brown_wall_banner,9402184, +437,green_wall_banner,9402184, +438,red_wall_banner,9402184, +439,black_wall_banner,9402184, +440,red_sandstone,7368816, +441,chiseled_red_sandstone,7368816, +442,cut_red_sandstone,7368816, +443,red_sandstone_stairs,7368816, +444,oak_slab,9402184, +445,spruce_slab,9402184, +446,birch_slab,9402184, +447,jungle_slab,9402184, +448,acacia_slab,9402184, +449,dark_oak_slab,9402184, +450,stone_slab,7368816, +451,smooth_stone_slab,7368816, +452,sandstone_slab,7368816, +453,cut_sandstone_slab,7368816, +454,petrified_oak_slab,7368816, +455,cobblestone_slab,7368816, +456,brick_slab,7368816, +457,stone_brick_slab,7368816, +458,nether_brick_slab,7368816, +459,quartz_slab,7368816, +460,red_sandstone_slab,7368816, +461,cut_red_sandstone_slab,7368816, +462,purpur_slab,7368816, +463,smooth_stone,7368816, +464,smooth_sandstone,7368816, +465,smooth_quartz,7368816, +466,smooth_red_sandstone,7368816, +467,spruce_fence_gate,9402184, +468,birch_fence_gate,9402184, +469,jungle_fence_gate,9402184, +470,acacia_fence_gate,9402184, +471,dark_oak_fence_gate,9402184, +472,spruce_fence,9402184, +473,birch_fence,9402184, +474,jungle_fence,9402184, +475,acacia_fence,9402184, +476,dark_oak_fence,9402184, +477,spruce_door,9402184, +478,birch_door,9402184, +479,jungle_door,9402184, +480,acacia_door,9402184, +481,dark_oak_door,9402184, +482,end_rod,0, +483,chorus_plant,31744, +484,chorus_flower,31744, +485,purpur_block,7368816, +486,purpur_pillar,7368816, +487,purpur_stairs,7368816, +488,end_stone_bricks,7368816, +489,beetroots,31744, +490,grass_path,9923917, +491,end_gateway,0, +492,repeating_command_block,10987431, +493,chain_command_block,10987431, +494,frosted_ice,10526975, +495,magma_block,7368816, +496,nether_wart_block,8368696, +497,red_nether_bricks,7368816, +498,bone_block,7368816, +499,structure_void,0, +500,observer,7368816, +501,shulker_box,8339378, +502,white_shulker_box,8339378, +503,orange_shulker_box,8339378, +504,magenta_shulker_box,8339378, +505,light_blue_shulker_box,8339378, +506,yellow_shulker_box,8339378, +507,lime_shulker_box,8339378, +508,pink_shulker_box,8339378, +509,gray_shulker_box,8339378, +510,light_gray_shulker_box,8339378, +511,cyan_shulker_box,8339378, +512,purple_shulker_box,8339378, +513,blue_shulker_box,8339378, +514,brown_shulker_box,8339378, +515,green_shulker_box,8339378, +516,red_shulker_box,8339378, +517,black_shulker_box,8339378, +518,white_glazed_terracotta,7368816, +519,orange_glazed_terracotta,7368816, +520,magenta_glazed_terracotta,7368816, +521,light_blue_glazed_terracotta,7368816, +522,yellow_glazed_terracotta,7368816, +523,lime_glazed_terracotta,7368816, +524,pink_glazed_terracotta,7368816, +525,gray_glazed_terracotta,7368816, +526,light_gray_glazed_terracotta,7368816, +527,cyan_glazed_terracotta,7368816, +528,purple_glazed_terracotta,7368816, +529,blue_glazed_terracotta,7368816, +530,brown_glazed_terracotta,7368816, +531,green_glazed_terracotta,7368816, +532,red_glazed_terracotta,7368816, +533,black_glazed_terracotta,7368816, +534,white_concrete,7368816, +535,orange_concrete,7368816, +536,magenta_concrete,7368816, +537,light_blue_concrete,7368816, +538,yellow_concrete,7368816, +539,lime_concrete,7368816, +540,pink_concrete,7368816, +541,gray_concrete,7368816, +542,light_gray_concrete,7368816, +543,cyan_concrete,7368816, +544,purple_concrete,7368816, +545,blue_concrete,7368816, +546,brown_concrete,7368816, +547,green_concrete,7368816, +548,red_concrete,7368816, +549,black_concrete,7368816, +550,white_concrete_powder,16247203, +551,orange_concrete_powder,16247203, +552,magenta_concrete_powder,16247203, +553,light_blue_concrete_powder,16247203, +554,yellow_concrete_powder,16247203, +555,lime_concrete_powder,16247203, +556,pink_concrete_powder,16247203, +557,gray_concrete_powder,16247203, +558,light_gray_concrete_powder,16247203, +559,cyan_concrete_powder,16247203, +560,purple_concrete_powder,16247203, +561,blue_concrete_powder,16247203, +562,brown_concrete_powder,16247203, +563,green_concrete_powder,16247203, +564,red_concrete_powder,16247203, +565,black_concrete_powder,16247203, +566,kelp,4210943, +567,kelp_plant,4210943, +568,dried_kelp_block,8368696, +569,turtle_egg,31744, +570,dead_tube_coral_block,7368816, +571,dead_brain_coral_block,7368816, +572,dead_bubble_coral_block,7368816, +573,dead_fire_coral_block,7368816, +574,dead_horn_coral_block,7368816, +575,tube_coral_block,7368816, +576,brain_coral_block,7368816, +577,bubble_coral_block,7368816, +578,fire_coral_block,7368816, +579,horn_coral_block,7368816, +580,dead_tube_coral,7368816, +581,dead_brain_coral,7368816, +582,dead_bubble_coral,7368816, +583,dead_fire_coral,7368816, +584,dead_horn_coral,7368816, +585,tube_coral,4210943, +586,brain_coral,4210943, +587,bubble_coral,4210943, +588,fire_coral,4210943, +589,horn_coral,4210943, +590,dead_tube_coral_fan,7368816, +591,dead_brain_coral_fan,7368816, +592,dead_bubble_coral_fan,7368816, +593,dead_fire_coral_fan,7368816, +594,dead_horn_coral_fan,7368816, +595,tube_coral_fan,4210943, +596,brain_coral_fan,4210943, +597,bubble_coral_fan,4210943, +598,fire_coral_fan,4210943, +599,horn_coral_fan,4210943, +600,dead_tube_coral_wall_fan,7368816, +601,dead_brain_coral_wall_fan,7368816, +602,dead_bubble_coral_wall_fan,7368816, +603,dead_fire_coral_wall_fan,7368816, +604,dead_horn_coral_wall_fan,7368816, +605,tube_coral_wall_fan,4210943, +606,brain_coral_wall_fan,4210943, +607,bubble_coral_wall_fan,4210943, +608,fire_coral_wall_fan,4210943, +609,horn_coral_wall_fan,4210943, +610,sea_pickle,4210943, +611,blue_ice,10526975, +612,conduit,0, +613,bamboo_sapling,9402184,plant +614,bamboo,9402184,plant +615,potted_bamboo,0, +616,void_air,0,dirt +617,cave_air,0,dirt +618,bubble_column,4210943, +619,polished_granite_stairs,7368816, +620,smooth_red_sandstone_stairs,7368816, +621,mossy_stone_brick_stairs,7368816, +622,polished_diorite_stairs,7368816, +623,mossy_cobblestone_stairs,7368816, +624,end_stone_brick_stairs,7368816, +625,stone_stairs,7368816, +626,smooth_sandstone_stairs,7368816, +627,smooth_quartz_stairs,7368816, +628,granite_stairs,7368816, +629,andesite_stairs,7368816, +630,red_nether_brick_stairs,7368816, +631,polished_andesite_stairs,7368816, +632,diorite_stairs,7368816, +633,polished_granite_slab,7368816, +634,smooth_red_sandstone_slab,7368816, +635,mossy_stone_brick_slab,7368816, +636,polished_diorite_slab,7368816, +637,mossy_cobblestone_slab,7368816, +638,end_stone_brick_slab,7368816, +639,smooth_sandstone_slab,7368816, +640,smooth_quartz_slab,7368816, +641,granite_slab,7368816, +642,andesite_slab,7368816, +643,red_nether_brick_slab,7368816, +644,polished_andesite_slab,7368816, +645,diorite_slab,7368816, +646,brick_wall,7368816, +647,prismarine_wall,7368816, +648,red_sandstone_wall,7368816, +649,mossy_stone_brick_wall,7368816, +650,granite_wall,7368816, +651,stone_brick_wall,7368816, +652,nether_brick_wall,7368816, +653,andesite_wall,7368816, +654,red_nether_brick_wall,7368816, +655,sandstone_wall,7368816, +656,end_stone_brick_wall,7368816, +657,diorite_wall,7368816, +658,scaffolding,0, +659,loom,9402184, +660,barrel,9402184, +661,smoker,7368816, +662,blast_furnace,7368816, +663,cartography_table,9402184, +664,fletching_table,9402184, +665,grindstone,10987431, +666,lectern,9402184, +667,smithing_table,9402184, +668,stonecutter,7368816, +669,bell,10987431, +670,lantern,10987431, +671,campfire,9402184, +672,sweet_berry_bush,31744, +673,structure_block,10987431, +674,jigsaw,10987431, +675,composter,9402184, +676,bee_nest,9402184, +677,beehive,9402184, +678,honey_block,10791096, +679,honeycomb_block,10791096, diff --git a/imaginaire/model_utils/gancraft/layers.py b/imaginaire/model_utils/gancraft/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d49900b0187c575d37c59a4fa6f62fb06413ef1d --- /dev/null +++ b/imaginaire/model_utils/gancraft/layers.py @@ -0,0 +1,153 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import numpy as np +import torch +import torch.nn as nn + + +class AffineMod(nn.Module): + r"""Learning affine modulation of activation. + + Args: + in_features (int): Number of input features. + style_features (int): Number of style features. + mod_bias (bool): Whether to modulate bias. + """ + + def __init__(self, + in_features, + style_features, + mod_bias=True + ): + super().__init__() + self.weight_alpha = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features)) + self.bias_alpha = nn.Parameter(torch.full([in_features], 1, dtype=torch.float)) # init to 1 + self.weight_beta = None + self.bias_beta = None + self.mod_bias = mod_bias + if mod_bias: + self.weight_beta = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features)) + self.bias_beta = nn.Parameter(torch.full([in_features], 0, dtype=torch.float)) + + @staticmethod + def _linear_f(x, w, b): + w = w.to(x.dtype) + x_shape = x.shape + x = x.reshape(-1, x_shape[-1]) + if b is not None: + b = b.to(x.dtype) + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = x.reshape(*x_shape[:-1], -1) + return x + + # x: B, ... , Cin + # z: B, 1, 1, , Cz + def forward(self, x, z): + x_shape = x.shape + z_shape = z.shape + x = x.reshape(x_shape[0], -1, x_shape[-1]) + z = z.reshape(z_shape[0], 1, z_shape[-1]) + + alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I] + x = x * alpha + + if self.mod_bias: + beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I] + x = x + beta + + x = x.reshape(*x_shape[:-1], x.shape[-1]) + return x + + +class ModLinear(nn.Module): + r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod). + Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across + multiple inputs. + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + style_features (int): Number of style features. + bias (bool): Apply additive bias before the activation function? + mod_bias (bool): Whether to modulate bias. + output_mode (bool): If True, modulate output instead of input. + weight_gain (float): Initialization gain + """ + + def __init__(self, + in_features, + out_features, + style_features, + bias=True, + mod_bias=True, + output_mode=False, + weight_gain=1, + bias_init=0 + ): + super().__init__() + weight_gain = weight_gain / np.sqrt(in_features) + self.weight = nn.Parameter(torch.randn([out_features, in_features]) * weight_gain) + self.bias = nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_alpha = nn.Parameter(torch.randn([in_features, style_features]) / np.sqrt(style_features)) + self.bias_alpha = nn.Parameter(torch.full([in_features], 1, dtype=torch.float)) # init to 1 + self.weight_beta = None + self.bias_beta = None + self.mod_bias = mod_bias + self.output_mode = output_mode + if mod_bias: + if output_mode: + mod_bias_dims = out_features + else: + mod_bias_dims = in_features + self.weight_beta = nn.Parameter(torch.randn([mod_bias_dims, style_features]) / np.sqrt(style_features)) + self.bias_beta = nn.Parameter(torch.full([mod_bias_dims], 0, dtype=torch.float)) + + @staticmethod + def _linear_f(x, w, b): + w = w.to(x.dtype) + x_shape = x.shape + x = x.reshape(-1, x_shape[-1]) + if b is not None: + b = b.to(x.dtype) + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = x.reshape(*x_shape[:-1], -1) + return x + + # x: B, ... , Cin + # z: B, 1, 1, , Cz + def forward(self, x, z): + x_shape = x.shape + z_shape = z.shape + x = x.reshape(x_shape[0], -1, x_shape[-1]) + z = z.reshape(z_shape[0], 1, z_shape[-1]) + + alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I] + w = self.weight.to(x.dtype) # [O I] + w = w.unsqueeze(0) * alpha # [1 O I] * [B 1 I] = [B O I] + + if self.mod_bias: + beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I] + if not self.output_mode: + x = x + beta + + b = self.bias + if b is not None: + b = b.to(x.dtype)[None, None, :] + if self.mod_bias and self.output_mode: + if b is None: + b = beta + else: + b = b + beta + + # [B ? I] @ [B I O] = [B ? O] + if b is not None: + x = torch.baddbmm(b, x, w.transpose(1, 2)) + else: + x = x.bmm(w.transpose(1, 2)) + x = x.reshape(*x_shape[:-1], x.shape[-1]) + return x diff --git a/imaginaire/model_utils/gancraft/loss.py b/imaginaire/model_utils/gancraft/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b1811de5307535167f645b4ea8a889a468b41780 --- /dev/null +++ b/imaginaire/model_utils/gancraft/loss.py @@ -0,0 +1,96 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GANLoss(nn.Module): + def __init__(self, target_real_label=1.0, target_fake_label=0.0): + r"""GAN loss constructor. + + Args: + target_real_label (float): Desired output label for the real images. + target_fake_label (float): Desired output label for the fake images. + """ + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_tensor = None + self.fake_label_tensor = None + + def forward(self, input_x, t_real, weight=None, + reduce_dim=True, dis_update=True): + r"""GAN loss computation. + + Args: + input_x (tensor or list of tensors): Output values. + t_real (boolean): Is this output value for real images. + reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use + multi-resolution discriminators. + weight (float): Weight to scale the loss value. + dis_update (boolean): Updating the discriminator or the generator. + Returns: + loss (tensor): Loss value. + """ + if isinstance(input_x, list): + loss = 0 + for pred_i in input_x: + if isinstance(pred_i, list): + pred_i = pred_i[-1] + loss_tensor = self.loss(pred_i, t_real, weight, + reduce_dim, dis_update) + bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) + new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) + loss += new_loss + return loss / len(input_x) + else: + return self.loss(input_x, t_real, weight, reduce_dim, dis_update) + + def loss(self, input_x, t_real, weight=None, + reduce_dim=True, dis_update=True): + r"""N+1 label GAN loss computation. + + Args: + input_x (tensor): Output values. + t_real (boolean): Is this output value for real images. + reduce_dim (boolean): Whether we reduce the dimensions first. This makes a difference when we use + multi-resolution discriminators. + weight (float): Weight to scale the loss value. + dis_update (boolean): Updating the discriminator or the generator. + Returns: + loss (tensor): Loss value. + """ + assert reduce_dim is True + pred = input_x['pred'].clone() + label = input_x['label'].clone() + batch_size = pred.size(0) + + # ignore label 0 + label[:, 0, ...] = 0 + pred[:, 0, ...] = 0 + pred = F.log_softmax(pred, dim=1) + assert pred.size(1) == (label.size(1) + 1) + if dis_update: + if t_real: + pred_real = pred[:, :-1, :, :] + loss = - label * pred_real + loss = torch.sum(loss, dim=1, keepdim=True) + else: + pred_fake = pred[:, -1, None, :, :] # N plus 1 + loss = - pred_fake + else: + assert t_real, "GAN loss must be aiming for real." + pred_real = pred[:, :-1, :, :] + loss = - label * pred_real + loss = torch.sum(loss, dim=1, keepdim=True) + + if weight is not None: + loss = loss * weight + if reduce_dim: + loss = torch.mean(loss) + else: + loss = loss.view(batch_size, -1).mean(dim=1) + return loss diff --git a/imaginaire/model_utils/gancraft/mc_lbl_reduction.py b/imaginaire/model_utils/gancraft/mc_lbl_reduction.py new file mode 100644 index 0000000000000000000000000000000000000000..03fec1d3b600cfd31358cf480924da5232e0104a --- /dev/null +++ b/imaginaire/model_utils/gancraft/mc_lbl_reduction.py @@ -0,0 +1,83 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os +import csv + + +class ReducedLabelMapper: + def __init__(self): + this_path = os.path.dirname(os.path.abspath(__file__)) + print('[ReducedLabelMapper] Loading from {}'.format(this_path)) + + # Load Minecraft LUT + mcid2rdlbl_lut = {} + mcid2mclbl_lut = {} + with open(os.path.join(this_path, 'mc_reduction.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + for row in csvreader: + mcid = int(row[0]) + mcid2rdlbl_lut[mcid] = row[3] + mcid2mclbl_lut[mcid] = row[1] + + # Load reduced label set + reduced_lbls = [] + rdlbl2rdid = {} + with open(os.path.join(this_path, 'reduced_coco_lbls.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + for idx, row in enumerate(csvreader): + rdlbl2rdid[row[0]] = idx + reduced_lbls.append(row[0]) + print(['{}: {}'.format(rdid, rdlbl) for rdid, rdlbl in enumerate(reduced_lbls)]) + # The first label should always be 'ignore' + assert reduced_lbls[0] == 'ignore' + + # Generate Minecraft ID to Reduced ID LUT + mcid2rdid_lut = [] + for mcid in range(len(mcid2rdlbl_lut)): + rdlbl = mcid2rdlbl_lut[mcid] + if rdlbl == '': + rdlbl = 'ignore' + rdid = rdlbl2rdid[rdlbl] + mcid2rdid_lut.append(rdid) + + # ================= coco part ================== + gg_label_list = [] + gglbl2ggid = {} + with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + for idx, row in enumerate(csvreader): + gg_label_list.append(row[0]) + gglbl2ggid[row[0]] = idx + + # Load coco -> reduced mapping table + gglbl2rdid = {} + with open(os.path.join(this_path, 'gaugan_reduction.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + for idx, row in enumerate(csvreader): + gglbl = row[0] + target_rdlbl = row[1] + ggid = gglbl2ggid[gglbl] + target_rdid = rdlbl2rdid[target_rdlbl] + gglbl2rdid[ggid] = target_rdid + ggid2rdid = [gglbl2rdid[i] for i in range(len(gglbl2rdid))] + + print('[ReducedLabelMapper] #Reduced Labels: {}'.format(len(reduced_lbls))) + + self.mcid2rdid_lut = mcid2rdid_lut + self.ggid2rdid = ggid2rdid + self.reduced_lbls = reduced_lbls + + self.ignore_id = rdlbl2rdid['ignore'] + self.dirt_id = rdlbl2rdid['dirt'] + self.water_id = rdlbl2rdid['water'] + + self.gglbl2ggid = gglbl2ggid + + def gglbl2ggid(self, gglbl): + return self.gglbl2ggid[gglbl] + + +if __name__ == '__main__': + mapper = ReducedLabelMapper() diff --git a/imaginaire/model_utils/gancraft/mc_reduction.csv b/imaginaire/model_utils/gancraft/mc_reduction.csv new file mode 100644 index 0000000000000000000000000000000000000000..254af7255d67be76b5d41cdbd162173e55ec0b9c --- /dev/null +++ b/imaginaire/model_utils/gancraft/mc_reduction.csv @@ -0,0 +1,680 @@ +0,air,0,sky +1,stone,7368816,stone +2,granite,7368816,rock +3,polished_granite,7368816,rock +4,diorite,7368816,rock +5,polished_diorite,7368816,rock +6,andesite,7368816,rock +7,polished_andesite,7368816,rock +8,grass_block,8368696,grass +9,dirt,9923917,dirt +10,coarse_dirt,9923917,dirt +11,podzol,9923917,dirt +12,cobblestone,7368816,stone +13,oak_planks,9402184, +14,spruce_planks,9402184, +15,birch_planks,9402184, +16,jungle_planks,9402184, +17,acacia_planks,9402184, +18,dark_oak_planks,9402184, +19,oak_sapling,31744,grass +20,spruce_sapling,31744,grass +21,birch_sapling,31744,grass +22,jungle_sapling,31744,grass +23,acacia_sapling,31744,grass +24,dark_oak_sapling,31744,grass +25,bedrock,7368816,rock +26,water,4210943,water +27,lava,16711680, +28,sand,16247203,sand +29,red_sand,16247203,sand +30,gravel,16247203,gravel +31,gold_ore,7368816,rock +32,iron_ore,7368816,rock +33,coal_ore,7368816,rock +34,oak_log,9402184,tree +35,spruce_log,9402184,tree +36,birch_log,9402184,tree +37,jungle_log,9402184,tree +38,acacia_log,9402184,tree +39,dark_oak_log,9402184,tree +40,stripped_spruce_log,9402184, +41,stripped_birch_log,9402184, +42,stripped_jungle_log,9402184, +43,stripped_acacia_log,9402184, +44,stripped_dark_oak_log,9402184, +45,stripped_oak_log,9402184, +46,oak_wood,9402184, +47,spruce_wood,9402184, +48,birch_wood,9402184, +49,jungle_wood,9402184, +50,acacia_wood,9402184, +51,dark_oak_wood,9402184, +52,stripped_oak_wood,9402184, +53,stripped_spruce_wood,9402184, +54,stripped_birch_wood,9402184, +55,stripped_jungle_wood,9402184, +56,stripped_acacia_wood,9402184, +57,stripped_dark_oak_wood,9402184, +58,oak_leaves,31744,tree +59,spruce_leaves,31744,tree +60,birch_leaves,31744,tree +61,jungle_leaves,31744,tree +62,acacia_leaves,31744,tree +63,dark_oak_leaves,31744,tree +64,sponge,15066419, +65,wet_sponge,15066419, +66,glass,0, +67,lapis_ore,7368816, +68,lapis_block,10987431, +69,dispenser,7368816, +70,sandstone,7368816,sand +71,chiseled_sandstone,7368816,sand +72,cut_sandstone,7368816,sand +73,note_block,9402184, +74,white_bed,13092807, +75,orange_bed,13092807, +76,magenta_bed,13092807, +77,light_blue_bed,13092807, +78,yellow_bed,13092807, +79,lime_bed,13092807, +80,pink_bed,13092807, +81,gray_bed,13092807, +82,light_gray_bed,13092807, +83,cyan_bed,13092807, +84,purple_bed,13092807, +85,blue_bed,13092807, +86,brown_bed,13092807, +87,green_bed,13092807, +88,red_bed,13092807, +89,black_bed,13092807, +90,powered_rail,0, +91,detector_rail,0, +92,sticky_piston,7368816, +93,cobweb,13092807, +94,grass,31744,grass +95,fern,31744,grass +96,dead_bush,31744,grass +97,seagrass,4210943,water +98,tall_seagrass,4210943,water +99,piston,7368816, +100,piston_head,7368816, +101,white_wool,13092807, +102,orange_wool,13092807, +103,magenta_wool,13092807, +104,light_blue_wool,13092807, +105,yellow_wool,13092807, +106,lime_wool,13092807, +107,pink_wool,13092807, +108,gray_wool,13092807, +109,light_gray_wool,13092807, +110,cyan_wool,13092807, +111,purple_wool,13092807, +112,blue_wool,13092807, +113,brown_wool,13092807, +114,green_wool,13092807, +115,red_wool,13092807, +116,black_wool,13092807, +117,moving_piston,7368816, +118,dandelion,31744,flower +119,poppy,31744,flower +120,blue_orchid,31744,flower +121,allium,31744,flower +122,azure_bluet,31744,flower +123,red_tulip,31744,flower +124,orange_tulip,31744,flower +125,white_tulip,31744,flower +126,pink_tulip,31744,flower +127,oxeye_daisy,31744,flower +128,cornflower,31744,flower +129,wither_rose,31744,flower +130,lily_of_the_valley,31744,flower +131,brown_mushroom,31744,flower +132,red_mushroom,31744,flower +133,gold_block,10987431, +134,iron_block,10987431, +135,bricks,7368816, +136,tnt,16711680, +137,bookshelf,9402184, +138,mossy_cobblestone,7368816, +139,obsidian,7368816, +140,torch,0, +141,wall_torch,0, +142,fire,0, +143,spawner,7368816, +144,oak_stairs,9402184, +145,chest,9402184, +146,redstone_wire,0, +147,diamond_ore,7368816, +148,diamond_block,10987431, +149,crafting_table,9402184, +150,wheat,31744, +151,farmland,9923917, +152,furnace,7368816, +153,oak_sign,9402184, +154,spruce_sign,9402184, +155,birch_sign,9402184, +156,acacia_sign,9402184, +157,jungle_sign,9402184, +158,dark_oak_sign,9402184, +159,oak_door,9402184, +160,ladder,0, +161,rail,0, +162,cobblestone_stairs,7368816, +163,oak_wall_sign,9402184, +164,spruce_wall_sign,9402184, +165,birch_wall_sign,9402184, +166,acacia_wall_sign,9402184, +167,jungle_wall_sign,9402184, +168,dark_oak_wall_sign,9402184, +169,lever,0, +170,stone_pressure_plate,7368816, +171,iron_door,10987431, +172,oak_pressure_plate,9402184, +173,spruce_pressure_plate,9402184, +174,birch_pressure_plate,9402184, +175,jungle_pressure_plate,9402184, +176,acacia_pressure_plate,9402184, +177,dark_oak_pressure_plate,9402184, +178,redstone_ore,7368816, +179,redstone_torch,0, +180,redstone_wall_torch,0, +181,stone_button,0, +182,snow,16777215,snow +183,ice,10526975,snow +184,snow_block,16777215,snow +185,cactus,31744,flower +186,clay,10791096,dirt +187,sugar_cane,31744,flower +188,jukebox,9402184, +189,oak_fence,9402184, +190,pumpkin,31744, +191,netherrack,7368816, +192,soul_sand,16247203, +193,glowstone,0, +194,nether_portal,0, +195,carved_pumpkin,31744, +196,jack_o_lantern,31744, +197,cake,0, +198,repeater,0, +199,white_stained_glass,0, +200,orange_stained_glass,0, +201,magenta_stained_glass,0, +202,light_blue_stained_glass,0, +203,yellow_stained_glass,0, +204,lime_stained_glass,0, +205,pink_stained_glass,0, +206,gray_stained_glass,0, +207,light_gray_stained_glass,0, +208,cyan_stained_glass,0, +209,purple_stained_glass,0, +210,blue_stained_glass,0, +211,brown_stained_glass,0, +212,green_stained_glass,0, +213,red_stained_glass,0, +214,black_stained_glass,0, +215,oak_trapdoor,9402184, +216,spruce_trapdoor,9402184, +217,birch_trapdoor,9402184, +218,jungle_trapdoor,9402184, +219,acacia_trapdoor,9402184, +220,dark_oak_trapdoor,9402184, +221,stone_bricks,7368816, +222,mossy_stone_bricks,7368816, +223,cracked_stone_bricks,7368816, +224,chiseled_stone_bricks,7368816, +225,infested_stone,10791096, +226,infested_cobblestone,10791096, +227,infested_stone_bricks,10791096, +228,infested_mossy_stone_bricks,10791096, +229,infested_cracked_stone_bricks,10791096, +230,infested_chiseled_stone_bricks,10791096, +231,brown_mushroom_block,9402184,tree +232,red_mushroom_block,9402184,tree +233,mushroom_stem,9402184,tree +234,iron_bars,10987431, +235,glass_pane,0, +236,melon,31744, +237,attached_pumpkin_stem,31744, +238,attached_melon_stem,31744, +239,pumpkin_stem,31744, +240,melon_stem,31744, +241,vine,31744,tree +242,oak_fence_gate,9402184, +243,brick_stairs,7368816, +244,stone_brick_stairs,7368816, +245,mycelium,8368696, +246,lily_pad,31744,grass +247,nether_bricks,7368816, +248,nether_brick_fence,7368816, +249,nether_brick_stairs,7368816, +250,nether_wart,31744, +251,enchanting_table,7368816, +252,brewing_stand,10987431, +253,cauldron,10987431, +254,end_portal,0, +255,end_portal_frame,7368816, +256,end_stone,7368816, +257,dragon_egg,31744, +258,redstone_lamp,0, +259,cocoa,31744, +260,sandstone_stairs,7368816, +261,emerald_ore,7368816, +262,ender_chest,7368816, +263,tripwire_hook,0, +264,tripwire,0, +265,emerald_block,10987431, +266,spruce_stairs,9402184, +267,birch_stairs,9402184, +268,jungle_stairs,9402184, +269,command_block,10987431, +270,beacon,0, +271,cobblestone_wall,7368816, +272,mossy_cobblestone_wall,7368816, +273,flower_pot,0, +274,potted_oak_sapling,0, +275,potted_spruce_sapling,0, +276,potted_birch_sapling,0, +277,potted_jungle_sapling,0, +278,potted_acacia_sapling,0, +279,potted_dark_oak_sapling,0, +280,potted_fern,0, +281,potted_dandelion,0, +282,potted_poppy,0, +283,potted_blue_orchid,0, +284,potted_allium,0, +285,potted_azure_bluet,0, +286,potted_red_tulip,0, +287,potted_orange_tulip,0, +288,potted_white_tulip,0, +289,potted_pink_tulip,0, +290,potted_oxeye_daisy,0, +291,potted_cornflower,0, +292,potted_lily_of_the_valley,0, +293,potted_wither_rose,0, +294,potted_red_mushroom,0, +295,potted_brown_mushroom,0, +296,potted_dead_bush,0, +297,potted_cactus,0, +298,carrots,31744, +299,potatoes,31744, +300,oak_button,0, +301,spruce_button,0, +302,birch_button,0, +303,jungle_button,0, +304,acacia_button,0, +305,dark_oak_button,0, +306,skeleton_skull,0, +307,skeleton_wall_skull,0, +308,wither_skeleton_skull,0, +309,wither_skeleton_wall_skull,0, +310,zombie_head,0, +311,zombie_wall_head,0, +312,player_head,0, +313,player_wall_head,0, +314,creeper_head,0, +315,creeper_wall_head,0, +316,dragon_head,0, +317,dragon_wall_head,0, +318,anvil,10987431, +319,chipped_anvil,10987431, +320,damaged_anvil,10987431, +321,trapped_chest,9402184, +322,light_weighted_pressure_plate,10987431, +323,heavy_weighted_pressure_plate,10987431, +324,comparator,0, +325,daylight_detector,9402184, +326,redstone_block,10987431, +327,nether_quartz_ore,7368816, +328,hopper,10987431, +329,quartz_block,7368816, +330,chiseled_quartz_block,7368816, +331,quartz_pillar,7368816, +332,quartz_stairs,7368816, +333,activator_rail,0, +334,dropper,7368816, +335,white_terracotta,7368816, +336,orange_terracotta,7368816, +337,magenta_terracotta,7368816, +338,light_blue_terracotta,7368816, +339,yellow_terracotta,7368816, +340,lime_terracotta,7368816, +341,pink_terracotta,7368816, +342,gray_terracotta,7368816, +343,light_gray_terracotta,7368816, +344,cyan_terracotta,7368816, +345,purple_terracotta,7368816, +346,blue_terracotta,7368816, +347,brown_terracotta,7368816, +348,green_terracotta,7368816, +349,red_terracotta,7368816, +350,black_terracotta,7368816, +351,white_stained_glass_pane,0, +352,orange_stained_glass_pane,0, +353,magenta_stained_glass_pane,0, +354,light_blue_stained_glass_pane,0, +355,yellow_stained_glass_pane,0, +356,lime_stained_glass_pane,0, +357,pink_stained_glass_pane,0, +358,gray_stained_glass_pane,0, +359,light_gray_stained_glass_pane,0, +360,cyan_stained_glass_pane,0, +361,purple_stained_glass_pane,0, +362,blue_stained_glass_pane,0, +363,brown_stained_glass_pane,0, +364,green_stained_glass_pane,0, +365,red_stained_glass_pane,0, +366,black_stained_glass_pane,0, +367,acacia_stairs,9402184, +368,dark_oak_stairs,9402184, +369,slime_block,10791096, +370,barrier,0, +371,iron_trapdoor,10987431, +372,prismarine,7368816, +373,prismarine_bricks,7368816, +374,dark_prismarine,7368816, +375,prismarine_stairs,7368816, +376,prismarine_brick_stairs,7368816, +377,dark_prismarine_stairs,7368816, +378,prismarine_slab,7368816, +379,prismarine_brick_slab,7368816, +380,dark_prismarine_slab,7368816, +381,sea_lantern,0, +382,hay_block,8368696, +383,white_carpet,13092807, +384,orange_carpet,13092807, +385,magenta_carpet,13092807, +386,light_blue_carpet,13092807, +387,yellow_carpet,13092807, +388,lime_carpet,13092807, +389,pink_carpet,13092807, +390,gray_carpet,13092807, +391,light_gray_carpet,13092807, +392,cyan_carpet,13092807, +393,purple_carpet,13092807, +394,blue_carpet,13092807, +395,brown_carpet,13092807, +396,green_carpet,13092807, +397,red_carpet,13092807, +398,black_carpet,13092807, +399,terracotta,7368816, +400,coal_block,7368816, +401,packed_ice,10526975,snow +402,sunflower,31744,flower +403,lilac,31744,flower +404,rose_bush,31744,flower +405,peony,31744,flower +406,tall_grass,31744,flower +407,large_fern,31744,flower +408,white_banner,9402184, +409,orange_banner,9402184, +410,magenta_banner,9402184, +411,light_blue_banner,9402184, +412,yellow_banner,9402184, +413,lime_banner,9402184, +414,pink_banner,9402184, +415,gray_banner,9402184, +416,light_gray_banner,9402184, +417,cyan_banner,9402184, +418,purple_banner,9402184, +419,blue_banner,9402184, +420,brown_banner,9402184, +421,green_banner,9402184, +422,red_banner,9402184, +423,black_banner,9402184, +424,white_wall_banner,9402184, +425,orange_wall_banner,9402184, +426,magenta_wall_banner,9402184, +427,light_blue_wall_banner,9402184, +428,yellow_wall_banner,9402184, +429,lime_wall_banner,9402184, +430,pink_wall_banner,9402184, +431,gray_wall_banner,9402184, +432,light_gray_wall_banner,9402184, +433,cyan_wall_banner,9402184, +434,purple_wall_banner,9402184, +435,blue_wall_banner,9402184, +436,brown_wall_banner,9402184, +437,green_wall_banner,9402184, +438,red_wall_banner,9402184, +439,black_wall_banner,9402184, +440,red_sandstone,7368816, +441,chiseled_red_sandstone,7368816, +442,cut_red_sandstone,7368816, +443,red_sandstone_stairs,7368816, +444,oak_slab,9402184, +445,spruce_slab,9402184, +446,birch_slab,9402184, +447,jungle_slab,9402184, +448,acacia_slab,9402184, +449,dark_oak_slab,9402184, +450,stone_slab,7368816, +451,smooth_stone_slab,7368816, +452,sandstone_slab,7368816, +453,cut_sandstone_slab,7368816, +454,petrified_oak_slab,7368816, +455,cobblestone_slab,7368816, +456,brick_slab,7368816, +457,stone_brick_slab,7368816, +458,nether_brick_slab,7368816, +459,quartz_slab,7368816, +460,red_sandstone_slab,7368816, +461,cut_red_sandstone_slab,7368816, +462,purpur_slab,7368816, +463,smooth_stone,7368816, +464,smooth_sandstone,7368816, +465,smooth_quartz,7368816, +466,smooth_red_sandstone,7368816, +467,spruce_fence_gate,9402184, +468,birch_fence_gate,9402184, +469,jungle_fence_gate,9402184, +470,acacia_fence_gate,9402184, +471,dark_oak_fence_gate,9402184, +472,spruce_fence,9402184, +473,birch_fence,9402184, +474,jungle_fence,9402184, +475,acacia_fence,9402184, +476,dark_oak_fence,9402184, +477,spruce_door,9402184, +478,birch_door,9402184, +479,jungle_door,9402184, +480,acacia_door,9402184, +481,dark_oak_door,9402184, +482,end_rod,0, +483,chorus_plant,31744, +484,chorus_flower,31744, +485,purpur_block,7368816, +486,purpur_pillar,7368816, +487,purpur_stairs,7368816, +488,end_stone_bricks,7368816, +489,beetroots,31744, +490,grass_path,9923917, +491,end_gateway,0, +492,repeating_command_block,10987431, +493,chain_command_block,10987431, +494,frosted_ice,10526975,snow +495,magma_block,7368816, +496,nether_wart_block,8368696, +497,red_nether_bricks,7368816, +498,bone_block,7368816, +499,structure_void,0, +500,observer,7368816, +501,shulker_box,8339378, +502,white_shulker_box,8339378, +503,orange_shulker_box,8339378, +504,magenta_shulker_box,8339378, +505,light_blue_shulker_box,8339378, +506,yellow_shulker_box,8339378, +507,lime_shulker_box,8339378, +508,pink_shulker_box,8339378, +509,gray_shulker_box,8339378, +510,light_gray_shulker_box,8339378, +511,cyan_shulker_box,8339378, +512,purple_shulker_box,8339378, +513,blue_shulker_box,8339378, +514,brown_shulker_box,8339378, +515,green_shulker_box,8339378, +516,red_shulker_box,8339378, +517,black_shulker_box,8339378, +518,white_glazed_terracotta,7368816, +519,orange_glazed_terracotta,7368816, +520,magenta_glazed_terracotta,7368816, +521,light_blue_glazed_terracotta,7368816, +522,yellow_glazed_terracotta,7368816, +523,lime_glazed_terracotta,7368816, +524,pink_glazed_terracotta,7368816, +525,gray_glazed_terracotta,7368816, +526,light_gray_glazed_terracotta,7368816, +527,cyan_glazed_terracotta,7368816, +528,purple_glazed_terracotta,7368816, +529,blue_glazed_terracotta,7368816, +530,brown_glazed_terracotta,7368816, +531,green_glazed_terracotta,7368816, +532,red_glazed_terracotta,7368816, +533,black_glazed_terracotta,7368816, +534,white_concrete,7368816, +535,orange_concrete,7368816, +536,magenta_concrete,7368816, +537,light_blue_concrete,7368816, +538,yellow_concrete,7368816, +539,lime_concrete,7368816, +540,pink_concrete,7368816, +541,gray_concrete,7368816, +542,light_gray_concrete,7368816, +543,cyan_concrete,7368816, +544,purple_concrete,7368816, +545,blue_concrete,7368816, +546,brown_concrete,7368816, +547,green_concrete,7368816, +548,red_concrete,7368816, +549,black_concrete,7368816, +550,white_concrete_powder,16247203, +551,orange_concrete_powder,16247203, +552,magenta_concrete_powder,16247203, +553,light_blue_concrete_powder,16247203, +554,yellow_concrete_powder,16247203, +555,lime_concrete_powder,16247203, +556,pink_concrete_powder,16247203, +557,gray_concrete_powder,16247203, +558,light_gray_concrete_powder,16247203, +559,cyan_concrete_powder,16247203, +560,purple_concrete_powder,16247203, +561,blue_concrete_powder,16247203, +562,brown_concrete_powder,16247203, +563,green_concrete_powder,16247203, +564,red_concrete_powder,16247203, +565,black_concrete_powder,16247203, +566,kelp,4210943, +567,kelp_plant,4210943, +568,dried_kelp_block,8368696, +569,turtle_egg,31744, +570,dead_tube_coral_block,7368816, +571,dead_brain_coral_block,7368816, +572,dead_bubble_coral_block,7368816, +573,dead_fire_coral_block,7368816, +574,dead_horn_coral_block,7368816, +575,tube_coral_block,7368816, +576,brain_coral_block,7368816, +577,bubble_coral_block,7368816, +578,fire_coral_block,7368816, +579,horn_coral_block,7368816, +580,dead_tube_coral,7368816, +581,dead_brain_coral,7368816, +582,dead_bubble_coral,7368816, +583,dead_fire_coral,7368816, +584,dead_horn_coral,7368816, +585,tube_coral,4210943, +586,brain_coral,4210943, +587,bubble_coral,4210943, +588,fire_coral,4210943, +589,horn_coral,4210943, +590,dead_tube_coral_fan,7368816, +591,dead_brain_coral_fan,7368816, +592,dead_bubble_coral_fan,7368816, +593,dead_fire_coral_fan,7368816, +594,dead_horn_coral_fan,7368816, +595,tube_coral_fan,4210943, +596,brain_coral_fan,4210943, +597,bubble_coral_fan,4210943, +598,fire_coral_fan,4210943, +599,horn_coral_fan,4210943, +600,dead_tube_coral_wall_fan,7368816, +601,dead_brain_coral_wall_fan,7368816, +602,dead_bubble_coral_wall_fan,7368816, +603,dead_fire_coral_wall_fan,7368816, +604,dead_horn_coral_wall_fan,7368816, +605,tube_coral_wall_fan,4210943, +606,brain_coral_wall_fan,4210943, +607,bubble_coral_wall_fan,4210943, +608,fire_coral_wall_fan,4210943, +609,horn_coral_wall_fan,4210943, +610,sea_pickle,4210943, +611,blue_ice,10526975,snow +612,conduit,0, +613,bamboo_sapling,9402184,flower +614,bamboo,9402184,tree +615,potted_bamboo,0, +616,void_air,0,dirt +617,cave_air,0,dirt +618,bubble_column,4210943, +619,polished_granite_stairs,7368816, +620,smooth_red_sandstone_stairs,7368816, +621,mossy_stone_brick_stairs,7368816, +622,polished_diorite_stairs,7368816, +623,mossy_cobblestone_stairs,7368816, +624,end_stone_brick_stairs,7368816, +625,stone_stairs,7368816, +626,smooth_sandstone_stairs,7368816, +627,smooth_quartz_stairs,7368816, +628,granite_stairs,7368816, +629,andesite_stairs,7368816, +630,red_nether_brick_stairs,7368816, +631,polished_andesite_stairs,7368816, +632,diorite_stairs,7368816, +633,polished_granite_slab,7368816, +634,smooth_red_sandstone_slab,7368816, +635,mossy_stone_brick_slab,7368816, +636,polished_diorite_slab,7368816, +637,mossy_cobblestone_slab,7368816, +638,end_stone_brick_slab,7368816, +639,smooth_sandstone_slab,7368816, +640,smooth_quartz_slab,7368816, +641,granite_slab,7368816, +642,andesite_slab,7368816, +643,red_nether_brick_slab,7368816, +644,polished_andesite_slab,7368816, +645,diorite_slab,7368816, +646,brick_wall,7368816, +647,prismarine_wall,7368816, +648,red_sandstone_wall,7368816, +649,mossy_stone_brick_wall,7368816, +650,granite_wall,7368816, +651,stone_brick_wall,7368816, +652,nether_brick_wall,7368816, +653,andesite_wall,7368816, +654,red_nether_brick_wall,7368816, +655,sandstone_wall,7368816, +656,end_stone_brick_wall,7368816, +657,diorite_wall,7368816, +658,scaffolding,0, +659,loom,9402184, +660,barrel,9402184, +661,smoker,7368816, +662,blast_furnace,7368816, +663,cartography_table,9402184, +664,fletching_table,9402184, +665,grindstone,10987431, +666,lectern,9402184, +667,smithing_table,9402184, +668,stonecutter,7368816, +669,bell,10987431, +670,lantern,10987431, +671,campfire,9402184, +672,sweet_berry_bush,31744, +673,structure_block,10987431, +674,jigsaw,10987431, +675,composter,9402184, +676,bee_nest,9402184, +677,beehive,9402184, +678,honey_block,10791096, +679,honeycomb_block,10791096, diff --git a/imaginaire/model_utils/gancraft/mc_utils.py b/imaginaire/model_utils/gancraft/mc_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..faa87e04060866541761586ffe8e41bcce9ca44c --- /dev/null +++ b/imaginaire/model_utils/gancraft/mc_utils.py @@ -0,0 +1,388 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import csv +import time +# For binary dilation +from scipy import ndimage +import os +from imaginaire.model_utils.gancraft.mc_lbl_reduction import ReducedLabelMapper + + +def load_voxel_new(voxel_path, shape=[256, 512, 512]): + voxel_world = np.fromfile(voxel_path, dtype='int32') + voxel_world = voxel_world.reshape( + shape[1]//16, shape[2]//16, 16, 16, shape[0]) + voxel_world = voxel_world.transpose(4, 0, 2, 1, 3) + voxel_world = voxel_world.reshape(shape[0], shape[1], shape[2]) + voxel_world = np.ascontiguousarray(voxel_world) + voxel_world = torch.from_numpy(voxel_world.astype(np.int32)) + return voxel_world + + +def gen_corner_voxel(voxel): + r"""Converting voxel center array to voxel corner array. The size of the + produced array grows by 1 on every dimension. + + Args: + voxel (torch.IntTensor, CPU): Input voxel of three dimensions + """ + structure = np.zeros([3, 3, 3], dtype=np.bool) + structure[1:, 1:, 1:] = True + voxel_p = F.pad(voxel, (0, 1, 0, 1, 0, 1)) + corners = ndimage.binary_dilation(voxel_p.numpy(), structure) + corners = torch.tensor(corners, dtype=torch.int32) + return corners + + +def calc_height_map(voxel_t): + r"""Calculate height map given a voxel grid [Y, X, Z] as input. + The height is defined as the Y index of the surface (non-air) block + + Args: + voxel (Y x X x Z torch.IntTensor, CPU): Input voxel of three dimensions + Output: + heightmap (X x Z torch.IntTensor) + """ + start_time = time.time() + m, h = torch.max((torch.flip(voxel_t, [0]) != 0).int(), dim=0, keepdim=False) + heightmap = voxel_t.shape[0] - 1 - h + heightmap[m == 0] = 0 # Special case when the whole vertical column is empty + + elapsed_time = time.time() - start_time + print("[GANcraft-utils] Heightmap time: {}".format(elapsed_time)) + return heightmap + + +def trans_vec_homo(m, v, is_vec=False): + r"""3-dimensional Homogeneous matrix and regular vector multiplication + Convert v to homogeneous vector, perform M-V multiplication, and convert back + Note that this function does not support autograd. + + Args: + m (4 x 4 tensor): a homogeneous matrix + v (3 tensor): a 3-d vector + vec (bool): if true, v is direction. Otherwise v is point + """ + if is_vec: + v = torch.tensor([v[0], v[1], v[2], 0], dtype=v.dtype) + else: + v = torch.tensor([v[0], v[1], v[2], 1], dtype=v.dtype) + v = torch.mv(m, v) + if not is_vec: + v = v / v[3] + v = v[:3] + return v + + +def cumsum_exclusive(tensor, dim): + cumsum = torch.cumsum(tensor, dim) + cumsum = torch.roll(cumsum, 1, dim) + cumsum.index_fill_(dim, torch.tensor([0], dtype=torch.long, device=tensor.device), 0) + return cumsum + + +def sample_depth_batched(depth2, nsamples, deterministic=False, use_box_boundaries=True, sample_depth=4): + r""" Make best effort to sample points within the same distance for every ray. + Exception: When there is not enough voxel. + + Args: + depth2 (N x 2 x 256 x 256 x 4 x 1 tensor): + - N: Batch. + - 2: Entrance / exit depth for each intersected box. + - 256, 256: Height, Width. + - 4: Number of intersected boxes along the ray. + - 1: One extra dim for consistent tensor dims. + depth2 can include NaNs. + deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling. + use_box_boundaries (bool): Whether to add the entrance / exit points into the sample. + sample_depth (float): Truncate the ray when it travels further than sample_depth inside voxels. + """ + + bs = depth2.size(0) + dim0 = depth2.size(2) + dim1 = depth2.size(3) + dists = depth2[:, 1] - depth2[:, 0] + dists[torch.isnan(dists)] = 0 # N, 256, 256, 4, 1 + accu_depth = torch.cumsum(dists, dim=-2) # N, 256, 256, 4, 1 + total_depth = accu_depth[..., [-1], :] # N, 256, 256, 1, 1 + + total_depth = torch.clamp(total_depth, None, sample_depth) + + # Ignore out of range box boundaries. Fill with random samples. + if use_box_boundaries: + boundary_samples = accu_depth.clone().detach() + boundary_samples_filler = torch.rand_like(boundary_samples) * total_depth + bad_mask = (accu_depth > sample_depth) | (dists == 0) + boundary_samples[bad_mask] = boundary_samples_filler[bad_mask] + + rand_shape = [bs, dim0, dim1, nsamples, 1] + # 256, 256, N, 1 + if deterministic: + rand_samples = torch.empty(rand_shape, dtype=total_depth.dtype, device=total_depth.device) + rand_samples[..., :, 0] = torch.linspace(0, 1, nsamples+2)[1:-1] + else: + rand_samples = torch.rand(rand_shape, dtype=total_depth.dtype, device=total_depth.device) # 256, 256, N, 1 + # Stratified sampling as in NeRF + rand_samples = rand_samples / nsamples + rand_samples[..., :, 0] += torch.linspace(0, 1, nsamples+1, device=rand_samples.device)[:-1] + rand_samples = rand_samples * total_depth # 256, 256, N, 1 + + # Can also include boundaries + if use_box_boundaries: + rand_samples = torch.cat([rand_samples, boundary_samples, torch.zeros( + [bs, dim0, dim1, 1, 1], dtype=total_depth.dtype, device=total_depth.device)], dim=-2) + rand_samples, _ = torch.sort(rand_samples, dim=-2, descending=False) + + midpoints = (rand_samples[..., 1:, :] + rand_samples[..., :-1, :]) / 2 + new_dists = rand_samples[..., 1:, :] - rand_samples[..., :-1, :] + + # Scatter the random samples back + # 256, 256, 1, M, 1 > 256, 256, N, 1, 1 + idx = torch.sum(midpoints.unsqueeze(-3) > accu_depth.unsqueeze(-2), dim=-3) # 256, 256, M, 1 + # print(idx.shape, idx.max(), idx.min()) # max 3, min 0 + + depth_deltas = depth2[:, 0, :, :, 1:, :] - depth2[:, 1, :, :, :-1, :] # There might be NaNs! + depth_deltas = torch.cumsum(depth_deltas, dim=-2) + depth_deltas = torch.cat([depth2[:, 0, :, :, [0], :], depth_deltas+depth2[:, 0, :, :, [0], :]], dim=-2) + heads = torch.gather(depth_deltas, -2, idx) # 256 256 M 1 + # heads = torch.gather(depth2[0], -2, idx) # 256 256 M 1 + + # print(torch.any(torch.isnan(heads))) + rand_depth = heads + midpoints # 256 256 N 1 + + return rand_depth, new_dists, idx + + +def volum_rendering_relu(sigma, dists, dim=2): + free_energy = F.relu(sigma) * dists + + a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here + b = torch.exp(-cumsum_exclusive(free_energy, dim=dim)) # probability of everything is empty up to now + probs = a * b # probability of the ray hits something here + + return probs + + +class McVoxel(nn.Module): + r"""Voxel management.""" + + def __init__(self, voxel_t, preproc_ver): + super(McVoxel, self).__init__() + # Filter voxel + voxel_t[voxel_t == 246] = 0 # lily_pad + voxel_t[voxel_t == 241] = 0 # vine + voxel_t[voxel_t == 611] = 26 # Blue ice -> water + voxel_t[voxel_t == 183] = 26 # ice -> water + voxel_t[voxel_t == 401] = 25 # Packed ice -> bedrock + + if preproc_ver >= 3 and preproc_ver < 6: + voxel_t[voxel_t == 27] = 25 # Lava -> bedrock + voxel_t[voxel_t == 616] = 9 # void_air -> dirt + voxel_t[voxel_t == 617] = 25 # cave_air -> bedrock + + if preproc_ver >= 6: + voxel_t[voxel_t == 616] = 0 # void_air -> air + voxel_t[voxel_t == 617] = 0 # cave_air -> air + + # Simplify voxel + structure = ndimage.generate_binary_structure(3, 3) + mask = voxel_t.numpy() > 0 + if preproc_ver == 4: # Hollow bottom + mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1) + voxel_t[mask] = 0 + if preproc_ver >= 5: # Close cell before hollow bottom + mask = ndimage.morphology.binary_dilation(mask, iterations=1, border_value=1) + mask = ndimage.morphology.binary_erosion(mask, iterations=1, border_value=1) + mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1) + voxel_t[mask] = 0 + + self.register_buffer('voxel_t', voxel_t, persistent=False) + + self.trans_mat = torch.eye(4) # Transform voxel to world + # Generate heightmap for camera trajectory generation + self.heightmap = calc_height_map(self.voxel_t) + self._truncate_voxel() + # Convert voxel ([X, Y, Z], int32) to corner ([X+1, Y+1, Z+1], int32) (Requires CPU tensor) + corner_t = gen_corner_voxel(self.voxel_t) + self.register_buffer('corner_t', corner_t, persistent=False) + + # Generate 3D position to 1D feature LUT table + nfilledvox = torch.sum(self.corner_t > 0) + print('[GANcraft-utils] Number of filled voxels: {} / {}'.format(nfilledvox.item(), torch.numel(self.corner_t))) + # Zero means non-existent voxel. + self.corner_t[self.corner_t > 0] = torch.arange(start=1, end=nfilledvox+1, step=1, dtype=torch.int32) + self.nfilledvox = nfilledvox + + def world2local(self, v, is_vec=False): + mat_world2local = torch.inverse(self.trans_mat) + return trans_vec_homo(mat_world2local, v, is_vec) + + def _truncate_voxel(self): + gnd_level = self.heightmap.min() + sky_level = self.heightmap.max() + 1 + self.voxel_t = self.voxel_t[gnd_level:sky_level, :, :] + self.trans_mat[0, 3] += gnd_level + print('[GANcraft-utils] Voxel truncated. Gnd: {}; Sky: {}.'.format(gnd_level.item(), sky_level.item())) + + def is_sea(self, loc): + r"""loc: [2]: x, z.""" + x = int(loc[1]) + z = int(loc[2]) + if x < 0 or x > self.heightmap.size(0) or z < 0 or z > self.heightmap.size(1): + print('[McVoxel] is_sea(): Index out of bound.') + return True + y = self.heightmap[x, z] - self.trans_mat[0, 3] + y = int(y) + if self.voxel_t[y, x, z] == 26: + print('[McVoxel] is_sea(): Get a sea.') + print(self.voxel_t[y, x, z], self.voxel_t[y+1, x, z]) + return True + else: + return False + + +class MCLabelTranslator: + r"""Resolving mapping across Minecraft voxel, coco-stuff label and GANcraft reduced label set.""" + + def __init__(self): + this_path = os.path.dirname(os.path.abspath(__file__)) + # Load voxel name lut + id2name_lut = {} + id2color_lut = {} + id2glbl_lut = {} + with open(os.path.join(this_path, 'id2name_gg.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + for row in csvreader: + id2name_lut[int(row[0])] = row[1] + id2color_lut[int(row[0])] = int(row[2]) + id2glbl_lut[int(row[0])] = row[3] + + # Load GauGAN color lut + glbl2color_lut = {} + glbl2cocoidx_lut = {} + with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile: + csvreader = csv.reader(csvfile, delimiter=',') + cocoidx = 1 # 0 is "Others" + for row in csvreader: + color = int(row[1].lstrip('#'), 16) + glbl2color_lut[row[0]] = color + glbl2cocoidx_lut[row[0]] = cocoidx + cocoidx += 1 + + # Generate id2ggcolor lut + id2ggcolor_lut = {} + for k, v in id2glbl_lut.items(): + if v: + id2ggcolor_lut[k] = glbl2color_lut[v] + else: + id2ggcolor_lut[k] = 0 + + # Generate id2cocoidx + id2cocoidx_lut = {} + for k, v in id2glbl_lut.items(): + if v: + id2cocoidx_lut[k] = glbl2cocoidx_lut[v] + else: + id2cocoidx_lut[k] = 0 + + self.id2color_lut = id2color_lut + self.id2name_lut = id2name_lut + self.id2glbl_lut = id2glbl_lut + self.id2ggcolor_lut = id2ggcolor_lut + self.id2cocoidx_lut = id2cocoidx_lut + + if True: + mapper = ReducedLabelMapper() + mcid2rdid_lut = mapper.mcid2rdid_lut + mcid2rdid_lut = torch.tensor(mcid2rdid_lut, dtype=torch.long) + self.mcid2rdid_lut = mcid2rdid_lut + self.num_reduced_lbls = len(mapper.reduced_lbls) + self.ignore_id = mapper.ignore_id + self.dirt_id = mapper.dirt_id + self.water_id = mapper.water_id + + self.mapper = mapper + + ggid2rdid_lut = mapper.ggid2rdid + [0] # Last index is ignore + ggid2rdid_lut = torch.tensor(ggid2rdid_lut, dtype=torch.long) + self.ggid2rdid_lut = ggid2rdid_lut + if True: + mc2coco_lut = list(zip(*sorted([(k, v) for k, v in self.id2cocoidx_lut.items()])))[1] + mc2coco_lut = torch.tensor(mc2coco_lut, dtype=torch.long) + self.mc2coco_lut = mc2coco_lut + + def gglbl2ggid(self, gglbl): + return self.mapper.gglbl2ggid[gglbl] + + def mc2coco(self, mc): + self.mc2coco_lut = self.mc2coco_lut.to(mc.device) + coco = self.mc2coco_lut[mc.long()] + return coco + + def mc2reduced(self, mc, ign2dirt=False): + self.mcid2rdid_lut = self.mcid2rdid_lut.to(mc.device) + reduced = self.mcid2rdid_lut[mc.long()] + if ign2dirt: + reduced[reduced == self.ignore_id] = self.dirt_id + return reduced + + def coco2reduced(self, coco): + self.ggid2rdid_lut = self.ggid2rdid_lut.to(coco.device) + reduced = self.ggid2rdid_lut[coco.long()] + return reduced + + def get_num_reduced_lbls(self): + return self.num_reduced_lbls + + @staticmethod + def uint32_to_4uint8(x): + dt1 = np.dtype(('i4', [('bytes', 'u1', 4)])) + color = x.view(dtype=dt1)['bytes'] + return color + + def mc_color(self, img): + r"""Obtaining Minecraft default color. + + Args: + img (H x W x 1 int32 numpy tensor): Segmentation map. + """ + lut = self.id2color_lut + lut = list(zip(*sorted([(k, v) for k, v in lut.items()])))[1] + lut = np.array(lut, dtype=np.uint32) + rgb = lut[img] + rgb = self.uint32_to_4uint8(rgb)[..., :3] + + return rgb + + +def rand_crop(cam_c, cam_res, target_res): + r"""Produces a new cam_c so that the effect of rendering with the new cam_c and target_res is the same as rendering + with the old parameters and then crop out target_res. + """ + d0 = np.random.randint(cam_res[0] - target_res[0] + 1) + d1 = np.random.randint(cam_res[1] - target_res[1] + 1) + cam_c = [cam_c[0]-d0, cam_c[1]-d1] + return cam_c + + +def segmask_smooth(seg_mask, kernel_size=7): + labels = F.avg_pool2d(seg_mask, kernel_size, 1, kernel_size//2) + onehot_idx = torch.argmax(labels, dim=1, keepdims=True) + labels.fill_(0.0) + labels.scatter_(1, onehot_idx, 1.0) + return labels + + +def colormap(x, cmap='viridis'): + x = np.nan_to_num(x, np.nan, np.nan, np.nan) + x = x - np.nanmin(x) + x = x / np.nanmax(x) + rgb = plt.get_cmap(cmap)(x)[..., :3] + return rgb diff --git a/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv b/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv new file mode 100644 index 0000000000000000000000000000000000000000..c82cc05572bbace78643911e9789f4f2cfd15f0e --- /dev/null +++ b/imaginaire/model_utils/gancraft/reduced_coco_lbls.csv @@ -0,0 +1,12 @@ +ignore +sky +tree +dirt +flower +grass +gravel +water +rock +stone +sand +snow \ No newline at end of file diff --git a/imaginaire/model_utils/gancraft/voxlib/Makefile b/imaginaire/model_utils/gancraft/voxlib/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..de903af09d2feda89118edc743048d70ce963d24 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/Makefile @@ -0,0 +1,11 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md + +all: + python setup.py build_ext --inplace + python setup.py install + +clean: + rm -rf *.o *.a *.so test build diff --git a/imaginaire/model_utils/gancraft/voxlib/__init__.py b/imaginaire/model_utils/gancraft/voxlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fce15c92b99ef0ad9feaeb31d75664a0971b385 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/__init__.py @@ -0,0 +1,7 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .positional_encoding import positional_encoding +from .sp_trilinear import sparse_trilinear_interp_worldcoord +from voxlib import ray_voxel_intersection_perspective diff --git a/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py b/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..ef95d0bd47103233e1c4f32c70dd5b463c9eac1d --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/positional_encoding.py @@ -0,0 +1,63 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch.autograd import Function +import voxlib + +# Cheatsheet: +# mark_dirty() must be used to mark any input that is modified inplace by the forward function. +# mark_non_differentiable() + + +class PositionalEncodingFunction(Function): + @staticmethod + def forward(ctx, in_feature, pe_degrees, dim, incl_orig): + out_feature = voxlib.positional_encoding(in_feature, pe_degrees, dim, incl_orig) + + ctx.save_for_backward(out_feature) + ctx.pe_degrees = pe_degrees + ctx.dim = dim + ctx.incl_orig = incl_orig + + return out_feature + + @staticmethod + def backward(ctx, out_feature_grad): + out_feature, = ctx.saved_tensors + + # torch::Tensor positional_encoding_backward(const torch::Tensor& out_feature_grad, + # const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig) { + in_feature_grad = voxlib.positional_encoding_backward( + out_feature_grad, out_feature, ctx.pe_degrees, ctx.dim, ctx.incl_orig) + + return in_feature_grad, None, None, None + + +def positional_encoding(in_feature, pe_degrees, dim=-1, incl_orig=False): + return PositionalEncodingFunction.apply(in_feature, pe_degrees, dim, incl_orig) + +# input: N, C +# output: N, pe_degrees*C + + +def positional_encoding_pt(pts, pe_degrees, dim=-1, incl_orig=False): + import numpy as np + pe_stor = [] + for i in range(pe_degrees): + pe_stor.append(torch.sin(pts * np.pi * 2 ** i)) + pe_stor.append(torch.cos(pts * np.pi * 2 ** i)) + if incl_orig: + pe_stor.append(pts) + pe = torch.cat(pe_stor, dim=dim) + return pe + + +if __name__ == '__main__': + x = torch.rand(384, 512, 5, 48).cuda() * 1024 + y = positional_encoding_pt(x, 4, incl_orig=True) + y2 = positional_encoding(x, 4, incl_orig=True) + + print(torch.abs(y - y2)) + print(torch.allclose(y, y2, rtol=1e-05, atol=1e-05)) diff --git a/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu b/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..278fc165991d70aae549dd376b979f53703a0647 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/positional_encoding_kernel.cu @@ -0,0 +1,285 @@ +// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, check out LICENSE.md + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + + +#include +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + + +struct PE_Params { + int ndegrees; + int pre_size; + int post_size; + bool incl_orig; +}; + +// const int TILE_DIM_X = 16; // channel dim +// const int TILE_DIM_Y = 64; // entry dim +// dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+TILE_DIM_Y-1)/TILE_DIM_Y, 1); +// dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1); +template +__global__ void positional_encoding_kernel( + float* __restrict__ out_feature, + const float* __restrict__ in_feature, const PE_Params p) { + + const int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x; + const int idx_entry_base = blockIdx.y * TILE_DIM_Y * DUP_Y + threadIdx.y * DUP_Y; + if (idx_feat >= p.post_size) { + return; + } + + int stride = p.ndegrees*2; + if (p.incl_orig) { + stride += 1; + } + + for (int j=0; j= p.pre_size) { + return; + } + float data = in_feature[idx_entry*p.post_size + idx_feat]; + + for (int i=0; i +__global__ void positional_encoding_backward_kernel( + float* __restrict__ in_feature_grad, + const float* __restrict__ out_feature_grad, const float* __restrict__ out_feature, const PE_Params p) { + + int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x; + const int idx_entry_base = blockIdx.y * TILE_DIM_Y * DUP_Y + threadIdx.y * DUP_Y; + + if (idx_feat >= p.post_size) { + return; + } + + int stride = p.ndegrees*2; + if (p.incl_orig) { + stride += 1; + } + + for (int j=0; j= p.pre_size) { + return; + } + + float grad = 0.0f; + for (int i=0; i +torch::Tensor positional_encoding_cuda(const torch::Tensor& in_feature, int ndegrees, int dim, bool incl_orig) { + CHECK_CUDA(in_feature); + + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + torch::Device device = in_feature.device(); + + assert(in_feature.dtype() == torch::kFloat32); + + // Handle negative index + if (dim < 0) { + dim = in_feature.dim() + dim; + } + assert(dim >= 0 && dim < in_feature.dim()); + + // No need to be contiguous. Input and output has the same memory layout. + CHECK_CONTIGUOUS(in_feature); + + PE_Params p; + p.ndegrees = ndegrees; + p.incl_orig = incl_orig; + + // This only works for contiguous tensors... + int pre_size = 1; + int post_size = 1; + for (int i=0; i out_feature_shape; + for (int i=0; i Each thread handle a single post_size + // Case 2: Concat at the middle (post_size > pre_size) --> Each thread handle + const int TILE_DIM_X = 16; // channel dim + const int TILE_DIM_Y = 64; // entry dim + //const int DUP_Y = 4; // Each thread handle multiple entries to save threads + const int DUP_Y = 8; // DGXA 64 samples per ray @ 256x256 + dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+(TILE_DIM_Y*DUP_Y)-1)/(TILE_DIM_Y*DUP_Y), 1); + dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1); + positional_encoding_kernel<<>>( + out_feature.data_ptr(), + in_feature.data_ptr(), p + ); + + THCudaCheck(cudaGetLastError()); + return out_feature; +} + +//in_feature_grad = voxrender_op.positional_encoding_backward(out_feature_grad, out_feature, ctx.pe_degrees, ctx.dim, ctx.incl_orig); +// Input: +// out_feature_grad: float32 [..., N*ndegree*2+incl_orig, ...] +// out_feature: float32 [..., N*ndegree*2+incl_orig, ...] +// ndegrees: int32 Degrees of PE encoding +// dim: int32 Dimension to concatenate +// incl_orig: bool Whether to include original feature vector or not +// Output: +// in_feature_grad: float32 [..., N, ...] +// std::vector +torch::Tensor positional_encoding_backward_cuda(const torch::Tensor& out_feature_grad_, const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig) { + CHECK_CUDA(out_feature_grad_); + CHECK_CUDA(out_feature); + + const torch::Tensor out_feature_grad = out_feature_grad_.contiguous(); + + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + torch::Device device = out_feature_grad.device(); + + assert(out_feature_grad.dtype() == torch::kFloat32); + assert(out_feature.dtype() == torch::kFloat32); + assert(out_feature_grad.sizes() == out_feature.sizes()); + + // Handle negative index + if (dim < 0) { + dim = out_feature.dim() + dim; + } + assert(dim >= 0 && dim < out_feature.dim()); + + CHECK_CONTIGUOUS(out_feature_grad); + CHECK_CONTIGUOUS(out_feature); + + PE_Params p; + p.ndegrees = ndegrees; + p.incl_orig = incl_orig; + + int expansion_factor = ndegrees*2; + if (incl_orig) { + expansion_factor += 1; + } + // This only works for contiguous tensors... + int pre_size = 1; + int post_size = 1; + for (int i=0; i out_feature_shape; + for (int i=0; i Each thread handle a single post_size + // Case 2: Concat at the middle (post_size > pre_size) --> Each thread handle + const int TILE_DIM_X = 16; // channel dim + const int TILE_DIM_Y = 64; // entry dim + //const int DUP_Y = 4; // Nothing to amortize + const int DUP_Y = 8; // DGXA + dim3 dimGrid((p.post_size+TILE_DIM_X-1)/TILE_DIM_X, (p.pre_size+(TILE_DIM_Y*DUP_Y)-1)/(TILE_DIM_Y*DUP_Y), 1); + dim3 dimBlock(TILE_DIM_X, TILE_DIM_Y, 1); + positional_encoding_backward_kernel<<>>( + in_feature_grad.data_ptr(), + out_feature_grad.data_ptr(), out_feature.data_ptr(), p + ); + + THCudaCheck(cudaGetLastError()); + + return in_feature_grad; +} diff --git a/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu b/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu new file mode 100644 index 0000000000000000000000000000000000000000..7ef22dc309e2eb6d944c50d917235f0c62219cb6 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/ray_voxel_intersection.cu @@ -0,0 +1,325 @@ +// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, check out LICENSE.md +// +// The ray marching algorithm used in this file is a variety of modified Bresenham method: +// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.42.3443&rep=rep1&type=pdf +// Search for "voxel traversal algorithm" for related information + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +//#include +#include +#include +#include + +#include "voxlib_common.h" + +struct RVIP_Params { + int voxel_dims[3]; + int voxel_strides[3]; + int max_samples; + int img_dims[2]; + // Camera parameters + float cam_ori[3]; + float cam_fwd[3]; + float cam_side[3]; + float cam_up[3]; + float cam_c[2]; + float cam_f; + //unsigned long seed; +}; + +/* + out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1] + out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1] + out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3] + Image coordinates refer to the center of the pixel + [0, 0, 0] at voxel coordinate is at the corner of the corner block (instead of at the center) +*/ +template +static __global__ void ray_voxel_intersection_perspective_kernel(int32_t* __restrict__ out_voxel_id, float* __restrict__ out_depth, float* __restrict__ out_raydirs, +const int32_t* __restrict__ in_voxel, const RVIP_Params p) { + + int img_coords[2]; + img_coords[1] = blockIdx.x*TILE_DIM+threadIdx.x; + img_coords[0] = blockIdx.y*TILE_DIM+threadIdx.y; + if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) { + return; + } + int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1]; + + // Calculate ray origin and direction + float rayori[3], raydir[3]; + rayori[0] = p.cam_ori[0]; + rayori[1] = p.cam_ori[1]; + rayori[2] = p.cam_ori[2]; + + // Camera intrinsics + float ndc_imcoords[2]; + ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height + ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1]; + + raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] + p.cam_fwd[0] * p.cam_f; + raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] + p.cam_fwd[1] * p.cam_f; + raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] + p.cam_fwd[2] * p.cam_f; + normalize(raydir); + + // Save out_raydirs + out_raydirs[pix_index*3] = raydir[0]; + out_raydirs[pix_index*3+1] = raydir[1]; + out_raydirs[pix_index*3+2] = raydir[2]; + + float axis_t[3]; + int axis_int[3]; + //int axis_intbound[3]; + + // Current voxel + axis_int[0] = floorf(rayori[0]); + axis_int[1] = floorf(rayori[1]); + axis_int[2] = floorf(rayori[2]); + + #pragma unroll + for (int i=0; i<3; i++) { + if (raydir[i] > 0) { + // Initial t value + // Handle boundary case where rayori[i] is a whole number. Always round Up for the next block + //axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) / raydir[i]; + axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i]; + } else if (raydir[i] < 0) { + axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i]; + } else { + axis_t[i] = HUGE_VALF; + } + } + + // Fused raymarching and sampling + bool quit = false; + for (int cur_plane=0; cur_plane < p.max_samples; cur_plane++) { // Last cycle is for calculating p2 + float t = nanf("0"); + float t2 = nanf("0"); + int32_t blk_id = 0; + // Find the next intersection + while (!quit) { + // Find the next smallest t + float tnow; + /* + #pragma unroll + for (int i=0; i<3; i++) { + if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) { + // Update current t + tnow = axis_t[i]; + // Update t candidates + if (raydir[i] > 0) { + axis_int[i] += 1; + if (axis_int[i] >= p.voxel_dims[i]) { + quit = true; + } + axis_t[i] = ((float)(axis_int[i]+1) - rayori[i]) / raydir[i]; + } else { + axis_int[i] -= 1; + if (axis_int[i] < 0) { + quit = true; + } + axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i]; + } + break; // Avoid advancing multiple steps as axis_t is updated + } + } + */ + // Hand unroll + if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) { + // Update current t + tnow = axis_t[0]; + // Update t candidates + if (raydir[0] > 0) { + axis_int[0] += 1; + if (axis_int[0] >= p.voxel_dims[0]) { + quit = true; + } + axis_t[0] = ((float)(axis_int[0]+1) - rayori[0]) / raydir[0]; + } else { + axis_int[0] -= 1; + if (axis_int[0] < 0) { + quit = true; + } + axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0]; + } + } else if (axis_t[1] <= axis_t[2]) { + tnow = axis_t[1]; + if (raydir[1] > 0) { + axis_int[1] += 1; + if (axis_int[1] >= p.voxel_dims[1]) { + quit = true; + } + axis_t[1] = ((float)(axis_int[1]+1) - rayori[1]) / raydir[1]; + } else { + axis_int[1] -= 1; + if (axis_int[1] < 0) { + quit = true; + } + axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1]; + } + } else { + tnow = axis_t[2]; + if (raydir[2] > 0) { + axis_int[2] += 1; + if (axis_int[2] >= p.voxel_dims[2]) { + quit = true; + } + axis_t[2] = ((float)(axis_int[2]+1) - rayori[2]) / raydir[2]; + } else { + axis_int[2] -= 1; + if (axis_int[2] < 0) { + quit = true; + } + axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2]; + } + } + + if (quit) { + break; + } + + // Skip empty space + // Could there be deadlock if the ray direction is away from the world? + if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] || axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] || axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) { + continue; + } + + // Test intersection using voxel grid + blk_id = in_voxel[axis_int[0]*p.voxel_strides[0] + axis_int[1]*p.voxel_strides[1] + axis_int[2]*p.voxel_strides[2]]; + if (blk_id == 0) { + continue; + } + + // Now that there is an intersection + t = tnow; + // Calculate t2 + /* + #pragma unroll + for (int i=0; i<3; i++) { + if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) { + t2 = axis_t[i]; + break; + } + } + */ + // Hand unroll + if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) { + t2 = axis_t[0]; + } else if (axis_t[1] <= axis_t[2]) { + t2 = axis_t[1]; + } else { + t2 = axis_t[2]; + } + break; + } // while !quit (ray marching loop) + + out_depth[pix_index*p.max_samples+cur_plane] = t; + out_depth[p.img_dims[0]*p.img_dims[1]*p.max_samples + pix_index*p.max_samples+cur_plane] = t2; + out_voxel_id[pix_index*p.max_samples+cur_plane] = blk_id; + } // cur_plane +} + + +/* + out: + out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1] + out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1] + out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3] + in: + in_voxel: torch CUDA int32 [X, Y, Z] [40, 512, 512] + cam_ori: torch float [3] + cam_dir: torch float [3] + cam_up: torch float [3] + cam_f: float + cam_c: int [2] + img_dims: int [2] + max_samples: int +*/ +std::vector ray_voxel_intersection_perspective_cuda(const torch::Tensor& in_voxel, const torch::Tensor& cam_ori, const torch::Tensor& cam_dir, const torch::Tensor& cam_up, float cam_f, const std::vector& cam_c, const std::vector& img_dims, int max_samples) { + CHECK_CUDA(in_voxel); + + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + torch::Device device = in_voxel.device(); + + //assert(in_voxel.dtype() == torch::kU8); + assert(in_voxel.dtype() == torch::kInt32); // Minecraft compatibility + assert(in_voxel.dim() == 3); + assert(cam_ori.dtype() == torch::kFloat32); + assert(cam_ori.numel() == 3); + assert(cam_dir.dtype() == torch::kFloat32); + assert(cam_dir.numel() == 3); + assert(cam_up.dtype() == torch::kFloat32); + assert(cam_up.numel() == 3); + assert(img_dims.size() == 2); + + RVIP_Params p; + + // Calculate camera rays + const torch::Tensor cam_ori_c = cam_ori.cpu(); + const torch::Tensor cam_dir_c = cam_dir.cpu(); + const torch::Tensor cam_up_c = cam_up.cpu(); + + // Get the coordinate frame of camera space in world space + normalize(p.cam_fwd, cam_dir_c.data_ptr()); + cross(p.cam_side, p.cam_fwd, cam_up_c.data_ptr()); + normalize(p.cam_side); + cross(p.cam_up, p.cam_side, p.cam_fwd); + normalize(p.cam_up); // Not absolutely necessary as both vectors are normalized. But just in case... + + copyarr(p.cam_ori, cam_ori_c.data_ptr()); + + p.cam_f = cam_f; + p.cam_c[0] = cam_c[0]; + p.cam_c[1] = cam_c[1]; + p.max_samples = max_samples; + //printf("[Renderer] max_dist: %ld\n", max_dist); + + p.voxel_dims[0] = in_voxel.size(0); + p.voxel_dims[1] = in_voxel.size(1); + p.voxel_dims[2] = in_voxel.size(2); + p.voxel_strides[0] = in_voxel.stride(0); + p.voxel_strides[1] = in_voxel.stride(1); + p.voxel_strides[2] = in_voxel.stride(2); + + //printf("[Renderer] Voxel resolution: %ld, %ld, %ld\n", p.voxel_dims[0], p.voxel_dims[1], p.voxel_dims[2]); + + p.img_dims[0] = img_dims[0]; + p.img_dims[1] = img_dims[1]; + + // Create output tensors + // For Minecraft Seg Mask + torch::Tensor out_voxel_id = torch::empty({p.img_dims[0], p.img_dims[1], p.max_samples, 1}, torch::TensorOptions().dtype(torch::kInt32).device(device)); + + torch::Tensor out_depth; + // Produce two sets of localcoords, one for entry point, the other one for exit point. They share the same corner_ids. + out_depth = torch::empty({2, p.img_dims[0], p.img_dims[1], p.max_samples, 1}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); + + torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3}, torch::TensorOptions().dtype(torch::kFloat32).device(device).requires_grad(false)); + + const int TILE_DIM = 8; + dim3 dimGrid((p.img_dims[1]+TILE_DIM-1)/TILE_DIM, (p.img_dims[0]+TILE_DIM-1)/TILE_DIM, 1); + dim3 dimBlock(TILE_DIM, TILE_DIM, 1); + + ray_voxel_intersection_perspective_kernel<<>>( + out_voxel_id.data_ptr(), out_depth.data_ptr(), out_raydirs.data_ptr(), in_voxel.data_ptr(), p + ); + + return {out_voxel_id, out_depth, out_raydirs}; +} diff --git a/imaginaire/model_utils/gancraft/voxlib/setup.py b/imaginaire/model_utils/gancraft/voxlib/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..1eca848211370c8ddf9dda55d5c67804a73061e9 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/setup.py @@ -0,0 +1,25 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +cxx_args = ['-fopenmp'] +nvcc_args = [] + +setup( + name='voxrender', + ext_modules=[ + CUDAExtension('voxlib', [ + 'voxlib.cpp', + 'ray_voxel_intersection.cu', + 'sp_trilinear_worldcoord_kernel.cu', + 'positional_encoding_kernel.cu' + ], + extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args} + ) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py new file mode 100644 index 0000000000000000000000000000000000000000..1bad56fb23f6b8e2a8e41573f8b6b85f9a5693f1 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear.py @@ -0,0 +1,35 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from torch.autograd import Function +import voxlib + +""" +It takes world coordinate as input instead of block-local coordinate. Corner IDs are looked up on-the-fly to +save memory. +""" + + +class SparseTrilinearWorldCoordFunction(Function): + @staticmethod + def forward(ctx, in_feature, corner_lut_t, in_worldcoord, ign_zero): + + out_feature = voxlib.sp_trilinear_worldcoord(in_feature, corner_lut_t, in_worldcoord, ign_zero, -1) + ctx.ign_zero = ign_zero + ctx.save_for_backward(in_feature, corner_lut_t, in_worldcoord) + + return out_feature + + @staticmethod + def backward(ctx, out_feature_grad): + in_feature, corner_lut_t, in_worldcoord = ctx.saved_tensors + + assert ctx.needs_input_grad[2] is False + in_feature_grad, = voxlib.sp_trilinear_worldcoord_backward( + out_feature_grad, in_feature, corner_lut_t, in_worldcoord, ctx.ign_zero, False) + return in_feature_grad, None, None, None, None + + +def sparse_trilinear_interp_worldcoord(in_feature, corner_lut_t, in_worldcoord, ign_zero=False): + return SparseTrilinearWorldCoordFunction.apply(in_feature, corner_lut_t, in_worldcoord, ign_zero) diff --git a/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..403d01fbda3f528211dd262d019a606f1f8f1640 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/sp_trilinear_worldcoord_kernel.cu @@ -0,0 +1,527 @@ +// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, check out LICENSE.md +// +// Fast routine for sparse tri-linear interpolation of high dimensional features. +// Ignore label is supported. + + +#include + +#include +#include +#include +#include + +#include +#include +#include + + +#include +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + + +struct SpTrilinear_wc_Params { + int in_feature_dim; + int in_feature_numentries; + int corner_lut_dims[3]; + int corner_lut_strides[3]; + int in_worldcoord_dims[8]; + int in_worldcoord_strides[8]; + int in_worldcoord_ndim; + int out_feature_dims[8]; + int out_feature_strides[8]; + bool ign_zero; +}; + + +// out_feature.data_ptr(), +// in_feature.data_ptr(), corner_lut_t.data_ptr(), in_worldcoord.data_ptr(), p +template +__global__ void sp_trilinear_worldcoord_kernel( + float* __restrict__ out_feature, + const float* __restrict__ in_feature, const int32_t* __restrict__ corner_lut_t, const float* __restrict__ in_worldcoord, SpTrilinear_wc_Params p) { + + const int GRID_X = gridDim.y; + int idx_entry = blockIdx.x * TILE_DIM_Y + threadIdx.y; + + // Index processing + //int index[7]; + int t = idx_entry; + int idx_in_worldcoord = 0; + int idx_out_feature = 0; + for (int i=p.in_worldcoord_ndim-2; i>=0; i--) { + int idx_t = t % p.in_worldcoord_dims[i]; + t = t / p.in_worldcoord_dims[i]; + idx_in_worldcoord += p.in_worldcoord_strides[i] * idx_t; + idx_out_feature += p.out_feature_strides[i] * idx_t; + } + if (t > 0) { + return; + } + int stride_in_worldcoord = p.in_worldcoord_strides[p.in_worldcoord_ndim-1]; + int stride_out_feature = p.out_feature_strides[p.in_worldcoord_ndim-1]; + + + float world_coords[3]; + world_coords[0] = in_worldcoord[idx_in_worldcoord]; + world_coords[1] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord]; + world_coords[2] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord*2]; + + float local_coords[3]; + int vox_coords[3]; + local_coords[0] = world_coords[0] - floorf(world_coords[0]); + vox_coords[0] = (int)floorf(world_coords[0]); + local_coords[1] = world_coords[1] - floorf(world_coords[1]); + vox_coords[1] = (int)floorf(world_coords[1]); + local_coords[2] = world_coords[2] - floorf(world_coords[2]); + vox_coords[2] = (int)floorf(world_coords[2]); + + float interp_weight[8]; + // 0,0,0 + interp_weight[0] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]); + // 0,0,1 + interp_weight[1] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]); + // 0,1,0 + interp_weight[2] = (1.0f-local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]); + // 0,1,1 + interp_weight[3] = (1.0f-local_coords[0])*(local_coords[1])*(local_coords[2]); + // 1,0,0 + interp_weight[4] = (local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]); + // 1,0,1 + interp_weight[5] = (local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]); + // 1,1,0 + interp_weight[6] = (local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]); + // 1,1,1 + interp_weight[7] = (local_coords[0])*(local_coords[1])*(local_coords[2]); + + int indices[8]; + // Hard boundary check (zero padding) + if (isnan(world_coords[0]) || isnan(world_coords[1]) || isnan(world_coords[2])) { + indices[0] = -1;indices[1] = -1;indices[2] = -1;indices[3] = -1; + indices[4] = -1;indices[5] = -1;indices[6] = -1;indices[7] = -1; + } else { + // Clamp to boundaries + int vox_coords_1[3]; + vox_coords_1[0] = min(max(vox_coords[0]+1, 0), p.corner_lut_dims[0]-1); + vox_coords_1[1] = min(max(vox_coords[1]+1, 0), p.corner_lut_dims[1]-1); + vox_coords_1[2] = min(max(vox_coords[2]+1, 0), p.corner_lut_dims[2]-1); + vox_coords[0] = min(max(vox_coords[0], 0), p.corner_lut_dims[0]-1); + vox_coords[1] = min(max(vox_coords[1], 0), p.corner_lut_dims[1]-1); + vox_coords[2] = min(max(vox_coords[2], 0), p.corner_lut_dims[2]-1); + int idx_corner_lut; + // 000 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[0] = corner_lut_t[idx_corner_lut]; + // 001 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[1] = corner_lut_t[idx_corner_lut]; + // 010 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[2] = corner_lut_t[idx_corner_lut]; + // 011 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[3] = corner_lut_t[idx_corner_lut]; + // 100 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[4] = corner_lut_t[idx_corner_lut]; + // 101 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[5] = corner_lut_t[idx_corner_lut]; + // 110 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[6] = corner_lut_t[idx_corner_lut]; + // 111 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[7] = corner_lut_t[idx_corner_lut]; + } + + if (p.ign_zero) { + // Zero indices are to be ignored +#pragma unroll + for (int i=0; i<8; i++) { + indices[i] -= 1; + } + } + + //int idx_feat = blockIdx.x * TILE_DIM_X * DUP_X + threadIdx.x; + int idx_feat = blockIdx.y * TILE_DIM_X + threadIdx.x; + for (int i=0; i= p.in_feature_dim) { + return; + } + float interp_feat = 0.0f; +#pragma unroll + for (int j=0; j<8; j++) { + if (indices[j] >= 0) { + interp_feat = fmaf(in_feature[indices[j]*p.in_feature_dim+idx_feat], interp_weight[j], interp_feat); + } + } + //out_feature[idx_entry*p.in_feature_dim+idx_feat] = interp_feat; + out_feature[idx_out_feature+stride_out_feature*idx_feat] = interp_feat; + //idx_feat += TILE_DIM_X; + idx_feat += TILE_DIM_X * GRID_X; + } +} + + +//sp_trilinear_worldcoord_backward2feature_kernel<<>>( +// in_feature_grad.data_ptr(), +// out_feature_grad.data_ptr(), in_feature.data_ptr(), in_corner_lut.data_ptr(), in_worldcoord.data_ptr(), p +// Backward to feature +template +__global__ void sp_trilinear_worldcoord_backward2feature_kernel( + float* __restrict__ in_feature_grad, + const float* __restrict__ out_feature_grad, const int32_t* __restrict__ corner_lut_t, const float* __restrict__ in_worldcoord, SpTrilinear_wc_Params p) { + + const int GRID_X = gridDim.x; + int idx_entry = blockIdx.y * TILE_DIM_Y + threadIdx.y; + + // Index processing + //int index[7]; + int t = idx_entry; + int idx_in_worldcoord = 0; + int idx_out_feature = 0; + for (int i=p.in_worldcoord_ndim-2; i>=0; i--) { + int idx_t = t % p.in_worldcoord_dims[i]; + t = t / p.in_worldcoord_dims[i]; + //index[i] = idx_t; + idx_in_worldcoord += p.in_worldcoord_strides[i] * idx_t; + idx_out_feature += p.out_feature_strides[i] * idx_t; + } + if (t > 0) { + return; + } + int stride_in_worldcoord = p.in_worldcoord_strides[p.in_worldcoord_ndim-1]; + int stride_out_feature = p.out_feature_strides[p.in_worldcoord_ndim-1]; + + float world_coords[3]; + world_coords[0] = in_worldcoord[idx_in_worldcoord]; + world_coords[1] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord]; + world_coords[2] = in_worldcoord[idx_in_worldcoord+stride_in_worldcoord*2]; + + float local_coords[3]; + int vox_coords[3]; + local_coords[0] = world_coords[0] - floorf(world_coords[0]); + vox_coords[0] = (int)floorf(world_coords[0]); + local_coords[1] = world_coords[1] - floorf(world_coords[1]); + vox_coords[1] = (int)floorf(world_coords[1]); + local_coords[2] = world_coords[2] - floorf(world_coords[2]); + vox_coords[2] = (int)floorf(world_coords[2]); + + float interp_weight[8]; + // 0,0,0 + interp_weight[0] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]); + // 0,0,1 + interp_weight[1] = (1.0f-local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]); + // 0,1,0 + interp_weight[2] = (1.0f-local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]); + // 0,1,1 + interp_weight[3] = (1.0f-local_coords[0])*(local_coords[1])*(local_coords[2]); + // 1,0,0 + interp_weight[4] = (local_coords[0])*(1.0f-local_coords[1])*(1.0f-local_coords[2]); + // 1,0,1 + interp_weight[5] = (local_coords[0])*(1.0f-local_coords[1])*(local_coords[2]); + // 1,1,0 + interp_weight[6] = (local_coords[0])*(local_coords[1])*(1.0f-local_coords[2]); + // 1,1,1 + interp_weight[7] = (local_coords[0])*(local_coords[1])*(local_coords[2]); + + int indices[8]; + // Hard boundary check (zero padding) + if (isnan(world_coords[0]) || isnan(world_coords[1]) || isnan(world_coords[2])) {// || + //vox_coords[0] < 0 || vox_coords[0] >= (p.corner_lut_dims[0]-1) || + //vox_coords[1] < 0 || vox_coords[1] >= (p.corner_lut_dims[1]-1) || + //vox_coords[2] < 0 || vox_coords[2] >= (p.corner_lut_dims[2]-1)) { + indices[0] = -1;indices[1] = -1;indices[2] = -1;indices[3] = -1; + indices[4] = -1;indices[5] = -1;indices[6] = -1;indices[7] = -1; + } else { + // Clamp to boundaries + int vox_coords_1[3]; + vox_coords_1[0] = min(max(vox_coords[0]+1, 0), p.corner_lut_dims[0]-1); + vox_coords_1[1] = min(max(vox_coords[1]+1, 0), p.corner_lut_dims[1]-1); + vox_coords_1[2] = min(max(vox_coords[2]+1, 0), p.corner_lut_dims[2]-1); + vox_coords[0] = min(max(vox_coords[0], 0), p.corner_lut_dims[0]-1); + vox_coords[1] = min(max(vox_coords[1], 0), p.corner_lut_dims[1]-1); + vox_coords[2] = min(max(vox_coords[2], 0), p.corner_lut_dims[2]-1); + int idx_corner_lut; + // 000 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[0] = corner_lut_t[idx_corner_lut]; + // 001 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[1] = corner_lut_t[idx_corner_lut]; + // 010 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[2] = corner_lut_t[idx_corner_lut]; + // 011 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[3] = corner_lut_t[idx_corner_lut]; + // 100 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[4] = corner_lut_t[idx_corner_lut]; + // 101 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[5] = corner_lut_t[idx_corner_lut]; + // 110 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords[2]; + indices[6] = corner_lut_t[idx_corner_lut]; + // 111 + idx_corner_lut = p.corner_lut_strides[0] * vox_coords_1[0] + + p.corner_lut_strides[1] * vox_coords_1[1] + + p.corner_lut_strides[2] * vox_coords_1[2]; + indices[7] = corner_lut_t[idx_corner_lut]; + } + + if (p.ign_zero) { +#pragma unroll + for (int i=0; i<8; i++) { + indices[i] -= 1; + } + } + + //int idx_feat = blockIdx.x * TILE_DIM_X * DUP_X + threadIdx.x; + int idx_feat = blockIdx.x * TILE_DIM_X + threadIdx.x; + for (int i=0; i= p.in_feature_dim) { + return; + } + float grad = out_feature_grad[idx_out_feature+stride_out_feature*idx_feat]; +#pragma unroll + for (int j=0; j<8; j++) { + if (indices[j] >= 0) { + //indices[j]*p.in_feature_dim+idx_feat + atomicAdd(&in_feature_grad[indices[j]*p.in_feature_dim+idx_feat], grad * interp_weight[j]); + } + } + //idx_feat += TILE_DIM_X; + idx_feat += TILE_DIM_X * GRID_X; + } +} + +// in_feature, corner_lut_t, in_world_coord, ign_zero=False +// Input: +// in_feature: float32 [M C] +// in_corner_lut: int32 [X Y Z] +// in_worldcoord: float32 [..., 3] +// ---Index: int32 [..., 8], containing [0, M]. 0 is ignore label. +// ---Coord: float32 [..., 3] +// Output: +// Interp. Feat: float32 [..., C] +// std::vector +torch::Tensor sp_trilinear_worldcoord_cuda(const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, int channel_pos) { + CHECK_CUDA(in_feature); + CHECK_CUDA(in_corner_lut); + CHECK_CUDA(in_worldcoord); + + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + torch::Device device = in_feature.device(); + + // assert(tensor.sizes() == std::vector{3, 4, 5}); + assert(in_feature.dtype() == torch::kFloat32); + assert(in_feature.dim() == 2); + assert(in_corner_lut.dtype() == torch::kInt32); + assert(in_corner_lut.dim() == 3); + assert(in_worldcoord.dtype() == torch::kFloat32); + assert(in_worldcoord.size(-1) == 3); + assert(in_worldcoord.dim() <= 8); + + CHECK_CONTIGUOUS(in_feature); + //CHECK_CONTIGUOUS(in_corner_lut); // Will still run correctly, but performance will suffer. + //CHECK_CONTIGUOUS(in_worldcoord); + + //int channel_pos = -1; // -1 for HWC, -3 for CHW + if (channel_pos < 0) { + channel_pos += in_worldcoord.dim(); + } + assert(channel_pos >= 0 && channel_pos < in_worldcoord.dim()); + + SpTrilinear_wc_Params p; + p.in_feature_dim = in_feature.size(1); + p.in_feature_numentries = in_feature.size(0); + p.in_worldcoord_ndim = in_worldcoord.dim(); + for (int i=0; i out_feature_shape; + //if (channel_first) { // Channel First format, suitable for 2D convolution + // //assert(false); + for (int i=0; i<<>>( + out_feature.data_ptr(), + in_feature.data_ptr(), in_corner_lut.data_ptr(), in_worldcoord.data_ptr(), p + ); + THCudaCheck(cudaGetLastError()); + return out_feature; +} + + +// Backward function for sparse trilinear interpolation +// Input: +// out_feature_grad: float32 [..., C] +// in_feature: float32 [M, C] +// in_corner_lut: int32 [X Y Z] +// ---in_index: int32 [..., 8], containing [0, M]. 0 is ignore label. +// in_worldcoord: float32 [..., 3] +// ign_zero: bool +// need_coord_grad: bool +// Output: +// in_feature_grad: float32 [M, C] +// in_coord_grad: float32 [..., 3] +std::vector sp_trilinear_worldcoord_backward_cuda(const torch::Tensor& out_feature_grad , const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, bool need_coord_grad) { + assert(need_coord_grad == false); + CHECK_CUDA(out_feature_grad); + CHECK_CUDA(in_feature); + CHECK_CUDA(in_corner_lut); + CHECK_CUDA(in_worldcoord); + + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + torch::Device device = out_feature_grad.device(); + + //for (int i=0; i{3, 4, 5}); + assert(out_feature_grad.dtype() == torch::kFloat32); + for (int i=0; i<<>>( + in_feature_grad.data_ptr(), + out_feature_grad.data_ptr(), in_corner_lut.data_ptr(), in_worldcoord.data_ptr(), p + ); + } + + THCudaCheck(cudaGetLastError()); + return {in_feature_grad}; +} diff --git a/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp b/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp new file mode 100644 index 0000000000000000000000000000000000000000..70095052d71f53e5e519a5f57b3c0848998a1b22 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/voxlib.cpp @@ -0,0 +1,31 @@ +// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, check out LICENSE.md +#include +#include +#include +#include + +// Fast voxel traversal along rays +std::vector ray_voxel_intersection_perspective_cuda(const torch::Tensor& in_voxel, const torch::Tensor& cam_ori, const torch::Tensor& cam_dir, const torch::Tensor& cam_up, float cam_f, const std::vector& cam_c, const std::vector& img_dims, int max_samples); + + +// World Coordinate Sparse Trilinear Interpolation +torch::Tensor sp_trilinear_worldcoord_cuda(const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, int channel_pos); + +std::vector sp_trilinear_worldcoord_backward_cuda(const torch::Tensor& out_feature_grad , const torch::Tensor& in_feature, const torch::Tensor& in_corner_lut, const torch::Tensor& in_worldcoord, bool ign_zero, bool need_coord_grad); + +// Fast & Memory Efficient Positional Encoding +torch::Tensor positional_encoding_cuda(const torch::Tensor& in_feature, int ndegrees, int dim, bool incl_orig); + +torch::Tensor positional_encoding_backward_cuda(const torch::Tensor& out_feature_grad, const torch::Tensor& out_feature, int ndegrees, int dim, bool incl_orig); + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ray_voxel_intersection_perspective", &ray_voxel_intersection_perspective_cuda, "Ray-voxel intersections given perspective camera parameters (CUDA)"); + m.def("sp_trilinear_worldcoord", &sp_trilinear_worldcoord_cuda, "Sparse Trilinear interpolation, world coordinate [forward] (CUDA)"); + m.def("sp_trilinear_worldcoord_backward", &sp_trilinear_worldcoord_backward_cuda, "Sparse Trilinear interpolation, world coordinate [backward] (CUDA)"); + m.def("positional_encoding", &positional_encoding_cuda, "Fused Positional Encoding [forward] (CUDA)"); + m.def("positional_encoding_backward", &positional_encoding_backward_cuda, "Fused Positional Encoding [backward] (CUDA)"); +} \ No newline at end of file diff --git a/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h b/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h new file mode 100644 index 0000000000000000000000000000000000000000..46b47fc80ecf802347607395ff04565732a4ee87 --- /dev/null +++ b/imaginaire/model_utils/gancraft/voxlib/voxlib_common.h @@ -0,0 +1,76 @@ +// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, check out LICENSE.md +#ifndef VOXLIB_COMMON_H +#define VOXLIB_COMMON_H + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_CPU(x) TORCH_CHECK(x.device().is_cpu(), #x " must be a CPU tensor") + +#include +#include +// CUDA vector math functions +__host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +template +__host__ __forceinline__ void cross(scalar_t* r, const scalar_t* a, const scalar_t* b) { + r[0] = a[1]*b[2] - a[2]*b[1]; + r[1] = a[2]*b[0] - a[0]*b[2]; + r[2] = a[0]*b[1] - a[1]*b[0]; +} + +__device__ __host__ __forceinline__ float dot(const float* a, const float* b) { + return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]; +} + +template +__device__ __host__ __forceinline__ void copyarr(scalar_t* r, const scalar_t* a) { + #pragma unroll + for (int i=0; i +__device__ __host__ __forceinline__ void normalize(scalar_t* a) { + scalar_t vec_len=0.0f; + #pragma unroll + for (int i=0; i +__device__ __host__ __forceinline__ void normalize(scalar_t* r, const scalar_t* a) { + scalar_t vec_len=0.0f; + #pragma unroll + for (int i=0; i= data_type_num_classes] = data_type_num_classes + data[data_type] = label_map / 255.0 + return data + + +def _encode_onehot(label_map, num_classes, use_dont_care): + r"""Make input one-hot. + + Args: + label_map (torch.Tensor): (C, H, W) tensor containing indices. + num_classes (int): Number of labels to expand tensor to. + use_dont_care (bool): Use the dont care label or not? + Returns: + output (torch.Tensor): (num_classes, H, W) one-hot tensor. + """ + # All labels lie in [0. num_classes - 1]. + # Encode dont care as num_classes. + label_map[label_map < 0] = num_classes + label_map[label_map >= num_classes] = num_classes + + size = label_map.size() + output_size = (num_classes + 1, size[-2], size[-1]) + output = torch.zeros(*output_size) + if label_map.dim() == 4: + output = output.unsqueeze(0).repeat(label_map.size(0), 1, 1, 1) + output = output.scatter_(1, label_map.data.long(), 1.0) + if not use_dont_care: + output = output[:, :num_classes, ...] + else: + output = output.scatter_(0, label_map.data.long(), 1.0) + if not use_dont_care: + output = output[:num_classes, ...] + return output diff --git a/imaginaire/model_utils/pix2pixHD.py b/imaginaire/model_utils/pix2pixHD.py new file mode 100644 index 0000000000000000000000000000000000000000..862eb591a665e3ea81b3f918696feac3640a6c94 --- /dev/null +++ b/imaginaire/model_utils/pix2pixHD.py @@ -0,0 +1,227 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +"""Utils for the pix2pixHD model.""" +import numpy as np +import torch + +from imaginaire.utils.data import get_paired_input_label_channel_number +from imaginaire.utils.distributed import dist_all_gather_tensor, is_master +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.trainer import (get_optimizer, get_optimizer_for_params, + wrap_model_and_optimizer) +from sklearn.cluster import KMeans + + +def cluster_features(cfg, train_data_loader, net_E, + preprocess=None, small_ratio=0.0625, is_cityscapes=True): + r"""Use clustering to compute the features. + + Args: + cfg (obj): Global configuration file. + train_data_loader (obj): Dataloader for iterate through the training + set. + net_E (nn.Module): Pytorch network. + preprocess (function): Pre-processing function. + small_ratio (float): We only consider instance that at least occupy + $(small_ratio) amount of image space. + is_cityscapes (bool): Is this is the cityscape dataset? In the + Cityscapes dataset, the instance labels for car start with 26001, + 26002, ... + + Returns: + ( num_labels x num_cluster_centers x feature_dims): cluster centers. + """ + # Encode features. + label_nc = get_paired_input_label_channel_number(cfg.data) + feat_nc = cfg.gen.enc.num_feat_channels + n_clusters = getattr(cfg.gen.enc, 'num_clusters', 10) + # Compute features. + features = {} + for label in range(label_nc): + features[label] = np.zeros((0, feat_nc + 1)) + for data in train_data_loader: + if preprocess is not None: + data = preprocess(data) + feat = encode_features(net_E, feat_nc, label_nc, + data['images'], data['instance_maps'], + is_cityscapes) + # We only collect the feature vectors for the master GPU. + if is_master(): + for label in range(label_nc): + features[label] = np.append( + features[label], feat[label], axis=0) + # Clustering. + # We only perform clustering for the master GPU. + if is_master(): + for label in range(label_nc): + feat = features[label] + # We only consider segments that are greater than a pre-set + # threshold. + feat = feat[feat[:, -1] > small_ratio, :-1] + if feat.shape[0]: + n_clusters = min(feat.shape[0], n_clusters) + kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat) + n, d = kmeans.cluster_centers_.shape + this_cluster = getattr(net_E, 'cluster_%d' % label) + this_cluster[0:n, :] = torch.Tensor( + kmeans.cluster_centers_).float() + + +def encode_features(net_E, feat_nc, label_nc, image, inst, + is_cityscapes=True): + r"""Compute feature embeddings for an image image. + TODO(Ting-Chun): To make this funciton dataset independent. + + Args: + net_E (nn.Module): The encoder network. + feat_nc (int): Feature dimensions + label_nc (int): Number of segmentation labels. + image (tensor): Input image tensor. + inst (tensor): Input instance map. + is_cityscapes (bool): Is this is the cityscape dataset? In the + Cityscapes dataset, the instance labels for car start with 26001, + 26002, ... + Returns: + (list of list of numpy vectors): We will have $(label_nc) + list. For each list, it will record a list of feature vectors of + dimension $(feat_nc+1) where the first $(feat_nc) dimensions is + the representative feature of an instance and the last dimension + is the proportion. + """ + # h, w = inst.size()[2:] + feat_map = net_E(image, inst) + feature_map_gather = dist_all_gather_tensor(feat_map) + inst_gathered = dist_all_gather_tensor(inst) + # Initialize the cluster centers. + # For each feature vector, + # 0:feat_nc will be the feature vector. + # The feat_nc dimension record the percentage of the instance. + feature = {} + for i in range(label_nc): + feature[i] = np.zeros((0, feat_nc + 1)) + if is_master(): + all_feat_map = torch.cat(feature_map_gather, 0) + all_inst_map = torch.cat(inst_gathered, 0) + # Scan through the batches. + for n in range(all_feat_map.size()[0]): + feat_map = all_feat_map[n:(n + 1), :, :, :] + inst = all_inst_map[n:(n + 1), :, :, :] + fh, fw = feat_map.size()[2:] + inst_np = inst.cpu().numpy().astype(int) + for i in np.unique(inst_np): + if is_cityscapes: + label = i if i < 1000 else i // 1000 + else: + label = i + idx = (inst == int(i)).nonzero() + num = idx.size()[0] + # We will just pick the middle pixel as its representative + # feature. + idx = idx[num // 2, :] + val = np.zeros((1, feat_nc + 1)) + for k in range(feat_nc): + # We expect idx[0]=0 and idx[1]=0 as the number of sample + # per processing is 1 (idx[0]=0) and the channel number of + # the instance map is 1. + val[0, k] = feat_map[ + idx[0], idx[1] + k, idx[2], idx[3]].item() + val[0, feat_nc] = float(num) / (fh * fw) + feature[label] = np.append(feature[label], val, axis=0) + return feature + else: + return feature + + +def get_edges(t): + r""" Compute edge maps for a given input instance map. + + Args: + t (4D tensor): Input instance map. + Returns: + (4D tensor): Output edge map. + """ + edge = torch.cuda.ByteTensor(t.size()).zero_() + edge[:, :, :, 1:] = edge[:, :, :, 1:] | ( + t[:, :, :, 1:] != t[:, :, :, :-1]).byte() + edge[:, :, :, :-1] = edge[:, :, :, :-1] | ( + t[:, :, :, 1:] != t[:, :, :, :-1]).byte() + edge[:, :, 1:, :] = edge[:, :, 1:, :] | ( + t[:, :, 1:, :] != t[:, :, :-1, :]).byte() + edge[:, :, :-1, :] = edge[:, :, :-1, :] | ( + t[:, :, 1:, :] != t[:, :, :-1, :]).byte() + return edge.float() + + +def get_train_params(net, param_names_start_with=[], param_names_include=[]): + r"""Get train parameters. + + Args: + net (obj): Network object. + param_names_start_with (list of strings): Params whose names + start with any of the strings will be trained. + param_names_include (list of strings): Params whose names include + any of the strings will be trained. + """ + params_to_train = [] + params_dict = net.state_dict() + list_of_param_names_to_train = set() + # Iterate through all params in the network and check if we need to + # train it. + for key, value in params_dict.items(): + do_train = False + # If the param name starts with the target string (excluding + # the 'module' part etc), we will train this param. + key_s = key.replace('module.', '').replace('averaged_model.', '') + for param_name in param_names_start_with: + if key_s.startswith(param_name): + do_train = True + list_of_param_names_to_train.add(param_name) + + # Otherwise, if the param name includes the target string, + # we will also train it. + if not do_train: + for param_name in param_names_include: + if param_name in key_s: + do_train = True + full_param_name = \ + key_s[:(key_s.find(param_name) + len(param_name))] + list_of_param_names_to_train.add(full_param_name) + + # If we decide to train the param, add it to the list to train. + if do_train: + module = net + key_list = key.split('.') + for k in key_list: + module = getattr(module, k) + params_to_train += [module] + + print('Training layers: ', sorted(list_of_param_names_to_train)) + return params_to_train + + +def get_optimizer_with_params(cfg, net_G, net_D, param_names_start_with=[], + param_names_include=[]): + r"""Return the optimizer object. + + Args: + cfg (obj): Global config. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + param_names_start_with (list of strings): Params whose names + start with any of the strings will be trained. + param_names_include (list of strings): Params whose names include + any of the strings will be trained. + """ + # If any of the param name lists is not empty, will only train + # these params. Otherwise will train the entire network (all params). + if param_names_start_with or param_names_include: + params = get_train_params(net_G, param_names_start_with, + param_names_include) + else: + params = net_G.parameters() + + opt_G = get_optimizer_for_params(cfg.gen_opt, params) + opt_D = get_optimizer(cfg.dis_opt, net_D) + return wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D) diff --git a/imaginaire/model_utils/rename_inputs.py b/imaginaire/model_utils/rename_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..f40b3f98f6bf21f9efb21c9b7cd99226adbb2769 --- /dev/null +++ b/imaginaire/model_utils/rename_inputs.py @@ -0,0 +1,15 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md + + +def rename_inputs(cfg, is_inference, data): + assert hasattr(cfg, 'rename_inputs') + attr = getattr(cfg, 'rename_inputs') + for key in attr.keys(): + value = attr[key] + data[key] = data[value] + # Delete the old key. + del data[value] + return data diff --git a/imaginaire/model_utils/wc_vid2vid/render.py b/imaginaire/model_utils/wc_vid2vid/render.py new file mode 100644 index 0000000000000000000000000000000000000000..304b1cb3b58384ad8fb99bacf43ddda0bc0b2ff6 --- /dev/null +++ b/imaginaire/model_utils/wc_vid2vid/render.py @@ -0,0 +1,199 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import pickle +import time + +import numpy as np + + +class SplatRenderer(object): + """Splatting 3D point cloud into image using precomputed mapping.""" + + def __init__(self): + self.reset() + + def reset(self): + """Reset the renderer.""" + # 1 = point seen before, 0 = not seen. + # This is numpy uint8 array of size (N, 1) + self.seen_mask = None + + # Time of first colorization of 3D point. + # This is numpy uint16 array of size (N, 1) + self.seen_time = None + + # colors[kp_idx] is color of kp_idx'th keypoint. + # This is a numpy uint8 array of size (N, 3) + self.colors = None + + self.time_taken = 0 + self.call_idx = 0 + + def num_points(self): + r"""Number of points with assigned colors.""" + return np.sum(self.seen_mask) + + def _resize_arrays(self, max_point_idx): + r"""Makes arrays bigger, if needed. + Args: + max_point_idx (int): Highest 3D point index seen so far. + """ + if self.colors is None: + old_max_point_idx = 0 + else: + old_max_point_idx = self.colors.shape[0] + + if max_point_idx > old_max_point_idx: + # Init new bigger arrays. + colors = np.zeros((max_point_idx, 3), dtype=np.uint8) + seen_mask = np.zeros((max_point_idx, 1), dtype=np.uint8) + seen_time = np.zeros((max_point_idx, 1), dtype=np.uint16) + # Copy old colors, if exist. + if old_max_point_idx > 0: + colors[:old_max_point_idx] = self.colors + seen_mask[:old_max_point_idx] = self.seen_mask + seen_time[:old_max_point_idx] = self.seen_time + # Reset pointers. + self.colors = colors + self.seen_mask = seen_mask + self.seen_time = seen_time + + def update_point_cloud(self, image, point_info): + r"""Updates point cloud with new points and colors. + Args: + image (H x W x 3, uint8): Select colors from this image to assign to + 3D points which do not have previously assigned colors. + point_info (N x 3): (i, j, 3D point idx) per row containing + mapping of image pixel to 3D point in point cloud. + """ + if point_info is None or len(point_info) == 0: + return + + start = time.time() + self.call_idx += 1 + + i_idxs = point_info[:, 0] + j_idxs = point_info[:, 1] + point_idxs = point_info[:, 2] + + # Allocate memory for new colors. + max_point_idx = np.max(np.array(point_idxs)) + 1 + self._resize_arrays(max_point_idx) + # print('max point idx:', max_point_idx) + + # Save only the new colors. + self.colors[point_idxs] = \ + self.seen_mask[point_idxs] * self.colors[point_idxs] + \ + (1 - self.seen_mask[point_idxs]) * image[i_idxs, j_idxs] + + # Save point seen times. + self.seen_time[point_idxs] = \ + self.seen_mask[point_idxs] * self.seen_time[point_idxs] + \ + (1 - self.seen_mask[point_idxs]) * self.call_idx + + # Update seen point mask. + self.seen_mask[point_idxs] = 1 + + end = time.time() + self.time_taken += (end - start) + + def render_image(self, point_info, w, h, return_mask=False): + r"""Creates image of (h, w) and fills in colors. + Args: + point_info (N x 3): (i, j, 3D point idx) per row containing + mapping of image pixel to 3D point in point cloud. + w (int): Width of output image. + h (int): Height of output image. + return_mask (bool): Return binary mask of coloring. + Returns: + (tuple): + - output (H x W x 3, uint8): Image formed with mapping and colors. + - mask (H x W x 1, uint8): Binary (255 or 0) mask of colorization. + """ + output = np.zeros((h, w, 3), dtype=np.uint8) + mask = np.zeros((h, w, 1), dtype=np.uint8) + + if point_info is None or len(point_info) == 0: + if return_mask: + return output, mask + else: + return output + + start = time.time() + + i_idxs = point_info[:, 0] + j_idxs = point_info[:, 1] + point_idxs = point_info[:, 2] + + # Allocate memory for new colors. + max_point_idx = np.max(np.array(point_idxs)) + 1 + self._resize_arrays(max_point_idx) + + # num_found = np.sum(self.seen_mask[point_idxs]) + # print('Found %d points to color' % (num_found)) + + # Copy colors. + output[i_idxs, j_idxs] = self.colors[point_idxs] + + end = time.time() + self.time_taken += (end - start) + + if return_mask: + mask[i_idxs, j_idxs] = 255 * self.seen_mask[point_idxs] + return output, mask + else: + return output + + +def decode_unprojections(data): + r"""Unpickle unprojections and make array. + Args: + data (array of pickled info): Each pickled string has keypoint mapping + info. + Returns: + output (dict): Keys are the different resolutions, and values are padded + mapping information. + """ + + # Unpickle unprojections and store them in a dict with resolutions as keys. + all_unprojections = {} + for item in data: + info = pickle.loads(item) + + for resolution, value in info.items(): + if resolution not in all_unprojections: + all_unprojections[resolution] = [] + + if not value or value is None: + point_info = [] + else: + point_info = value + all_unprojections[resolution].append(point_info) + + outputs = {} + for resolution, values in all_unprojections.items(): + # Get max length of mapping. + max_len = 0 + for value in values: + max_len = max(max_len, len(value)) + # Entries are a 3-tuple of (i_idx, j_idx, point_idx). + assert len(value) % 3 == 0 + + # Pad each mapping to max_len. + values = [ + value + # Original info. + [-1] * (max_len - len(value)) + # Padding. + [len(value) // 3] * 3 # End sentinel with length. + for value in values + ] + + # Convert each mapping to numpy and reshape. + values = [np.array(value).reshape(-1, 3) for value in values] + + # Stack and put in output. + # Shape is (T, N, 3). T is time steps, N is num mappings. + outputs[resolution] = np.stack(values, axis=0) + + return outputs diff --git a/imaginaire/optimizers/__init__.py b/imaginaire/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69bedc71589e178b96730174c9860b8cc8430e55 --- /dev/null +++ b/imaginaire/optimizers/__init__.py @@ -0,0 +1,8 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .fromage import Fromage +from .madam import Madam + +__all__ = ['Fromage', 'Madam'] diff --git a/imaginaire/optimizers/__pycache__/__init__.cpython-38.pyc b/imaginaire/optimizers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4a0a8977871d0626a62bfdd2561625a4fc3b5d7 Binary files /dev/null and b/imaginaire/optimizers/__pycache__/__init__.cpython-38.pyc differ diff --git a/imaginaire/optimizers/__pycache__/fromage.cpython-38.pyc b/imaginaire/optimizers/__pycache__/fromage.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f8740580fc34277745504381d463ec83bd9d493 Binary files /dev/null and b/imaginaire/optimizers/__pycache__/fromage.cpython-38.pyc differ diff --git a/imaginaire/optimizers/__pycache__/madam.cpython-38.pyc b/imaginaire/optimizers/__pycache__/madam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a384b1798545cd7b6d9b3515e1afc2b7fd892f68 Binary files /dev/null and b/imaginaire/optimizers/__pycache__/madam.cpython-38.pyc differ diff --git a/imaginaire/optimizers/fromage.py b/imaginaire/optimizers/fromage.py new file mode 100644 index 0000000000000000000000000000000000000000..d00203de89f55fd122f71b7de8718ed7ef681ec8 --- /dev/null +++ b/imaginaire/optimizers/fromage.py @@ -0,0 +1,44 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# import torch +import math + +from torch.optim.optimizer import Optimizer, required + + +class Fromage(Optimizer): + r"""Fromage optimizer implementation (https://arxiv.org/abs/2002.03432)""" + + def __init__(self, params, lr=required, momentum=0): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + defaults = dict(lr=lr, momentum=momentum) + super(Fromage, self).__init__(params, defaults) + + def step(self, closure=None): + r"""Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + d_p = p.grad.data + d_p_norm = p.grad.norm() + p_norm = p.norm() + if p_norm > 0.0 and d_p_norm > 0.0: + p.data.add_(-group['lr'], d_p * (p_norm / d_p_norm)) + else: + p.data.add_(-group['lr'], d_p) + p.data /= math.sqrt(1 + group['lr'] ** 2) + + return loss diff --git a/imaginaire/optimizers/madam.py b/imaginaire/optimizers/madam.py new file mode 100644 index 0000000000000000000000000000000000000000..11bf71d049d9323e9ba646713413578ae5eb4503 --- /dev/null +++ b/imaginaire/optimizers/madam.py @@ -0,0 +1,54 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch.optim.optimizer import Optimizer, required + + +class Madam(Optimizer): + r"""MADAM optimizer implementation (https://arxiv.org/abs/2006.14560)""" + def __init__(self, params, lr=required, scale=3.0, + g_bound=None, momentum=0): + self.scale = scale + self.g_bound = g_bound + defaults = dict(lr=lr, momentum=momentum) + super(Madam, self).__init__(params, defaults) + + def step(self, closure=None): + r"""Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + state = self.state[p] + if len(state) == 0: + state['max'] = self.scale * (p * p).mean().sqrt().item() + state['step'] = 0 + state['exp_avg_sq'] = torch.zeros_like(p) + + state['step'] += 1 + bias_correction = 1 - 0.999 ** state['step'] + state['exp_avg_sq'] = 0.999 * state[ + 'exp_avg_sq'] + 0.001 * p.grad.data ** 2 + g_normed = \ + p.grad.data / (state['exp_avg_sq'] / bias_correction).sqrt() + g_normed[torch.isnan(g_normed)] = 0 + if self.g_bound is not None: + g_normed.clamp_(-self.g_bound, self.g_bound) + + p.data *= torch.exp( + -group['lr'] * g_normed * torch.sign(p.data)) + p.data.clamp_(-state['max'], state['max']) + + return loss diff --git a/imaginaire/third_party/__init__.py b/imaginaire/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imaginaire/third_party/__pycache__/__init__.cpython-38.pyc b/imaginaire/third_party/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5c112b422e5e97bfa8f09f096fbac9b1e6bae54 Binary files /dev/null and b/imaginaire/third_party/__pycache__/__init__.cpython-38.pyc differ diff --git a/imaginaire/third_party/bias_act/__init__.py b/imaginaire/third_party/bias_act/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9dfe0aec6e5fd4a1538ed959abff7c5106784c9b --- /dev/null +++ b/imaginaire/third_party/bias_act/__init__.py @@ -0,0 +1,3 @@ +from .bias_act import FusedNonlinearity + +__all__ = ['FusedNonlinearity'] diff --git a/imaginaire/third_party/bias_act/__pycache__/__init__.cpython-38.pyc b/imaginaire/third_party/bias_act/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11e3a34fdbfff20d3f67e073c1b6048c72f57216 Binary files /dev/null and b/imaginaire/third_party/bias_act/__pycache__/__init__.cpython-38.pyc differ diff --git a/imaginaire/third_party/bias_act/__pycache__/bias_act.cpython-38.pyc b/imaginaire/third_party/bias_act/__pycache__/bias_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d99f8e6245f42b7450319d0f9bd9974ee53b5402 Binary files /dev/null and b/imaginaire/third_party/bias_act/__pycache__/bias_act.cpython-38.pyc differ diff --git a/imaginaire/third_party/bias_act/bias_act.py b/imaginaire/third_party/bias_act/bias_act.py new file mode 100644 index 0000000000000000000000000000000000000000..29b01dc97884036aec1c42feb184c510c5ad0870 --- /dev/null +++ b/imaginaire/third_party/bias_act/bias_act.py @@ -0,0 +1,219 @@ +# flake8: noqa +import numpy as np +from types import SimpleNamespace + +import torch +from torch import nn + +import bias_act_cuda + +# ---------------------------------------------------------------------------- + +activation_funcs = { + 'linear': SimpleNamespace(func=lambda x, **_: x, def_alpha=0, def_gain=1, + cuda_idx=1, ref='', has_2nd_grad=False), + 'relu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.relu(x), + def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, + ref='y', has_2nd_grad=False), + 'leakyrelu': SimpleNamespace( + func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), + def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', + has_2nd_grad=False), + 'tanh': SimpleNamespace(func=lambda x, **_: torch.tanh(x), def_alpha=0, + def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), + 'sigmoid': SimpleNamespace(func=lambda x, **_: torch.sigmoid(x), + def_alpha=0, def_gain=1, cuda_idx=5, ref='y', + has_2nd_grad=True), + 'elu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.elu(x), + def_alpha=0, def_gain=1, cuda_idx=6, ref='y', + has_2nd_grad=True), + 'selu': SimpleNamespace(func=lambda x, **_: torch.nn.functional.selu(x), + def_alpha=0, def_gain=1, cuda_idx=7, ref='y', + has_2nd_grad=True), + 'softplus': SimpleNamespace( + func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, + def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), + 'swish': SimpleNamespace(func=lambda x, **_: torch.sigmoid(x) * x, + def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, + ref='x', has_2nd_grad=True), +} + +# ---------------------------------------------------------------------------- + +_null_tensor = torch.empty([0]) + + +def _bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, + impl='cuda'): + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda': + return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, + clamp=clamp).apply(x, b) + return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, + clamp=clamp) + + +# ---------------------------------------------------------------------------- + +def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): + assert isinstance(x, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if b is not None: + assert isinstance(b, torch.Tensor) and b.ndim == 1 + assert 0 <= dim < x.ndim + assert b.shape[0] == x.shape[dim] + x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + x = spec.func(x, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + x = x * gain + + # Clamp. + if clamp >= 0: + x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type + return x + + +# ---------------------------------------------------------------------------- + +_bias_act_cuda_cache = dict() + + +def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): + """Fast CUDA implementation of `bias_act()` using custom ops. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + if x.ndim > 2 and x.stride()[1] == 1: + ctx.memory_format = torch.channels_last + else: + ctx.memory_format = torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or b is not \ + _null_tensor: + y = bias_act_cuda.bias_act_cuda(x, b, _null_tensor, _null_tensor, + _null_tensor, 0, dim, spec.cuda_idx, alpha, + gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, + y if 'y' in spec.ref else _null_tensor) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + if x.ndim > 2 and x.stride()[1] == 1: + ctx.memory_format = torch.channels_last + else: + ctx.memory_format = torch.contiguous_format + dx = bias_act_cuda.bias_act_cuda(dy, b, x, y, _null_tensor, 1, dim, + spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor, + x, b, y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and ( + ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): + d_x = bias_act_cuda.bias_act_cuda(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, + alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda + + +class FusedNonlinearity(nn.Module): + def __init__(self, nonlinearity, num_channels=None, lr_mul=1.0, alpha=None, impl='cuda', gain=None): + super().__init__() + if num_channels is not None: + self.bias = nn.Parameter(torch.zeros(num_channels)) + else: + self.register_parameter('bias', None) + self.nonlinearity = nonlinearity + self.gain = gain + self.alpha = alpha + self.lr_mul = lr_mul + self.impl = impl + + def forward(self, x): + bias = self.bias.type_as(x) * self.lr_mul if self.bias is not None else None + return _bias_act( + x, b=bias, dim=1, act=self.nonlinearity, + alpha=self.alpha, gain=self.gain, clamp=None, impl=self.impl + ) + + def __repr__(self): + mod_str = f'{self.__class__.__name__}(type={self.nonlinearity}' + if self.gain is not None: + mod_str += f', gain={self.gain}' + if self.alpha is not None: + mod_str += f', alpha={self.alpha}' + if self.lr_mul != 1: + mod_str += f', lr_mul={self.lr_mul}' + mod_str += ')' + return mod_str diff --git a/imaginaire/third_party/bias_act/setup.py b/imaginaire/third_party/bias_act/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..13607b6083bd59eba24c5f4ac48a34048b55f642 --- /dev/null +++ b/imaginaire/third_party/bias_act/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + + +cuda_version = os.getenv('CUDA_VERSION') +print('CUDA_VERSION: {}'.format(cuda_version)) + +nvcc_args = list() +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_50,code=sm_50') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_52,code=sm_52') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_60,code=sm_60') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_61,code=sm_61') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_70,code=sm_70') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_75,code=sm_75') +if cuda_version is not None: + if cuda_version >= '11.0': + nvcc_args.append('-gencode') + nvcc_args.append('arch=compute_80,code=sm_80') +nvcc_args.append('-Xcompiler') +nvcc_args.append('-Wall') +nvcc_args.append('-std=c++14') + +setup( + name='bias_act_cuda', + py_modules=['bias_act'], + ext_modules=[ + CUDAExtension('bias_act_cuda', [ + './src/bias_act_cuda.cc', + './src/bias_act_cuda_kernel.cu' + ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'], + 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda.cc b/imaginaire/third_party/bias_act/src/bias_act_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf975dbe6784e89cfa056574da8780d1e5f5b97d --- /dev/null +++ b/imaginaire/third_party/bias_act/src/bias_act_cuda.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include +#include +#include + +#include "bias_act_cuda.h" + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) +{ + if (x.dim() != y.dim()) + return false; + for (int64_t i = 0; i < x.dim(); i++) + { + if (x.size(i) != y.size(i)) + return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) + return false; + } + return true; +} + +//------------------------------------------------------------------------ + +static torch::Tensor bias_act_cuda(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda_kernel", [&] + { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("bias_act_cuda", &bias_act_cuda); +} + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda.h b/imaginaire/third_party/bias_act/src/bias_act_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..a32187e1fb7e3bae509d4eceaf900866866875a4 --- /dev/null +++ b/imaginaire/third_party/bias_act/src/bias_act_cuda.h @@ -0,0 +1,38 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct bias_act_kernel_params +{ + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu b/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..9adbb942b5ce5740a5527449995e1887cda12816 --- /dev/null +++ b/imaginaire/third_party/bias_act/src/bias_act_cuda_kernel.cu @@ -0,0 +1,173 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "bias_act_cuda.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) + { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) + { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) + { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) + { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) + { + if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) + { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) + { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) + { + if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) + { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } + } + + // swish + if (A == 9) + { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else + { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) + { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template void* choose_bias_act_kernel(const bias_act_kernel_params& p) +{ + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); +template void* choose_bias_act_kernel (const bias_act_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/channelnorm/channelnorm.py b/imaginaire/third_party/channelnorm/channelnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd46711ca0bf2b6bb650112fa364100f6d4c927 --- /dev/null +++ b/imaginaire/third_party/channelnorm/channelnorm.py @@ -0,0 +1,39 @@ +# flake8: noqa +from torch.autograd import Function, Variable +from torch.nn.modules.module import Module +import channelnorm_cuda + + +class ChannelNormFunction(Function): + @staticmethod + def forward(ctx, input1, norm_deg=2): + assert input1.is_contiguous() + b, _, h, w = input1.size() + output = input1.new(b, 1, h, w).zero_() + + channelnorm_cuda.forward(input1, output, norm_deg) + ctx.save_for_backward(input1, output) + ctx.norm_deg = norm_deg + + return output + + @staticmethod + def backward(ctx, grad_output): + input1, output = ctx.saved_tensors + + grad_input1 = Variable(input1.new(input1.size()).zero_()) + + channelnorm_cuda.backward(input1, output, grad_output.data, + grad_input1.data, ctx.norm_deg) + + return grad_input1, None + + +class ChannelNorm(Module): + + def __init__(self, norm_deg=2): + super(ChannelNorm, self).__init__() + self.norm_deg = norm_deg + + def forward(self, input1): + return ChannelNormFunction.apply(input1, self.norm_deg) diff --git a/imaginaire/third_party/channelnorm/setup.py b/imaginaire/third_party/channelnorm/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8503ad2b254915bbbab391eb31baf0dcdc9a6bd1 --- /dev/null +++ b/imaginaire/third_party/channelnorm/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + + +cuda_version = os.getenv('CUDA_VERSION') +print('CUDA_VERSION: {}'.format(cuda_version)) + +nvcc_args = list() +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_50,code=sm_50') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_52,code=sm_52') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_60,code=sm_60') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_61,code=sm_61') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_70,code=sm_70') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_75,code=sm_75') +if cuda_version is not None: + if cuda_version >= '11.0': + nvcc_args.append('-gencode') + nvcc_args.append('arch=compute_80,code=sm_80') +nvcc_args.append('-Xcompiler') +nvcc_args.append('-Wall') +nvcc_args.append('-std=c++14') + +setup( + name='channelnorm_cuda', + py_modules=['channelnorm'], + ext_modules=[ + CUDAExtension('channelnorm_cuda', [ + './src/channelnorm_cuda.cc', + './src/channelnorm_kernel.cu' + ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'], + 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc b/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..69d82eb184e97b2eefa9810ad156d1104cf84745 --- /dev/null +++ b/imaginaire/third_party/channelnorm/src/channelnorm_cuda.cc @@ -0,0 +1,31 @@ +#include +#include + +#include "channelnorm_kernel.cuh" + +int channelnorm_cuda_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg) { + + channelnorm_kernel_forward(input1, output, norm_deg); + return 1; +} + + +int channelnorm_cuda_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg) { + + channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg); + return 1; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)"); + m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)"); +} + diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..99ace6855a61373443a6ddff7c7858eb474e9e48 --- /dev/null +++ b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cu @@ -0,0 +1,177 @@ +#include +#include +#include + +#include "channelnorm_kernel.cuh" + +#define CUDA_NUM_THREADS 512 + +#define DIM0(TENSOR) ((TENSOR).x) +#define DIM1(TENSOR) ((TENSOR).y) +#define DIM2(TENSOR) ((TENSOR).z) +#define DIM3(TENSOR) ((TENSOR).w) + +#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) + +using at::Half; + +template +__global__ void kernel_channelnorm_update_output( + const int n, + const scalar_t* __restrict__ input1, + const long4 input1_size, + const long4 input1_stride, + scalar_t* __restrict__ output, + const long4 output_size, + const long4 output_stride, + int norm_deg) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + int dim_b = DIM0(output_size); + int dim_c = DIM1(output_size); + int dim_h = DIM2(output_size); + int dim_w = DIM3(output_size); + int dim_chw = dim_c * dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + int i1dim_c = DIM1(input1_size); + int i1dim_h = DIM2(input1_size); + int i1dim_w = DIM3(input1_size); + int i1dim_chw = i1dim_c * i1dim_h * i1dim_w; + int i1dim_hw = i1dim_h * i1dim_w; + + float result = 0.0; + + for (int c = 0; c < i1dim_c; ++c) { + int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x; + scalar_t val = input1[i1Index]; + result += static_cast(val * val); + } + result = sqrt(result); + output[index] = static_cast(result); +} + + +template +__global__ void kernel_channelnorm_backward_input1( + const int n, + const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, + int norm_deg) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= n) { + return; + } + + float val = 0.0; + + int dim_b = DIM0(gradInput_size); + int dim_c = DIM1(gradInput_size); + int dim_h = DIM2(gradInput_size); + int dim_w = DIM3(gradInput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + + int outIndex = b * dim_hw + y * dim_w + x; + val = static_cast(gradOutput[outIndex]) * static_cast(input1[index]) / (static_cast(output[outIndex])+1e-9); + gradInput[index] = static_cast(val); + +} + +void channelnorm_kernel_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg) { + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + int n = output.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] { + + kernel_channelnorm_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + output.data(), + output_size, + output_stride, + norm_deg); + + })); + + // TODO: ATen-equivalent check + + // THCudaCheck(cudaGetLastError()); +} + +void channelnorm_kernel_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg) { + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); + const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); + + const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); + const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); + + int n = gradInput1.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] { + + kernel_channelnorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + output.data(), + output_size, + output_stride, + gradOutput.data(), + gradOutput_size, + gradOutput_stride, + gradInput1.data(), + gradInput1_size, + gradInput1_stride, + norm_deg + ); + + })); + + // TODO: Add ATen-equivalent check + +// THCudaCheck(cudaGetLastError()); +} diff --git a/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3e6223f7fe60feb4bf9e4f66c3d849b84c89dcda --- /dev/null +++ b/imaginaire/third_party/channelnorm/src/channelnorm_kernel.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include + +void channelnorm_kernel_forward( + at::Tensor& input1, + at::Tensor& output, + int norm_deg); + + +void channelnorm_kernel_backward( + at::Tensor& input1, + at::Tensor& output, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + int norm_deg); diff --git a/imaginaire/third_party/correlation/correlation.py b/imaginaire/third_party/correlation/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..e47739dff7475c0f29bff32bc2dc9f097161d144 --- /dev/null +++ b/imaginaire/third_party/correlation/correlation.py @@ -0,0 +1,105 @@ +# flake8: noqa +import torch +from torch.nn.modules.module import Module +from torch.autograd import Function +import correlation_cuda + + +class CorrelationFunction(Function): + + @staticmethod + def forward(ctx, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_multiply, + input1, + input2): + ctx.save_for_backward(input1, input2) + ctx.pad_size = pad_size + ctx.kernel_size = kernel_size + ctx.max_displacement = max_displacement + ctx.stride1 = stride1 + ctx.stride2 = stride2 + ctx.corr_multiply = corr_multiply + + with torch.cuda.device_of(input1): + rbot1 = input1.new() + rbot2 = input2.new() + output = input1.new() + + correlation_cuda.forward( + input1, + input2, + rbot1, + rbot2, + output, + ctx.pad_size, + ctx.kernel_size, + ctx.max_displacement, + ctx.stride1, + ctx.stride2, + ctx.corr_multiply) + + return output + + @staticmethod + def backward(ctx, grad_output): + input1, input2 = ctx.saved_tensors + + with torch.cuda.device_of(input1): + rbot1 = input1.new() + rbot2 = input2.new() + + grad_input1 = input1.new() + grad_input2 = input2.new() + + correlation_cuda.backward( + input1, + input2, + rbot1, + rbot2, + grad_output, + grad_input1, + grad_input2, + ctx.pad_size, + ctx.kernel_size, + ctx.max_displacement, + ctx.stride1, + ctx.stride2, + ctx.corr_multiply) + + return grad_input1, grad_input2 + +class Correlation(Module): + def __init__( + self, + pad_size=0, + kernel_size=0, + max_displacement=0, + stride1=1, + stride2=2, + corr_multiply=1): + super(Correlation, self).__init__() + self.pad_size = pad_size + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride1 = stride1 + self.stride2 = stride2 + self.corr_multiply = corr_multiply + + def forward(self, input1, input2): + + result = CorrelationFunction.apply( + self.pad_size, + self.kernel_size, + self.max_displacement, + self.stride1, + self.stride2, + self.corr_multiply, + input1, + input2) + + return result diff --git a/imaginaire/third_party/correlation/setup.py b/imaginaire/third_party/correlation/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..0c02aacc53102a5ef534db1a7cd69c546004f268 --- /dev/null +++ b/imaginaire/third_party/correlation/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + + +cuda_version = os.getenv('CUDA_VERSION') +print('CUDA_VERSION: {}'.format(cuda_version)) + +nvcc_args = list() +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_50,code=sm_50') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_52,code=sm_52') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_60,code=sm_60') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_61,code=sm_61') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_70,code=sm_70') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_75,code=sm_75') +if cuda_version is not None: + if cuda_version >= '11.0': + nvcc_args.append('-gencode') + nvcc_args.append('arch=compute_80,code=sm_80') +nvcc_args.append('-Xcompiler') +nvcc_args.append('-Wall') +nvcc_args.append('-std=c++14') + +setup( + name='correlation_cuda', + py_modules=['correlation'], + ext_modules=[ + CUDAExtension('correlation_cuda', [ + './src/correlation_cuda.cc', + './src/correlation_cuda_kernel.cu' + ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'], + 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/third_party/correlation/src/correlation_cuda.cc b/imaginaire/third_party/correlation/src/correlation_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..feccd65295fa90a22564b08fc80464a76361a1aa --- /dev/null +++ b/imaginaire/third_party/correlation/src/correlation_cuda.cc @@ -0,0 +1,173 @@ +#include +#include +#include +#include +#include +#include + +#include "correlation_cuda_kernel.cuh" + +int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply) +{ + + int batchSize = input1.size(0); + + int nInputChannels = input1.size(1); + int inputHeight = input1.size(2); + int inputWidth = input1.size(3); + + int kernel_radius = (kernel_size - 1) / 2; + int border_radius = kernel_radius + max_displacement; + + int paddedInputHeight = inputHeight + 2 * pad_size; + int paddedInputWidth = inputWidth + 2 * pad_size; + + int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); + + int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); + int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); + + rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); + + rInput1.fill_(0); + rInput2.fill_(0); + output.fill_(0); + + int success = correlation_forward_cuda_kernel( + output, + output.size(0), + output.size(1), + output.size(2), + output.size(3), + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + input1, + input1.size(1), + input1.size(2), + input1.size(3), + input1.stride(0), + input1.stride(1), + input1.stride(2), + input1.stride(3), + input2, + input2.size(1), + input2.stride(0), + input2.stride(1), + input2.stride(2), + input2.stride(3), + rInput1, + rInput2, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_type_multiply, + at::cuda::getCurrentCUDAStream() + //at::globalContext().getCurrentCUDAStream() + ); + + //check for errors + if (!success) { + AT_ERROR("CUDA call failed"); + } + + return 1; + +} + +int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, + at::Tensor& gradInput1, at::Tensor& gradInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply) +{ + + int batchSize = input1.size(0); + int nInputChannels = input1.size(1); + int paddedInputHeight = input1.size(2)+ 2 * pad_size; + int paddedInputWidth = input1.size(3)+ 2 * pad_size; + + int height = input1.size(2); + int width = input1.size(3); + + rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); + gradInput1.resize_({batchSize, nInputChannels, height, width}); + gradInput2.resize_({batchSize, nInputChannels, height, width}); + + rInput1.fill_(0); + rInput2.fill_(0); + gradInput1.fill_(0); + gradInput2.fill_(0); + + int success = correlation_backward_cuda_kernel(gradOutput, + gradOutput.size(0), + gradOutput.size(1), + gradOutput.size(2), + gradOutput.size(3), + gradOutput.stride(0), + gradOutput.stride(1), + gradOutput.stride(2), + gradOutput.stride(3), + input1, + input1.size(1), + input1.size(2), + input1.size(3), + input1.stride(0), + input1.stride(1), + input1.stride(2), + input1.stride(3), + input2, + input2.stride(0), + input2.stride(1), + input2.stride(2), + input2.stride(3), + gradInput1, + gradInput1.stride(0), + gradInput1.stride(1), + gradInput1.stride(2), + gradInput1.stride(3), + gradInput2, + gradInput2.size(1), + gradInput2.stride(0), + gradInput2.stride(1), + gradInput2.stride(2), + gradInput2.stride(3), + rInput1, + rInput2, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_type_multiply, + at::cuda::getCurrentCUDAStream() + //at::globalContext().getCurrentCUDAStream() + ); + + if (!success) { + AT_ERROR("CUDA call failed"); + } + + return 1; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); + m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); +} + diff --git a/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..eaf86fc129137d055de7400916567c6669b45c19 --- /dev/null +++ b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cu @@ -0,0 +1,564 @@ +#include + +#include "correlation_cuda_kernel.cuh" + +#define CUDA_NUM_THREADS 1024 +#define THREADS_PER_BLOCK 32 +#define FULL_MASK 0xffffffff + +#include +#include +#include +#include + +using at::Half; + +template +__forceinline__ __device__ scalar_t warpReduceSum(scalar_t val) { + for (int offset = 16; offset > 0; offset /= 2) + val += __shfl_down_sync(FULL_MASK, val, offset); + return val; +} + +template +__forceinline__ __device__ scalar_t blockReduceSum(scalar_t val) { + + static __shared__ scalar_t shared[32]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; + + if (wid == 0) + val = warpReduceSum(val); + + return val; +} + + +template +__global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size) +{ + + // n (batch size), c (num of channels), y (height), x (width) + int n = blockIdx.x; + int y = blockIdx.y; + int x = blockIdx.z; + + int ch_off = threadIdx.x; + scalar_t value; + + int dimcyx = channels * height * width; + int dimyx = height * width; + + int p_dimx = (width + 2 * pad_size); + int p_dimy = (height + 2 * pad_size); + int p_dimyxc = channels * p_dimy * p_dimx; + int p_dimxc = p_dimx * channels; + + for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { + value = input[n * dimcyx + c * dimyx + y * width + x]; + rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; + } +} + + +template +__global__ void correlation_forward(scalar_t* __restrict__ output, const int nOutputChannels, + const int outputHeight, const int outputWidth, const scalar_t* __restrict__ rInput1, + const int nInputChannels, const int inputHeight, const int inputWidth, + const scalar_t* __restrict__ rInput2, const int pad_size, const int kernel_size, + const int max_displacement, const int stride1, const int stride2) { + + int32_t pInputWidth = inputWidth + 2 * pad_size; + int32_t pInputHeight = inputHeight + 2 * pad_size; + + int32_t kernel_rad = (kernel_size - 1) / 2; + + int32_t displacement_rad = max_displacement / stride2; + + int32_t displacement_size = 2 * displacement_rad + 1; + + int32_t n = blockIdx.x; + int32_t y1 = blockIdx.y * stride1 + max_displacement; + int32_t x1 = blockIdx.z * stride1 + max_displacement; + int32_t c = threadIdx.x; + + int32_t pdimyxc = pInputHeight * pInputWidth * nInputChannels; + + int32_t pdimxc = pInputWidth * nInputChannels; + + int32_t pdimc = nInputChannels; + + int32_t tdimcyx = nOutputChannels * outputHeight * outputWidth; + int32_t tdimyx = outputHeight * outputWidth; + int32_t tdimx = outputWidth; + + int32_t nelems = kernel_size * kernel_size * pdimc; + + // element-wise product along channel axis + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + int x2 = x1 + ti * stride2; + int y2 = y1 + tj * stride2; + + float acc0 = 0.0f; + + for (int j = -kernel_rad; j <= kernel_rad; ++j) { + for (int i = -kernel_rad; i <= kernel_rad; ++i) { + // THREADS_PER_BLOCK + #pragma unroll + for (int ch = c; ch < pdimc; ch += blockDim.x) { + + int indx1 = n * pdimyxc + (y1 + j) * pdimxc + + (x1 + i) * pdimc + ch; + int indx2 = n * pdimyxc + (y2 + j) * pdimxc + + (x2 + i) * pdimc + ch; + acc0 += static_cast(rInput1[indx1] * rInput2[indx2]); + } + } + } + + if (blockDim.x == warpSize) { + __syncwarp(); + acc0 = warpReduceSum(acc0); + } else { + __syncthreads(); + acc0 = blockReduceSum(acc0); + } + + if (threadIdx.x == 0) { + + int tc = (tj + displacement_rad) * displacement_size + + (ti + displacement_rad); + const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx + + blockIdx.z; + output[tindx] = static_cast(acc0 / nelems); + } + } + } +} + + +template +__global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, + const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, + const scalar_t* __restrict__ rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2) + { + // n (batch size), c (num of channels), y (height), x (width) + + int n = item; + int y = blockIdx.x * stride1 + pad_size; + int x = blockIdx.y * stride1 + pad_size; + int c = blockIdx.z; + int tch_off = threadIdx.x; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + int displacement_size = 2 * displacement_rad + 1; + + int xmin = (x - kernel_rad - max_displacement) / stride1; + int ymin = (y - kernel_rad - max_displacement) / stride1; + + int xmax = (x + kernel_rad - max_displacement) / stride1; + int ymax = (y + kernel_rad - max_displacement) / stride1; + + if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { + // assumes gradInput1 is pre-allocated and zero filled + return; + } + + if (xmin > xmax || ymin > ymax) { + // assumes gradInput1 is pre-allocated and zero filled + return; + } + + xmin = max(0,xmin); + xmax = min(outputWidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(outputHeight-1,ymax); + + int pInputWidth = inputWidth + 2 * pad_size; + int pInputHeight = inputHeight + 2 * pad_size; + + int pdimyxc = pInputHeight * pInputWidth * nInputChannels; + int pdimxc = pInputWidth * nInputChannels; + int pdimc = nInputChannels; + + int tdimcyx = nOutputChannels * outputHeight * outputWidth; + int tdimyx = outputHeight * outputWidth; + int tdimx = outputWidth; + + int odimcyx = nInputChannels * inputHeight* inputWidth; + int odimyx = inputHeight * inputWidth; + int odimx = inputWidth; + + scalar_t nelems = kernel_size * kernel_size * nInputChannels; + + __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; + prod_sum[tch_off] = 0; + + for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { + + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; + + scalar_t val2 = rInput2[indx2]; + + for (int j = ymin; j <= ymax; ++j) { + for (int i = xmin; i <= xmax; ++i) { + int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; + prod_sum[tch_off] += gradOutput[tindx] * val2; + } + } + } + __syncthreads(); + + if(tch_off == 0) { + scalar_t reduce_sum = 0; + for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { + reduce_sum += prod_sum[idx]; + } + const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); + gradInput1[indx1] = reduce_sum / nelems; + } + +} + +template +__global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth, + const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, + const scalar_t* __restrict__ rInput1, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2) +{ + // n (batch size), c (num of channels), y (height), x (width) + + int n = item; + int y = blockIdx.x * stride1 + pad_size; + int x = blockIdx.y * stride1 + pad_size; + int c = blockIdx.z; + + int tch_off = threadIdx.x; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + int displacement_size = 2 * displacement_rad + 1; + + int pInputWidth = inputWidth + 2 * pad_size; + int pInputHeight = inputHeight + 2 * pad_size; + + int pdimyxc = pInputHeight * pInputWidth * nInputChannels; + int pdimxc = pInputWidth * nInputChannels; + int pdimc = nInputChannels; + + int tdimcyx = nOutputChannels * outputHeight * outputWidth; + int tdimyx = outputHeight * outputWidth; + int tdimx = outputWidth; + + int odimcyx = nInputChannels * inputHeight* inputWidth; + int odimyx = inputHeight * inputWidth; + int odimx = inputWidth; + + scalar_t nelems = kernel_size * kernel_size * nInputChannels; + + __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; + prod_sum[tch_off] = 0; + + for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int xmin = (x - kernel_rad - max_displacement - i2) / stride1; + int ymin = (y - kernel_rad - max_displacement - j2) / stride1; + + int xmax = (x + kernel_rad - max_displacement - i2) / stride1; + int ymax = (y + kernel_rad - max_displacement - j2) / stride1; + + if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { + // assumes gradInput2 is pre-allocated and zero filled + continue; + } + + if (xmin > xmax || ymin > ymax) { + // assumes gradInput2 is pre-allocated and zero filled + continue; + } + + xmin = max(0,xmin); + xmax = min(outputWidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(outputHeight-1,ymax); + + int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; + scalar_t val1 = rInput1[indx1]; + + for (int j = ymin; j <= ymax; ++j) { + for (int i = xmin; i <= xmax; ++i) { + int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; + prod_sum[tch_off] += gradOutput[tindx] * val1; + } + } + } + + __syncthreads(); + + if(tch_off == 0) { + scalar_t reduce_sum = 0; + for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { + reduce_sum += prod_sum[idx]; + } + const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); + gradInput2[indx2] = reduce_sum / nelems; + } + +} + +int correlation_forward_cuda_kernel(at::Tensor& output, + int ob, + int oc, + int oh, + int ow, + int osb, + int osc, + int osh, + int osw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gc, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream) +{ + + int batchSize = ob; + + int nInputChannels = ic; + int inputWidth = iw; + int inputHeight = ih; + + int nOutputChannels = oc; + int outputWidth = ow; + int outputHeight = oh; + + dim3 blocks_grid(batchSize, inputHeight, inputWidth); + dim3 threads_block(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] { + + channels_first<<>>( + input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size); + + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] { + + channels_first<<>> ( + input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size); + + })); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] { + + correlation_forward<<>> + (output.data(), nOutputChannels, outputHeight, outputWidth, + rInput1.data(), nInputChannels, inputHeight, inputWidth, + rInput2.data(), + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + + })); + + cudaError_t err = cudaGetLastError(); + + + // check for errors + if (err != cudaSuccess) { + printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); + return 0; + } + + return 1; +} + + +int correlation_backward_cuda_kernel( + at::Tensor& gradOutput, + int gob, + int goc, + int goh, + int gow, + int gosb, + int gosc, + int gosh, + int gosw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& gradInput1, + int gisb, + int gisc, + int gish, + int gisw, + + at::Tensor& gradInput2, + int ggc, + int ggsb, + int ggsc, + int ggsh, + int ggsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream) +{ + + int batchSize = gob; + int num = batchSize; + + int nInputChannels = ic; + int inputWidth = iw; + int inputHeight = ih; + + int nOutputChannels = goc; + int outputWidth = gow; + int outputHeight = goh; + + dim3 blocks_grid(batchSize, inputHeight, inputWidth); + dim3 threads_block(THREADS_PER_BLOCK); + + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] { + + channels_first<<>>( + input1.data(), + rInput1.data(), + nInputChannels, + inputHeight, + inputWidth, + pad_size + ); + })); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { + + channels_first<<>>( + input2.data(), + rInput2.data(), + nInputChannels, + inputHeight, + inputWidth, + pad_size + ); + })); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); + + for (int n = 0; n < num; ++n) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { + + + correlation_backward_input1<<>> ( + n, gradInput1.data(), nInputChannels, inputHeight, inputWidth, + gradOutput.data(), nOutputChannels, outputHeight, outputWidth, + rInput2.data(), + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + })); + } + + for(int n = 0; n < batchSize; n++) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] { + + correlation_backward_input2<<>>( + n, gradInput2.data(), nInputChannels, inputHeight, inputWidth, + gradOutput.data(), nOutputChannels, outputHeight, outputWidth, + rInput1.data(), + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + + })); + } + + // check for errors + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); + return 0; + } + + return 1; +} diff --git a/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1586d3af6bc184bfea8482a991a6625f865f02b3 --- /dev/null +++ b/imaginaire/third_party/correlation/src/correlation_cuda_kernel.cuh @@ -0,0 +1,91 @@ +#pragma once + +#include +#include +#include + +int correlation_forward_cuda_kernel(at::Tensor& output, + int ob, + int oc, + int oh, + int ow, + int osb, + int osc, + int osh, + int osw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gc, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream); + + +int correlation_backward_cuda_kernel( + at::Tensor& gradOutput, + int gob, + int goc, + int goh, + int gow, + int gosb, + int gosc, + int gosh, + int gosw, + + at::Tensor& input1, + int ic, + int ih, + int iw, + int isb, + int isc, + int ish, + int isw, + + at::Tensor& input2, + int gsb, + int gsc, + int gsh, + int gsw, + + at::Tensor& gradInput1, + int gisb, + int gisc, + int gish, + int gisw, + + at::Tensor& gradInput2, + int ggc, + int ggsb, + int ggsc, + int ggsh, + int ggsw, + + at::Tensor& rInput1, + at::Tensor& rInput2, + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2, + int corr_type_multiply, + cudaStream_t stream); diff --git a/imaginaire/third_party/flow_net/__init__.py b/imaginaire/third_party/flow_net/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/imaginaire/third_party/flow_net/flow_net.py b/imaginaire/third_party/flow_net/flow_net.py new file mode 100644 index 0000000000000000000000000000000000000000..41759c50fa1389b6fbe2e5db725a4b71ff3f2342 --- /dev/null +++ b/imaginaire/third_party/flow_net/flow_net.py @@ -0,0 +1,89 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn as nn +import torch.nn.functional as F +import types +from imaginaire.third_party.flow_net.flownet2 import models as \ + flownet2_models +from imaginaire.third_party.flow_net.flownet2.utils import tools \ + as flownet2_tools +from imaginaire.model_utils.fs_vid2vid import resample +from imaginaire.utils.io import get_checkpoint + + +class FlowNet(nn.Module): + def __init__(self, pretrained=True, fp16=False): + super().__init__() + flownet2_args = types.SimpleNamespace() + setattr(flownet2_args, 'fp16', fp16) + setattr(flownet2_args, 'rgb_max', 1.0) + if fp16: + print('FlowNet2 is running in fp16 mode.') + self.flowNet = flownet2_tools.module_to_dict(flownet2_models)[ + 'FlowNet2'](flownet2_args).to('cuda') + if pretrained: + flownet2_path = get_checkpoint('flownet2.pth.tar', + '1hF8vS6YeHkx3j2pfCeQqqZGwA_PJq_Da') + checkpoint = torch.load(flownet2_path, + map_location=torch.device('cpu')) + self.flowNet.load_state_dict(checkpoint['state_dict']) + self.flowNet.eval() + + def forward(self, input_A, input_B): + size = input_A.size() + assert(len(size) == 4 or len(size) == 5 or len(size) == 6) + if len(size) >= 5: + if len(size) == 5: + b, n, c, h, w = size + else: + b, t, n, c, h, w = size + input_A = input_A.contiguous().view(-1, c, h, w) + input_B = input_B.contiguous().view(-1, c, h, w) + flow, conf = self.compute_flow_and_conf(input_A, input_B) + if len(size) == 5: + return flow.view(b, n, 2, h, w), conf.view(b, n, 1, h, w) + else: + return flow.view(b, t, n, 2, h, w), conf.view(b, t, n, 1, h, w) + else: + return self.compute_flow_and_conf(input_A, input_B) + + def compute_flow_and_conf(self, im1, im2): + assert(im1.size()[1] == 3) + assert(im1.size() == im2.size()) + old_h, old_w = im1.size()[2], im1.size()[3] + new_h, new_w = old_h // 64 * 64, old_w // 64 * 64 + if old_h != new_h: + im1 = F.interpolate(im1, size=(new_h, new_w), mode='bilinear', + align_corners=False) + im2 = F.interpolate(im2, size=(new_h, new_w), mode='bilinear', + align_corners=False) + data1 = torch.cat([im1.unsqueeze(2), im2.unsqueeze(2)], dim=2) + with torch.no_grad(): + flow1 = self.flowNet(data1) + # img_diff = torch.sum(abs(im1 - resample(im2, flow1)), + # dim=1, keepdim=True) + # conf = torch.clamp(1 - img_diff, 0, 1) + + conf = (self.norm(im1 - resample(im2, flow1)) < 0.02).float() + + # data2 = torch.cat([im2.unsqueeze(2), im1.unsqueeze(2)], dim=2) + # with torch.no_grad(): + # flow2 = self.flowNet(data2) + # warped_flow2 = resample(flow2, flow1) + # flow_sum = self.norm(flow1 + warped_flow2) + # disocc = flow_sum > (0.05 * (self.norm(flow1) + + # self.norm(warped_flow2)) + 0.5) + # conf = 1 - disocc.float() + + if old_h != new_h: + flow1 = F.interpolate(flow1, size=(old_h, old_w), mode='bilinear', + align_corners=False) * old_h / new_h + conf = F.interpolate(conf, size=(old_h, old_w), mode='bilinear', + align_corners=False) + return flow1, conf + + def norm(self, t): + return torch.sum(t * t, dim=1, keepdim=True) diff --git a/imaginaire/third_party/flow_net/flownet2/models.py b/imaginaire/third_party/flow_net/flownet2/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d0066464a01942c85909a7e5ddbc97e39f244623 --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/models.py @@ -0,0 +1,474 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch.nn import init +import torch.nn as nn +import resample2d +import channelnorm +import numpy as np +from imaginaire.third_party.flow_net.flownet2.networks import flownet_c +from imaginaire.third_party.flow_net.flownet2.networks import flownet_s +from imaginaire.third_party.flow_net.flownet2.networks import flownet_sd +from imaginaire.third_party.flow_net.flownet2.networks import flownet_fusion +from imaginaire.third_party.flow_net.flownet2.networks.submodules import \ + tofp16, tofp32 +'Parameter count = 162,518,834' + + +class FlowNet2(nn.Module): + def __init__(self, args, use_batch_norm=False, div_flow=20.): + super(FlowNet2, self).__init__() + self.batch_norm = use_batch_norm + self.div_flow = div_flow + self.rgb_max = args.rgb_max + self.args = args + self.channelnorm = channelnorm.ChannelNorm() + # First Block (FlowNetC) + self.flownetc = flownet_c.FlowNetC( + args, use_batch_norm=self.batch_norm) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear', + align_corners=False) + self.args = args + # if args.fp16: + # self.resample1 = nn.Sequential( + # tofp32(), resample2d.Resample2d(), tofp16()) + # else: + self.resample1 = resample2d.Resample2d() + # Block (FlowNetS1) + self.flownets_1 = flownet_s.FlowNetS( + args, use_batch_norm=self.batch_norm) + self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear', + align_corners=False) + # if args.fp16: + # self.resample2 = nn.Sequential( + # tofp32(), resample2d.Resample2d(), tofp16()) + # else: + self.resample2 = resample2d.Resample2d() + # Block (FlowNetS2) + self.flownets_2 = flownet_s.FlowNetS( + args, use_batch_norm=self.batch_norm) + # Block (FlowNetSD) + self.flownets_d = flownet_sd.FlowNetSD( + args, use_batch_norm=self.batch_norm) + self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest') + self.upsample4 = nn.Upsample(scale_factor=4, mode='nearest') + # if args.fp16: + # self.resample3 = nn.Sequential( + # tofp32(), resample2d.Resample2d(), tofp16()) + # else: + self.resample3 = resample2d.Resample2d() + # if args.fp16: + # self.resample4 = nn.Sequential( + # tofp32(), resample2d.Resample2d(), tofp16()) + # else: + self.resample4 = resample2d.Resample2d() + # Block (FLowNetFusion) + self.flownetfusion = flownet_fusion.FlowNetFusion( + args, use_batch_norm=self.batch_norm) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + def init_deconv_bilinear(self, weight): + f_shape = weight.size() + height, width = f_shape[-2], f_shape[-1] + f = np.ceil(width / 2.0) + c = (2 * f - 1 - f % 2) / (2.0 * f) + bilinear = np.zeros([height, width]) + for x in range(width): + for y in range(height): + value = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) + bilinear[x, y] = value + min_dim = min(f_shape[0], f_shape[1]) + weight.data.fill_(0.) + for i in range(min_dim): + weight.data[i, i, :, :] = torch.from_numpy(bilinear) + return + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean( + dim=-1).view(inputs.size()[:2] + (1, 1, 1,)) + x = (inputs - rgb_mean) / self.rgb_max + x1 = x[:, :, 0, :, :] + x2 = x[:, :, 1, :, :] + x = torch.cat((x1, x2), dim=1) + # flownetc + flownetc_flow2 = self.flownetc(x)[0] + flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow) + # warp img1 to img0; + # magnitude of diff between img0 and and warped_img1, + if self.args.fp16: + resampled_img1 = self.resample1(tofp32()(x[:, 3:, :, :]), + flownetc_flow) + resampled_img1 = tofp16()(resampled_img1) + else: + resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow) + diff_img0 = x[:, :3, :, :] - resampled_img1 + norm_diff_img0 = self.channelnorm(diff_img0) + # concat img0, img1, img1->img0, flow, diff-mag ; + concat1 = torch.cat( + (x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0), + dim=1) + # flownets1 + flownets1_flow2 = self.flownets_1(concat1)[0] + flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow) + # warp img1 to img0 using flownets1; + # magnitude of diff between img0 and and warped_img1 + if self.args.fp16: + resampled_img1 = self.resample2(tofp32()(x[:, 3:, :, :]), + flownets1_flow) + resampled_img1 = tofp16()(resampled_img1) + else: + resampled_img1 = self.resample2(x[:, 3:, :, :], flownets1_flow) + diff_img0 = x[:, :3, :, :] - resampled_img1 + norm_diff_img0 = self.channelnorm(diff_img0) + # concat img0, img1, img1->img0, flow, diff-mag + concat2 = torch.cat( + (x, + resampled_img1, + flownets1_flow / + self.div_flow, + norm_diff_img0), + dim=1) + # flownets2 + flownets2_flow2 = self.flownets_2(concat2)[0] + flownets2_flow = self.upsample4(flownets2_flow2 * self.div_flow) + norm_flownets2_flow = self.channelnorm(flownets2_flow) + if self.args.fp16: + diff_flownets2_flow = self.resample4(tofp32()(x[:, 3:, :, :]), + flownets2_flow) + diff_flownets2_flow = tofp16()(diff_flownets2_flow) + else: + diff_flownets2_flow = self.resample4(x[:, 3:, :, :], flownets2_flow) + diff_flownets2_img1 = self.channelnorm( + (x[:, :3, :, :] - diff_flownets2_flow)) + # flownetsd + flownetsd_flow2 = self.flownets_d(x)[0] + flownetsd_flow = self.upsample3(flownetsd_flow2 / self.div_flow) + norm_flownetsd_flow = self.channelnorm(flownetsd_flow) + if self.args.fp16: + diff_flownetsd_flow = self.resample3(tofp32()(x[:, 3:, :, :]), + flownetsd_flow) + diff_flownetsd_flow = tofp16()(diff_flownetsd_flow) + else: + diff_flownetsd_flow = self.resample3(x[:, 3:, :, :], flownetsd_flow) + diff_flownetsd_img1 = self.channelnorm( + (x[:, :3, :, :] - diff_flownetsd_flow)) + # concat img1 flownetsd, flownets2, norm_flownetsd, + # norm_flownets2, diff_flownetsd_img1, diff_flownets2_img1 + concat3 = torch.cat((x[:, :3, :, :], flownetsd_flow, flownets2_flow, + norm_flownetsd_flow, norm_flownets2_flow, + diff_flownetsd_img1, diff_flownets2_img1), dim=1) + flownetfusion_flow = self.flownetfusion(concat3) + return flownetfusion_flow + + +class FlowNet2C(flownet_c.FlowNetC): + def __init__(self, args, use_batch_norm=False, div_flow=20): + super( + FlowNet2C, + self).__init__( + args, + use_batch_norm=use_batch_norm, + div_flow=20) + self.rgb_max = args.rgb_max + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean( + dim=-1).view(inputs.size()[:2] + (1, 1, 1,)) + x = (inputs - rgb_mean) / self.rgb_max + x1 = x[:, :, 0, :, :] + x2 = x[:, :, 1, :, :] + # FlownetC top input stream + out_conv1a = self.conv1(x1) + out_conv2a = self.conv2(out_conv1a) + out_conv3a = self.conv3(out_conv2a) + # FlownetC bottom input stream + out_conv1b = self.conv1(x2) + out_conv2b = self.conv2(out_conv1b) + out_conv3b = self.conv3(out_conv2b) + # Merge streams + out_corr = self.corr(out_conv3a, out_conv3b) # False + out_corr = self.corr_activation(out_corr) + # Redirect top input stream and concatenate + out_conv_redir = self.conv_redir(out_conv3a) + in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1) + # Merged conv layers + out_conv3_1 = self.conv3_1(in_conv3_1) + out_conv4 = self.conv4_1(self.conv4(out_conv3_1)) + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) + flow5 = self.predict_flow5(concat5) + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) + flow4 = self.predict_flow4(concat4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + concat3 = torch.cat((out_conv3_1, out_deconv3, flow4_up), 1) + flow3 = self.predict_flow3(concat3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + concat2 = torch.cat((out_conv2a, out_deconv2, flow3_up), 1) + flow2 = self.predict_flow2(concat2) + if self.training: + return flow2, flow3, flow4, flow5, flow6 + else: + return self.upsample1(flow2 * self.div_flow) + + +class FlowNet2S(flownet_s.FlowNetS): + def __init__(self, args, use_batch_norm=False, div_flow=20): + super(FlowNet2S, self).__init__(args, input_channels=6, + use_batch_norm=use_batch_norm) + self.rgb_max = args.rgb_max + self.div_flow = div_flow + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean( + dim=-1).view(inputs.size()[:2] + (1, 1, 1,)) + x = (inputs - rgb_mean) / self.rgb_max + x = torch.cat((x[:, :, 0, :, :], x[:, :, 1, :, :]), dim=1) + out_conv1 = self.conv1(x) + out_conv2 = self.conv2(out_conv1) + out_conv3 = self.conv3_1(self.conv3(out_conv2)) + out_conv4 = self.conv4_1(self.conv4(out_conv3)) + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) + flow5 = self.predict_flow5(concat5) + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) + flow4 = self.predict_flow4(concat4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1) + flow3 = self.predict_flow3(concat3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1) + flow2 = self.predict_flow2(concat2) + if self.training: + return flow2, flow3, flow4, flow5, flow6 + else: + return self.upsample1(flow2 * self.div_flow) + + +class FlowNet2SD(flownet_sd.FlowNetSD): + def __init__(self, args, use_batch_norm=False, div_flow=20): + super(FlowNet2SD, self).__init__(args, use_batch_norm=use_batch_norm) + self.rgb_max = args.rgb_max + self.div_flow = div_flow + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean( + dim=-1).view(inputs.size()[:2] + (1, 1, 1,)) + x = (inputs - rgb_mean) / self.rgb_max + x = torch.cat((x[:, :, 0, :, :], x[:, :, 1, :, :]), dim=1) + out_conv0 = self.conv0(x) + out_conv1 = self.conv1_1(self.conv1(out_conv0)) + out_conv2 = self.conv2_1(self.conv2(out_conv1)) + out_conv3 = self.conv3_1(self.conv3(out_conv2)) + out_conv4 = self.conv4_1(self.conv4(out_conv3)) + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) + out_interconv5 = self.inter_conv5(concat5) + flow5 = self.predict_flow5(out_interconv5) + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) + out_interconv4 = self.inter_conv4(concat4) + flow4 = self.predict_flow4(out_interconv4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1) + out_interconv3 = self.inter_conv3(concat3) + flow3 = self.predict_flow3(out_interconv3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1) + out_interconv2 = self.inter_conv2(concat2) + flow2 = self.predict_flow2(out_interconv2) + if self.training: + return flow2, flow3, flow4, flow5, flow6 + else: + return self.upsample1(flow2 * self.div_flow) + + +class FlowNet2CS(nn.Module): + def __init__(self, args, use_batch_norm=False, div_flow=20.): + super(FlowNet2CS, self).__init__() + self.use_batch_norm = use_batch_norm + self.div_flow = div_flow + self.rgb_max = args.rgb_max + self.args = args + self.channelnorm = channelnorm.ChannelNorm() + # First Block (FlowNetC) + self.flownetc = flownet_c.FlowNetC( + args, use_batch_norm=self.use_batch_norm) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear', + align_corners=False) + self.args = args + # if args.fp16: + # self.resample1 = nn.Sequential( + # tofp32(), resample2d.Resample2d(), tofp16()) + # else: + self.resample1 = resample2d.Resample2d() + # Block (FlowNetS1) + self.flownets_1 = flownet_s.FlowNetS( + args, use_batch_norm=self.use_batch_norm) + self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear', + align_corners=False) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform(m.bias) + init.xavier_uniform(m.weight) + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform(m.bias) + init.xavier_uniform(m.weight) + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean( + dim=-1).view(inputs.size()[:2] + (1, 1, 1,)) + x = (inputs - rgb_mean) / self.rgb_max + x1 = x[:, :, 0, :, :] + x2 = x[:, :, 1, :, :] + x = torch.cat((x1, x2), dim=1) + # flownetc + flownetc_flow2 = self.flownetc(x)[0] + flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow) + # warp img1 to img0; + # magnitude of diff between img0 and and warped_img1, + if self.args.fp16: + resampled_img1 = self.resample1(tofp32()(x[:, 3:, :, :]), + flownetc_flow) + resampled_img1 = tofp16()(resampled_img1) + else: + resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow) + diff_img0 = x[:, :3, :, :] - resampled_img1 + norm_diff_img0 = self.channelnorm(diff_img0) + # concat img0, img1, img1->img0, flow, diff-mag ; + concat1 = torch.cat( + (x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0), + dim=1) + # flownets1 + flownets1_flow2 = self.flownets_1(concat1)[0] + flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow) + return flownets1_flow + + +class FlowNet2CSS(nn.Module): + def __init__(self, args, use_batch_norm=False, div_flow=20.): + super(FlowNet2CSS, self).__init__() + self.use_batch_norm = use_batch_norm + self.div_flow = div_flow + self.rgb_max = args.rgb_max + self.args = args + self.channelnorm = channelnorm.ChannelNorm() + # First Block (FlowNetC) + self.flownetc = flownet_c.FlowNetC( + args, use_batch_norm=self.use_batch_norm) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear', + align_corners=False) + self.args = args + # if args.fp16: + # self.resample1 = nn.Sequential( + # tofp32(), resample2d.Resample2d(), tofp16()) + # else: + self.resample1 = resample2d.Resample2d() + # Block (FlowNetS1) + self.flownets_1 = flownet_s.FlowNetS( + args, use_batch_norm=self.use_batch_norm) + self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear', + align_corners=False) + # if args.fp16: + # self.resample2 = nn.Sequential( + # tofp32(), resample2d.Resample2d(), tofp16()) + # else: + self.resample2 = resample2d.Resample2d() + # Block (FlowNetS2) + self.flownets_2 = flownet_s.FlowNetS( + args, use_batch_norm=self.use_batch_norm) + self.upsample3 = nn.Upsample(scale_factor=4, mode='nearest', + align_corners=False) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform(m.bias) + init.xavier_uniform(m.weight) + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform(m.bias) + init.xavier_uniform(m.weight) + + def forward(self, inputs): + rgb_mean = inputs.contiguous().view(inputs.size()[:2] + (-1,)).mean( + dim=-1).view(inputs.size()[:2] + (1, 1, 1,)) + x = (inputs - rgb_mean) / self.rgb_max + x1 = x[:, :, 0, :, :] + x2 = x[:, :, 1, :, :] + x = torch.cat((x1, x2), dim=1) + # flownetc + flownetc_flow2 = self.flownetc(x)[0] + flownetc_flow = self.upsample1(flownetc_flow2 * self.div_flow) + # Warp img1 to img0; + # Magnitude of diff between img0 and and warped_img1, + if self.args.fp16: + resampled_img1 = self.resample1(tofp32()(x[:, 3:, :, :]), + flownetc_flow) + resampled_img1 = tofp16()(resampled_img1) + else: + resampled_img1 = self.resample1(x[:, 3:, :, :], flownetc_flow) + diff_img0 = x[:, :3, :, :] - resampled_img1 + norm_diff_img0 = self.channelnorm(diff_img0) + # concat img0, img1, img1->img0, flow, diff-mag ; + concat1 = torch.cat( + (x, resampled_img1, flownetc_flow / self.div_flow, norm_diff_img0), + dim=1) + # flownets1 + flownets1_flow2 = self.flownets_1(concat1)[0] + flownets1_flow = self.upsample2(flownets1_flow2 * self.div_flow) + # Warp img1 to img0 using flownets1; + # magnitude of diff between img0 and and warped_img1 + if self.args.fp16: + resampled_img1 = self.resample2(tofp32()(x[:, 3:, :, :]), + flownets1_flow) + resampled_img1 = tofp16()(resampled_img1) + else: + resampled_img1 = self.resample2(x[:, 3:, :, :], flownets1_flow) + diff_img0 = x[:, :3, :, :] - resampled_img1 + norm_diff_img0 = self.channelnorm(diff_img0) + # concat img0, img1, img1->img0, flow, diff-mag + concat2 = torch.cat( + (x, + resampled_img1, + flownets1_flow / + self.div_flow, + norm_diff_img0), + dim=1) + # flownets2 + flownets2_flow2 = self.flownets_2(concat2)[0] + flownets2_flow = self.upsample3(flownets2_flow2 * self.div_flow) + return flownets2_flow diff --git a/imaginaire/third_party/flow_net/flownet2/networks/__init__.py b/imaginaire/third_party/flow_net/flownet2/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb7ad34e7ce1c37ea1653d73cde323cdb5569e4 --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/networks/__init__.py @@ -0,0 +1,6 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch +# with some modifications. diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_c.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_c.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b3719c26eda72d61429c3c52b49707cf48c558 --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_c.py @@ -0,0 +1,160 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch +# with some modifications. +from torch.nn import init +import correlation +import torch +import torch.nn as nn +from .submodules import conv, predict_flow, deconv, tofp16, tofp32 + + +class FlowNetC(nn.Module): + def __init__(self, args, use_batch_norm=True, div_flow=20): + r"""FlowNet2 C module. Check out the FlowNet2 paper for more details + https://arxiv.org/abs/1612.01925 + + Args: + args (obj): Network initialization arguments + use_batch_norm (bool): Use batch norm or not. Default is true. + div_flow (int): Flow devision factor. Default is 20. + """ + super(FlowNetC, self).__init__() + + self.use_batch_norm = use_batch_norm + self.div_flow = div_flow + + self.conv1 = conv(self.use_batch_norm, 3, 64, kernel_size=7, stride=2) + self.conv2 = conv(self.use_batch_norm, 64, 128, kernel_size=5, stride=2) + self.conv3 = conv(self.use_batch_norm, 128, 256, kernel_size=5, + stride=2) + self.conv_redir = conv(self.use_batch_norm, 256, 32, + kernel_size=1, stride=1) + self.args = args + # if args.fp16: + # self.corr = nn.Sequential( + # tofp32(), + # correlation.Correlation(pad_size=20, kernel_size=1, + # max_displacement=20, stride1=1, + # stride2=2, corr_multiply=1), + # tofp16()) + # else: + self.corr = correlation.Correlation(pad_size=20, kernel_size=1, + max_displacement=20, stride1=1, + stride2=2, corr_multiply=1) + + self.corr_activation = nn.LeakyReLU(0.1, inplace=True) + self.conv3_1 = conv(self.use_batch_norm, 473, 256) + self.conv4 = conv(self.use_batch_norm, 256, 512, stride=2) + self.conv4_1 = conv(self.use_batch_norm, 512, 512) + self.conv5 = conv(self.use_batch_norm, 512, 512, stride=2) + self.conv5_1 = conv(self.use_batch_norm, 512, 512) + self.conv6 = conv(self.use_batch_norm, 512, 1024, stride=2) + self.conv6_1 = conv(self.use_batch_norm, 1024, 1024) + + self.deconv5 = deconv(1024, 512) + self.deconv4 = deconv(1026, 256) + self.deconv3 = deconv(770, 128) + self.deconv2 = deconv(386, 64) + + self.predict_flow6 = predict_flow(1024) + self.predict_flow5 = predict_flow(1026) + self.predict_flow4 = predict_flow(770) + self.predict_flow3 = predict_flow(386) + self.predict_flow2 = predict_flow(194) + + self.upsampled_flow6_to_5 = nn.ConvTranspose2d( + 2, 2, 4, 2, 1, bias=True) + self.upsampled_flow5_to_4 = nn.ConvTranspose2d( + 2, 2, 4, 2, 1, bias=True) + self.upsampled_flow4_to_3 = nn.ConvTranspose2d( + 2, 2, 4, 2, 1, bias=True) + self.upsampled_flow3_to_2 = nn.ConvTranspose2d( + 2, 2, 4, 2, 1, bias=True) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear', + align_corners=False) + + def forward(self, x): + r""" + + Args: + x (tensor): Input tensors of concatenated images. + Returns: + flow2 (tensor): Output flow tensors. + """ + x1 = x[:, 0:3, :, :] + x2 = x[:, 3::, :, :] + + out_conv1a = self.conv1(x1) + out_conv2a = self.conv2(out_conv1a) + out_conv3a = self.conv3(out_conv2a) + + # FlownetC bottom input stream + out_conv1b = self.conv1(x2) + + out_conv2b = self.conv2(out_conv1b) + out_conv3b = self.conv3(out_conv2b) + + # Merge streams + if self.args.fp16: + out_corr = self.corr(tofp32()(out_conv3a), + tofp32()(out_conv3b)) # False + out_corr = tofp16()(out_corr) + else: + out_corr = self.corr(out_conv3a, out_conv3b) # False + out_corr = self.corr_activation(out_corr) + + # Redirect top input stream and concatenate + out_conv_redir = self.conv_redir(out_conv3a) + + in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1) + + # Merged conv layers + out_conv3_1 = self.conv3_1(in_conv3_1) + + out_conv4 = self.conv4_1(self.conv4(out_conv3_1)) + + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + + concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) + + flow5 = self.predict_flow5(concat5) + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) + + flow4 = self.predict_flow4(concat4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + concat3 = torch.cat((out_conv3_1, out_deconv3, flow4_up), 1) + + flow3 = self.predict_flow3(concat3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + concat2 = torch.cat((out_conv2a, out_deconv2, flow3_up), 1) + + flow2 = self.predict_flow2(concat2) + + if self.training: + return flow2, flow3, flow4, flow5, flow6 + else: + return flow2, diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_fusion.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..edddd2446d906bfc0b93df47b6f18a45ac42bc79 --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_fusion.py @@ -0,0 +1,82 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch +# with some modifications. +from torch.nn import init +import torch +import torch.nn as nn +from .submodules import conv, i_conv, predict_flow, deconv + + +class FlowNetFusion(nn.Module): + r"""FlowNet2 Fusion module. Check out the FlowNet2 paper for more details + https://arxiv.org/abs/1612.01925 + + Args: + args (obj): Network initialization arguments + use_batch_norm (bool): Use batch norm or not. Default is true. + """ + def __init__(self, args, use_batch_norm=True): + super(FlowNetFusion, self).__init__() + + self.use_batch_norm = use_batch_norm + self.conv0 = conv(self.use_batch_norm, 11, 64) + self.conv1 = conv(self.use_batch_norm, 64, 64, stride=2) + self.conv1_1 = conv(self.use_batch_norm, 64, 128) + self.conv2 = conv(self.use_batch_norm, 128, 128, stride=2) + self.conv2_1 = conv(self.use_batch_norm, 128, 128) + + self.deconv1 = deconv(128, 32) + self.deconv0 = deconv(162, 16) + + self.inter_conv1 = i_conv(self.use_batch_norm, 162, 32) + self.inter_conv0 = i_conv(self.use_batch_norm, 82, 16) + + self.predict_flow2 = predict_flow(128) + self.predict_flow1 = predict_flow(32) + self.predict_flow0 = predict_flow(16) + + self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + + def forward(self, x): + r""" + + Args: + x (tensor): Input tensors of concatenated images. + Returns: + flow2 (tensor): Output flow tensors. + """ + out_conv0 = self.conv0(x) + out_conv1 = self.conv1_1(self.conv1(out_conv0)) + out_conv2 = self.conv2_1(self.conv2(out_conv1)) + + flow2 = self.predict_flow2(out_conv2) + flow2_up = self.upsampled_flow2_to_1(flow2) + out_deconv1 = self.deconv1(out_conv2) + + concat1 = torch.cat((out_conv1, out_deconv1, flow2_up), 1) + out_interconv1 = self.inter_conv1(concat1) + flow1 = self.predict_flow1(out_interconv1) + flow1_up = self.upsampled_flow1_to_0(flow1) + out_deconv0 = self.deconv0(concat1) + + concat0 = torch.cat((out_conv0, out_deconv0, flow1_up), 1) + out_interconv0 = self.inter_conv0(concat0) + flow0 = self.predict_flow0(out_interconv0) + + return flow0 diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_s.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_s.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8a76c5a5d66354e07aad2522a782961a50c24c --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_s.py @@ -0,0 +1,121 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch +# with some modifications. +''' +Portions of this code copyright 2017, Clement Pinard +''' +from torch.nn import init +import torch +import torch.nn as nn +from .submodules import conv, predict_flow, deconv + + +class FlowNetS(nn.Module): + r"""FlowNet2 S module. Check out the FlowNet2 paper for more details + https://arxiv.org/abs/1612.01925 + + Args: + args (obj): Network initialization arguments + input_channels (int): Number of input channels. Default is 12. + use_batch_norm (bool): Use batch norm or not. Default is true. + """ + def __init__(self, args, input_channels=12, use_batch_norm=True): + super(FlowNetS, self).__init__() + + self.use_batch_norm = use_batch_norm + self.conv1 = conv( + self.use_batch_norm, + input_channels, + 64, + kernel_size=7, + stride=2) + self.conv2 = conv(self.use_batch_norm, 64, 128, kernel_size=5, stride=2) + self.conv3 = conv(self.use_batch_norm, 128, 256, kernel_size=5, + stride=2) + self.conv3_1 = conv(self.use_batch_norm, 256, 256) + self.conv4 = conv(self.use_batch_norm, 256, 512, stride=2) + self.conv4_1 = conv(self.use_batch_norm, 512, 512) + self.conv5 = conv(self.use_batch_norm, 512, 512, stride=2) + self.conv5_1 = conv(self.use_batch_norm, 512, 512) + self.conv6 = conv(self.use_batch_norm, 512, 1024, stride=2) + self.conv6_1 = conv(self.use_batch_norm, 1024, 1024) + + self.deconv5 = deconv(1024, 512) + self.deconv4 = deconv(1026, 256) + self.deconv3 = deconv(770, 128) + self.deconv2 = deconv(386, 64) + + self.predict_flow6 = predict_flow(1024) + self.predict_flow5 = predict_flow(1026) + self.predict_flow4 = predict_flow(770) + self.predict_flow3 = predict_flow(386) + self.predict_flow2 = predict_flow(194) + + self.upsampled_flow6_to_5 = nn.ConvTranspose2d( + 2, 2, 4, 2, 1, bias=False) + self.upsampled_flow5_to_4 = nn.ConvTranspose2d( + 2, 2, 4, 2, 1, bias=False) + self.upsampled_flow4_to_3 = nn.ConvTranspose2d( + 2, 2, 4, 2, 1, bias=False) + self.upsampled_flow3_to_2 = nn.ConvTranspose2d( + 2, 2, 4, 2, 1, bias=False) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear', + align_corners=False) + + def forward(self, x): + r""" + + Args: + x (tensor): Input tensors of concatenated images. + Returns: + flow2 (tensor): Output flow tensors. + """ + out_conv1 = self.conv1(x) + + out_conv2 = self.conv2(out_conv1) + out_conv3 = self.conv3_1(self.conv3(out_conv2)) + out_conv4 = self.conv4_1(self.conv4(out_conv3)) + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + + concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) + flow5 = self.predict_flow5(concat5) + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + + concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) + flow4 = self.predict_flow4(concat4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + + concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1) + flow3 = self.predict_flow3(concat3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + + concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1) + flow2 = self.predict_flow2(concat2) + + if self.training: + return flow2, flow3, flow4, flow5, flow6 + else: + return flow2, diff --git a/imaginaire/third_party/flow_net/flownet2/networks/flownet_sd.py b/imaginaire/third_party/flow_net/flownet2/networks/flownet_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4340347252a9591d7540689abaae821d759060 --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/networks/flownet_sd.py @@ -0,0 +1,121 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch +# with some modifications. +import torch +import torch.nn as nn +from .submodules import conv, i_conv, predict_flow, deconv +from torch.nn import init + + +class FlowNetSD(nn.Module): + r"""FlowNet2 SD module. Check out the FlowNet2 paper for more details + https://arxiv.org/abs/1612.01925 + + Args: + args (obj): Network initialization arguments + use_batch_norm (bool): Use batch norm or not. Default is true. + """ + def __init__(self, args, use_batch_norm=True): + super(FlowNetSD, self).__init__() + + self.use_batch_norm = use_batch_norm + self.conv0 = conv(self.use_batch_norm, 6, 64) + self.conv1 = conv(self.use_batch_norm, 64, 64, stride=2) + self.conv1_1 = conv(self.use_batch_norm, 64, 128) + self.conv2 = conv(self.use_batch_norm, 128, 128, stride=2) + self.conv2_1 = conv(self.use_batch_norm, 128, 128) + self.conv3 = conv(self.use_batch_norm, 128, 256, stride=2) + self.conv3_1 = conv(self.use_batch_norm, 256, 256) + self.conv4 = conv(self.use_batch_norm, 256, 512, stride=2) + self.conv4_1 = conv(self.use_batch_norm, 512, 512) + self.conv5 = conv(self.use_batch_norm, 512, 512, stride=2) + self.conv5_1 = conv(self.use_batch_norm, 512, 512) + self.conv6 = conv(self.use_batch_norm, 512, 1024, stride=2) + self.conv6_1 = conv(self.use_batch_norm, 1024, 1024) + + self.deconv5 = deconv(1024, 512) + self.deconv4 = deconv(1026, 256) + self.deconv3 = deconv(770, 128) + self.deconv2 = deconv(386, 64) + + self.inter_conv5 = i_conv(self.use_batch_norm, 1026, 512) + self.inter_conv4 = i_conv(self.use_batch_norm, 770, 256) + self.inter_conv3 = i_conv(self.use_batch_norm, 386, 128) + self.inter_conv2 = i_conv(self.use_batch_norm, 194, 64) + + self.predict_flow6 = predict_flow(1024) + self.predict_flow5 = predict_flow(512) + self.predict_flow4 = predict_flow(256) + self.predict_flow3 = predict_flow(128) + self.predict_flow2 = predict_flow(64) + + self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + + if isinstance(m, nn.ConvTranspose2d): + if m.bias is not None: + init.uniform_(m.bias) + init.xavier_uniform_(m.weight) + # init_deconv_bilinear(m.weight) + self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear', + align_corners=False) + + def forward(self, x): + r""" + + Args: + x (tensor): Input tensors of concatenated images. + Returns: + flow2 (tensor): Output flow tensors. + """ + out_conv0 = self.conv0(x) + out_conv1 = self.conv1_1(self.conv1(out_conv0)) + out_conv2 = self.conv2_1(self.conv2(out_conv1)) + + out_conv3 = self.conv3_1(self.conv3(out_conv2)) + out_conv4 = self.conv4_1(self.conv4(out_conv3)) + out_conv5 = self.conv5_1(self.conv5(out_conv4)) + out_conv6 = self.conv6_1(self.conv6(out_conv5)) + + flow6 = self.predict_flow6(out_conv6) + flow6_up = self.upsampled_flow6_to_5(flow6) + out_deconv5 = self.deconv5(out_conv6) + + concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) + out_interconv5 = self.inter_conv5(concat5) + flow5 = self.predict_flow5(out_interconv5) + + flow5_up = self.upsampled_flow5_to_4(flow5) + out_deconv4 = self.deconv4(concat5) + + concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) + out_interconv4 = self.inter_conv4(concat4) + flow4 = self.predict_flow4(out_interconv4) + flow4_up = self.upsampled_flow4_to_3(flow4) + out_deconv3 = self.deconv3(concat4) + + concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1) + out_interconv3 = self.inter_conv3(concat3) + flow3 = self.predict_flow3(out_interconv3) + flow3_up = self.upsampled_flow3_to_2(flow3) + out_deconv2 = self.deconv2(concat3) + + concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1) + out_interconv2 = self.inter_conv2(concat2) + flow2 = self.predict_flow2(out_interconv2) + + if self.training: + return flow2, flow3, flow4, flow5, flow6 + else: + return flow2, diff --git a/imaginaire/third_party/flow_net/flownet2/networks/submodules.py b/imaginaire/third_party/flow_net/flownet2/networks/submodules.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ab504401c1473bcc52ae4a1029afd74eed6d11 --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/networks/submodules.py @@ -0,0 +1,113 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# The file is duplicated from https://github.com/NVIDIA/flownet2-pytorch +# with some modifications. +import torch +import torch.nn as nn +import numpy as np + + +def conv(use_batch_norm, in_planes, out_planes, kernel_size=3, stride=1): + if use_batch_norm: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, + stride=stride, padding=(kernel_size - 1) // 2, + bias=False), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.1, inplace=True) + ) + else: + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=( + kernel_size - 1) // 2, + bias=True), + nn.LeakyReLU( + 0.1, + inplace=True)) + + +def i_conv(use_batch_norm, in_planes, out_planes, kernel_size=3, stride=1, + bias=True): + if use_batch_norm: + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=( + kernel_size - 1) // 2, + bias=bias), + nn.BatchNorm2d(out_planes), + ) + else: + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=( + kernel_size - + 1) // + 2, + bias=bias), + ) + + +def predict_flow(in_planes): + return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, + bias=True) + + +def deconv(in_planes, out_planes): + return nn.Sequential( + nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, + padding=1, bias=True), + nn.LeakyReLU(0.1, inplace=True) + ) + + +class tofp16(nn.Module): + def __init__(self): + super(tofp16, self).__init__() + + def forward(self, input): + return input.half() + + +class tofp32(nn.Module): + def __init__(self): + super(tofp32, self).__init__() + + def forward(self, input): + return input.float() + + +def init_deconv_bilinear(weight): + f_shape = weight.size() + heigh, width = f_shape[-2], f_shape[-1] + f = np.ceil(width / 2.0) + c = (2 * f - 1 - f % 2) / (2.0 * f) + bilinear = np.zeros([heigh, width]) + for x in range(width): + for y in range(heigh): + value = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) + bilinear[x, y] = value + weight.data.fill_(0.) + for i in range(f_shape[0]): + for j in range(f_shape[1]): + weight.data[i, j, :, :] = torch.from_numpy(bilinear) + + +def save_grad(grads, name): + def hook(grad): + grads[name] = grad + return hook diff --git a/imaginaire/third_party/flow_net/flownet2/utils/__init__.py b/imaginaire/third_party/flow_net/flownet2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/third_party/flow_net/flownet2/utils/flow_utils.py b/imaginaire/third_party/flow_net/flownet2/utils/flow_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7bffeba58a93de4379c8e9ed54af58b56baa13eb --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/utils/flow_utils.py @@ -0,0 +1,219 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import numpy as np +import matplotlib.pyplot as plt +import os.path + +TAG_CHAR = np.array([202021.25], np.float32) + + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/ + # reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures + # (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is + # (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + + +def writeFlow(filename, uv, v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in deep. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:, :, 0] + v = uv[:, :, 1] + else: + u = uv + + assert(u.shape == v.shape) + height, width = u.shape + f = open(filename, 'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width * nBands)) + tmp[:, np.arange(width) * 2] = u + tmp[:, np.arange(width) * 2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +# ref: https://github.com/sampepose/flownet2-tf/ +# blob/18f87081db44939414fc4a48834f9e0da3e69f4c/src/flowlib.py#L240 +def visulize_flow_file(flow_filename, save_dir=None): + flow_data = readFlow(flow_filename) + img = flow2img(flow_data) + # plt.imshow(img) + # plt.show() + if save_dir: + idx = flow_filename.rfind("/") + 1 + plt.imsave(os.path.join(save_dir, "%s-vis.png" % + flow_filename[idx:-4]), img) + + +def flow2img(flow_data): + """ + convert optical flow into color image + :param flow_data: + :return: color image + """ + # print(flow_data.shape) + # print(type(flow_data)) + u = flow_data[:, :, 0] + v = flow_data[:, :, 1] + + UNKNOW_FLOW_THRESHOLD = 1e7 + pr1 = abs(u) > UNKNOW_FLOW_THRESHOLD + pr2 = abs(v) > UNKNOW_FLOW_THRESHOLD + idx_unknown = (pr1 | pr2) + u[idx_unknown] = v[idx_unknown] = 0 + + # get max value in each direction + maxu = -999. + maxv = -999. + minu = 999. + minv = 999. + maxu = max(maxu, np.max(u)) + maxv = max(maxv, np.max(v)) + minu = min(minu, np.min(u)) + minv = min(minv, np.min(v)) + + rad = np.sqrt(u ** 2 + v ** 2) + maxrad = max(-1, np.max(rad)) + u = u / maxrad + np.finfo(float).eps + v = v / maxrad + np.finfo(float).eps + + img = compute_color(u, v) + + idx = np.repeat(idx_unknown[:, :, np.newaxis], 3, axis=2) + img[idx] = 0 + + return np.uint8(img) + + +def compute_color(u, v): + """ + compute optical flow color map + :param u: horizontal optical flow + :param v: vertical optical flow + :return: + """ + + height, width = u.shape + img = np.zeros((height, width, 3)) + + NAN_idx = np.isnan(u) | np.isnan(v) + u[NAN_idx] = v[NAN_idx] = 0 + + colorwheel = make_color_wheel() + ncols = np.size(colorwheel, 0) + + rad = np.sqrt(u ** 2 + v ** 2) + + a = np.arctan2(-v, -u) / np.pi + + fk = (a + 1) / 2 * (ncols - 1) + 1 + + k0 = np.floor(fk).astype(int) + + k1 = k0 + 1 + k1[k1 == ncols + 1] = 1 + f = fk - k0 + + for i in range(0, np.size(colorwheel, 1)): + tmp = colorwheel[:, i] + col0 = tmp[k0 - 1] / 255 + col1 = tmp[k1 - 1] / 255 + col = (1 - f) * col0 + f * col1 + + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + notidx = np.logical_not(idx) + + col[notidx] *= 0.75 + img[:, :, i] = np.uint8(np.floor(255 * col * (1 - NAN_idx))) + + return img + + +def make_color_wheel(): + """ + Generate color wheel according Middlebury color code + :return: Color wheel + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + + colorwheel = np.zeros([ncols, 3]) + + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) + col += RY + + # YG + colorwheel[col:col + YG, 0] = 255 - \ + np.transpose(np.floor(255 * np.arange(0, YG) / YG)) + colorwheel[col:col + YG, 1] = 255 + col += YG + + # GC + colorwheel[col:col + GC, 1] = 255 + colorwheel[col:col + GC, + 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) + col += GC + + # CB + colorwheel[col:col + CB, 1] = 255 - \ + np.transpose(np.floor(255 * np.arange(0, CB) / CB)) + colorwheel[col:col + CB, 2] = 255 + col += CB + + # BM + colorwheel[col:col + BM, 2] = 255 + colorwheel[col:col + BM, + 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) + col += + BM + + # MR + colorwheel[col:col + MR, 2] = 255 - \ + np.transpose(np.floor(255 * np.arange(0, MR) / MR)) + colorwheel[col:col + MR, 0] = 255 + + return colorwheel diff --git a/imaginaire/third_party/flow_net/flownet2/utils/frame_utils.py b/imaginaire/third_party/flow_net/flownet2/utils/frame_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1e3f83d6179afcb266e5923af5dd54fc3dd3fc --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/utils/frame_utils.py @@ -0,0 +1,23 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import numpy as np +from os.path import splitext +from scipy.misc import imread +from . import flow_utils + + +def read_gen(file_name): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + im = imread(file_name) + if im.shape[2] > 3: + return im[:, :, :3] + else: + return im + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return flow_utils.readFlow(file_name).astype(np.float32) + return [] diff --git a/imaginaire/third_party/flow_net/flownet2/utils/param_utils.py b/imaginaire/third_party/flow_net/flownet2/utils/param_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b084c9c35b957888acea86987ab25073e43feac0 --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/utils/param_utils.py @@ -0,0 +1,275 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +import torch.nn as nn +import numpy as np + + +def parse_flownetc(modules, weights, biases): + keys = [ + 'conv1', + 'conv2', + 'conv3', + 'conv_redir', + 'conv3_1', + 'conv4', + 'conv4_1', + 'conv5', + 'conv5_1', + 'conv6', + 'conv6_1', + + 'deconv5', + 'deconv4', + 'deconv3', + 'deconv2', + + 'Convolution1', + 'Convolution2', + 'Convolution3', + 'Convolution4', + 'Convolution5', + + 'upsample_flow6to5', + 'upsample_flow5to4', + 'upsample_flow4to3', + 'upsample_flow3to2', + + ] + i = 0 + for m in modules: + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + weight = weights[keys[i]].copy() + bias = biases[keys[i]].copy() + if keys[i] == 'conv1': + m.weight.data[:, :, :, :] = torch.from_numpy( + np.flip(weight, axis=1).copy()) + m.bias.data[:] = torch.from_numpy(bias) + else: + m.weight.data[:, :, :, :] = torch.from_numpy(weight) + m.bias.data[:] = torch.from_numpy(bias) + + i = i + 1 + return + + +def parse_flownets(modules, weights, biases, param_prefix='net2_'): + keys = [ + 'conv1', + 'conv2', + 'conv3', + 'conv3_1', + 'conv4', + 'conv4_1', + 'conv5', + 'conv5_1', + 'conv6', + 'conv6_1', + + 'deconv5', + 'deconv4', + 'deconv3', + 'deconv2', + + 'predict_conv6', + 'predict_conv5', + 'predict_conv4', + 'predict_conv3', + 'predict_conv2', + + 'upsample_flow6to5', + 'upsample_flow5to4', + 'upsample_flow4to3', + 'upsample_flow3to2', + ] + for i, k in enumerate(keys): + if 'upsample' in k: + keys[i] = param_prefix + param_prefix + k + else: + keys[i] = param_prefix + k + i = 0 + for m in modules: + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + weight = weights[keys[i]].copy() + bias = biases[keys[i]].copy() + if keys[i] == param_prefix + 'conv1': + m.weight.data[:, 0:3, :, :] = torch.from_numpy( + np.flip(weight[:, 0:3, :, :], axis=1).copy()) + m.weight.data[:, 3:6, :, :] = torch.from_numpy( + np.flip(weight[:, 3:6, :, :], axis=1).copy()) + m.weight.data[:, 6:9, :, :] = torch.from_numpy( + np.flip(weight[:, 6:9, :, :], axis=1).copy()) + m.weight.data[:, 9::, :, :] = torch.from_numpy( + weight[:, 9:, :, :].copy()) + if m.bias is not None: + m.bias.data[:] = torch.from_numpy(bias) + else: + m.weight.data[:, :, :, :] = torch.from_numpy(weight) + if m.bias is not None: + m.bias.data[:] = torch.from_numpy(bias) + i = i + 1 + return + + +def parse_flownetsonly(modules, weights, biases, param_prefix=''): + keys = [ + 'conv1', + 'conv2', + 'conv3', + 'conv3_1', + 'conv4', + 'conv4_1', + 'conv5', + 'conv5_1', + 'conv6', + 'conv6_1', + + 'deconv5', + 'deconv4', + 'deconv3', + 'deconv2', + + 'Convolution1', + 'Convolution2', + 'Convolution3', + 'Convolution4', + 'Convolution5', + + 'upsample_flow6to5', + 'upsample_flow5to4', + 'upsample_flow4to3', + 'upsample_flow3to2', + ] + for i, k in enumerate(keys): + if 'upsample' in k: + keys[i] = param_prefix + param_prefix + k + else: + keys[i] = param_prefix + k + i = 0 + for m in modules: + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + weight = weights[keys[i]].copy() + bias = biases[keys[i]].copy() + if keys[i] == param_prefix + 'conv1': + # print ("%s :"%(keys[i]), m.weight.size(), m.bias.size(), + # tf_w[keys[i]].shape[::-1]) + m.weight.data[:, 0:3, :, :] = torch.from_numpy( + np.flip(weight[:, 0:3, :, :], axis=1).copy()) + m.weight.data[:, 3:6, :, :] = torch.from_numpy( + np.flip(weight[:, 3:6, :, :], axis=1).copy()) + if m.bias is not None: + m.bias.data[:] = torch.from_numpy(bias) + else: + m.weight.data[:, :, :, :] = torch.from_numpy(weight) + if m.bias is not None: + m.bias.data[:] = torch.from_numpy(bias) + i = i + 1 + return + + +def parse_flownetsd(modules, weights, biases, param_prefix='netsd_'): + keys = [ + 'conv0', + 'conv1', + 'conv1_1', + 'conv2', + 'conv2_1', + 'conv3', + 'conv3_1', + 'conv4', + 'conv4_1', + 'conv5', + 'conv5_1', + 'conv6', + 'conv6_1', + + 'deconv5', + 'deconv4', + 'deconv3', + 'deconv2', + + 'interconv5', + 'interconv4', + 'interconv3', + 'interconv2', + + 'Convolution1', + 'Convolution2', + 'Convolution3', + 'Convolution4', + 'Convolution5', + + 'upsample_flow6to5', + 'upsample_flow5to4', + 'upsample_flow4to3', + 'upsample_flow3to2', + ] + for i, k in enumerate(keys): + keys[i] = param_prefix + k + + i = 0 + for m in modules: + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + weight = weights[keys[i]].copy() + bias = biases[keys[i]].copy() + if keys[i] == param_prefix + 'conv0': + m.weight.data[:, 0:3, :, :] = torch.from_numpy( + np.flip(weight[:, 0:3, :, :], axis=1).copy()) + m.weight.data[:, 3:6, :, :] = torch.from_numpy( + np.flip(weight[:, 3:6, :, :], axis=1).copy()) + if m.bias is not None: + m.bias.data[:] = torch.from_numpy(bias) + else: + m.weight.data[:, :, :, :] = torch.from_numpy(weight) + if m.bias is not None: + m.bias.data[:] = torch.from_numpy(bias) + i = i + 1 + + return + + +def parse_flownetfusion(modules, weights, biases, param_prefix='fuse_'): + keys = [ + 'conv0', + 'conv1', + 'conv1_1', + 'conv2', + 'conv2_1', + + 'deconv1', + 'deconv0', + + 'interconv1', + 'interconv0', + + '_Convolution5', + '_Convolution6', + '_Convolution7', + + 'upsample_flow2to1', + 'upsample_flow1to0', + ] + for i, k in enumerate(keys): + keys[i] = param_prefix + k + + i = 0 + for m in modules: + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + weight = weights[keys[i]].copy() + bias = biases[keys[i]].copy() + if keys[i] == param_prefix + 'conv0': + m.weight.data[:, 0:3, :, :] = torch.from_numpy( + np.flip(weight[:, 0:3, :, :], axis=1).copy()) + m.weight.data[:, 3::, :, :] = torch.from_numpy( + weight[:, 3:, :, :].copy()) + if m.bias is not None: + m.bias.data[:] = torch.from_numpy(bias) + else: + m.weight.data[:, :, :, :] = torch.from_numpy(weight) + if m.bias is not None: + m.bias.data[:] = torch.from_numpy(bias) + i = i + 1 + + return diff --git a/imaginaire/third_party/flow_net/flownet2/utils/tools.py b/imaginaire/third_party/flow_net/flownet2/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d6c9208117824c712aa54e3a3871273b851de63c --- /dev/null +++ b/imaginaire/third_party/flow_net/flownet2/utils/tools.py @@ -0,0 +1,194 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os +import time +import math +import subprocess +import shutil +from os.path import join +import numpy as np +from inspect import isclass +from pytz import timezone +from datetime import datetime +import inspect +import torch + + +def datestr(): + pacific = timezone('US/Pacific') + now = datetime.now(pacific) + return '{}{:02}{:02}_{:02}{:02}'.format( + now.year, now.month, now.day, now.hour, now.minute) + + +def module_to_dict(module, exclude=[]): + return dict([(x, getattr(module, x)) for x in dir(module) + if isclass(getattr(module, x)) + and x not in exclude + and getattr(module, x) not in exclude]) + + +class TimerBlock: + def __init__(self, title): + print(("{}".format(title))) + + def __enter__(self): + self.start = time.clock() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.end = time.clock() + self.interval = self.end - self.start + + if exc_type is not None: + self.log("Operation failed\n") + else: + self.log("Operation finished\n") + + def log(self, string): + duration = time.clock() - self.start + units = 's' + if duration > 60: + duration = duration / 60. + units = 'm' + print((" [{:.3f}{}] {}".format(duration, units, string))) + + def log2file(self, fid, string): + fid = open(fid, 'a') + fid.write("%s\n" % (string)) + fid.close() + + +def add_arguments_for_module( + parser, + module, + argument_for_class, + default, + skip_params=[], + parameter_defaults={}): + argument_group = parser.add_argument_group(argument_for_class.capitalize()) + + module_dict = module_to_dict(module) + argument_group.add_argument( + '--' + argument_for_class, + type=str, + default=default, + choices=list( + module_dict.keys())) + + args, unknown_args = parser.parse_known_args() + class_obj = module_dict[vars(args)[argument_for_class]] + + argspec = inspect.getargspec(class_obj.__init__) + + defaults = argspec.defaults[::-1] if argspec.defaults else None + + args = argspec.args[::-1] + for i, arg in enumerate(args): + cmd_arg = '{}_{}'.format(argument_for_class, arg) + if arg not in skip_params + ['self', 'args']: + if arg in list(parameter_defaults.keys()): + argument_group.add_argument( + '--{}'.format(cmd_arg), + type=type( + parameter_defaults[arg]), + default=parameter_defaults[arg]) + elif (defaults is not None and i < len(defaults)): + argument_group.add_argument( + '--{}'.format(cmd_arg), + type=type( + defaults[i]), + default=defaults[i]) + else: + print(("[Warning]: non-default argument '{}' " + "detected on class '{}'. This argument " + "cannot be modified via the command line" + .format(arg, module.__class__.__name__))) + # We don't have a good way of dealing with + # inferring the type of the argument + # TODO: try creating a custom action and using ast's infer type? + # else: + # argument_group.add_argument('--{}'.format( + # cmd_arg), required=True) + + +def kwargs_from_args(args, argument_for_class): + argument_for_class = argument_for_class + '_' + return {key[len(argument_for_class):]: value for key, value in list(vars( + args).items()) if + argument_for_class in key and key != argument_for_class + 'class'} + + +def format_dictionary_of_losses(labels, values): + try: + string = ', '.join([('{}: {:' + + ('.3f' if value >= 0.001 else '.1e') + + '}').format(name, value) for name, value in + zip(labels, values)]) + except (TypeError, ValueError) as e: + print((list(zip(labels, values)))) + string = '[Log Error] ' + str(e) + + return string + + +class IteratorTimer(): + def __init__(self, iterable): + self.iterable = iterable + self.iterator = self.iterable.__iter__() + + def __iter__(self): + return self + + def __len__(self): + return len(self.iterable) + + def __next__(self): + start = time.time() + n = next(self.iterator) + self.last_duration = (time.time() - start) + return n + + next = __next__ + + +def gpumemusage(): + gpu_mem = subprocess.check_output( + "nvidia-smi | grep MiB | cut -f 3 -d '|'", + shell=True).replace( + ' ', + '').replace( + '\n', + '').replace( + 'i', + '') + all_stat = [float(a) for a in gpu_mem.replace('/', '').split('MB')[:-1]] + + gpu_mem = '' + for i in range(len(all_stat) / 2): + curr, tot = all_stat[2 * i], all_stat[2 * i + 1] + util = "%1.2f" % (100 * curr / tot) + '%' + cmem = str(int(math.ceil(curr / 1024.))) + 'GB' + gmem = str(int(math.ceil(tot / 1024.))) + 'GB' + gpu_mem += util + '--' + join(cmem, gmem) + ' ' + return gpu_mem + + +def update_hyperparameter_schedule(args, epoch, global_iteration, optimizer): + if args.schedule_lr_frequency > 0: + for param_group in optimizer.param_groups: + if (global_iteration + 1) % args.schedule_lr_frequency == 0: + param_group['lr'] /= float(args.schedule_lr_fraction) + param_group['lr'] = float( + np.maximum(param_group['lr'], 0.000001)) + + +def save_checkpoint(state, is_best, path, prefix, + filename='checkpoint.pth.tar'): + prefix_save = os.path.join(path, prefix) + name = prefix_save + '_' + filename + torch.save(state, name) + if is_best: + shutil.copyfile(name, prefix_save + '_model_best.pth.tar') diff --git a/imaginaire/third_party/resample2d/resample2d.py b/imaginaire/third_party/resample2d/resample2d.py new file mode 100644 index 0000000000000000000000000000000000000000..cbdea3fa9941894090aa124adda1b62e1ea5e012 --- /dev/null +++ b/imaginaire/third_party/resample2d/resample2d.py @@ -0,0 +1,62 @@ +# flake8: noqa +from torch.nn.modules.module import Module +from torch.autograd import Function, Variable +from torch.cuda.amp import autocast +import resample2d_cuda + + +class Resample2dFunction(Function): + + @staticmethod + # def forward(ctx, input1, input2, kernel_size=1, bilinear=True): + def forward(ctx, input1, input2, kernel_size=1): + assert input1.is_contiguous() + assert input2.is_contiguous() + + ctx.save_for_backward(input1, input2) + ctx.kernel_size = kernel_size + ctx.bilinear = True + + _, d, _, _ = input1.size() + b, _, h, w = input2.size() + output = input1.new(b, d, h, w).zero_() + + resample2d_cuda.forward(input1, input2, output, kernel_size) + + return output + + @staticmethod + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + assert grad_output.is_contiguous() + + input1, input2 = ctx.saved_tensors + + grad_input1 = Variable(input1.new(input1.size()).zero_()) + grad_input2 = Variable(input1.new(input2.size()).zero_()) + + # resample2d_cuda.backward(input1, input2, grad_output.data, + # grad_input1.data, grad_input2.data, + # ctx.kernel_size, ctx.bilinear) + resample2d_cuda.backward(input1, input2, grad_output.data, + grad_input1.data, grad_input2.data, + ctx.kernel_size) + + return grad_input1, grad_input2, None, None + + +class Resample2d(Module): + + def __init__(self, kernel_size=1, bilinear=True): + super(Resample2d, self).__init__() + self.kernel_size = kernel_size + self.bilinear = bilinear + + @autocast(False) + def forward(self, input1, input2): + input1, input2 = input1.float(), input2.float() + input1_c = input1.contiguous() + # return Resample2dFunction.apply( + # input1_c, input2, self.kernel_size, self.bilinear) + return Resample2dFunction.apply( + input1_c, input2, self.kernel_size) \ No newline at end of file diff --git a/imaginaire/third_party/resample2d/setup.py b/imaginaire/third_party/resample2d/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..54c14d9743bf514e01661bd1bc307a4e3f8986fe --- /dev/null +++ b/imaginaire/third_party/resample2d/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + + +cuda_version = os.getenv('CUDA_VERSION') +print('CUDA_VERSION: {}'.format(cuda_version)) + +nvcc_args = list() +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_50,code=sm_50') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_52,code=sm_52') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_60,code=sm_60') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_61,code=sm_61') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_70,code=sm_70') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_75,code=sm_75') +if cuda_version is not None: + if cuda_version >= '11.0': + nvcc_args.append('-gencode') + nvcc_args.append('arch=compute_80,code=sm_80') +nvcc_args.append('-Xcompiler') +nvcc_args.append('-Wall') +nvcc_args.append('-std=c++14') + +setup( + name='resample2d_cuda', + py_modules=['resample2d'], + ext_modules=[ + CUDAExtension('resample2d_cuda', [ + './src/resample2d_cuda.cc', + './src/resample2d_kernel.cu' + ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'], + 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/third_party/resample2d/src/resample2d_cuda.cc b/imaginaire/third_party/resample2d/src/resample2d_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..b330a06bc0f20fe82c275e9a784f7ed91faf7717 --- /dev/null +++ b/imaginaire/third_party/resample2d/src/resample2d_cuda.cc @@ -0,0 +1,34 @@ +#include +#include + +#include "resample2d_kernel.cuh" + +int resample2d_cuda_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size/*, bool bilinear*/) { + resample2d_kernel_forward(input1, input2, output, kernel_size/*, + bilinear*/); + return 1; +} + +int resample2d_cuda_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size/*, bool bilinear*/) { + resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, + gradInput2, kernel_size/*, bilinear*/); + return 1; +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)"); + m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)"); +} + diff --git a/imaginaire/third_party/resample2d/src/resample2d_kernel.cu b/imaginaire/third_party/resample2d/src/resample2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..654ca8e417b6d22ff623d72d19e24997a2db284c --- /dev/null +++ b/imaginaire/third_party/resample2d/src/resample2d_kernel.cu @@ -0,0 +1,328 @@ +#include +#include +#include + +#define CUDA_NUM_THREADS 512 +#define THREADS_PER_BLOCK 64 + +#define DIM0(TENSOR) ((TENSOR).x) +#define DIM1(TENSOR) ((TENSOR).y) +#define DIM2(TENSOR) ((TENSOR).z) +#define DIM3(TENSOR) ((TENSOR).w) + +#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) + +template +__global__ void kernel_resample2d_update_output(const int n, + const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + scalar_t* __restrict__ output, + const long4 output_size, const + long4 output_stride, int + kernel_size/*, bool bilinear*/) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + bool bilinear = true; + if (index >= n) { + return; + } + + scalar_t val = 0.0f; + + int dim_b = DIM0(output_size); + int dim_c = DIM1(output_size); + int dim_h = DIM2(output_size); + int dim_w = DIM3(output_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast(x) + dx; + scalar_t yf = static_cast(y) + dy; + scalar_t alpha = xf - floor(xf); // alpha + scalar_t beta = yf - floor(yf); // beta + + if (bilinear) { + int xL = max(min( int (floor(xf)), dim_w-1), 0); + int xR = max(min( int (floor(xf)+1), dim_w -1), 0); + int yT = max(min( int (floor(yf)), dim_h-1), 0); + int yB = max(min( int (floor(yf)+1), dim_h-1), 0); + + for (int fy = 0; fy < kernel_size; fy += 1) { + for (int fx = 0; fx < kernel_size; fx += 1) { + val += static_cast((1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx)); + val += static_cast((alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx)); + val += static_cast((1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx)); + val += static_cast((alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx)); + } + } + + output[index] = val; + } + else { + int xN = max(min( int (floor(xf + 0.5)), dim_w - 1), 0); + int yN = max(min( int (floor(yf + 0.5)), dim_h - 1), 0); + + output[index] = static_cast ( DIM3_INDEX(input1, b, c, yN, xN) ); + } + +} + + +template +__global__ void kernel_resample2d_backward_input1( + const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 + gradInput_stride, int kernel_size/*, bool bilinear*/) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + bool bilinear = true; + if (index >= n) { + return; + } + + int dim_b = DIM0(gradOutput_size); + int dim_c = DIM1(gradOutput_size); + int dim_h = DIM2(gradOutput_size); + int dim_w = DIM3(gradOutput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast(x) + dx; + scalar_t yf = static_cast(y) + dy; + scalar_t alpha = xf - int(xf); // alpha + scalar_t beta = yf - int(yf); // beta + + int idim_h = DIM2(input1_size); + int idim_w = DIM3(input1_size); + + int xL = max(min( int (floor(xf)), idim_w-1), 0); + int xR = max(min( int (floor(xf)+1), idim_w -1), 0); + int yT = max(min( int (floor(yf)), idim_h-1), 0); + int yB = max(min( int (floor(yf)+1), idim_h-1), 0); + + for (int fy = 0; fy < kernel_size; fy += 1) { + for (int fx = 0; fx < kernel_size; fx += 1) { + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)), (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)), (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)), (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); + } + } + +} + +template +__global__ void kernel_resample2d_backward_input2( + const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, + const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, + const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, + scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 + gradInput_stride, int kernel_size/*, bool bilinear*/) { + + int index = blockIdx.x * blockDim.x + threadIdx.x; + bool bilinear = true; + if (index >= n) { + return; + } + + scalar_t output = 0.0; + int kernel_rad = (kernel_size - 1)/2; + + int dim_b = DIM0(gradInput_size); + int dim_c = DIM1(gradInput_size); + int dim_h = DIM2(gradInput_size); + int dim_w = DIM3(gradInput_size); + int dim_chw = dim_c * dim_h * dim_w; + int dim_hw = dim_h * dim_w; + + int b = ( index / dim_chw ) % dim_b; + int c = ( index / dim_hw ) % dim_c; + int y = ( index / dim_w ) % dim_h; + int x = ( index ) % dim_w; + + int odim_c = DIM1(gradOutput_size); + + scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); + scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); + + scalar_t xf = static_cast(x) + dx; + scalar_t yf = static_cast(y) + dy; + + int xL = max(min( int (floor(xf)), dim_w-1), 0); + int xR = max(min( int (floor(xf)+1), dim_w -1), 0); + int yT = max(min( int (floor(yf)), dim_h-1), 0); + int yB = max(min( int (floor(yf)+1), dim_h-1), 0); + + if (c % 2) { + float gamma = 1 - (xf - floor(xf)); // alpha + for (int i = 0; i <= 2*kernel_rad; ++i) { + for (int j = 0; j <= 2*kernel_rad; ++j) { + for (int ch = 0; ch < odim_c; ++ch) { + output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); + output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); + output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); + output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); + } + } + } + } + else { + float gamma = 1 - (yf - floor(yf)); // alpha + for (int i = 0; i <= 2*kernel_rad; ++i) { + for (int j = 0; j <= 2*kernel_rad; ++j) { + for (int ch = 0; ch < odim_c; ++ch) { + output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); + output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); + output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); + output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); + } + } + } + + } + + gradInput[index] = output; + +} + +void resample2d_kernel_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size/*, + bool bilinear*/) { + + int n = output.numel(); + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); + const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); + + const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); + const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); + + // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF +// AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] { + + kernel_resample2d_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + input2.data(), + input2_size, + input2_stride, + output.data(), + output_size, + output_stride, + kernel_size/*, + bilinear*/); + +// })); + + // TODO: ATen-equivalent check + + // THCudaCheck(cudaGetLastError()); + +} + +void resample2d_kernel_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size/*, + bool bilinear*/) { + + int n = gradOutput.numel(); + + const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); + const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); + + const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); + const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); + + const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); + const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); + + const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); + const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); + +// AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] { + + kernel_resample2d_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + input2.data(), + input2_size, + input2_stride, + gradOutput.data(), + gradOutput_size, + gradOutput_stride, + gradInput1.data(), + gradInput1_size, + gradInput1_stride, + kernel_size/*, + bilinear*/ + ); + +// })); + + const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3)); + const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3)); + + n = gradInput2.numel(); + +// AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] { + + + kernel_resample2d_backward_input2<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( +//at::globalContext().getCurrentCUDAStream() >>>( + n, + input1.data(), + input1_size, + input1_stride, + input2.data(), + input2_size, + input2_stride, + gradOutput.data(), + gradOutput_size, + gradOutput_stride, + gradInput2.data(), + gradInput2_size, + gradInput2_stride, + kernel_size/*, + bilinear*/ + ); + +// })); + + // TODO: Use the ATen equivalent to get last error + + // THCudaCheck(cudaGetLastError()); + +} diff --git a/imaginaire/third_party/resample2d/src/resample2d_kernel.cuh b/imaginaire/third_party/resample2d/src/resample2d_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..3a815269a562e762cd7bd0c73af21d468d4eb2fd --- /dev/null +++ b/imaginaire/third_party/resample2d/src/resample2d_kernel.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include + +void resample2d_kernel_forward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& output, + int kernel_size/*, + bool bilinear*/); + +void resample2d_kernel_backward( + at::Tensor& input1, + at::Tensor& input2, + at::Tensor& gradOutput, + at::Tensor& gradInput1, + at::Tensor& gradInput2, + int kernel_size/*, + bool bilinear*/); \ No newline at end of file diff --git a/imaginaire/third_party/upfirdn2d/__init__.py b/imaginaire/third_party/upfirdn2d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c92bf9c45932a8578f64ef83dc0b067ebd27ca0 --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import BlurUpsample, BlurDownsample, Blur + +__all__ = ['BlurUpsample', 'BlurDownsample', 'Blur'] diff --git a/imaginaire/third_party/upfirdn2d/__pycache__/__init__.cpython-38.pyc b/imaginaire/third_party/upfirdn2d/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaa606ee083d81fe1526c72e9939d485d771c19a Binary files /dev/null and b/imaginaire/third_party/upfirdn2d/__pycache__/__init__.cpython-38.pyc differ diff --git a/imaginaire/third_party/upfirdn2d/__pycache__/upfirdn2d.cpython-38.pyc b/imaginaire/third_party/upfirdn2d/__pycache__/upfirdn2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf77a6379ef867beeb0dd391f94950ab0d4b6c91 Binary files /dev/null and b/imaginaire/third_party/upfirdn2d/__pycache__/upfirdn2d.cpython-38.pyc differ diff --git a/imaginaire/third_party/upfirdn2d/setup.py b/imaginaire/third_party/upfirdn2d/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d69207b2daaeaebaa81e5dfe3c1001656fe4248a --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/setup.py @@ -0,0 +1,43 @@ +# flake8: noqa +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os + + +cuda_version = os.getenv('CUDA_VERSION') +print('CUDA_VERSION: {}'.format(cuda_version)) + +nvcc_args = list() +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_50,code=sm_50') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_52,code=sm_52') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_60,code=sm_60') +# nvcc_args.append('-gencode') +# nvcc_args.append('arch=compute_61,code=sm_61') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_70,code=sm_70') +nvcc_args.append('-gencode') +nvcc_args.append('arch=compute_75,code=sm_75') +if cuda_version is not None: + if cuda_version >= '11.0': + nvcc_args.append('-gencode') + nvcc_args.append('arch=compute_80,code=sm_80') +nvcc_args.append('-Xcompiler') +nvcc_args.append('-Wall') +nvcc_args.append('-std=c++14') + +setup( + name='upfirdn2d_cuda', + py_modules=['upfirdn2d'], + ext_modules=[ + CUDAExtension('upfirdn2d_cuda', [ + './src/upfirdn2d_cuda.cc', + './src/upfirdn2d_cuda_kernel.cu' + ], extra_compile_args={'cxx': ['-Wall', '-std=c++14'], + 'nvcc': nvcc_args}) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.cc b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.cc new file mode 100644 index 0000000000000000000000000000000000000000..65df7a9ad78e4f6f7560feed79048983f60e8add --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include +#include "upfirdn2d_cuda.h" + +//------------------------------------------------------------------------ + +static torch::Tensor upfirdn2d_cuda(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) +{ + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda_kernel", [&] + { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = dim3( + ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, + p.launchMajor); + } + else // small + { + blockSize = dim3(256, 1, 1); + gridSize = dim3( + ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, + p.launchMajor); + } + + // Launch CUDA kernel. + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); + return y; +} + +//------------------------------------------------------------------------ + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("upfirdn2d_cuda", &upfirdn2d_cuda); +} + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.h b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..c9e2032bcac9d2abde7a75eea4d812da348afadd --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct upfirdn2d_kernel_params +{ + const void* x; + const float* f; + void* y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct upfirdn2d_kernel_spec +{ + void* kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda_kernel.cu b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..d7f8938f7ac1220d934fe6a357de543a452445e4 --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/src/upfirdn2d_cuda_kernel.cu @@ -0,0 +1,350 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include "upfirdn2d_cuda.h" + +//------------------------------------------------------------------------ +// Helpers. + +template struct InternalType; +template <> struct InternalType { typedef double scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; +template <> struct InternalType { typedef float scalar_t; }; + +static __device__ __forceinline__ int floor_div(int a, int b) +{ + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. + +template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) + filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) + { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) + { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) + filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) + { + for (int x = 0; x < w; x++) + { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) +{ + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) + { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) + { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) + { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) + { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) + v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) + { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) + { + scalar_t v = 0; + #pragma unroll + for (int y = 0; y < filterH / upy; y++) + #pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) +{ + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + + upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous + if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last + + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; + } + if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; + } + if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last + { + if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last + { + if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; + } + if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; + } + if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last + { + if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); + +//------------------------------------------------------------------------ diff --git a/imaginaire/third_party/upfirdn2d/upfirdn2d.py b/imaginaire/third_party/upfirdn2d/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..8548efe56653d6d8083f68d6e6617ba84b398d1e --- /dev/null +++ b/imaginaire/third_party/upfirdn2d/upfirdn2d.py @@ -0,0 +1,471 @@ +# flake8: noqa +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +"""Custom PyTorch ops for efficient resampling of 2D images.""" + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +import upfirdn2d_cuda + + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + assert fw >= 1 and fh >= 1 + return fw, fh + + +class BlurUpsample(nn.Module): + def __init__(self, + kernel=(1, 3, 3, 1), + factor=2, + padding_mode='zeros'): + super().__init__() + p = len(kernel) + px0 = (p + factor - 1) // 2 + px1 = (p - factor) // 2 + py0 = (p + factor - 1) // 2 + py1 = (p - factor) // 2 + + self.pad = [px0, px1, py0, py1] + self.factor = factor + self.register_buffer('kernel', setup_filter(kernel)) + self.kernel_1d = kernel + self.padding_mode = padding_mode + + def forward(self, x): + if self.padding_mode != 'zeros': + x = F.pad(x, list(self.pad) * 2, mode=self.padding_mode) + out = upfirdn2d( + x, self.kernel, up=self.factor, gain=self.factor ** 2) + else: + out = upfirdn2d( + x, self.kernel, up=self.factor, padding=self.pad, + gain=self.factor ** 2) + return out + + def extra_repr(self): + s = 'kernel={kernel_1d}, ' \ + 'padding_mode={padding_mode}, pad={pad}' + return s.format(**self.__dict__) + + +class BlurDownsample(nn.Module): + def __init__(self, kernel=(1, 3, 3, 1), factor=2, padding_mode='zeros'): + super().__init__() + p = len(kernel) + px0 = (p - factor + 1) // 2 + px1 = (p - factor) // 2 + py0 = (p - factor + 1) // 2 + py1 = (p - factor) // 2 + + self.pad = [px0, px1, py0, py1] + self.factor = factor + self.register_buffer('kernel', setup_filter(kernel)) + self.kernel_1d = kernel + self.padding_mode = padding_mode + + def forward(self, x): + if self.padding_mode != 'zeros': + x = F.pad(x, list(self.pad) * 2, mode=self.padding_mode) + out = upfirdn2d(x, self.kernel, down=self.factor) + else: + out = upfirdn2d(x, self.kernel, down=self.factor, padding=self.pad) + return out + + def extra_repr(self): + s = 'kernel={kernel_1d}, ' \ + 'padding_mode={padding_mode}, pad={pad}' + return s.format(**self.__dict__) + + +class Blur(nn.Module): + def __init__(self, + kernel=(1, 3, 3, 1), + pad=0, + padding_mode='zeros'): + super().__init__() + self.register_buffer('kernel', setup_filter(kernel)) + self.kernel_1d = kernel + self.padding_mode = padding_mode + self.pad = pad + + def forward(self, x): + if self.padding_mode != 'zeros': + x = F.pad(x, list(self.pad) * 2, mode=self.padding_mode) + out = upfirdn2d(x, self.kernel) + else: + out = upfirdn2d(x, self.kernel, padding=self.pad) + return out + + def extra_repr(self): + s = 'kernel={kernel_1d}, ' \ + 'padding_mode={padding_mode}, pad={pad}' + return s.format(**self.__dict__) + + +# ---------------------------------------------------------------------------- + +def setup_filter(f, device=torch.device('cpu'), normalize=True, + flip_filter=False, gain=1, separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically). + + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain ** (f.ndim / 2)) + f = f.to(device=device) + return f + + +# ---------------------------------------------------------------------------- + +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side (`padding`). + Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it + so that the footprint of all output pixels lies within the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to scipy.signal.upfirdn(). + The fused op is considerably more efficient than performing the same calculation + using standard PyTorch ops. It supports gradients of arbitrary order. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the upsampled image. Can be a single number + or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + assert isinstance(x, torch.Tensor) + assert impl in ['ref', 'cuda'] + if impl == 'cuda' and x.device.type == 'cuda': + return _upfirdn2d_cuda(up=up, down=down, padding=padding, + flip_filter=flip_filter, gain=gain).apply(x, f) + return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + + +# ---------------------------------------------------------------------------- + +def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. + """ + # Validate arguments. + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + assert f.dtype == torch.float32 and not f.requires_grad + batch_size, num_channels, in_height, in_width = x.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Upsample by inserting zeros. + x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), + max(pady1, 0)]) + x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), + max(-padx0, 0): x.shape[3] - max(-padx1, 0)] + + # Setup filter. + f = f * (gain ** (f.ndim / 2)) + f = f.to(x.dtype) + if not flip_filter: + f = f.flip(list(range(f.ndim))) + + # Convolve with the filter. + f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) + if f.ndim == 4: + x = F.conv2d(input=x, weight=f, groups=num_channels) + else: + x = F.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) + x = F.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + + +# ---------------------------------------------------------------------------- + +_upfirdn2d_cuda_cache = dict() + + +def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = ( + upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = upfirdn2d_cuda.upfirdn2d_cuda(y, f, upx, upy, downx, downy, padx0, + padx1, pady0, pady1, flip_filter, gain) + else: + y = upfirdn2d_cuda.upfirdn2d_cuda(y, f.unsqueeze(0), upx, 1, downx, 1, + padx0, padx1, 0, 0, flip_filter, + np.sqrt(gain)) + y = upfirdn2d_cuda.upfirdn2d_cuda(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, + pady0, pady1, flip_filter, np.sqrt(gain)) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + + +# ---------------------------------------------------------------------------- + +def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + + +# ---------------------------------------------------------------------------- + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl) + + +# ---------------------------------------------------------------------------- + +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, + impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + +# ---------------------------------------------------------------------------- diff --git a/imaginaire/trainers/__init__.py b/imaginaire/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/trainers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/trainers/base.py b/imaginaire/trainers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..461aed4b7f092ac6a376d1db7b28ef0fd646901f --- /dev/null +++ b/imaginaire/trainers/base.py @@ -0,0 +1,982 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import json +import os +import time + +import torch +import torchvision +import wandb +from torch.cuda.amp import GradScaler, autocast +from tqdm import tqdm + +from imaginaire.utils.distributed import is_master, master_only +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.io import save_pilimage_in_jpeg +from imaginaire.utils.meters import Meter +from imaginaire.utils.misc import to_cuda, to_device, requires_grad, to_channels_last +from imaginaire.utils.model_average import (calibrate_batch_norm_momentum, + reset_batch_norm) +from imaginaire.utils.visualization import tensor2pilimage + + +class BaseTrainer(object): + r"""Base trainer. We expect that all trainers inherit this class. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, + cfg, + net_G, + net_D, + opt_G, + opt_D, + sch_G, + sch_D, + train_data_loader, + val_data_loader): + super(BaseTrainer, self).__init__() + print('Setup trainer.') + + # Initialize models and data loaders. + self.cfg = cfg + self.net_G = net_G + if cfg.trainer.model_average_config.enabled: + # Two wrappers (DDP + model average). + self.net_G_module = self.net_G.module.module + else: + # One wrapper (DDP) + self.net_G_module = self.net_G.module + self.val_data_loader = val_data_loader + self.is_inference = train_data_loader is None + self.net_D = net_D + self.opt_G = opt_G + self.opt_D = opt_D + self.sch_G = sch_G + self.sch_D = sch_D + self.train_data_loader = train_data_loader + if self.cfg.trainer.channels_last: + self.net_G = self.net_G.to(memory_format=torch.channels_last) + self.net_D = self.net_D.to(memory_format=torch.channels_last) + + # Initialize amp. + if self.cfg.trainer.amp_config.enabled: + print("Using automatic mixed precision training.") + self.scaler_G = GradScaler(**vars(self.cfg.trainer.amp_config)) + self.scaler_D = GradScaler(**vars(self.cfg.trainer.amp_config)) + # In order to check whether the discriminator/generator has + # skipped the last parameter update due to gradient overflow. + self.last_step_count_G = 0 + self.last_step_count_D = 0 + self.skipped_G = False + self.skipped_D = False + + # Initialize data augmentation policy. + self.aug_policy = cfg.trainer.aug_policy + print("Augmentation policy: {}".format(self.aug_policy)) + + # Initialize loss functions. + # All loss names have weights. Some have criterion modules. + # Mapping from loss names to criterion modules. + self.criteria = torch.nn.ModuleDict() + # Mapping from loss names to loss weights. + self.weights = dict() + self.losses = dict(gen_update=dict(), dis_update=dict()) + self.gen_losses = self.losses['gen_update'] + self.dis_losses = self.losses['dis_update'] + self._init_loss(cfg) + for loss_name, loss_weight in self.weights.items(): + print("Loss {:<20} Weight {}".format(loss_name, loss_weight)) + if loss_name in self.criteria.keys() and \ + self.criteria[loss_name] is not None: + self.criteria[loss_name].to('cuda') + + if self.is_inference: + # The initialization steps below can be skipped during inference. + return + + # Initialize logging attributes. + self.current_iteration = 0 + self.current_epoch = 0 + self.start_iteration_time = None + self.start_epoch_time = None + self.elapsed_iteration_time = 0 + self.time_iteration = None + self.time_epoch = None + self.best_fid = None + if self.cfg.speed_benchmark: + self.accu_gen_forw_iter_time = 0 + self.accu_gen_loss_iter_time = 0 + self.accu_gen_back_iter_time = 0 + self.accu_gen_step_iter_time = 0 + self.accu_gen_avg_iter_time = 0 + self.accu_dis_forw_iter_time = 0 + self.accu_dis_loss_iter_time = 0 + self.accu_dis_back_iter_time = 0 + self.accu_dis_step_iter_time = 0 + + # Initialize tensorboard and hparams. + self._init_tensorboard() + self._init_hparams() + + # Initialize validation parameters. + self.val_sample_size = getattr(cfg.trainer, 'val_sample_size', 50000) + self.kid_num_subsets = getattr(cfg.trainer, 'kid_num_subsets', 10) + self.kid_subset_size = self.val_sample_size // self.kid_num_subsets + self.metrics_path = os.path.join(torch.hub.get_dir(), 'metrics') + self.best_metrics = {} + self.eval_networks = getattr(cfg.trainer, 'eval_network', ['clean_inception']) + if self.cfg.metrics_iter is None: + self.cfg.metrics_iter = self.cfg.snapshot_save_iter + if self.cfg.metrics_epoch is None: + self.cfg.metrics_epoch = self.cfg.snapshot_save_epoch + + # AWS credentials. + if hasattr(cfg, 'aws_credentials_file'): + with open(cfg.aws_credentials_file) as fin: + self.credentials = json.load(fin) + else: + self.credentials = None + + if 'TORCH_HOME' not in os.environ: + os.environ['TORCH_HOME'] = os.path.join( + os.environ['HOME'], ".cache") + + def _init_tensorboard(self): + r"""Initialize the tensorboard. Different algorithms might require + different performance metrics. Hence, custom tensorboard + initialization might be necessary. + """ + # Logging frequency: self.cfg.logging_iter + self.meters = {} + + # Logging frequency: self.cfg.snapshot_save_iter + self.metric_meters = {} + + # Logging frequency: self.cfg.image_display_iter + self.image_meter = Meter('images', reduce=False) + + def _init_hparams(self): + r"""Initialize a dictionary of hyperparameters that we want to monitor + in the HParams dashboard in tensorBoard. + """ + self.hparam_dict = {} + + def _write_tensorboard(self): + r"""Write values to tensorboard. By default, we will log the time used + per iteration, time used per epoch, generator learning rate, and + discriminator learning rate. We will log all the losses as well as + custom meters. + """ + # Logs that are shared by all models. + self._write_to_meters({'time/iteration': self.time_iteration, + 'time/epoch': self.time_epoch, + 'optim/gen_lr': self.sch_G.get_last_lr()[0], + 'optim/dis_lr': self.sch_D.get_last_lr()[0]}, + self.meters, + reduce=False) + # Logs for loss values. Different models have different losses. + self._write_loss_meters() + # Other custom logs. + self._write_custom_meters() + + def _write_loss_meters(self): + r"""Write all loss values to tensorboard.""" + for update, losses in self.losses.items(): + # update is 'gen_update' or 'dis_update'. + assert update == 'gen_update' or update == 'dis_update' + for loss_name, loss in losses.items(): + if loss is not None: + full_loss_name = update + '/' + loss_name + if full_loss_name not in self.meters.keys(): + # Create a new meter if it doesn't exist. + self.meters[full_loss_name] = Meter( + full_loss_name, reduce=True) + self.meters[full_loss_name].write(loss.item()) + + def _write_custom_meters(self): + r"""Dummy member function to be overloaded by the child class. + In the child class, you can write down whatever you want to track. + """ + pass + + @staticmethod + def _write_to_meters(data, meters, reduce=True): + r"""Write values to meters.""" + if reduce or is_master(): + for key, value in data.items(): + if key not in meters: + meters[key] = Meter(key, reduce=reduce) + meters[key].write(value) + + def _flush_meters(self, meters): + r"""Flush all meters using the current iteration.""" + for meter in meters.values(): + meter.flush(self.current_iteration) + + def _pre_save_checkpoint(self): + r"""Implement the things you want to do before saving a checkpoint. + For example, you can compute the K-mean features (pix2pixHD) before + saving the model weights to a checkpoint. + """ + pass + + def save_checkpoint(self, current_epoch, current_iteration): + r"""Save network weights, optimizer parameters, scheduler parameters + to a checkpoint. + """ + self._pre_save_checkpoint() + _save_checkpoint(self.cfg, + self.net_G, self.net_D, + self.opt_G, self.opt_D, + self.sch_G, self.sch_D, + current_epoch, current_iteration) + + def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True): + r"""Load network weights, optimizer parameters, scheduler parameters + from a checkpoint. + + Args: + cfg (obj): Global configuration. + checkpoint_path (str): Path to the checkpoint. + resume (bool or None): If not ``None``, will determine whether or + not to load optimizers in addition to network weights. + """ + if os.path.exists(checkpoint_path): + # If checkpoint_path exists, we will load its weights to + # initialize our network. + if resume is None: + resume = False + elif os.path.exists(os.path.join(cfg.logdir, 'latest_checkpoint.txt')): + # This is for resuming the training from the previously saved + # checkpoint. + fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt') + with open(fn, 'r') as f: + line = f.read().splitlines() + checkpoint_path = os.path.join(cfg.logdir, line[0].split(' ')[-1]) + if resume is None: + resume = True + else: + # checkpoint not found and not specified. We will train + # everything from scratch. + current_epoch = 0 + current_iteration = 0 + print('No checkpoint found.') + resume = False + return resume, current_epoch, current_iteration + # Load checkpoint + checkpoint = torch.load( + checkpoint_path, map_location=lambda storage, loc: storage) + current_epoch = 0 + current_iteration = 0 + if resume: + self.net_G.load_state_dict(checkpoint['net_G'], strict=self.cfg.trainer.strict_resume) + if not self.is_inference: + self.net_D.load_state_dict(checkpoint['net_D'], strict=self.cfg.trainer.strict_resume) + if 'opt_G' in checkpoint: + current_epoch = checkpoint['current_epoch'] + current_iteration = checkpoint['current_iteration'] + self.opt_G.load_state_dict(checkpoint['opt_G']) + self.opt_D.load_state_dict(checkpoint['opt_D']) + if load_sch: + self.sch_G.load_state_dict(checkpoint['sch_G']) + self.sch_D.load_state_dict(checkpoint['sch_D']) + else: + if self.cfg.gen_opt.lr_policy.iteration_mode: + self.sch_G.last_epoch = current_iteration + else: + self.sch_G.last_epoch = current_epoch + if self.cfg.dis_opt.lr_policy.iteration_mode: + self.sch_D.last_epoch = current_iteration + else: + self.sch_D.last_epoch = current_epoch + print('Load from: {}'.format(checkpoint_path)) + else: + print('Load network weights only.') + else: + try: + self.net_G.load_state_dict(checkpoint['net_G'], strict=self.cfg.trainer.strict_resume) + if 'net_D' in checkpoint: + self.net_D.load_state_dict(checkpoint['net_D'], strict=self.cfg.trainer.strict_resume) + except Exception: + if self.cfg.trainer.model_average_config.enabled: + net_G_module = self.net_G.module.module + else: + net_G_module = self.net_G.module + if hasattr(net_G_module, 'load_pretrained_network'): + net_G_module.load_pretrained_network(self.net_G, checkpoint['net_G']) + print('Load generator weights only.') + else: + raise ValueError('Checkpoint cannot be loaded.') + + print('Done with loading the checkpoint.') + return resume, current_epoch, current_iteration + + def start_of_epoch(self, current_epoch): + r"""Things to do before an epoch. + + Args: + current_epoch (int): Current number of epoch. + """ + self._start_of_epoch(current_epoch) + self.current_epoch = current_epoch + self.start_epoch_time = time.time() + + def start_of_iteration(self, data, current_iteration): + r"""Things to do before an iteration. + + Args: + data (dict): Data used for the current iteration. + current_iteration (int): Current number of iteration. + """ + data = self._start_of_iteration(data, current_iteration) + data = to_cuda(data) + if self.cfg.trainer.channels_last: + data = to_channels_last(data) + self.current_iteration = current_iteration + if not self.is_inference: + self.net_D.train() + self.net_G.train() + # torch.cuda.synchronize() + self.start_iteration_time = time.time() + return data + + def end_of_iteration(self, data, current_epoch, current_iteration): + r"""Things to do after an iteration. + + Args: + data (dict): Data used for the current iteration. + current_epoch (int): Current number of epoch. + current_iteration (int): Current number of iteration. + """ + self.current_iteration = current_iteration + self.current_epoch = current_epoch + # Update the learning rate policy for the generator if operating in the + # iteration mode. + if self.cfg.gen_opt.lr_policy.iteration_mode: + self.sch_G.step() + # Update the learning rate policy for the discriminator if operating in + # the iteration mode. + if self.cfg.dis_opt.lr_policy.iteration_mode: + self.sch_D.step() + + # Accumulate time + # torch.cuda.synchronize() + self.elapsed_iteration_time += time.time() - self.start_iteration_time + # Logging. + if current_iteration % self.cfg.logging_iter == 0: + ave_t = self.elapsed_iteration_time / self.cfg.logging_iter + self.time_iteration = ave_t + print('Iteration: {}, average iter time: ' + '{:6f}.'.format(current_iteration, ave_t)) + self.elapsed_iteration_time = 0 + + if self.cfg.speed_benchmark: + # Below code block only needed when analyzing computation + # bottleneck. + print('\tGenerator FWD time {:6f}'.format( + self.accu_gen_forw_iter_time / self.cfg.logging_iter)) + print('\tGenerator LOS time {:6f}'.format( + self.accu_gen_loss_iter_time / self.cfg.logging_iter)) + print('\tGenerator BCK time {:6f}'.format( + self.accu_gen_back_iter_time / self.cfg.logging_iter)) + print('\tGenerator STP time {:6f}'.format( + self.accu_gen_step_iter_time / self.cfg.logging_iter)) + print('\tGenerator AVG time {:6f}'.format( + self.accu_gen_avg_iter_time / self.cfg.logging_iter)) + + print('\tDiscriminator FWD time {:6f}'.format( + self.accu_dis_forw_iter_time / self.cfg.logging_iter)) + print('\tDiscriminator LOS time {:6f}'.format( + self.accu_dis_loss_iter_time / self.cfg.logging_iter)) + print('\tDiscriminator BCK time {:6f}'.format( + self.accu_dis_back_iter_time / self.cfg.logging_iter)) + print('\tDiscriminator STP time {:6f}'.format( + self.accu_dis_step_iter_time / self.cfg.logging_iter)) + + print('{:6f}'.format(ave_t)) + + self.accu_gen_forw_iter_time = 0 + self.accu_gen_loss_iter_time = 0 + self.accu_gen_back_iter_time = 0 + self.accu_gen_step_iter_time = 0 + self.accu_gen_avg_iter_time = 0 + self.accu_dis_forw_iter_time = 0 + self.accu_dis_loss_iter_time = 0 + self.accu_dis_back_iter_time = 0 + self.accu_dis_step_iter_time = 0 + + self._end_of_iteration(data, current_epoch, current_iteration) + + # Save everything to the checkpoint. + if current_iteration % self.cfg.snapshot_save_iter == 0: + if current_iteration >= self.cfg.snapshot_save_start_iter: + self.save_checkpoint(current_epoch, current_iteration) + + # Compute metrics. + if current_iteration % self.cfg.metrics_iter == 0: + self.save_image(self._get_save_path('images', 'jpg'), data) + self.write_metrics() + + # Compute image to be saved. + elif current_iteration % self.cfg.image_save_iter == 0: + self.save_image(self._get_save_path('images', 'jpg'), data) + elif current_iteration % self.cfg.image_display_iter == 0: + image_path = os.path.join(self.cfg.logdir, 'images', 'current.jpg') + self.save_image(image_path, data) + + # Logging. + self._write_tensorboard() + if current_iteration % self.cfg.logging_iter == 0: + # Write all logs to tensorboard. + self._flush_meters(self.meters) + + from torch.distributed import barrier + import torch.distributed as dist + if dist.is_initialized(): + barrier() + + def end_of_epoch(self, data, current_epoch, current_iteration): + r"""Things to do after an epoch. + + Args: + data (dict): Data used for the current iteration. + + current_epoch (int): Current number of epoch. + current_iteration (int): Current number of iteration. + """ + # Update the learning rate policy for the generator if operating in the + # epoch mode. + self.current_iteration = current_iteration + self.current_epoch = current_epoch + if not self.cfg.gen_opt.lr_policy.iteration_mode: + self.sch_G.step() + # Update the learning rate policy for the discriminator if operating + # in the epoch mode. + if not self.cfg.dis_opt.lr_policy.iteration_mode: + self.sch_D.step() + elapsed_epoch_time = time.time() - self.start_epoch_time + # Logging. + print('Epoch: {}, total time: {:6f}.'.format(current_epoch, + elapsed_epoch_time)) + self.time_epoch = elapsed_epoch_time + self._end_of_epoch(data, current_epoch, current_iteration) + + # Save everything to the checkpoint. + if current_iteration % self.cfg.snapshot_save_iter == 0: + if current_epoch >= self.cfg.snapshot_save_start_epoch: + self.save_checkpoint(current_epoch, current_iteration) + + # Compute metrics. + if current_iteration % self.cfg.metrics_iter == 0: + self.save_image(self._get_save_path('images', 'jpg'), data) + self.write_metrics() + + def pre_process(self, data): + r"""Custom data pre-processing function. Utilize this function if you + need to preprocess your data before sending it to the generator and + discriminator. + + Args: + data (dict): Data used for the current iteration. + """ + + def recalculate_batch_norm_statistics(self, data_loader, averaged=True): + r"""Update the statistics in the moving average model. + + Args: + data_loader (torch.utils.data.DataLoader): Data loader for + estimating the statistics. + averaged (Boolean): True/False, we recalculate batch norm statistics for EMA/regular + """ + if not self.cfg.trainer.model_average_config.enabled: + return + if averaged: + net_G = self.net_G.module.averaged_model + else: + net_G = self.net_G_module + model_average_iteration = \ + self.cfg.trainer.model_average_config.num_batch_norm_estimation_iterations + if model_average_iteration == 0: + return + with torch.no_grad(): + # Accumulate bn stats.. + net_G.train() + # Reset running stats. + net_G.apply(reset_batch_norm) + for cal_it, cal_data in enumerate(data_loader): + if cal_it >= model_average_iteration: + print('Done with {} iterations of updating batch norm ' + 'statistics'.format(model_average_iteration)) + break + cal_data = to_device(cal_data, 'cuda') + cal_data = self.pre_process(cal_data) + # Averaging over all batches + net_G.apply(calibrate_batch_norm_momentum) + net_G(cal_data) + + def save_image(self, path, data): + r"""Compute visualization images and save them to the disk. + + Args: + path (str): Location of the file. + data (dict): Data used for the current iteration. + """ + self.net_G.eval() + vis_images = self._get_visualizations(data) + if is_master() and vis_images is not None: + vis_images = torch.cat( + [img for img in vis_images if img is not None], dim=3).float() + vis_images = (vis_images + 1) / 2 + print('Save output images to {}'.format(path)) + vis_images.clamp_(0, 1) + os.makedirs(os.path.dirname(path), exist_ok=True) + image_grid = torchvision.utils.make_grid( + vis_images, nrow=1, padding=0, normalize=False) + if self.cfg.trainer.image_to_tensorboard: + self.image_meter.write_image(image_grid, self.current_iteration) + torchvision.utils.save_image(image_grid, path, nrow=1) + wandb.log({os.path.splitext(os.path.basename(path))[0]: [wandb.Image(path)]}) + + def write_metrics(self): + r"""Write metrics to the tensorboard.""" + cur_fid = self._compute_fid() + if cur_fid is not None: + if self.best_fid is not None: + self.best_fid = min(self.best_fid, cur_fid) + else: + self.best_fid = cur_fid + metric_dict = {'FID': cur_fid, 'best_FID': self.best_fid} + self._write_to_meters(metric_dict, self.metric_meters, reduce=False) + self._flush_meters(self.metric_meters) + + def _get_save_path(self, subdir, ext): + r"""Get the image save path. + + Args: + subdir (str): Sub-directory under the main directory for saving + the outputs. + ext (str): Filename extension for the image (e.g., jpg, png, ...). + Return: + (str): image filename to be used to save the visualization results. + """ + subdir_path = os.path.join(self.cfg.logdir, subdir) + if not os.path.exists(subdir_path): + os.makedirs(subdir_path, exist_ok=True) + return os.path.join( + subdir_path, 'epoch_{:05}_iteration_{:09}.{}'.format( + self.current_epoch, self.current_iteration, ext)) + + def _get_outputs(self, net_D_output, real=True): + r"""Return output values. Note that when the gan mode is relativistic. + It will do the difference before returning. + + Args: + net_D_output (dict): + real_outputs (tensor): Real output values. + fake_outputs (tensor): Fake output values. + real (bool): Return real or fake. + """ + + def _get_difference(a, b): + r"""Get difference between two lists of tensors or two tensors. + + Args: + a: list of tensors or tensor + b: list of tensors or tensor + """ + out = list() + for x, y in zip(a, b): + if isinstance(x, list): + res = _get_difference(x, y) + else: + res = x - y + out.append(res) + return out + + if real: + if self.cfg.trainer.gan_relativistic: + return _get_difference(net_D_output['real_outputs'], net_D_output['fake_outputs']) + else: + return net_D_output['real_outputs'] + else: + if self.cfg.trainer.gan_relativistic: + return _get_difference(net_D_output['fake_outputs'], net_D_output['real_outputs']) + else: + return net_D_output['fake_outputs'] + + def _start_of_epoch(self, current_epoch): + r"""Operations to do before starting an epoch. + + Args: + current_epoch (int): Current number of epoch. + """ + pass + + def _start_of_iteration(self, data, current_iteration): + r"""Operations to do before starting an iteration. + + Args: + data (dict): Data used for the current iteration. + current_iteration (int): Current epoch number. + Returns: + (dict): Data used for the current iteration. They might be + processed by the custom _start_of_iteration function. + """ + return data + + def _end_of_iteration(self, data, current_epoch, current_iteration): + r"""Operations to do after an iteration. + + Args: + data (dict): Data used for the current iteration. + current_epoch (int): Current number of epoch. + current_iteration (int): Current epoch number. + """ + pass + + def _end_of_epoch(self, data, current_epoch, current_iteration): + r"""Operations to do after an epoch. + + Args: + data (dict): Data used for the current iteration. + current_epoch (int): Current number of epoch. + current_iteration (int): Current epoch number. + """ + pass + + def _get_visualizations(self, data): + r"""Compute visualization outputs. + + Args: + data (dict): Data used for the current iteration. + """ + return None + + def _compute_fid(self): + r"""FID computation function to be overloaded.""" + return None + + def _init_loss(self, cfg): + r"""Every trainer should implement its own init loss function.""" + raise NotImplementedError + + def gen_update(self, data): + r"""Update the generator. + + Args: + data (dict): Data used for the current iteration. + """ + update_finished = False + while not update_finished: + # Set requires_grad flags. + requires_grad(self.net_G_module, True) + requires_grad(self.net_D, False) + + # Compute the loss. + self._time_before_forward() + with autocast(enabled=self.cfg.trainer.amp_config.enabled): + total_loss = self.gen_forward(data) + if total_loss is None: + return + + # Zero-grad and backpropagate the loss. + self.opt_G.zero_grad(set_to_none=True) + self._time_before_backward() + self.scaler_G.scale(total_loss).backward() + + # Optionally clip gradient norm. + if hasattr(self.cfg.gen_opt, 'clip_grad_norm'): + self.scaler_G.unscale_(self.opt_G) + total_norm = torch.nn.utils.clip_grad_norm_( + self.net_G_module.parameters(), + self.cfg.gen_opt.clip_grad_norm + ) + self.gen_grad_norm = total_norm + if torch.isfinite(total_norm) and \ + total_norm > self.cfg.gen_opt.clip_grad_norm: + # print(f"Gradient norm of the generator ({total_norm}) " + # f"too large.") + if getattr(self.cfg.gen_opt, 'skip_grad', False): + print(f"Skip gradient update.") + self.opt_G.zero_grad(set_to_none=True) + self.scaler_G.step(self.opt_G) + self.scaler_G.update() + break + # else: + # print(f"Clip gradient norm to " + # f"{self.cfg.gen_opt.clip_grad_norm}.") + + # Perform an optimizer step. + self._time_before_step() + self.scaler_G.step(self.opt_G) + self.scaler_G.update() + # Whether the step above was skipped. + if self.last_step_count_G == self.opt_G._step_count: + print("Generator overflowed!") + if not torch.isfinite(total_loss): + print("Generator loss is not finite. Skip this iteration!") + update_finished = True + else: + self.last_step_count_G = self.opt_G._step_count + update_finished = True + + self._extra_gen_step(data) + + # Update model average. + self._time_before_model_avg() + if self.cfg.trainer.model_average_config.enabled: + self.net_G.module.update_average() + + self._detach_losses() + self._time_before_leave_gen() + + def gen_forward(self, data): + r"""Every trainer should implement its own generator forward.""" + raise NotImplementedError + + def _extra_gen_step(self, data): + pass + + def dis_update(self, data): + r"""Update the discriminator. + + Args: + data (dict): Data used for the current iteration. + """ + update_finished = False + while not update_finished: + # Set requires_grad flags. + requires_grad(self.net_G_module, False) + requires_grad(self.net_D, True) + + # Compute the loss. + self._time_before_forward() + with autocast(enabled=self.cfg.trainer.amp_config.enabled): + total_loss = self.dis_forward(data) + if total_loss is None: + return + + # Zero-grad and backpropagate the loss. + self.opt_D.zero_grad(set_to_none=True) + self._time_before_backward() + self.scaler_D.scale(total_loss).backward() + + # Optionally clip gradient norm. + if hasattr(self.cfg.dis_opt, 'clip_grad_norm'): + self.scaler_D.unscale_(self.opt_D) + total_norm = torch.nn.utils.clip_grad_norm_( + self.net_D.parameters(), self.cfg.dis_opt.clip_grad_norm + ) + self.dis_grad_norm = total_norm + if torch.isfinite(total_norm) and \ + total_norm > self.cfg.dis_opt.clip_grad_norm: + print(f"Gradient norm of the discriminator ({total_norm}) " + f"too large.") + if getattr(self.cfg.dis_opt, 'skip_grad', False): + print(f"Skip gradient update.") + self.opt_D.zero_grad(set_to_none=True) + self.scaler_D.step(self.opt_D) + self.scaler_D.update() + continue + else: + print(f"Clip gradient norm to " + f"{self.cfg.dis_opt.clip_grad_norm}.") + + # Perform an optimizer step. + self._time_before_step() + self.scaler_D.step(self.opt_D) + self.scaler_D.update() + # Whether the step above was skipped. + if self.last_step_count_D == self.opt_D._step_count: + print("Discriminator overflowed!") + if not torch.isfinite(total_loss): + print("Discriminator loss is not finite. " + "Skip this iteration!") + update_finished = True + else: + self.last_step_count_D = self.opt_D._step_count + update_finished = True + + self._extra_dis_step(data) + + self._detach_losses() + self._time_before_leave_dis() + + def dis_forward(self, data): + r"""Every trainer should implement its own discriminator forward.""" + raise NotImplementedError + + def _extra_dis_step(self, data): + pass + + def test(self, data_loader, output_dir, inference_args): + r"""Compute results images for a batch of input data and save the + results in the specified folder. + + Args: + data_loader (torch.utils.data.DataLoader): PyTorch dataloader. + output_dir (str): Target location for saving the output image. + """ + if self.cfg.trainer.model_average_config.enabled: + net_G = self.net_G.module.averaged_model + else: + net_G = self.net_G.module + net_G.eval() + + print('# of samples %d' % len(data_loader)) + for it, data in enumerate(tqdm(data_loader)): + data = self.start_of_iteration(data, current_iteration=-1) + with torch.no_grad(): + output_images, file_names = \ + net_G.inference(data, **vars(inference_args)) + for output_image, file_name in zip(output_images, file_names): + fullname = os.path.join(output_dir, file_name + '.jpg') + output_image = tensor2pilimage(output_image.clamp_(-1, 1), + minus1to1_normalized=True) + save_pilimage_in_jpeg(fullname, output_image) + + def _get_total_loss(self, gen_forward): + r"""Return the total loss to be backpropagated. + Args: + gen_forward (bool): If ``True``, backpropagates the generator loss, + otherwise the discriminator loss. + """ + losses = self.gen_losses if gen_forward else self.dis_losses + total_loss = torch.tensor(0., device=torch.device('cuda')) + # Iterates over all possible losses. + for loss_name in self.weights: + # If it is for the current model (gen/dis). + if loss_name in losses: + # Multiply it with the corresponding weight + # and add it to the total loss. + total_loss += losses[loss_name] * self.weights[loss_name] + losses['total'] = total_loss # logging purpose + return total_loss + + def _detach_losses(self): + r"""Detach all logging variables to prevent potential memory leak.""" + for loss_name in self.gen_losses: + self.gen_losses[loss_name] = self.gen_losses[loss_name].detach() + for loss_name in self.dis_losses: + self.dis_losses[loss_name] = self.dis_losses[loss_name].detach() + + def _time_before_forward(self): + r""" + Record time before applying forward. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + self.forw_time = time.time() + + def _time_before_loss(self): + r""" + Record time before computing loss. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + self.loss_time = time.time() + + def _time_before_backward(self): + r""" + Record time before applying backward. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + self.back_time = time.time() + + def _time_before_step(self): + r""" + Record time before updating the weights + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + self.step_time = time.time() + + def _time_before_model_avg(self): + r""" + Record time before applying model average. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + self.avg_time = time.time() + + def _time_before_leave_gen(self): + r""" + Record forward, backward, loss, and model average time for the + generator update. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + end_time = time.time() + self.accu_gen_forw_iter_time += self.loss_time - self.forw_time + self.accu_gen_loss_iter_time += self.back_time - self.loss_time + self.accu_gen_back_iter_time += self.step_time - self.back_time + self.accu_gen_step_iter_time += self.avg_time - self.step_time + self.accu_gen_avg_iter_time += end_time - self.avg_time + + def _time_before_leave_dis(self): + r""" + Record forward, backward, loss time for the discriminator update. + """ + if self.cfg.speed_benchmark: + torch.cuda.synchronize() + end_time = time.time() + self.accu_dis_forw_iter_time += self.loss_time - self.forw_time + self.accu_dis_loss_iter_time += self.back_time - self.loss_time + self.accu_dis_back_iter_time += self.step_time - self.back_time + self.accu_dis_step_iter_time += end_time - self.step_time + + +@master_only +def _save_checkpoint(cfg, + net_G, net_D, + opt_G, opt_D, + sch_G, sch_D, + current_epoch, current_iteration): + r"""Save network weights, optimizer parameters, scheduler parameters + in the checkpoint. + + Args: + cfg (obj): Global configuration. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + current_epoch (int): Current epoch. + current_iteration (int): Current iteration. + """ + latest_checkpoint_path = 'epoch_{:05}_iteration_{:09}_checkpoint.pt'.format( + current_epoch, current_iteration) + save_path = os.path.join(cfg.logdir, latest_checkpoint_path) + torch.save( + { + 'net_G': net_G.state_dict(), + 'net_D': net_D.state_dict(), + 'opt_G': opt_G.state_dict(), + 'opt_D': opt_D.state_dict(), + 'sch_G': sch_G.state_dict(), + 'sch_D': sch_D.state_dict(), + 'current_epoch': current_epoch, + 'current_iteration': current_iteration, + }, + save_path, + ) + fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt') + with open(fn, 'wt') as f: + f.write('latest_checkpoint: %s' % latest_checkpoint_path) + print('Save checkpoint to {}'.format(save_path)) + return save_path diff --git a/imaginaire/trainers/fs_vid2vid.py b/imaginaire/trainers/fs_vid2vid.py new file mode 100644 index 0000000000000000000000000000000000000000..6f30b4dd588667b8cef1e433c5503f9ac419e190 --- /dev/null +++ b/imaginaire/trainers/fs_vid2vid.py @@ -0,0 +1,292 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os + +import imageio +import numpy as np +import torch +from tqdm import tqdm + + +from imaginaire.model_utils.fs_vid2vid import (concat_frames, get_fg_mask, + pre_process_densepose, + random_roll) +from imaginaire.model_utils.pix2pixHD import get_optimizer_with_params +from imaginaire.trainers.vid2vid import Trainer as vid2vidTrainer +from imaginaire.utils.distributed import is_master +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.misc import to_cuda +from imaginaire.utils.visualization import tensor2flow, tensor2im + + +class Trainer(vid2vidTrainer): + r"""Initialize vid2vid trainer. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, + train_data_loader, val_data_loader): + super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, + opt_D, sch_G, sch_D, + train_data_loader, val_data_loader) + + def _start_of_iteration(self, data, current_iteration): + r"""Things to do before an iteration. + + Args: + data (dict): Data used for the current iteration. + current_iteration (int): Current number of iteration. + """ + data = self.pre_process(data) + return to_cuda(data) + + def pre_process(self, data): + r"""Do any data pre-processing here. + + Args: + data (dict): Data used for the current iteration. + """ + data_cfg = self.cfg.data + if hasattr(data_cfg, 'for_pose_dataset') and \ + ('pose_maps-densepose' in data_cfg.input_labels): + pose_cfg = data_cfg.for_pose_dataset + data['label'] = pre_process_densepose(pose_cfg, data['label'], + self.is_inference) + data['few_shot_label'] = pre_process_densepose( + pose_cfg, data['few_shot_label'], self.is_inference) + return data + + def get_test_output_images(self, data): + r"""Get the visualization output of test function. + + Args: + data (dict): Training data at the current iteration. + """ + vis_images = [ + tensor2im(data['few_shot_images'][:, 0]), + self.visualize_label(data['label'][:, -1]), + tensor2im(data['images'][:, -1]), + tensor2im(self.net_G_output['fake_images']), + ] + return vis_images + + def get_data_t(self, data, net_G_output, data_prev, t): + r"""Get data at current time frame given the sequence of data. + + Args: + data (dict): Training data for current iteration. + net_G_output (dict): Output of the generator (for previous frame). + data_prev (dict): Data for previous frame. + t (int): Current time. + """ + label = data['label'][:, t] if 'label' in data else None + image = data['images'][:, t] + + if data_prev is not None: + nG = self.cfg.data.num_frames_G + prev_labels = concat_frames(data_prev['prev_labels'], + data_prev['label'], nG - 1) + prev_images = concat_frames( + data_prev['prev_images'], + net_G_output['fake_images'].detach(), nG - 1) + else: + prev_labels = prev_images = None + + data_t = dict() + data_t['label'] = label + data_t['image'] = image + data_t['ref_labels'] = data['few_shot_label'] if 'few_shot_label' \ + in data else None + data_t['ref_images'] = data['few_shot_images'] + data_t['prev_labels'] = prev_labels + data_t['prev_images'] = prev_images + data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None + + # if 'landmarks_xy' in data: + # data_t['landmarks_xy'] = data['landmarks_xy'][:, t] + # data_t['ref_landmarks_xy'] = data['few_shot_landmarks_xy'] + return data_t + + def post_process(self, data, net_G_output): + r"""Do any postprocessing of the data / output here. + + Args: + data (dict): Training data at the current iteration. + net_G_output (dict): Output of the generator. + """ + if self.has_fg: + fg_mask = get_fg_mask(data['label'], self.has_fg) + if net_G_output['fake_raw_images'] is not None: + net_G_output['fake_raw_images'] = \ + net_G_output['fake_raw_images'] * fg_mask + + return data, net_G_output + + def test(self, test_data_loader, root_output_dir, inference_args): + r"""Run inference on the specified sequence. + + Args: + test_data_loader (object): Test data loader. + root_output_dir (str): Location to dump outputs. + inference_args (optional): Optional args. + """ + self.reset() + test_data_loader.dataset.set_sequence_length(0) + # Set the inference sequences. + test_data_loader.dataset.set_inference_sequence_idx( + inference_args.driving_seq_index, + inference_args.few_shot_seq_index, + inference_args.few_shot_frame_index) + + video = [] + for idx, data in enumerate(tqdm(test_data_loader)): + key = data['key']['images'][0][0] + filename = key.split('/')[-1] + + # Create output dir for this sequence. + if idx == 0: + seq_name = '%03d' % inference_args.driving_seq_index + output_dir = os.path.join(root_output_dir, seq_name) + os.makedirs(output_dir, exist_ok=True) + video_path = output_dir + + # Get output and save images. + data['img_name'] = filename + data = self.start_of_iteration(data, current_iteration=-1) + output = self.test_single(data, output_dir, inference_args) + video.append(output) + + # Save output as mp4. + imageio.mimsave(video_path + '.mp4', video, fps=15) + + def save_image(self, path, data): + r"""Save the output images to path. + Note when the generate_raw_output is FALSE. Then, + first_net_G_output['fake_raw_images'] is None and will not be displayed. + In model average mode, we will plot the flow visualization twice. + + Args: + path (str): Save path. + data (dict): Training data for current iteration. + """ + self.net_G.eval() + if self.cfg.trainer.model_average_config.enabled: + self.net_G.module.averaged_model.eval() + + self.net_G_output = None + with torch.no_grad(): + first_net_G_output, last_net_G_output, _ = self.gen_frames(data) + if self.cfg.trainer.model_average_config.enabled: + first_net_G_output_avg, last_net_G_output_avg, _ = \ + self.gen_frames(data, use_model_average=True) + + def get_images(data, net_G_output, return_first_frame=True, + for_model_average=False): + r"""Get the ourput images to save. + + Args: + data (dict): Training data for current iteration. + net_G_output (dict): Generator output. + return_first_frame (bool): Return output for first frame in the + sequence. + for_model_average (bool): For model average output. + Return: + vis_images (list of numpy arrays): Visualization images. + """ + frame_idx = 0 if return_first_frame else -1 + warped_idx = 0 if return_first_frame else 1 + vis_images = [] + if not for_model_average: + vis_images += [ + tensor2im(data['few_shot_images'][:, frame_idx]), + self.visualize_label(data['label'][:, frame_idx]), + tensor2im(data['images'][:, frame_idx]) + ] + vis_images += [ + tensor2im(net_G_output['fake_images']), + tensor2im(net_G_output['fake_raw_images'])] + if not for_model_average: + vis_images += [ + tensor2im(net_G_output['warped_images'][warped_idx]), + tensor2flow(net_G_output['fake_flow_maps'][warped_idx]), + tensor2im(net_G_output['fake_occlusion_masks'][warped_idx], + normalize=False) + ] + return vis_images + + if is_master(): + vis_images_first = get_images(data, first_net_G_output) + if self.cfg.trainer.model_average_config.enabled: + vis_images_first += get_images(data, first_net_G_output_avg, + for_model_average=True) + if self.sequence_length > 1: + vis_images_last = get_images(data, last_net_G_output, + return_first_frame=False) + if self.cfg.trainer.model_average_config.enabled: + vis_images_last += get_images(data, last_net_G_output_avg, + return_first_frame=False, + for_model_average=True) + + # If generating a video, the first row of each batch will be + # the first generated frame and the flow/mask for warping the + # reference image, and the second row will be the last + # generated frame and the flow/mask for warping the previous + # frame. If using model average, the frames generated by model + # average will be at the rightmost columns. + vis_images = [[np.vstack((im_first, im_last)) + for im_first, im_last in + zip(imgs_first, imgs_last)] + for imgs_first, imgs_last in zip(vis_images_first, + vis_images_last) + if imgs_first is not None] + else: + vis_images = vis_images_first + + image_grid = np.hstack([np.vstack(im) for im in vis_images + if im is not None]) + + print('Save output images to {}'.format(path)) + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.imwrite(path, image_grid) + + def finetune(self, data, inference_args): + r"""Finetune the model for a few iterations on the inference data.""" + # Get the list of params to finetune. + self.net_G, self.net_D, self.opt_G, self.opt_D = \ + get_optimizer_with_params(self.cfg, self.net_G, self.net_D, + param_names_start_with=[ + 'weight_generator.fc', 'conv_img', + 'up']) + data_finetune = {k: v for k, v in data.items()} + ref_labels = data_finetune['few_shot_label'] + ref_images = data_finetune['few_shot_images'] + + # Number of iterations to finetune. + iterations = getattr(inference_args, 'finetune_iter', 100) + for it in range(1, iterations + 1): + # Randomly set one of the reference images as target. + idx = np.random.randint(ref_labels.size(1)) + tgt_label, tgt_image = ref_labels[:, idx], ref_images[:, idx] + # Randomly shift and flip the target image. + tgt_label, tgt_image = random_roll([tgt_label, tgt_image]) + data_finetune['label'] = tgt_label.unsqueeze(1) + data_finetune['images'] = tgt_image.unsqueeze(1) + + self.gen_update(data_finetune) + self.dis_update(data_finetune) + if (it % (iterations // 10)) == 0: + print(it) + + self.has_finetuned = True diff --git a/imaginaire/trainers/funit.py b/imaginaire/trainers/funit.py new file mode 100644 index 0000000000000000000000000000000000000000..71141064cc92639120a8bd45b1a982bed0c5f37a --- /dev/null +++ b/imaginaire/trainers/funit.py @@ -0,0 +1,244 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np +import os +from imaginaire.evaluation import compute_fid, compute_kid +from imaginaire.utils.diff_aug import apply_diff_aug +from imaginaire.losses import GANLoss +from imaginaire.trainers.base import BaseTrainer +from imaginaire.utils.distributed import is_master + + +class Trainer(BaseTrainer): + r"""Reimplementation of the FUNIT (https://arxiv.org/abs/1905.01723) + algorithm. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, + train_data_loader, val_data_loader): + self.best_kid = None + self.use_fid = getattr(cfg.trainer, 'use_fid', False) + self.use_kid = getattr(cfg.trainer, 'use_kid', True) + self.kid_num_subsets = getattr(cfg.trainer, 'kid_num_subsets', 1) + self.kid_sample_size = getattr(cfg.trainer, 'kid_sample_size', 256) + self.kid_subset_size = getattr(cfg.trainer, 'kid_subset_size', 256) + super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, + train_data_loader, val_data_loader) + + def _init_loss(self, cfg): + r"""Initialize loss terms. In FUNIT, we have several loss terms + including the GAN loss, the image reconstruction loss, the feature + matching loss, and the gradient penalty loss. + + Args: + cfg (obj): Global configuration. + """ + self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode) + self.criteria['image_recon'] = nn.L1Loss() + self.criteria['feature_matching'] = nn.L1Loss() + + for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items(): + if loss_weight > 0: + self.weights[loss_name] = loss_weight + + def gen_forward(self, data): + r"""Compute the loss for FUNIT generator. + + Args: + data (dict): Training data at the current iteration. + """ + + net_G_output = self.net_G(data) + + # Differentiable augmentation. + keys = ['images_recon', 'images_trans'] + net_D_output = self.net_D(data, apply_diff_aug( + net_G_output, keys, self.aug_policy)) + + self._time_before_loss() + + # GAN loss + # We use both the translation and reconstruction streams. + if 'gan' in self.weights: + self.gen_losses['gan'] = 0.5 * ( + self.criteria['gan']( + net_D_output['fake_out_trans'], + True, dis_update=False) + + self.criteria['gan']( + net_D_output['fake_out_recon'], + True, dis_update=False)) + + # Image reconstruction loss + if 'image_recon' in self.weights: + self.gen_losses['image_recon'] = \ + self.criteria['image_recon'](net_G_output['images_recon'], + data['images_content']) + + # Feature matching loss + if 'feature_matching' in self.weights: + self.gen_losses['feature_matching'] = \ + self.criteria['feature_matching']( + net_D_output['fake_features_trans'], + net_D_output['real_features_style']) + + # Compute total loss + total_loss = self._get_total_loss(gen_forward=True) + return total_loss + + def dis_forward(self, data): + r"""Compute the loss for FUNIT discriminator. + + Args: + data (dict): Training data at the current iteration. + """ + with torch.no_grad(): + net_G_output = self.net_G(data) + net_G_output['images_trans'].requires_grad = True + net_D_output = self.net_D( + apply_diff_aug(data, ['images_style'], self.aug_policy), + apply_diff_aug(net_G_output, ['images_trans'], self.aug_policy), + recon=False) + + self._time_before_loss() + + self.dis_losses['gan'] = \ + self.criteria['gan'](net_D_output['real_out_style'], True) + \ + self.criteria['gan'](net_D_output['fake_out_trans'], False) + + # Compute total loss + total_loss = self._get_total_loss(gen_forward=False) + return total_loss + + def _get_visualizations(self, data): + r"""Compute visualization image. + + Args: + data (dict): The current batch. + """ + net_G_for_evaluation = self.net_G + with torch.no_grad(): + net_G_output = net_G_for_evaluation(data) + vis_images = [data['images_content'], + data['images_style'], + net_G_output['images_recon'], + net_G_output['images_trans']] + _, _, h, w = net_G_output['images_recon'].size() + if 'attn_a' in net_G_output: + for i in range(net_G_output['attn_a'].size(1)): + vis_images += [ + F.interpolate( + net_G_output['attn_a'][:, i:i + 1, :, :], ( + h, w)).expand(-1, 3, -1, -1)] + for i in range(net_G_output['attn_a'].size(1)): + vis_images += [ + F.interpolate( + net_G_output['attn_b'][:, i:i + 1, :, :], ( + h, w)).expand(-1, 3, -1, -1)] + if self.cfg.trainer.model_average_config.enabled: + net_G_for_evaluation = self.net_G.module.averaged_model + net_G_output = net_G_for_evaluation(data) + vis_images += [net_G_output['images_recon'], + net_G_output['images_trans']] + return vis_images + + def _compute_fid(self): + r"""Compute FID. We will compute a FID value per test class. That is + if you have 30 test classes, we will compute 30 different FID values. + We will then report the mean of the FID values as the final + performance number as described in the FUNIT paper. + """ + self.net_G.eval() + if self.cfg.trainer.model_average_config.enabled: + net_G_for_evaluation = self.net_G.module.averaged_model + else: + net_G_for_evaluation = self.net_G + + all_fid_values = [] + num_test_classes = self.val_data_loader.dataset.num_style_classes + for class_idx in range(num_test_classes): + fid_path = self._get_save_path(os.path.join('fid', str(class_idx)), + 'npy') + self.val_data_loader.dataset.set_sample_class_idx(class_idx) + + fid_value = compute_fid(fid_path, self.val_data_loader, + net_G_for_evaluation, 'images_style', + 'images_trans') + all_fid_values.append(fid_value) + + if is_master(): + mean_fid = np.mean(all_fid_values) + print('Epoch {:05}, Iteration {:09}, Mean FID {}'.format( + self.current_epoch, self.current_iteration, mean_fid)) + return mean_fid + else: + return None + + def _compute_kid(self): + self.net_G.eval() + if self.cfg.trainer.model_average_config.enabled: + net_G_for_evaluation = self.net_G.module.averaged_model + else: + net_G_for_evaluation = self.net_G + + all_kid_values = [] + num_test_classes = self.val_data_loader.dataset.num_style_classes + for class_idx in range(num_test_classes): + kid_path = self._get_save_path(os.path.join('kid', str(class_idx)), + 'npy') + self.val_data_loader.dataset.set_sample_class_idx(class_idx) + + kid_value = compute_kid( + kid_path, self.val_data_loader, net_G_for_evaluation, + 'images_style', 'images_trans', + num_subsets=self.kid_num_subsets, + sample_size=self.kid_sample_size, + subset_size=self.kid_subset_size) + all_kid_values.append(kid_value) + + if is_master(): + mean_kid = np.mean(all_kid_values) + print('Epoch {:05}, Iteration {:09}, Mean FID {}'.format( + self.current_epoch, self.current_iteration, mean_kid)) + return mean_kid + else: + return None + + def write_metrics(self): + r"""Write metrics to the tensorboard.""" + metric_dict = {} + if self.use_kid: + cur_kid = self._compute_kid() + if cur_kid is not None: + if self.best_kid is not None: + self.best_kid = min(self.best_kid, cur_kid) + else: + self.best_kid = cur_kid + metric_dict.update({'KID': cur_kid, 'best_KID': self.best_kid}) + if self.use_fid: + cur_fid = self._compute_fid() + if cur_fid is not None: + if self.best_fid is not None: + self.best_fid = min(self.best_fid, cur_fid) + else: + self.best_fid = cur_fid + metric_dict.update({'FID': cur_fid, 'best_FID': self.best_fid}) + + if is_master(): + self._write_to_meters(metric_dict, self.metric_meters) + self._flush_meters(self.metric_meters) diff --git a/imaginaire/trainers/gancraft.py b/imaginaire/trainers/gancraft.py new file mode 100644 index 0000000000000000000000000000000000000000..167e26161493ea82901d9d8e14f030ad71444c0e --- /dev/null +++ b/imaginaire/trainers/gancraft.py @@ -0,0 +1,327 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import collections +import os + +import torch +import torch.nn as nn + +from imaginaire.config import Config +from imaginaire.generators.spade import Generator as SPADEGenerator +from imaginaire.losses import (FeatureMatchingLoss, GaussianKLLoss, PerceptualLoss) +from imaginaire.model_utils.gancraft.loss import GANLoss +from imaginaire.trainers.base import BaseTrainer +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.io import get_checkpoint +from imaginaire.utils.misc import split_labels, to_device +from imaginaire.utils.trainer import ModelAverage, WrappedModel +from imaginaire.utils.visualization import tensor2label + + +class GauGANLoader(object): + r"""Manages the SPADE/GauGAN model used to generate pseudo-GTs for training GANcraft. + + Args: + gaugan_cfg (Config): SPADE configuration. + """ + + def __init__(self, gaugan_cfg): + print('[GauGANLoader] Loading GauGAN model.') + cfg = Config(gaugan_cfg.config) + default_checkpoint_path = os.path.basename(gaugan_cfg.config).split('.yaml')[0] + '-' + \ + cfg.pretrained_weight + '.pt' + checkpoint = get_checkpoint(default_checkpoint_path, cfg.pretrained_weight) + ckpt = torch.load(checkpoint) + + net_G = WrappedModel(ModelAverage(SPADEGenerator(cfg.gen, cfg.data).to('cuda'))) + net_G.load_state_dict(ckpt['net_G']) + self.net_GG = net_G.module.averaged_model + self.net_GG.eval() + self.net_GG.half() + print('[GauGANLoader] GauGAN loading complete.') + + def eval(self, label, z=None, style_img=None): + r"""Produce output given segmentation and other conditioning inputs. + random style will be used if neither z nor style_img is provided. + + Args: + label (N x C x H x W tensor): One-hot segmentation mask of shape. + z: Style vector. + style_img: Style image. + """ + inputs = {'label': label[:, :-1].detach().half()} + random_style = True + + if z is not None: + random_style = False + inputs['z'] = z.detach().half() + elif style_img is not None: + random_style = False + inputs['images'] = style_img.detach().half() + + net_GG_output = self.net_GG(inputs, random_style=random_style) + + return net_GG_output['fake_images'] + + +class Trainer(BaseTrainer): + r"""Initialize GANcraft trainer. + + Args: + cfg (Config): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, + cfg, + net_G, + net_D, + opt_G, + opt_D, + sch_G, + sch_D, + train_data_loader, + val_data_loader): + super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, + opt_D, sch_G, sch_D, + train_data_loader, val_data_loader) + + # Load the pseudo-GT network only if in training mode, else not needed. + if not self.is_inference: + self.gaugan_model = GauGANLoader(cfg.trainer.gaugan_loader) + + def _init_loss(self, cfg): + r"""Initialize loss terms. + + Args: + cfg (obj): Global configuration. + """ + if hasattr(cfg.trainer.loss_weight, 'gan'): + self.criteria['GAN'] = GANLoss() + self.weights['GAN'] = cfg.trainer.loss_weight.gan + if hasattr(cfg.trainer.loss_weight, 'pseudo_gan'): + self.criteria['PGAN'] = GANLoss() + self.weights['PGAN'] = cfg.trainer.loss_weight.pseudo_gan + if hasattr(cfg.trainer.loss_weight, 'l2'): + self.criteria['L2'] = nn.MSELoss() + self.weights['L2'] = cfg.trainer.loss_weight.l2 + if hasattr(cfg.trainer.loss_weight, 'l1'): + self.criteria['L1'] = nn.L1Loss() + self.weights['L1'] = cfg.trainer.loss_weight.l1 + if hasattr(cfg.trainer.loss_weight, 'TV') + if hasattr(cfg.trainer, 'perceptual_loss'): + self.criteria['Perceptual'] = \ + PerceptualLoss( + network=cfg.trainer.perceptual_loss.mode, + layers=cfg.trainer.perceptual_loss.layers, + weights=cfg.trainer.perceptual_loss.weights) + self.weights['Perceptual'] = cfg.trainer.loss_weight.perceptual + # Setup the feature matching loss. + if hasattr(cfg.trainer.loss_weight, 'feature_matching'): + self.criteria['FeatureMatching'] = FeatureMatchingLoss() + self.weights['FeatureMatching'] = \ + cfg.trainer.loss_weight.feature_matching + # Setup the Gaussian KL divergence loss. + if hasattr(cfg.trainer.loss_weight, 'kl'): + self.criteria['GaussianKL'] = GaussianKLLoss() + self.weights['GaussianKL'] = cfg.trainer.loss_weight.kl + + def _start_of_epoch(self, current_epoch): + torch.cuda.empty_cache() # Prevent the first iteration from running OOM. + + def _start_of_iteration(self, data, current_iteration): + r"""Model specific custom start of iteration process. We will do two + things. First, put all the data to GPU. Second, we will resize the + input so that it becomes multiple of the factor for bug-free + convolutional operations. This factor is given by the yaml file. + E.g., base = getattr(self.net_G, 'base', 32) + + Args: + data (dict): The current batch. + current_iteration (int): The iteration number of the current batch. + """ + data = to_device(data, 'cuda') + + # Sample camera poses and pseudo-GTs. + with torch.no_grad(): + samples = self.net_G.module.sample_camera(data, self.gaugan_model.eval) + + return {**data, **samples} + + def gen_forward(self, data): + r"""Compute the loss for SPADE generator. + + Args: + data (dict): Training data at the current iteration. + """ + net_G_output = self.net_G(data, random_style=False) + + self._time_before_loss() + + if 'GAN' in self.criteria or 'PGAN' in self.criteria: + incl_pseudo_real = False + if 'FeatureMatching' in self.criteria: + incl_pseudo_real = True + net_D_output = self.net_D(data, net_G_output, incl_real=False, incl_pseudo_real=incl_pseudo_real) + output_fake = net_D_output['fake_outputs'] # Choose from real_outputs and fake_outputs. + + gan_loss = self.criteria['GAN'](output_fake, True, dis_update=False) + if 'GAN' in self.criteria: + self.gen_losses['GAN'] = gan_loss + if 'PGAN' in self.criteria: + self.gen_losses['PGAN'] = gan_loss + + if 'FeatureMatching' in self.criteria: + self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching']( + net_D_output['fake_features'], net_D_output['pseudo_real_features']) + + if 'GaussianKL' in self.criteria: + self.gen_losses['GaussianKL'] = self.criteria['GaussianKL'](net_G_output['mu'], net_G_output['logvar']) + + # Perceptual loss is always between fake image and pseudo real image. + if 'Perceptual' in self.criteria: + self.gen_losses['Perceptual'] = self.criteria['Perceptual']( + net_G_output['fake_images'], data['pseudo_real_img']) + + # Reconstruction loss between fake and pseudo real. + if 'L2' in self.criteria: + self.gen_losses['L2'] = self.criteria['L2'](net_G_output['fake_images'], data['pseudo_real_img']) + if 'L1' in self.criteria: + self.gen_losses['L1'] = self.criteria['L1'](net_G_output['fake_images'], data['pseudo_real_img']) + + total_loss = 0 + for key in self.criteria: + total_loss = total_loss + self.gen_losses[key] * self.weights[key] + + self.gen_losses['total'] = total_loss + return total_loss + + def dis_forward(self, data): + r"""Compute the loss for GANcraft discriminator. + + Args: + data (dict): Training data at the current iteration. + """ + if 'GAN' not in self.criteria and 'PGAN' not in self.criteria: + return + + with torch.no_grad(): + net_G_output = self.net_G(data, random_style=False) + net_G_output['fake_images'] = net_G_output['fake_images'].detach() + + incl_real = False + incl_pseudo_real = False + if 'GAN' in self.criteria: + incl_real = True + if 'PGAN' in self.criteria: + incl_pseudo_real = True + net_D_output = self.net_D(data, net_G_output, incl_real=incl_real, incl_pseudo_real=incl_pseudo_real) + + self._time_before_loss() + total_loss = 0 + if 'GAN' in self.criteria: + output_fake = net_D_output['fake_outputs'] + output_real = net_D_output['real_outputs'] + + fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True) + true_loss = self.criteria['GAN'](output_real, True, dis_update=True) + self.dis_losses['GAN/fake'] = fake_loss + self.dis_losses['GAN/true'] = true_loss + self.dis_losses['GAN'] = fake_loss + true_loss + total_loss = total_loss + self.dis_losses['GAN'] * self.weights['GAN'] + if 'PGAN' in self.criteria: + output_fake = net_D_output['fake_outputs'] + output_pseudo_real = net_D_output['pseudo_real_outputs'] + + fake_loss = self.criteria['PGAN'](output_fake, False, dis_update=True) + true_loss = self.criteria['PGAN'](output_pseudo_real, True, dis_update=True) + self.dis_losses['PGAN/fake'] = fake_loss + self.dis_losses['PGAN/true'] = true_loss + self.dis_losses['PGAN'] = fake_loss + true_loss + total_loss = total_loss + self.dis_losses['PGAN'] * self.weights['PGAN'] + + self.dis_losses['total'] = total_loss + return total_loss + + def _get_visualizations(self, data): + r"""Compute visualization image. + + Args: + data (dict): The current batch. + """ + with torch.no_grad(): + label_lengths = self.train_data_loader.dataset.get_label_lengths() + labels = split_labels(data['label'], label_lengths) + + # Get visualization of the real image and segmentation mask. + segmap = tensor2label(labels['seg_maps'], label_lengths['seg_maps'], output_normalized_tensor=True) + segmap = torch.cat([x.unsqueeze(0) for x in segmap], 0) + + # Get output from GANcraft model + net_G_output_randstyle = self.net_G(data, random_style=True) + net_G_output = self.net_G(data, random_style=False) + + vis_images = [data['images'], segmap, net_G_output_randstyle['fake_images'], net_G_output['fake_images']] + + if 'fake_masks' in data: + # Get pseudo-GT. + labels = split_labels(data['fake_masks'], label_lengths) + segmap = tensor2label(labels['seg_maps'], label_lengths['seg_maps'], output_normalized_tensor=True) + segmap = torch.cat([x.unsqueeze(0) for x in segmap], 0) + vis_images.append(segmap) + + if 'pseudo_real_img' in data: + vis_images.append(data['pseudo_real_img']) + + if self.cfg.trainer.model_average_config.enabled: + net_G_model_average_output = self.net_G.module.averaged_model(data, random_style=True) + vis_images.append(net_G_model_average_output['fake_images']) + return vis_images + + def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True): + r"""Load network weights, optimizer parameters, scheduler parameters + from a checkpoint. + + Args: + cfg (obj): Global configuration. + checkpoint_path (str): Path to the checkpoint. + resume (bool or None): If not ``None``, will determine whether or + not to load optimizers in addition to network weights. + """ + ret = super().load_checkpoint(cfg, checkpoint_path, resume, load_sch) + + if getattr(cfg.trainer, 'reset_opt_g_on_resume', False): + self.opt_G.state = collections.defaultdict(dict) + print('[GANcraft::load_checkpoint] Resetting opt_G.state') + if getattr(cfg.trainer, 'reset_opt_d_on_resume', False): + self.opt_D.state = collections.defaultdict(dict) + print('[GANcraft::load_checkpoint] Resetting opt_D.state') + + return ret + + def test(self, data_loader, output_dir, inference_args): + r"""Compute results images for a batch of input data and save the + results in the specified folder. + + Args: + data_loader (torch.utils.data.DataLoader): PyTorch dataloader. + output_dir (str): Target location for saving the output image. + """ + if self.cfg.trainer.model_average_config.enabled: + net_G = self.net_G.module.averaged_model + else: + net_G = self.net_G.module + net_G.eval() + + torch.cuda.empty_cache() + with torch.no_grad(): + net_G.inference(output_dir, **vars(inference_args)) diff --git a/imaginaire/trainers/munit.py b/imaginaire/trainers/munit.py new file mode 100644 index 0000000000000000000000000000000000000000..d0bc4b42e1d2b82ac4702cb8f99808603818e3d9 --- /dev/null +++ b/imaginaire/trainers/munit.py @@ -0,0 +1,312 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch + +from imaginaire.evaluation import compute_fid +from imaginaire.losses import (GANLoss, GaussianKLLoss, + PerceptualLoss) +from imaginaire.trainers.base import BaseTrainer +from imaginaire.utils.misc import random_shift +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.diff_aug import apply_diff_aug + + +class Trainer(BaseTrainer): + r"""Reimplementation of the MUNIT (https://arxiv.org/abs/1804.04732) + algorithm. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, + train_data_loader, val_data_loader): + super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, + train_data_loader, val_data_loader) + self.gan_recon = getattr(cfg.trainer, 'gan_recon', False) + self.best_fid_a = None + self.best_fid_b = None + + def _init_loss(self, cfg): + r"""Initialize loss terms. In MUNIT, we have several loss terms + including the GAN loss, the image reconstruction loss, the content + reconstruction loss, the style reconstruction loss, the cycle + reconstruction loss. We also have an optional perceptual loss. A user + can choose to have gradient penalty or consistency regularization too. + + Args: + cfg (obj): Global configuration. + """ + self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode) + self.criteria['kl'] = GaussianKLLoss() + self.criteria['image_recon'] = torch.nn.L1Loss() + if getattr(cfg.trainer.loss_weight, 'perceptual', 0) > 0: + self.criteria['perceptual'] = \ + PerceptualLoss(network=cfg.trainer.perceptual_mode, + layers=cfg.trainer.perceptual_layers) + + for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items(): + if loss_weight > 0: + self.weights[loss_name] = loss_weight + + def gen_forward(self, data): + r"""Compute the loss for MUNIT generator. + + Args: + data (dict): Training data at the current iteration. + """ + cycle_recon = 'cycle_recon' in self.weights + image_recon = 'image_recon' in self.weights + perceptual = 'perceptual' in self.weights + within_latent_recon = 'style_recon_within' in self.weights or \ + 'content_recon_within' in self.weights + + net_G_output = self.net_G(data, + image_recon=image_recon, + cycle_recon=cycle_recon, + within_latent_recon=within_latent_recon) + + # Differentiable augmentation. + keys = ['images_ab', 'images_ba'] + if self.gan_recon: + keys += ['images_aa', 'images_bb'] + net_D_output = self.net_D(data, + apply_diff_aug( + net_G_output, keys, self.aug_policy), + real=False, + gan_recon=self.gan_recon) + + self._time_before_loss() + + # GAN loss + if self.gan_recon: + self.gen_losses['gan_a'] = \ + 0.5 * (self.criteria['gan'](net_D_output['out_ba'], + True, dis_update=False) + + self.criteria['gan'](net_D_output['out_aa'], + True, dis_update=False)) + self.gen_losses['gan_b'] = \ + 0.5 * (self.criteria['gan'](net_D_output['out_ab'], + True, dis_update=False) + + self.criteria['gan'](net_D_output['out_bb'], + True, dis_update=False)) + else: + self.gen_losses['gan_a'] = self.criteria['gan']( + net_D_output['out_ba'], True, dis_update=False) + self.gen_losses['gan_b'] = self.criteria['gan']( + net_D_output['out_ab'], True, dis_update=False) + self.gen_losses['gan'] = \ + self.gen_losses['gan_a'] + self.gen_losses['gan_b'] + + # Perceptual loss + if perceptual: + self.gen_losses['perceptual_a'] = \ + self.criteria['perceptual'](net_G_output['images_ab'], + data['images_a']) + self.gen_losses['perceptual_b'] = \ + self.criteria['perceptual'](net_G_output['images_ba'], + data['images_b']) + self.gen_losses['perceptual'] = \ + self.gen_losses['perceptual_a'] + \ + self.gen_losses['perceptual_b'] + + # Image reconstruction loss + if image_recon: + self.gen_losses['image_recon'] = \ + self.criteria['image_recon'](net_G_output['images_aa'], + data['images_a']) + \ + self.criteria['image_recon'](net_G_output['images_bb'], + data['images_b']) + + # Style reconstruction loss + self.gen_losses['style_recon_a'] = torch.abs( + net_G_output['style_ba'] - + net_G_output['style_a_rand']).mean() + self.gen_losses['style_recon_b'] = torch.abs( + net_G_output['style_ab'] - + net_G_output['style_b_rand']).mean() + self.gen_losses['style_recon'] = \ + self.gen_losses['style_recon_a'] + self.gen_losses['style_recon_b'] + + if within_latent_recon: + self.gen_losses['style_recon_aa'] = torch.abs( + net_G_output['style_aa'] - + net_G_output['style_a'].detach()).mean() + self.gen_losses['style_recon_bb'] = torch.abs( + net_G_output['style_bb'] - + net_G_output['style_b'].detach()).mean() + self.gen_losses['style_recon_within'] = \ + self.gen_losses['style_recon_aa'] + \ + self.gen_losses['style_recon_bb'] + + # Content reconstruction loss + self.gen_losses['content_recon_a'] = torch.abs( + net_G_output['content_ab'] - + net_G_output['content_a'].detach()).mean() + self.gen_losses['content_recon_b'] = torch.abs( + net_G_output['content_ba'] - + net_G_output['content_b'].detach()).mean() + self.gen_losses['content_recon'] = \ + self.gen_losses['content_recon_a'] + \ + self.gen_losses['content_recon_b'] + + if within_latent_recon: + self.gen_losses['content_recon_aa'] = torch.abs( + net_G_output['content_aa'] - + net_G_output['content_a'].detach()).mean() + self.gen_losses['content_recon_bb'] = torch.abs( + net_G_output['content_bb'] - + net_G_output['content_b'].detach()).mean() + self.gen_losses['content_recon_within'] = \ + self.gen_losses['content_recon_aa'] + \ + self.gen_losses['content_recon_bb'] + + # KL loss + self.gen_losses['kl'] = \ + self.criteria['kl'](net_G_output['style_a']) + \ + self.criteria['kl'](net_G_output['style_b']) + + # Cycle reconstruction loss + if cycle_recon: + self.gen_losses['cycle_recon'] = \ + torch.abs(net_G_output['images_aba'] - + data['images_a']).mean() + \ + torch.abs(net_G_output['images_bab'] - + data['images_b']).mean() + + # Compute total loss + total_loss = self._get_total_loss(gen_forward=True) + return total_loss + + def dis_forward(self, data): + r"""Compute the loss for MUNIT discriminator. + + Args: + data (dict): Training data at the current iteration. + """ + with torch.no_grad(): + net_G_output = self.net_G(data, + image_recon=self.gan_recon, + latent_recon=False, + cycle_recon=False, + within_latent_recon=False) + net_G_output['images_ba'].requires_grad = True + net_G_output['images_ab'].requires_grad = True + + # Differentiable augmentation. + keys_fake = ['images_ab', 'images_ba'] + if self.gan_recon: + keys_fake += ['images_aa', 'images_bb'] + keys_real = ['images_a', 'images_b'] + + net_D_output = self.net_D( + apply_diff_aug(data, keys_real, self.aug_policy), + apply_diff_aug(net_G_output, keys_fake, self.aug_policy), + gan_recon=self.gan_recon) + + self._time_before_loss() + + # GAN loss. + self.dis_losses['gan_a'] = \ + self.criteria['gan'](net_D_output['out_a'], True) + \ + self.criteria['gan'](net_D_output['out_ba'], False) + self.dis_losses['gan_b'] = \ + self.criteria['gan'](net_D_output['out_b'], True) + \ + self.criteria['gan'](net_D_output['out_ab'], False) + self.dis_losses['gan'] = \ + self.dis_losses['gan_a'] + self.dis_losses['gan_b'] + + # Consistency regularization. + self.dis_losses['consistency_reg'] = \ + torch.tensor(0., device=torch.device('cuda')) + if 'consistency_reg' in self.weights: + data_aug, net_G_output_aug = {}, {} + data_aug['images_a'] = random_shift(data['images_a'].flip(-1)) + data_aug['images_b'] = random_shift(data['images_b'].flip(-1)) + net_G_output_aug['images_ab'] = \ + random_shift(net_G_output['images_ab'].flip(-1)) + net_G_output_aug['images_ba'] = \ + random_shift(net_G_output['images_ba'].flip(-1)) + net_D_output_aug = self.net_D(data_aug, net_G_output_aug) + feature_names = ['fea_ba', 'fea_ab', + 'fea_a', 'fea_b'] + for feature_name in feature_names: + self.dis_losses['consistency_reg'] += \ + torch.pow(net_D_output_aug[feature_name] - + net_D_output[feature_name], 2).mean() + + # Compute total loss + total_loss = self._get_total_loss(gen_forward=False) + return total_loss + + def _get_visualizations(self, data): + r"""Compute visualization image. + + Args: + data (dict): The current batch. + """ + if self.cfg.trainer.model_average_config.enabled: + net_G_for_evaluation = self.net_G.module.averaged_model + else: + net_G_for_evaluation = self.net_G + with torch.no_grad(): + net_G_output = net_G_for_evaluation(data, random_style=False) + net_G_output_random = net_G_for_evaluation(data) + vis_images = [data['images_a'], + data['images_b'], + net_G_output['images_aa'], + net_G_output['images_bb'], + net_G_output['images_ab'], + net_G_output_random['images_ab'], + net_G_output['images_ba'], + net_G_output_random['images_ba'], + net_G_output['images_aba'], + net_G_output['images_bab']] + return vis_images + + def write_metrics(self): + r"""Compute metrics and save them to tensorboard""" + cur_fid_a, cur_fid_b = self._compute_fid() + if self.best_fid_a is not None: + self.best_fid_a = min(self.best_fid_a, cur_fid_a) + else: + self.best_fid_a = cur_fid_a + if self.best_fid_b is not None: + self.best_fid_b = min(self.best_fid_b, cur_fid_b) + else: + self.best_fid_b = cur_fid_b + self._write_to_meters({'FID_a': cur_fid_a, + 'best_FID_a': self.best_fid_a, + 'FID_b': cur_fid_b, + 'best_FID_b': self.best_fid_b}, + self.metric_meters) + self._flush_meters(self.metric_meters) + + def _compute_fid(self): + r"""Compute FID for both domains. + """ + self.net_G.eval() + if self.cfg.trainer.model_average_config.enabled: + net_G_for_evaluation = self.net_G.module.averaged_model + else: + net_G_for_evaluation = self.net_G + fid_a_path = self._get_save_path('fid_a', 'npy') + fid_b_path = self._get_save_path('fid_b', 'npy') + fid_value_a = compute_fid(fid_a_path, self.val_data_loader, + net_G_for_evaluation, 'images_a', 'images_ba') + fid_value_b = compute_fid(fid_b_path, self.val_data_loader, + net_G_for_evaluation, 'images_b', 'images_ab') + print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format( + self.current_epoch, self.current_iteration, + fid_value_a, fid_value_b)) + return fid_value_a, fid_value_b diff --git a/imaginaire/trainers/pix2pixHD.py b/imaginaire/trainers/pix2pixHD.py new file mode 100644 index 0000000000000000000000000000000000000000..af8ed264adb0f5e2be61ce258ea3221b3b365f6a --- /dev/null +++ b/imaginaire/trainers/pix2pixHD.py @@ -0,0 +1,202 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools + +import torch + +from imaginaire.evaluation import compute_fid +from imaginaire.losses import FeatureMatchingLoss, GANLoss, PerceptualLoss +from imaginaire.model_utils.pix2pixHD import cluster_features, get_edges +from imaginaire.trainers.spade import Trainer as SPADETrainer +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.misc import to_cuda + + +class Trainer(SPADETrainer): + r"""Initialize pix2pixHD trainer. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, + train_data_loader, val_data_loader): + super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, + opt_D, sch_G, sch_D, + train_data_loader, val_data_loader) + + def _assign_criteria(self, name, criterion, weight): + r"""Assign training loss terms. + + Args: + name (str): Loss name + criterion (obj): Loss object. + weight (float): Loss weight. It should be non-negative. + """ + self.criteria[name] = criterion + self.weights[name] = weight + + def _init_loss(self, cfg): + r"""Initialize training loss terms. In pix2pixHD, there are three + loss terms: GAN loss, feature matching loss, and perceptual loss. + + Args: + cfg (obj): Global configuration. + """ + self.criteria = dict() + self.weights = dict() + trainer_cfg = cfg.trainer + loss_weight = cfg.trainer.loss_weight + # GAN loss and feature matching loss. + self._assign_criteria('GAN', + GANLoss(trainer_cfg.gan_mode), + loss_weight.gan) + self._assign_criteria('FeatureMatching', + FeatureMatchingLoss(), + loss_weight.feature_matching) + self._assign_criteria('Perceptual', + PerceptualLoss( + network=cfg.trainer.perceptual_loss.mode, + layers=cfg.trainer.perceptual_loss.layers, + weights=cfg.trainer.perceptual_loss.weights), + loss_weight.perceptual) + + def _start_of_iteration(self, data, current_iteration): + r"""Things to do before an iteration. + + Args: + data (dict): Data used for the current iteration. + current_iteration (int): Current number of iteration. + """ + return self.pre_process(data) + + def gen_forward(self, data): + r"""Compute the loss for pix2pixHD generator. + + Args: + data (dict): Training data at the current iteration. + """ + net_G_output = self.net_G(data) + net_D_output = self.net_D(data, net_G_output) + + self._time_before_loss() + + output_fake = self._get_outputs(net_D_output, real=False) + self.gen_losses['GAN'] = \ + self.criteria['GAN'](output_fake, True, dis_update=False) + + self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching']( + net_D_output['fake_features'], net_D_output['real_features']) + + if hasattr(self.cfg.trainer, 'perceptual_loss'): + self.gen_losses['Perceptual'] = self.criteria['Perceptual']( + net_G_output['fake_images'], data['images']) + + total_loss = self.gen_losses['GAN'].new_tensor([0]) + for key in self.criteria: + total_loss += self.gen_losses[key] * self.weights[key] + + self.gen_losses['total'] = total_loss + return total_loss + + def dis_forward(self, data): + r"""Compute the loss for pix2pixHD discriminator. + + Args: + data (dict): Training data at the current iteration. + """ + with torch.no_grad(): + net_G_output = self.net_G(data) + net_G_output['fake_images'] = net_G_output['fake_images'].detach() + net_D_output = self.net_D(data, net_G_output) + + self._time_before_loss() + + output_fake = self._get_outputs(net_D_output, real=False) + output_real = self._get_outputs(net_D_output, real=True) + fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True) + true_loss = self.criteria['GAN'](output_real, True, dis_update=True) + self.dis_losses['GAN'] = fake_loss + true_loss + total_loss = self.dis_losses['GAN'] * self.weights['GAN'] + self.dis_losses['total'] = total_loss + return total_loss + + def pre_process(self, data): + r"""Data pre-processing step for the pix2pixHD method. It takes a + dictionary as input where the dictionary contains a label field. The + label field is the concatenation of the segmentation mask and the + instance map. In this function, we will replace the instance map with + an edge map. We will also add a "instance_maps" field to the dictionary. + + Args: + data (dict): Input dictionary. + data['label']: Input label map where the last channel is the + instance map. + """ + data = to_cuda(data) + if self.cfg.trainer.model_average_config.enabled: + net_G = self.net_G.module.module + else: + net_G = self.net_G.module + if net_G.contain_instance_map: + inst_maps = data['label'][:, -1:] + edge_maps = get_edges(inst_maps) + data['instance_maps'] = inst_maps.clone() + data['label'][:, -1:] = edge_maps + return data + + def _pre_save_checkpoint(self): + r"""Implement the things you want to do before saving the checkpoints. + For example, you can compute the K-mean features (pix2pixHD) before + saving the model weights to the checkponts. + """ + if hasattr(self.cfg.gen, 'enc'): + if self.cfg.trainer.model_average_config.enabled: + net_E = self.net_G.module.averaged_model.encoder + else: + net_E = self.net_G.module.encoder + is_cityscapes = getattr(self.cfg.gen, 'is_cityscapes', False) + cluster_features(self.cfg, self.val_data_loader, + net_E, + self.pre_process, + is_cityscapes) + + def _compute_fid(self): + r"""We will compute FID for the regular model using the eval mode. + For the moving average model, we will use the eval mode. + """ + self.net_G.eval() + net_G_for_evaluation = \ + functools.partial(self.net_G, random_style=True) + regular_fid_path = self._get_save_path('regular_fid', 'npy') + regular_fid_value = compute_fid(regular_fid_path, + self.val_data_loader, + net_G_for_evaluation, + preprocess=self.pre_process) + print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format( + self.current_epoch, self.current_iteration, regular_fid_value)) + if self.cfg.trainer.model_average_config.enabled: + avg_net_G_for_evaluation = \ + functools.partial(self.net_G.module.averaged_model, + random_style=True) + fid_path = self._get_save_path('average_fid', 'npy') + fid_value = compute_fid(fid_path, self.val_data_loader, + avg_net_G_for_evaluation, + preprocess=self.pre_process) + print('Epoch {:05}, Iteration {:09}, FID {}'.format( + self.current_epoch, self.current_iteration, fid_value)) + self.net_G.float() + return regular_fid_value, fid_value + else: + self.net_G.float() + return regular_fid_value diff --git a/imaginaire/trainers/spade.py b/imaginaire/trainers/spade.py new file mode 100644 index 0000000000000000000000000000000000000000..a6132f20d019c18d3bd8670f00ed8b5a7a389e85 --- /dev/null +++ b/imaginaire/trainers/spade.py @@ -0,0 +1,282 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools +import math + +import torch +import torch.nn.functional as F + +from imaginaire.evaluation import compute_fid +from imaginaire.losses import (FeatureMatchingLoss, GANLoss, GaussianKLLoss, + PerceptualLoss) +from imaginaire.trainers.base import BaseTrainer +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.model_average import reset_batch_norm, \ + calibrate_batch_norm_momentum +from imaginaire.utils.misc import split_labels, to_device +from imaginaire.utils.visualization import tensor2label + + +class Trainer(BaseTrainer): + r"""Initialize SPADE trainer. + + Args: + cfg (Config): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, + cfg, + net_G, + net_D, + opt_G, + opt_D, + sch_G, + sch_D, + train_data_loader, + val_data_loader): + super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, + opt_D, sch_G, sch_D, + train_data_loader, val_data_loader) + if cfg.data.type == 'imaginaire.datasets.paired_videos': + self.video_mode = True + else: + self.video_mode = False + + def _init_loss(self, cfg): + r"""Initialize loss terms. + + Args: + cfg (obj): Global configuration. + """ + self.criteria['GAN'] = GANLoss(cfg.trainer.gan_mode) + self.weights['GAN'] = cfg.trainer.loss_weight.gan + # Setup the perceptual loss. Note that perceptual loss can run in + # fp16 mode for additional speed. We find that running on fp16 mode + # leads to improve training speed while maintaining the same accuracy. + if hasattr(cfg.trainer, 'perceptual_loss'): + self.criteria['Perceptual'] = \ + PerceptualLoss( + network=cfg.trainer.perceptual_loss.mode, + layers=cfg.trainer.perceptual_loss.layers, + weights=cfg.trainer.perceptual_loss.weights) + self.weights['Perceptual'] = cfg.trainer.loss_weight.perceptual + # Setup the feature matching loss. + self.criteria['FeatureMatching'] = FeatureMatchingLoss() + self.weights['FeatureMatching'] = \ + cfg.trainer.loss_weight.feature_matching + # Setup the Gaussian KL divergence loss. + self.criteria['GaussianKL'] = GaussianKLLoss() + self.weights['GaussianKL'] = cfg.trainer.loss_weight.kl + + def _start_of_iteration(self, data, current_iteration): + r"""Model specific custom start of iteration process. We will do two + things. First, put all the data to GPU. Second, we will resize the + input so that it becomes multiple of the factor for bug-free + convolutional operations. This factor is given by the yaml file. + E.g., base = getattr(self.net_G, 'base', 32) + + Args: + data (dict): The current batch. + current_iteration (int): The iteration number of the current batch. + """ + data = to_device(data, 'cuda') + data = self._resize_data(data) + return data + + def gen_forward(self, data): + r"""Compute the loss for SPADE generator. + + Args: + data (dict): Training data at the current iteration. + """ + net_G_output = self.net_G(data) + net_D_output = self.net_D(data, net_G_output) + + self._time_before_loss() + + output_fake = self._get_outputs(net_D_output, real=False) + self.gen_losses['GAN'] = self.criteria['GAN'](output_fake, True, dis_update=False) + + self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching']( + net_D_output['fake_features'], net_D_output['real_features']) + + if self.net_G_module.use_style_encoder: + self.gen_losses['GaussianKL'] = \ + self.criteria['GaussianKL'](net_G_output['mu'], + net_G_output['logvar']) + else: + self.gen_losses['GaussianKL'] = \ + self.gen_losses['GAN'].new_tensor([0]) + + if hasattr(self.cfg.trainer, 'perceptual_loss'): + self.gen_losses['Perceptual'] = self.criteria['Perceptual']( + net_G_output['fake_images'], data['images']) + + total_loss = self.gen_losses['GAN'].new_tensor([0]) + for key in self.criteria: + total_loss += self.gen_losses[key] * self.weights[key] + + self.gen_losses['total'] = total_loss + return total_loss + + def dis_forward(self, data): + r"""Compute the loss for SPADE discriminator. + + Args: + data (dict): Training data at the current iteration. + """ + with torch.no_grad(): + net_G_output = self.net_G(data) + net_G_output['fake_images'] = net_G_output['fake_images'].detach() + net_D_output = self.net_D(data, net_G_output) + + self._time_before_loss() + + output_fake = self._get_outputs(net_D_output, real=False) + output_real = self._get_outputs(net_D_output, real=True) + fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True) + true_loss = self.criteria['GAN'](output_real, True, dis_update=True) + self.dis_losses['GAN/fake'] = fake_loss + self.dis_losses['GAN/true'] = true_loss + self.dis_losses['GAN'] = fake_loss + true_loss + total_loss = self.dis_losses['GAN'] * self.weights['GAN'] + self.dis_losses['total'] = total_loss + return total_loss + + def _get_visualizations(self, data): + r"""Compute visualization image. We will first recalculate the batch + statistics for the moving average model. + + Args: + data (dict): The current batch. + """ + self.recalculate_batch_norm_statistics( + self.train_data_loader) + with torch.no_grad(): + label_lengths = self.train_data_loader.dataset.get_label_lengths() + labels = split_labels(data['label'], label_lengths) + # Get visualization of the segmentation mask. + vis_images = list() + vis_images.append(data['images']) + net_G_output = self.net_G(data, random_style=True) + # print(labels.keys()) + for key in labels.keys(): + if 'seg' in key: + segmaps = tensor2label(labels[key], label_lengths[key], output_normalized_tensor=True) + segmaps = torch.cat([x.unsqueeze(0) for x in segmaps], 0) + vis_images.append(segmaps) + if 'edge' in key: + edgemaps = torch.cat((labels[key], labels[key], labels[key]), 1) + vis_images.append(edgemaps) + + vis_images.append(net_G_output['fake_images']) + if self.cfg.trainer.model_average_config.enabled: + net_G_model_average_output = \ + self.net_G.module.averaged_model(data, random_style=True) + vis_images.append(net_G_model_average_output['fake_images']) + return vis_images + + def recalculate_batch_norm_statistics(self, data_loader): + r"""Update the statistics in the moving average model. + + Args: + data_loader (pytorch data loader): Data loader for estimating the + statistics. + """ + if not self.cfg.trainer.model_average_config.enabled: + return + model_average_iteration = \ + self.cfg.trainer.model_average_config.num_batch_norm_estimation_iterations + if model_average_iteration == 0: + return + with torch.no_grad(): + # Accumulate bn stats.. + self.net_G.module.averaged_model.train() + # Reset running stats. + self.net_G.module.averaged_model.apply(reset_batch_norm) + for cal_it, cal_data in enumerate(data_loader): + if cal_it >= model_average_iteration: + print('Done with {} iterations of updating batch norm ' + 'statistics'.format(model_average_iteration)) + break + # cal_data = to_device(cal_data, 'cuda') + cal_data = self._start_of_iteration(cal_data, 0) + # Averaging over all batches + self.net_G.module.averaged_model.apply( + calibrate_batch_norm_momentum) + self.net_G.module.averaged_model(cal_data) + + def write_metrics(self): + r"""If moving average model presents, we have two meters one for + regular FID and one for average FID. If no moving average model, + we just report average FID. + """ + if self.cfg.trainer.model_average_config.enabled: + regular_fid, average_fid = self._compute_fid() + metric_dict = {'FID/average': average_fid, 'FID/regular': regular_fid} + self._write_to_meters(metric_dict, self.metric_meters, reduce=False) + else: + regular_fid = self._compute_fid() + metric_dict = {'FID/regular': regular_fid} + self._write_to_meters(metric_dict, self.metric_meters, reduce=False) + self._flush_meters(self.metric_meters) + + def _compute_fid(self): + r"""We will compute FID for the regular model using the eval mode. + For the moving average model, we will use the eval mode. + """ + self.net_G.eval() + net_G_for_evaluation = \ + functools.partial(self.net_G, random_style=True) + regular_fid_path = self._get_save_path('regular_fid', 'npy') + preprocess = \ + functools.partial(self._start_of_iteration, current_iteration=0) + + regular_fid_value = compute_fid(regular_fid_path, + self.val_data_loader, + net_G_for_evaluation, + preprocess=preprocess) + print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format( + self.current_epoch, self.current_iteration, regular_fid_value)) + if self.cfg.trainer.model_average_config.enabled: + avg_net_G_for_evaluation = \ + functools.partial(self.net_G.module.averaged_model, + random_style=True) + fid_path = self._get_save_path('average_fid', 'npy') + fid_value = compute_fid(fid_path, self.val_data_loader, + avg_net_G_for_evaluation, + preprocess=preprocess) + print('Epoch {:05}, Iteration {:09}, FID {}'.format( + self.current_epoch, self.current_iteration, fid_value)) + self.net_G.float() + return regular_fid_value, fid_value + else: + self.net_G.float() + return regular_fid_value + + def _resize_data(self, data): + r"""Resize input label maps and images so that it can be properly + generated by the generator. + + Args: + data (dict): Input dictionary contains 'label' and 'image fields. + """ + base = getattr(self.net_G, 'base', 32) + sy = math.floor(data['label'].size()[2] * 1.0 // base) * base + sx = math.floor(data['label'].size()[3] * 1.0 // base) * base + data['label'] = F.interpolate( + data['label'], size=[sy, sx], mode='nearest') + if 'images' in data.keys(): + data['images'] = F.interpolate( + data['images'], size=[sy, sx], mode='bicubic') + return data diff --git a/imaginaire/trainers/unit.py b/imaginaire/trainers/unit.py new file mode 100644 index 0000000000000000000000000000000000000000..9b3a4e0926acd2affcd96c6a9a09463c5056951c --- /dev/null +++ b/imaginaire/trainers/unit.py @@ -0,0 +1,210 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch import nn + +from imaginaire.evaluation import compute_fid +from imaginaire.losses import GANLoss, PerceptualLoss # GaussianKLLoss +from imaginaire.trainers.base import BaseTrainer + + +class Trainer(BaseTrainer): + r"""Reimplementation of the UNIT (https://arxiv.org/abs/1703.00848) + algorithm. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, + train_data_loader, val_data_loader): + super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, + train_data_loader, val_data_loader) + self.best_fid_a = None + self.best_fid_b = None + + def _init_loss(self, cfg): + r"""Initialize loss terms. In UNIT, we have several loss terms + including the GAN loss, the image reconstruction loss, the cycle + reconstruction loss, and the gaussian kl loss. We also have an + optional perceptual loss. A user can choose to have the gradient + penalty loss too. + + Args: + cfg (obj): Global configuration. + """ + self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode) + # self.criteria['gaussian_kl'] = GaussianKLLoss() + self.criteria['image_recon'] = nn.L1Loss() + self.criteria['cycle_recon'] = nn.L1Loss() + if getattr(cfg.trainer.loss_weight, 'perceptual', 0) > 0: + self.criteria['perceptual'] = \ + PerceptualLoss(network=cfg.trainer.perceptual_mode, + layers=cfg.trainer.perceptual_layers) + + for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items(): + if loss_weight > 0: + self.weights[loss_name] = loss_weight + + def gen_forward(self, data): + r"""Compute the loss for UNIT generator. + + Args: + data (dict): Training data at the current iteration. + """ + cycle_recon = 'cycle_recon' in self.weights + perceptual = 'perceptual' in self.weights + net_G_output = self.net_G(data, cycle_recon=cycle_recon) + net_D_output = self.net_D(data, net_G_output, real=False) + + self._time_before_loss() + + # GAN loss + self.gen_losses['gan_a'] = self.criteria['gan']( + net_D_output['out_ba'], True, dis_update=False) + self.gen_losses['gan_b'] = self.criteria['gan']( + net_D_output['out_ab'], True, dis_update=False) + self.gen_losses['gan'] = \ + self.gen_losses['gan_a'] + self.gen_losses['gan_b'] + + # Perceptual loss + if perceptual: + self.gen_losses['perceptual_a'] = \ + self.criteria['perceptual'](net_G_output['images_ab'], + data['images_a']) + self.gen_losses['perceptual_b'] = \ + self.criteria['perceptual'](net_G_output['images_ba'], + data['images_b']) + self.gen_losses['perceptual'] = \ + self.gen_losses['perceptual_a'] + \ + self.gen_losses['perceptual_b'] + + # Image reconstruction loss + self.gen_losses['image_recon'] = \ + self.criteria['image_recon'](net_G_output['images_aa'], + data['images_a']) + \ + self.criteria['image_recon'](net_G_output['images_bb'], + data['images_b']) + + """ + # KL loss + self.gen_losses['gaussian_kl'] = \ + self.criteria['gaussian_kl'](net_G_output['content_mu_a']) + \ + self.criteria['gaussian_kl'](net_G_output['content_mu_b']) + \ + self.criteria['gaussian_kl'](net_G_output['content_mu_a_recon']) + \ + self.criteria['gaussian_kl'](net_G_output['content_mu_b_recon']) + """ + + # Cycle reconstruction loss + if cycle_recon: + self.gen_losses['cycle_recon_aba'] = \ + self.criteria['cycle_recon'](net_G_output['images_aba'], + data['images_a']) + self.gen_losses['cycle_recon_bab'] = \ + self.criteria['cycle_recon'](net_G_output['images_bab'], + data['images_b']) + self.gen_losses['cycle_recon'] = \ + self.gen_losses['cycle_recon_aba'] + \ + self.gen_losses['cycle_recon_bab'] + + # Compute total loss + total_loss = self._get_total_loss(gen_forward=True) + return total_loss + + def dis_forward(self, data): + r"""Compute the loss for UNIT discriminator. + + Args: + data (dict): Training data at the current iteration. + """ + with torch.no_grad(): + net_G_output = self.net_G(data, image_recon=False, + cycle_recon=False) + net_G_output['images_ba'].requires_grad = True + net_G_output['images_ab'].requires_grad = True + net_D_output = self.net_D(data, net_G_output) + + self._time_before_loss() + + # GAN loss. + self.dis_losses['gan_a'] = \ + self.criteria['gan'](net_D_output['out_a'], True) + \ + self.criteria['gan'](net_D_output['out_ba'], False) + self.dis_losses['gan_b'] = \ + self.criteria['gan'](net_D_output['out_b'], True) + \ + self.criteria['gan'](net_D_output['out_ab'], False) + self.dis_losses['gan'] = \ + self.dis_losses['gan_a'] + self.dis_losses['gan_b'] + + # Compute total loss + total_loss = self._get_total_loss(gen_forward=False) + return total_loss + + def _get_visualizations(self, data): + r"""Compute visualization image. + + Args: + data (dict): The current batch. + """ + if self.cfg.trainer.model_average_config.enabled: + net_G_for_evaluation = self.net_G.module.averaged_model + else: + net_G_for_evaluation = self.net_G + with torch.no_grad(): + net_G_output = net_G_for_evaluation(data) + vis_images = [data['images_a'], + data['images_b'], + net_G_output['images_aa'], + net_G_output['images_bb'], + net_G_output['images_ab'], + net_G_output['images_ba'], + net_G_output['images_aba'], + net_G_output['images_bab']] + return vis_images + + def write_metrics(self): + r"""Compute metrics and save them to tensorboard""" + cur_fid_a, cur_fid_b = self._compute_fid() + if self.best_fid_a is not None: + self.best_fid_a = min(self.best_fid_a, cur_fid_a) + else: + self.best_fid_a = cur_fid_a + if self.best_fid_b is not None: + self.best_fid_b = min(self.best_fid_b, cur_fid_b) + else: + self.best_fid_b = cur_fid_b + self._write_to_meters({'FID_a': cur_fid_a, + 'best_FID_a': self.best_fid_a, + 'FID_b': cur_fid_b, + 'best_FID_b': self.best_fid_b}, + self.metric_meters) + self._flush_meters(self.metric_meters) + + def _compute_fid(self): + r"""Compute FID for both domains. + """ + self.net_G.eval() + if self.cfg.trainer.model_average_config.enabled: + net_G_for_evaluation = self.net_G.module.averaged_model + else: + net_G_for_evaluation = self.net_G + fid_a_path = self._get_save_path('fid_a', 'npy') + fid_b_path = self._get_save_path('fid_b', 'npy') + fid_value_a = compute_fid(fid_a_path, self.val_data_loader, + net_G_for_evaluation, 'images_a', 'images_ba') + fid_value_b = compute_fid(fid_b_path, self.val_data_loader, + net_G_for_evaluation, 'images_b', 'images_ab') + print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format( + self.current_epoch, self.current_iteration, + fid_value_a, fid_value_b)) + return fid_value_a, fid_value_b diff --git a/imaginaire/trainers/vid2vid.py b/imaginaire/trainers/vid2vid.py new file mode 100644 index 0000000000000000000000000000000000000000..cd6b46fa64f6738c21636f785470e25120324b57 --- /dev/null +++ b/imaginaire/trainers/vid2vid.py @@ -0,0 +1,913 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os + +from torch.cuda.amp import autocast +import imageio +import numpy as np +import torch +from tqdm import tqdm + +from imaginaire.evaluation.fid import compute_fid +from imaginaire.losses import (FeatureMatchingLoss, FlowLoss, GANLoss, + PerceptualLoss) +from imaginaire.model_utils.fs_vid2vid import (concat_frames, detach, + get_fg_mask, + pre_process_densepose, resample) +from imaginaire.trainers.base import BaseTrainer +from imaginaire.utils.distributed import is_master +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.misc import get_nested_attr, split_labels, to_cuda +from imaginaire.utils.visualization import (tensor2flow, tensor2im, tensor2label) +from imaginaire.utils.visualization.pose import tensor2pose + + +class Trainer(BaseTrainer): + r"""Initialize vid2vid trainer. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, + train_data_loader, val_data_loader): + super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, + opt_D, sch_G, sch_D, + train_data_loader, val_data_loader) + # Below is for testing setting, the FID computation during training + # is just for getting a quick idea of the performance. It does not + # equal to the final performance evaluation. + # Below, we will determine how many videos that we want to do + # evaluation, and the length of each video. + # It is better to keep the number of videos to be multiple of 8 so + # that all the GPUs in a node will contribute equally to the + # evaluation. None of them is idol. + self.sample_size = ( + getattr(cfg.trainer, 'num_videos_to_test', 64), + getattr(cfg.trainer, 'num_frames_per_video', 10) + ) + + self.sequence_length = 1 + if not self.is_inference: + self.train_dataset = self.train_data_loader.dataset + self.sequence_length_max = \ + min(getattr(cfg.data.train, 'max_sequence_length', 100), + self.train_dataset.sequence_length_max) + self.Tensor = torch.cuda.FloatTensor + self.has_fg = getattr(cfg.data, 'has_foreground', False) + + self.net_G_output = self.data_prev = None + self.net_G_module = self.net_G.module + if self.cfg.trainer.model_average_config.enabled: + self.net_G_module = self.net_G_module.module + + def _assign_criteria(self, name, criterion, weight): + r"""Assign training loss terms. + + Args: + name (str): Loss name + criterion (obj): Loss object. + weight (float): Loss weight. It should be non-negative. + """ + self.criteria[name] = criterion + self.weights[name] = weight + + def _init_loss(self, cfg): + r"""Initialize training loss terms. In vid2vid, in addition to + the GAN loss, feature matching loss, and perceptual loss used in + pix2pixHD, we also add temporal GAN (and feature matching) loss, + and flow warping loss. Optionally, we can also add an additional + face discriminator for the face region. + + Args: + cfg (obj): Global configuration. + """ + self.criteria = dict() + self.weights = dict() + trainer_cfg = cfg.trainer + loss_weight = cfg.trainer.loss_weight + + # GAN loss and feature matching loss. + self._assign_criteria('GAN', + GANLoss(trainer_cfg.gan_mode), + loss_weight.gan) + self._assign_criteria('FeatureMatching', + FeatureMatchingLoss(), + loss_weight.feature_matching) + + # Perceptual loss. + perceptual_loss = cfg.trainer.perceptual_loss + self._assign_criteria('Perceptual', + PerceptualLoss( + network=perceptual_loss.mode, + layers=perceptual_loss.layers, + weights=perceptual_loss.weights, + num_scales=getattr(perceptual_loss, + 'num_scales', 1)), + loss_weight.perceptual) + + # L1 Loss. + if getattr(loss_weight, 'L1', 0) > 0: + self._assign_criteria('L1', torch.nn.L1Loss(), loss_weight.L1) + + # Whether to add an additional discriminator for specific regions. + self.add_dis_cfg = getattr(self.cfg.dis, 'additional_discriminators', + None) + if self.add_dis_cfg is not None: + for name in self.add_dis_cfg: + add_dis_cfg = self.add_dis_cfg[name] + self.weights['GAN_' + name] = add_dis_cfg.loss_weight + self.weights['FeatureMatching_' + name] = \ + loss_weight.feature_matching + + # Temporal GAN loss. + self.num_temporal_scales = get_nested_attr(self.cfg.dis, + 'temporal.num_scales', 0) + for s in range(self.num_temporal_scales): + self.weights['GAN_T%d' % s] = loss_weight.temporal_gan + self.weights['FeatureMatching_T%d' % s] = \ + loss_weight.feature_matching + + # Flow loss. It consists of three parts: L1 loss compared to GT, + # warping loss when used to warp images, and loss on the occlusion mask. + self.use_flow = hasattr(cfg.gen, 'flow') + if self.use_flow: + self.criteria['Flow'] = FlowLoss(cfg) + self.weights['Flow'] = self.weights['Flow_L1'] = \ + self.weights['Flow_Warp'] = \ + self.weights['Flow_Mask'] = loss_weight.flow + + # Other custom losses. + self._define_custom_losses() + + def _define_custom_losses(self): + r"""All other custom losses are defined here.""" + pass + + def _start_of_epoch(self, current_epoch): + r"""Things to do before an epoch. When current_epoch is smaller than + $(single_frame_epoch), we only train a single frame and the generator is + just an image generator. After that, we start doing temporal training + and train multiple frames. We will double the number of training frames + every $(num_epochs_temporal_step) epochs. + + Args: + current_epoch (int): Current number of epoch. + """ + cfg = self.cfg + # Only generates one frame at the beginning of training + if current_epoch < cfg.single_frame_epoch: + self.train_dataset.sequence_length = 1 + # Then add the temporal network to generator, and train multiple frames. + elif current_epoch == cfg.single_frame_epoch: + self.init_temporal_network() + + # Double the length of training sequence every few epochs. + temp_epoch = current_epoch - cfg.single_frame_epoch + if temp_epoch > 0: + sequence_length = \ + cfg.data.train.initial_sequence_length * \ + (2 ** (temp_epoch // cfg.num_epochs_temporal_step)) + sequence_length = min(sequence_length, self.sequence_length_max) + if sequence_length > self.sequence_length: + self.sequence_length = sequence_length + self.train_dataset.set_sequence_length(sequence_length) + print('------- Updating sequence length to %d -------' % + sequence_length) + + def init_temporal_network(self): + r"""Initialize temporal training when beginning to train multiple + frames. Set the sequence length to $(initial_sequence_length). + """ + self.tensorboard_init = False + # Update training sequence length. + self.sequence_length = self.cfg.data.train.initial_sequence_length + if not self.is_inference: + self.train_dataset.set_sequence_length(self.sequence_length) + print('------ Now start training %d frames -------' % + self.sequence_length) + + def _start_of_iteration(self, data, current_iteration): + r"""Things to do before an iteration. + + Args: + data (dict): Data used for the current iteration. + current_iteration (int): Current number of iteration. + """ + data = self.pre_process(data) + return to_cuda(data) + + def pre_process(self, data): + r"""Do any data pre-processing here. + + Args: + data (dict): Data used for the current iteration. + """ + data_cfg = self.cfg.data + if hasattr(data_cfg, 'for_pose_dataset') and \ + ('pose_maps-densepose' in data_cfg.input_labels): + pose_cfg = data_cfg.for_pose_dataset + data['label'] = pre_process_densepose(pose_cfg, data['label'], + self.is_inference) + return data + + def post_process(self, data, net_G_output): + r"""Do any postprocessing of the data / output here. + + Args: + data (dict): Training data at the current iteration. + net_G_output (dict): Output of the generator. + """ + return data, net_G_output + + def gen_update(self, data): + r"""Update the vid2vid generator. We update in the fashion of + dis_update (frame 1), gen_update (frame 1), + dis_update (frame 2), gen_update (frame 2), ... in each iteration. + + Args: + data (dict): Training data at the current iteration. + """ + # Whether to reuse generator output for both gen_update and + # dis_update. It saves time but consumes a bit more memory. + reuse_gen_output = getattr(self.cfg.trainer, 'reuse_gen_output', True) + + past_frames = [None, None] + net_G_output = None + data_prev = None + for t in range(self.sequence_length): + data_t = self.get_data_t(data, net_G_output, data_prev, t) + data_prev = data_t + + # Discriminator update. + if reuse_gen_output: + net_G_output = self.net_G(data_t) + else: + with torch.no_grad(): + net_G_output = self.net_G(data_t) + data_t, net_G_output = self.post_process(data_t, net_G_output) + + # Get losses and update D if image generated by network in training. + if 'fake_images_source' not in net_G_output: + net_G_output['fake_images_source'] = 'in_training' + if net_G_output['fake_images_source'] != 'pretrained': + net_D_output, _ = self.net_D(data_t, detach(net_G_output), past_frames) + self.get_dis_losses(net_D_output) + + # Generator update. + if not reuse_gen_output: + net_G_output = self.net_G(data_t) + data_t, net_G_output = self.post_process(data_t, net_G_output) + + # Get losses and update G if image generated by network in training. + if 'fake_images_source' not in net_G_output: + net_G_output['fake_images_source'] = 'in_training' + if net_G_output['fake_images_source'] != 'pretrained': + net_D_output, past_frames = \ + self.net_D(data_t, net_G_output, past_frames) + self.get_gen_losses(data_t, net_G_output, net_D_output) + + # update average + if self.cfg.trainer.model_average_config.enabled: + self.net_G.module.update_average() + + def dis_update(self, data): + r"""The update is already done in gen_update. + + Args: + data (dict): Training data at the current iteration. + """ + pass + + def reset(self): + r"""Reset the trainer (for inference) at the beginning of a sequence. + """ + # print('Resetting trainer.') + self.net_G_output = self.data_prev = None + self.t = 0 + + self.test_in_model_average_mode = getattr( + self, 'test_in_model_average_mode', self.cfg.trainer.model_average_config.enabled) + if self.test_in_model_average_mode: + net_G_module = self.net_G.module.averaged_model + else: + net_G_module = self.net_G.module + if hasattr(net_G_module, 'reset'): + net_G_module.reset() + + def create_sequence_output_dir(self, output_dir, key): + r"""Create output subdir for this sequence. + + Args: + output_dir (str): Root output dir. + key (str): LMDB key which contains sequence name and file name. + Returns: + output_dir (str): Output subdir for this sequence. + seq_name (str): Name of this sequence. + """ + seq_dir = '/'.join(key.split('/')[:-1]) + output_dir = os.path.join(output_dir, seq_dir) + os.makedirs(output_dir, exist_ok=True) + seq_name = seq_dir.replace('/', '-') + return output_dir, seq_name + + def test(self, test_data_loader, root_output_dir, inference_args): + r"""Run inference on all sequences. + + Args: + test_data_loader (object): Test data loader. + root_output_dir (str): Location to dump outputs. + inference_args (optional): Optional args. + """ + + # Go over all sequences. + loader = test_data_loader + num_inference_sequences = loader.dataset.num_inference_sequences() + for sequence_idx in range(num_inference_sequences): + loader.dataset.set_inference_sequence_idx(sequence_idx) + print('Seq id: %d, Seq length: %d' % + (sequence_idx + 1, len(loader))) + + # Reset model at start of new inference sequence. + self.reset() + self.sequence_length = len(loader) + + # Go over all frames of this sequence. + video = [] + for idx, data in enumerate(tqdm(loader)): + key = data['key']['images'][0][0] + filename = key.split('/')[-1] + + # Create output dir for this sequence. + if idx == 0: + output_dir, seq_name = \ + self.create_sequence_output_dir(root_output_dir, key) + video_path = os.path.join(output_dir, '..', seq_name) + + # Get output and save images. + data['img_name'] = filename + data = self.start_of_iteration(data, current_iteration=-1) + output = self.test_single(data, output_dir, inference_args) + video.append(output) + + # Save output as mp4. + imageio.mimsave(video_path + '.mp4', video, fps=15) + + def test_single(self, data, output_dir=None, inference_args=None): + r"""The inference function. If output_dir exists, also save the + output image. + Args: + data (dict): Training data at the current iteration. + output_dir (str): Save image directory. + inference_args (obj): Inference args. + """ + if getattr(inference_args, 'finetune', False): + if not getattr(self, 'has_finetuned', False): + self.finetune(data, inference_args) + + net_G = self.net_G + if self.test_in_model_average_mode: + net_G = net_G.module.averaged_model + net_G.eval() + + data_t = self.get_data_t(data, self.net_G_output, self.data_prev, 0) + if self.is_inference or self.sequence_length > 1: + self.data_prev = data_t + + # Generator forward. + with torch.no_grad(): + self.net_G_output = net_G(data_t) + + if output_dir is None: + return self.net_G_output + + save_fake_only = getattr(inference_args, 'save_fake_only', False) + if save_fake_only: + image_grid = tensor2im(self.net_G_output['fake_images'])[0] + else: + vis_images = self.get_test_output_images(data) + image_grid = np.hstack([np.vstack(im) for im in + vis_images if im is not None]) + if 'img_name' in data: + save_name = data['img_name'].split('.')[0] + '.jpg' + else: + save_name = '%04d.jpg' % self.t + output_filename = os.path.join(output_dir, save_name) + os.makedirs(output_dir, exist_ok=True) + imageio.imwrite(output_filename, image_grid) + self.t += 1 + + return image_grid + + def get_test_output_images(self, data): + r"""Get the visualization output of test function. + + Args: + data (dict): Training data at the current iteration. + """ + vis_images = [ + self.visualize_label(data['label'][:, -1]), + tensor2im(data['images'][:, -1]), + tensor2im(self.net_G_output['fake_images']), + ] + return vis_images + + def gen_frames(self, data, use_model_average=False): + r"""Generate a sequence of frames given a sequence of data. + + Args: + data (dict): Training data at the current iteration. + use_model_average (bool): Whether to use model average + for update or not. + """ + net_G_output = None # Previous generator output. + data_prev = None # Previous data. + if use_model_average: + net_G = self.net_G.module.averaged_model + else: + net_G = self.net_G + + # Iterate through the length of sequence. + all_info = {'inputs': [], 'outputs': []} + for t in range(self.sequence_length): + # Get the data at the current time frame. + data_t = self.get_data_t(data, net_G_output, data_prev, t) + data_prev = data_t + + # Generator forward. + with torch.no_grad(): + net_G_output = net_G(data_t) + + # Do any postprocessing if necessary. + data_t, net_G_output = self.post_process(data_t, net_G_output) + + if t == 0: + # Get the output at beginning of sequence for visualization. + first_net_G_output = net_G_output + + all_info['inputs'].append(data_t) + all_info['outputs'].append(net_G_output) + + return first_net_G_output, net_G_output, all_info + + def get_gen_losses(self, data_t, net_G_output, net_D_output): + r"""Compute generator losses. + + Args: + data_t (dict): Training data at the current time t. + net_G_output (dict): Output of the generator. + net_D_output (dict): Output of the discriminator. + """ + update_finished = False + while not update_finished: + with autocast(enabled=self.cfg.trainer.amp_config.enabled): + # Individual frame GAN loss and feature matching loss. + self.gen_losses['GAN'], self.gen_losses['FeatureMatching'] = \ + self.compute_gan_losses(net_D_output['indv'], + dis_update=False) + + # Perceptual loss. + self.gen_losses['Perceptual'] = self.criteria['Perceptual']( + net_G_output['fake_images'], data_t['image']) + + # L1 loss. + if getattr(self.cfg.trainer.loss_weight, 'L1', 0) > 0: + self.gen_losses['L1'] = self.criteria['L1']( + net_G_output['fake_images'], data_t['image']) + + # Raw (hallucinated) output image losses (GAN and perceptual). + if 'raw' in net_D_output: + raw_GAN_losses = self.compute_gan_losses( + net_D_output['raw'], dis_update=False + ) + fg_mask = get_fg_mask(data_t['label'], self.has_fg) + raw_perceptual_loss = self.criteria['Perceptual']( + net_G_output['fake_raw_images'] * fg_mask, + data_t['image'] * fg_mask) + self.gen_losses['GAN'] += raw_GAN_losses[0] + self.gen_losses['FeatureMatching'] += raw_GAN_losses[1] + self.gen_losses['Perceptual'] += raw_perceptual_loss + + # Additional discriminator losses. + if self.add_dis_cfg is not None: + for name in self.add_dis_cfg: + (self.gen_losses['GAN_' + name], + self.gen_losses['FeatureMatching_' + name]) = \ + self.compute_gan_losses(net_D_output[name], + dis_update=False) + + # Flow and mask loss. + if self.use_flow: + (self.gen_losses['Flow_L1'], self.gen_losses['Flow_Warp'], + self.gen_losses['Flow_Mask']) = self.criteria['Flow']( + data_t, net_G_output, self.current_epoch) + + # Temporal GAN loss and feature matching loss. + if self.cfg.trainer.loss_weight.temporal_gan > 0: + if self.sequence_length > 1: + for s in range(self.num_temporal_scales): + loss_GAN, loss_FM = self.compute_gan_losses( + net_D_output['temporal_%d' % s], + dis_update=False + ) + self.gen_losses['GAN_T%d' % s] = loss_GAN + self.gen_losses['FeatureMatching_T%d' % s] = loss_FM + + # Other custom losses. + self._get_custom_gen_losses(data_t, net_G_output, net_D_output) + + # Sum all losses together. + total_loss = self.Tensor(1).fill_(0) + for key in self.gen_losses: + if key != 'total': + total_loss += self.gen_losses[key] * self.weights[key] + self.gen_losses['total'] = total_loss + + # Zero-grad and backpropagate the loss. + self.opt_G.zero_grad(set_to_none=True) + self.scaler_G.scale(total_loss).backward() + + # Optionally clip gradient norm. + if hasattr(self.cfg.gen_opt, 'clip_grad_norm'): + self.scaler_G.unscale_(self.opt_G) + total_norm = torch.nn.utils.clip_grad_norm_( + self.net_G_module.parameters(), + self.cfg.gen_opt.clip_grad_norm + ) + if torch.isfinite(total_norm) and \ + total_norm > self.cfg.gen_opt.clip_grad_norm: + print(f"Gradient norm of the generator ({total_norm}) " + f"too large, clipping it to " + f"{self.cfg.gen_opt.clip_grad_norm}.") + + # Perform an optimizer step. + self.scaler_G.step(self.opt_G) + self.scaler_G.update() + # Whether the step above was skipped. + if self.last_step_count_G == self.opt_G._step_count: + print("Generator overflowed!") + else: + self.last_step_count_G = self.opt_G._step_count + update_finished = True + + def _get_custom_gen_losses(self, data_t, net_G_output, net_D_output): + r"""All other custom generator losses go here. + + Args: + data_t (dict): Training data at the current time t. + net_G_output (dict): Output of the generator. + net_D_output (dict): Output of the discriminator. + """ + pass + + def get_dis_losses(self, net_D_output): + r"""Compute discriminator losses. + + Args: + net_D_output (dict): Output of the discriminator. + """ + update_finished = False + while not update_finished: + with autocast(enabled=self.cfg.trainer.amp_config.enabled): + # Individual frame GAN loss. + self.dis_losses['GAN'] = self.compute_gan_losses( + net_D_output['indv'], dis_update=True + ) + + # Raw (hallucinated) output image GAN loss. + if 'raw' in net_D_output: + raw_loss = self.compute_gan_losses(net_D_output['raw'], + dis_update=True) + self.dis_losses['GAN'] += raw_loss + + # Additional GAN loss. + if self.add_dis_cfg is not None: + for name in self.add_dis_cfg: + self.dis_losses['GAN_' + name] = \ + self.compute_gan_losses(net_D_output[name], + dis_update=True) + + # Temporal GAN loss. + if self.cfg.trainer.loss_weight.temporal_gan > 0: + if self.sequence_length > 1: + for s in range(self.num_temporal_scales): + self.dis_losses['GAN_T%d' % s] = \ + self.compute_gan_losses( + net_D_output['temporal_%d' % s], + dis_update=True + ) + + # Other custom losses. + self._get_custom_dis_losses(net_D_output) + + # Sum all losses together. + total_loss = self.Tensor(1).fill_(0) + for key in self.dis_losses: + if key != 'total': + total_loss += self.dis_losses[key] * self.weights[key] + self.dis_losses['total'] = total_loss + + # Zero-grad and backpropagate the loss. + self.opt_D.zero_grad(set_to_none=True) + self._time_before_backward() + self.scaler_D.scale(total_loss).backward() + + # Optionally clip gradient norm. + if hasattr(self.cfg.dis_opt, 'clip_grad_norm'): + self.scaler_D.unscale_(self.opt_D) + total_norm = torch.nn.utils.clip_grad_norm_( + self.net_D.parameters(), self.cfg.dis_opt.clip_grad_norm + ) + if torch.isfinite(total_norm) and \ + total_norm > self.cfg.dis_opt.clip_grad_norm: + print(f"Gradient norm of the discriminator ({total_norm}) " + f"too large, clipping it to " + f"{self.cfg.dis_opt.clip_grad_norm}.") + + # Perform an optimizer step. + self._time_before_step() + self.scaler_D.step(self.opt_D) + self.scaler_D.update() + # Whether the step above was skipped. + if self.last_step_count_D == self.opt_D._step_count: + print("Discriminator overflowed!") + else: + self.last_step_count_D = self.opt_D._step_count + update_finished = True + + def _get_custom_dis_losses(self, net_D_output): + r"""All other custom losses go here. + + Args: + net_D_output (dict): Output of the discriminator. + """ + pass + + def compute_gan_losses(self, net_D_output, dis_update): + r"""Compute GAN loss and feature matching loss. + + Args: + net_D_output (dict): Output of the discriminator. + dis_update (bool): Whether to update discriminator. + """ + if net_D_output['pred_fake'] is None: + return self.Tensor(1).fill_(0) if dis_update else [ + self.Tensor(1).fill_(0), self.Tensor(1).fill_(0)] + if dis_update: + # Get the GAN loss for real/fake outputs. + GAN_loss = \ + self.criteria['GAN'](net_D_output['pred_fake']['output'], False, + dis_update=True) + \ + self.criteria['GAN'](net_D_output['pred_real']['output'], True, + dis_update=True) + return GAN_loss + else: + # Get the GAN loss and feature matching loss for fake output. + GAN_loss = self.criteria['GAN']( + net_D_output['pred_fake']['output'], True, dis_update=False) + + FM_loss = self.criteria['FeatureMatching']( + net_D_output['pred_fake']['features'], + net_D_output['pred_real']['features']) + return GAN_loss, FM_loss + + def get_data_t(self, data, net_G_output, data_prev, t): + r"""Get data at current time frame given the sequence of data. + + Args: + data (dict): Training data for current iteration. + net_G_output (dict): Output of the generator (for previous frame). + data_prev (dict): Data for previous frame. + t (int): Current time. + """ + label = data['label'][:, t] + image = data['images'][:, t] + + if data_prev is not None: + # Concat previous labels/fake images to the ones before. + num_frames_G = self.cfg.data.num_frames_G + prev_labels = concat_frames(data_prev['prev_labels'], + data_prev['label'], num_frames_G - 1) + prev_images = concat_frames( + data_prev['prev_images'], + net_G_output['fake_images'].detach(), num_frames_G - 1) + else: + prev_labels = prev_images = None + + data_t = dict() + data_t['label'] = label + data_t['image'] = image + data_t['prev_labels'] = prev_labels + data_t['prev_images'] = prev_images + data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None + return data_t + + def _end_of_iteration(self, data, current_epoch, current_iteration): + r"""Print the errors to console.""" + if not torch.distributed.is_initialized(): + if current_iteration % self.cfg.logging_iter == 0: + message = '(epoch: %d, iters: %d) ' % (current_epoch, + current_iteration) + for k, v in self.gen_losses.items(): + if k != 'total': + message += '%s: %.3f, ' % (k, v) + message += '\n' + for k, v in self.dis_losses.items(): + if k != 'total': + message += '%s: %.3f, ' % (k, v) + print(message) + + def write_metrics(self): + r"""If moving average model presents, we have two meters one for + regular FID and one for average FID. If no moving average model, + we just report average FID. + """ + if self.cfg.trainer.model_average_config.enabled: + regular_fid, average_fid = self._compute_fid() + if regular_fid is None or average_fid is None: + return + metric_dict = {'FID/average': average_fid, 'FID/regular': regular_fid} + self._write_to_meters(metric_dict, self.metric_meters, reduce=False) + else: + regular_fid = self._compute_fid() + if regular_fid is None: + return + metric_dict = {'FID/regular': regular_fid} + self._write_to_meters(metric_dict, self.metric_meters, reduce=False) + self._flush_meters(self.metric_meters) + + def _compute_fid(self): + r"""Compute FID values.""" + self.net_G.eval() + self.net_G_output = None + # Due to complicated video evaluation procedure we are using, we will + # pass the trainer to the evaluation code instead of the + # generator network. + # net_G_for_evaluation = self.net_G + trainer = self + self.test_in_model_average_mode = False + regular_fid_path = self._get_save_path('regular_fid', 'npy') + few_shot = True if 'few_shot' in self.cfg.data.type else False + regular_fid_value = compute_fid(regular_fid_path, self.val_data_loader, + trainer, + sample_size=self.sample_size, + is_video=True, few_shot_video=few_shot) + print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format( + self.current_epoch, self.current_iteration, regular_fid_value)) + if self.cfg.trainer.model_average_config.enabled: + # Due to complicated video evaluation procedure we are using, + # we will pass the trainer to the evaluation code instead of the + # generator network. + # avg_net_G_for_evaluation = self.net_G.module.averaged_model + trainer_avg_mode = self + self.test_in_model_average_mode = True + # The above flag will be reset after computing FID. + fid_path = self._get_save_path('average_fid', 'npy') + few_shot = True if 'few_shot' in self.cfg.data.type else False + fid_value = compute_fid(fid_path, self.val_data_loader, + trainer_avg_mode, + sample_size=self.sample_size, + is_video=True, few_shot_video=few_shot) + print('Epoch {:05}, Iteration {:09}, Average FID {}'.format( + self.current_epoch, self.current_iteration, fid_value)) + self.net_G.float() + return regular_fid_value, fid_value + else: + self.net_G.float() + return regular_fid_value + + def visualize_label(self, label): + r"""Visualize the input label when saving to image. + + Args: + label (tensor): Input label tensor. + """ + cfgdata = self.cfg.data + if hasattr(cfgdata, 'for_pose_dataset'): + label = tensor2pose(self.cfg, label) + elif hasattr(cfgdata, 'input_labels') and \ + 'seg_maps' in cfgdata.input_labels: + for input_type in cfgdata.input_types: + if 'seg_maps' in input_type: + num_labels = cfgdata.one_hot_num_classes.seg_maps + label = tensor2label(label, num_labels) + elif getattr(cfgdata, 'label_channels', 1) > 3: + label = tensor2im(label.sum(1, keepdim=True)) + else: + label = tensor2im(label) + return label + + def save_image(self, path, data): + r"""Save the output images to path. + Note when the generate_raw_output is FALSE. Then, + first_net_G_output['fake_raw_images'] is None and will not be displayed. + In model average mode, we will plot the flow visualization twice. + Args: + path (str): Save path. + data (dict): Training data for current iteration. + """ + self.net_G.eval() + if self.cfg.trainer.model_average_config.enabled: + self.net_G.module.averaged_model.eval() + self.net_G_output = None + with torch.no_grad(): + first_net_G_output, net_G_output, all_info = self.gen_frames(data) + if self.cfg.trainer.model_average_config.enabled: + first_net_G_output_avg, net_G_output_avg, _ = self.gen_frames( + data, use_model_average=True) + + # Visualize labels. + label_lengths = self.train_data_loader.dataset.get_label_lengths() + labels = split_labels(data['label'], label_lengths) + vis_labels_start, vis_labels_end = [], [] + for key, value in labels.items(): + if key == 'seg_maps': + vis_labels_start.append(self.visualize_label(value[:, -1])) + vis_labels_end.append(self.visualize_label(value[:, 0])) + else: + vis_labels_start.append(tensor2im(value[:, -1])) + vis_labels_end.append(tensor2im(value[:, 0])) + + if is_master(): + vis_images = [ + *vis_labels_start, + tensor2im(data['images'][:, -1]), + tensor2im(net_G_output['fake_images']), + tensor2im(net_G_output['fake_raw_images'])] + if self.cfg.trainer.model_average_config.enabled: + vis_images += [ + tensor2im(net_G_output_avg['fake_images']), + tensor2im(net_G_output_avg['fake_raw_images'])] + + if self.sequence_length > 1: + vis_images_first = [ + *vis_labels_end, + tensor2im(data['images'][:, 0]), + tensor2im(first_net_G_output['fake_images']), + tensor2im(first_net_G_output['fake_raw_images']) + ] + if self.cfg.trainer.model_average_config.enabled: + vis_images_first += [ + tensor2im(first_net_G_output_avg['fake_images']), + tensor2im(first_net_G_output_avg['fake_raw_images'])] + + if self.use_flow: + flow_gt, conf_gt = self.criteria['Flow'].flowNet( + data['images'][:, -1], data['images'][:, -2]) + warped_image_gt = resample(data['images'][:, -1], flow_gt) + vis_images_first += [ + tensor2flow(flow_gt), + tensor2im(conf_gt, normalize=False), + tensor2im(warped_image_gt), + ] + vis_images += [ + tensor2flow(net_G_output['fake_flow_maps']), + tensor2im(net_G_output['fake_occlusion_masks'], + normalize=False), + tensor2im(net_G_output['warped_images']), + ] + if self.cfg.trainer.model_average_config.enabled: + vis_images_first += [ + tensor2flow(flow_gt), + tensor2im(conf_gt, normalize=False), + tensor2im(warped_image_gt), + ] + vis_images += [ + tensor2flow(net_G_output_avg['fake_flow_maps']), + tensor2im(net_G_output_avg['fake_occlusion_masks'], + normalize=False), + tensor2im(net_G_output_avg['warped_images'])] + + vis_images = [[np.vstack((im_first, im)) + for im_first, im in zip(imgs_first, imgs)] + for imgs_first, imgs in zip(vis_images_first, + vis_images) + if imgs is not None] + + image_grid = np.hstack([np.vstack(im) for im in + vis_images if im is not None]) + + print('Save output images to {}'.format(path)) + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.imwrite(path, image_grid) + + # Gather all outputs for dumping into video. + if self.sequence_length > 1: + output_images = [] + for item in all_info['outputs']: + output_images.append(tensor2im(item['fake_images'])[0]) + + imageio.mimwrite(os.path.splitext(path)[0] + '.mp4', + output_images, fps=2, macro_block_size=None) + + self.net_G.float() diff --git a/imaginaire/trainers/wc_vid2vid.py b/imaginaire/trainers/wc_vid2vid.py new file mode 100644 index 0000000000000000000000000000000000000000..78b3007cd1767ab3df3b43524747d858b1d3c454 --- /dev/null +++ b/imaginaire/trainers/wc_vid2vid.py @@ -0,0 +1,503 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os +import time + +import imageio +import numpy as np +import torch +from tqdm import tqdm + +from imaginaire.losses import MaskedL1Loss +from imaginaire.model_utils.fs_vid2vid import concat_frames, resample +from imaginaire.trainers.vid2vid import Trainer as Vid2VidTrainer +from imaginaire.utils.distributed import is_master +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.misc import split_labels, to_cuda +from imaginaire.utils.visualization import tensor2flow, tensor2im + + +class Trainer(Vid2VidTrainer): + r"""Initialize world consistent vid2vid trainer. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network. + net_D (obj): Discriminator network. + opt_G (obj): Optimizer for the generator network. + opt_D (obj): Optimizer for the discriminator network. + sch_G (obj): Scheduler for the generator optimizer. + sch_D (obj): Scheduler for the discriminator optimizer. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + """ + + def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, + train_data_loader, val_data_loader): + super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, + opt_D, sch_G, sch_D, + train_data_loader, val_data_loader) + self.guidance_start_after = getattr(cfg.gen.guidance, 'start_from', 0) + self.train_data_loader = train_data_loader + + def _define_custom_losses(self): + r"""All other custom losses are defined here.""" + # Setup the guidance loss. + self.criteria['Guidance'] = MaskedL1Loss(normalize_over_valid=True) + self.weights['Guidance'] = self.cfg.trainer.loss_weight.guidance + + def start_of_iteration(self, data, current_iteration): + r"""Things to do before an iteration. + + Args: + data (dict): Data used for the current iteration. + current_iteration (int): Current iteration number. + """ + self.net_G_module.reset_renderer(is_flipped_input=data['is_flipped']) + # Keep unprojections on cpu to prevent unnecessary transfer. + unprojections = data.pop('unprojections') + data = to_cuda(data) + data['unprojections'] = unprojections + + self.current_iteration = current_iteration + if not self.is_inference: + self.net_D.train() + self.net_G.train() + self.start_iteration_time = time.time() + return data + + def reset(self): + r"""Reset the trainer (for inference) at the beginning of a sequence.""" + # Inference time. + self.net_G_module.reset_renderer(is_flipped_input=False) + + # print('Resetting trainer.') + self.net_G_output = self.data_prev = None + self.t = 0 + + test_in_model_average_mode = getattr( + self, 'test_in_model_average_mode', False) + if test_in_model_average_mode: + if hasattr(self.net_G.module.averaged_model, 'reset'): + self.net_G.module.averaged_model.reset() + else: + if hasattr(self.net_G.module, 'reset'): + self.net_G.module.reset() + + def create_sequence_output_dir(self, output_dir, key): + r"""Create output subdir for this sequence. + + Args: + output_dir (str): Root output dir. + key (str): LMDB key which contains sequence name and file name. + Returns: + output_dir (str): Output subdir for this sequence. + seq_name (str): Name of this sequence. + """ + seq_dir = '/'.join(key.split('/')[:-1]) + output_dir = os.path.join(output_dir, seq_dir) + os.makedirs(output_dir, exist_ok=True) + os.makedirs(output_dir + '/all', exist_ok=True) + os.makedirs(output_dir + '/fake', exist_ok=True) + seq_name = seq_dir.replace('/', '-') + return output_dir, seq_name + + def test(self, test_data_loader, root_output_dir, inference_args): + r"""Run inference on all sequences. + + Args: + test_data_loader (object): Test data loader. + root_output_dir (str): Location to dump outputs. + inference_args (optional): Optional args. + """ + + # Go over all sequences. + loader = test_data_loader + num_inference_sequences = loader.dataset.num_inference_sequences() + for sequence_idx in range(num_inference_sequences): + loader.dataset.set_inference_sequence_idx(sequence_idx) + print('Seq id: %d, Seq length: %d' % + (sequence_idx + 1, len(loader))) + + # Reset model at start of new inference sequence. + self.reset() + self.sequence_length = len(loader) + + # Go over all frames of this sequence. + video = [] + for idx, data in enumerate(tqdm(loader)): + key = data['key']['images'][0][0] + filename = key.split('/')[-1] + + # Create output dir for this sequence. + if idx == 0: + output_dir, seq_name = \ + self.create_sequence_output_dir(root_output_dir, key) + video_path = os.path.join(output_dir, '..', seq_name) + + # Get output, and save all vis to all/. + data['img_name'] = filename + data = to_cuda(data) + output = self.test_single(data, output_dir=output_dir + '/all') + + # Dump just the fake image here. + fake = tensor2im(output['fake_images'])[0] + video.append(fake) + imageio.imsave(output_dir + '/fake/%s.jpg' % (filename), fake) + + # Save as mp4 and gif. + imageio.mimsave(video_path + '.mp4', video, fps=15) + + def test_single(self, data, output_dir=None, save_fake_only=False): + r"""The inference function. If output_dir exists, also save the + output image. + + Args: + data (dict): Training data at the current iteration. + output_dir (str): Save image directory. + save_fake_only (bool): Only save the fake output image. + """ + if self.is_inference and self.cfg.trainer.model_average_config.enabled: + test_in_model_average_mode = True + else: + test_in_model_average_mode = getattr( + self, 'test_in_model_average_mode', False) + data_t = self.get_data_t(data, self.net_G_output, self.data_prev, 0) + if self.sequence_length > 1: + self.data_prev = data_t + + # Generator forward. + # Reset renderer if first time step. + if self.t == 0: + self.net_G_module.reset_renderer( + is_flipped_input=data['is_flipped']) + with torch.no_grad(): + if test_in_model_average_mode: + net_G = self.net_G.module.averaged_model + else: + net_G = self.net_G + self.net_G_output = net_G(data_t) + + if output_dir is not None: + if save_fake_only: + image_grid = tensor2im(self.net_G_output['fake_images'])[0] + else: + vis_images = self.get_test_output_images(data) + image_grid = np.hstack([np.vstack(im) for im in + vis_images if im is not None]) + if 'img_name' in data: + save_name = data['img_name'].split('.')[0] + '.jpg' + else: + save_name = '%04d.jpg' % self.t + output_filename = os.path.join(output_dir, save_name) + os.makedirs(output_dir, exist_ok=True) + imageio.imwrite(output_filename, image_grid) + self.t += 1 + + return self.net_G_output + + def get_test_output_images(self, data): + r"""Get the visualization output of test function. + + Args: + data (dict): Training data at the current iteration. + """ + # Visualize labels. + label_lengths = self.val_data_loader.dataset.get_label_lengths() + labels = split_labels(data['label'], label_lengths) + vis_labels = [] + for key, value in labels.items(): + if key == 'seg_maps': + vis_labels.append(self.visualize_label(value[:, -1])) + else: + vis_labels.append(tensor2im(value[:, -1])) + + # Get gt image. + im = tensor2im(data['images'][:, -1]) + + # Get guidance image and masks. + if self.net_G_output['guidance_images_and_masks'] is not None: + guidance_image = tensor2im( + self.net_G_output['guidance_images_and_masks'][:, :3]) + guidance_mask = tensor2im( + self.net_G_output['guidance_images_and_masks'][:, 3:4], + normalize=False) + else: + guidance_image = [np.zeros_like(item) for item in im] + guidance_mask = [np.zeros_like(item) for item in im] + + # Create output. + vis_images = [ + *vis_labels, + im, + guidance_image, guidance_mask, + tensor2im(self.net_G_output['fake_images']), + ] + return vis_images + + def gen_frames(self, data, use_model_average=False): + r"""Generate a sequence of frames given a sequence of data. + + Args: + data (dict): Training data at the current iteration. + use_model_average (bool): Whether to use model average + for update or not. + """ + net_G_output = None # Previous generator output. + data_prev = None # Previous data. + if use_model_average: + net_G = self.net_G.module.averaged_model + else: + net_G = self.net_G + + # Iterate through the length of sequence. + self.net_G_module.reset_renderer(is_flipped_input=data['is_flipped']) + + all_info = {'inputs': [], 'outputs': []} + for t in range(self.sequence_length): + # Get the data at the current time frame. + data_t = self.get_data_t(data, net_G_output, data_prev, t) + data_prev = data_t + + # Generator forward. + with torch.no_grad(): + net_G_output = net_G(data_t) + + # Do any postprocessing if necessary. + data_t, net_G_output = self.post_process(data_t, net_G_output) + + if t == 0: + # Get the output at beginning of sequence for visualization. + first_net_G_output = net_G_output + + all_info['inputs'].append(data_t) + all_info['outputs'].append(net_G_output) + + return first_net_G_output, net_G_output, all_info + + def _get_custom_gen_losses(self, data_t, net_G_output, net_D_output): + r"""All other custom generator losses go here. + + Args: + data_t (dict): Training data at the current time t. + net_G_output (dict): Output of the generator. + net_D_output (dict): Output of the discriminator. + """ + # Compute guidance loss. + if net_G_output['guidance_images_and_masks'] is not None: + guidance_image = net_G_output['guidance_images_and_masks'][:, :3] + guidance_mask = net_G_output['guidance_images_and_masks'][:, 3:] + self.gen_losses['Guidance'] = self.criteria['Guidance']( + net_G_output['fake_images'], guidance_image, guidance_mask) + else: + self.gen_losses['Guidance'] = self.Tensor(1).fill_(0) + + def get_data_t(self, data, net_G_output, data_prev, t): + r"""Get data at current time frame given the sequence of data. + + Args: + data (dict): Training data for current iteration. + net_G_output (dict): Output of the generator (for previous frame). + data_prev (dict): Data for previous frame. + t (int): Current time. + """ + label = data['label'][:, t] + image = data['images'][:, t] + + # Get keypoint mapping. + unprojection = None + if t >= self.guidance_start_after: + if 'unprojections' in data: + try: + # Remove unwanted padding. + unprojection = {} + for key, value in data['unprojections'].items(): + value = value[0, t].cpu().numpy() + length = value[-1][0] + unprojection[key] = value[:length] + except: # noqa + pass + + if data_prev is not None: + # Concat previous labels/fake images to the ones before. + num_frames_G = self.cfg.data.num_frames_G + prev_labels = concat_frames(data_prev['prev_labels'], + data_prev['label'], num_frames_G - 1) + prev_images = concat_frames( + data_prev['prev_images'], + net_G_output['fake_images'].detach(), num_frames_G - 1) + else: + prev_labels = prev_images = None + + data_t = dict() + data_t['label'] = label + data_t['image'] = image + data_t['prev_labels'] = prev_labels + data_t['prev_images'] = prev_images + data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None + data_t['unprojection'] = unprojection + return data_t + + def save_image(self, path, data): + r"""Save the output images to path. + Note when the generate_raw_output is FALSE. Then, + first_net_G_output['fake_raw_images'] is None and will not be displayed. + In model average mode, we will plot the flow visualization twice. + + Args: + path (str): Save path. + data (dict): Training data for current iteration. + """ + self.net_G.eval() + if self.cfg.trainer.model_average_config.enabled: + self.net_G.module.averaged_model.eval() + self.net_G_output = None + with torch.no_grad(): + first_net_G_output, net_G_output, all_info = self.gen_frames(data) + if self.cfg.trainer.model_average_config.enabled: + first_net_G_output_avg, net_G_output_avg = self.gen_frames( + data, use_model_average=True) + + # Visualize labels. + label_lengths = self.train_data_loader.dataset.get_label_lengths() + labels = split_labels(data['label'], label_lengths) + vis_labels_start, vis_labels_end = [], [] + for key, value in labels.items(): + if 'seg_maps' in key: + vis_labels_start.append(self.visualize_label(value[:, -1])) + vis_labels_end.append(self.visualize_label(value[:, 0])) + else: + normalize = self.train_data_loader.dataset.normalize[key] + vis_labels_start.append( + tensor2im(value[:, -1], normalize=normalize)) + vis_labels_end.append( + tensor2im(value[:, 0], normalize=normalize)) + + if is_master(): + vis_images = [ + *vis_labels_start, + tensor2im(data['images'][:, -1]), + tensor2im(net_G_output['fake_images']), + tensor2im(net_G_output['fake_raw_images'])] + if self.cfg.trainer.model_average_config.enabled: + vis_images += [ + tensor2im(net_G_output_avg['fake_images']), + tensor2im(net_G_output_avg['fake_raw_images'])] + + if self.sequence_length > 1: + if net_G_output['guidance_images_and_masks'] is not None: + guidance_image = tensor2im( + net_G_output['guidance_images_and_masks'][:, :3]) + guidance_mask = tensor2im( + net_G_output['guidance_images_and_masks'][:, 3:4], + normalize=False) + else: + im = tensor2im(data['images'][:, -1]) + guidance_image = [np.zeros_like(item) for item in im] + guidance_mask = [np.zeros_like(item) for item in im] + vis_images += [guidance_image, guidance_mask] + + vis_images_first = [ + *vis_labels_end, + tensor2im(data['images'][:, 0]), + tensor2im(first_net_G_output['fake_images']), + tensor2im(first_net_G_output['fake_raw_images']), + [np.zeros_like(item) for item in guidance_image], + [np.zeros_like(item) for item in guidance_mask] + ] + if self.cfg.trainer.model_average_config.enabled: + vis_images_first += [ + tensor2im(first_net_G_output_avg['fake_images']), + tensor2im(first_net_G_output_avg['fake_raw_images'])] + + if self.use_flow: + flow_gt, conf_gt = self.criteria['Flow'].flowNet( + data['images'][:, -1], data['images'][:, -2]) + warped_image_gt = resample(data['images'][:, -1], flow_gt) + vis_images_first += [ + tensor2flow(flow_gt), + tensor2im(conf_gt, normalize=False), + tensor2im(warped_image_gt), + ] + vis_images += [ + tensor2flow(net_G_output['fake_flow_maps']), + tensor2im(net_G_output['fake_occlusion_masks'], + normalize=False), + tensor2im(net_G_output['warped_images']), + ] + if self.cfg.trainer.model_average_config.enabled: + vis_images_first += [ + tensor2flow(flow_gt), + tensor2im(conf_gt, normalize=False), + tensor2im(warped_image_gt), + ] + vis_images += [ + tensor2flow(net_G_output_avg['fake_flow_maps']), + tensor2im(net_G_output_avg['fake_occlusion_masks'], + normalize=False), + tensor2im(net_G_output_avg['warped_images'])] + + vis_images = [[np.vstack((im_first, im)) + for im_first, im in zip(imgs_first, imgs)] + for imgs_first, imgs in zip(vis_images_first, + vis_images) + if imgs is not None] + + image_grid = np.hstack([np.vstack(im) for im in + vis_images if im is not None]) + + print('Save output images to {}'.format(path)) + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.imwrite(path, image_grid) + + # Gather all inputs and outputs for dumping into video. + if self.sequence_length > 1: + input_images, output_images, output_guidance = [], [], [] + for item in all_info['inputs']: + input_images.append(tensor2im(item['image'])[0]) + for item in all_info['outputs']: + output_images.append(tensor2im(item['fake_images'])[0]) + if item['guidance_images_and_masks'] is not None: + output_guidance.append(tensor2im( + item['guidance_images_and_masks'][:, :3])[0]) + else: + output_guidance.append(np.zeros_like(output_images[-1])) + + imageio.mimwrite(os.path.splitext(path)[0] + '.mp4', + output_images, fps=2, macro_block_size=None) + imageio.mimwrite(os.path.splitext(path)[0] + '_guidance.mp4', + output_guidance, fps=2, macro_block_size=None) + + # for idx, item in enumerate(output_guidance): + # imageio.imwrite(os.path.splitext( + # path)[0] + '_guidance_%d.jpg' % (idx), item) + # for idx, item in enumerate(input_images): + # imageio.imwrite(os.path.splitext( + # path)[0] + '_input_%d.jpg' % (idx), item) + + self.net_G.float() + + def _compute_fid(self): + r"""Compute fid. Ignore for faster training.""" + return None + + def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True): + r"""Save network weights, optimizer parameters, scheduler parameters + in the checkpoint. + + Args: + cfg (obj): Global configuration. + checkpoint_path (str): Path to the checkpoint. + """ + # Create the single image model. + if self.train_data_loader is None: + load_single_image_model_weights = False + else: + load_single_image_model_weights = True + self.net_G.module._init_single_image_model( + load_weights=load_single_image_model_weights) + + # Call the original super function. + return super().load_checkpoint(cfg, checkpoint_path, resume, load_sch) diff --git a/imaginaire/utils/__init__.py b/imaginaire/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13acefe2181136b1629ec31f9d122fb46bf26780 --- /dev/null +++ b/imaginaire/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md diff --git a/imaginaire/utils/__pycache__/__init__.cpython-38.pyc b/imaginaire/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0671333f42ff39838dabfd22f45a229e98c2033f Binary files /dev/null and b/imaginaire/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/imaginaire/utils/__pycache__/distributed.cpython-38.pyc b/imaginaire/utils/__pycache__/distributed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14d9ab6fd67fee0fbc2f552b61d8aea46723bc79 Binary files /dev/null and b/imaginaire/utils/__pycache__/distributed.cpython-38.pyc differ diff --git a/imaginaire/utils/__pycache__/init_weight.cpython-38.pyc b/imaginaire/utils/__pycache__/init_weight.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75e710fcc11586e2e8df484362244f5ca9367701 Binary files /dev/null and b/imaginaire/utils/__pycache__/init_weight.cpython-38.pyc differ diff --git a/imaginaire/utils/__pycache__/misc.cpython-38.pyc b/imaginaire/utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26718e225ce3330bb85d0a6851ee629aff81c6d1 Binary files /dev/null and b/imaginaire/utils/__pycache__/misc.cpython-38.pyc differ diff --git a/imaginaire/utils/__pycache__/model_average.cpython-38.pyc b/imaginaire/utils/__pycache__/model_average.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03c944016eebac22362c35a46c8fd7bd6327f7b0 Binary files /dev/null and b/imaginaire/utils/__pycache__/model_average.cpython-38.pyc differ diff --git a/imaginaire/utils/__pycache__/trainer.cpython-38.pyc b/imaginaire/utils/__pycache__/trainer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f699fc92b20ab9dfa0768b4479f377770b31447b Binary files /dev/null and b/imaginaire/utils/__pycache__/trainer.cpython-38.pyc differ diff --git a/imaginaire/utils/cudnn.py b/imaginaire/utils/cudnn.py new file mode 100644 index 0000000000000000000000000000000000000000..c7a5cc3b5607c56e997a6c38c184e4b3f4e302f8 --- /dev/null +++ b/imaginaire/utils/cudnn.py @@ -0,0 +1,22 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch.backends.cudnn as cudnn + +from imaginaire.utils.distributed import master_only_print as print + + +def init_cudnn(deterministic, benchmark): + r"""Initialize the cudnn module. The two things to consider is whether to + use cudnn benchmark and whether to use cudnn deterministic. If cudnn + benchmark is set, then the cudnn deterministic is automatically false. + + Args: + deterministic (bool): Whether to use cudnn deterministic. + benchmark (bool): Whether to use cudnn benchmark. + """ + cudnn.deterministic = deterministic + cudnn.benchmark = benchmark + print('cudnn benchmark: {}'.format(benchmark)) + print('cudnn deterministic: {}'.format(deterministic)) diff --git a/imaginaire/utils/data.py b/imaginaire/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..22268c955ff7df4e5933e8ce2fb0b38c9c0e2f4a --- /dev/null +++ b/imaginaire/utils/data.py @@ -0,0 +1,612 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +# flake8: noqa: E712 +"""Utils for handling datasets.""" + +import time +import numpy as np +from PIL import Image + +# https://github.com/albumentations-team/albumentations#comments +import cv2 +# from imaginaire.utils.distributed import master_only_print as print +import albumentations as alb # noqa nopep8 + +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +IMG_EXTENSIONS = ('jpg', 'jpeg', 'png', 'ppm', 'bmp', + 'pgm', 'tif', 'tiff', 'webp', + 'JPG', 'JPEG', 'PNG', 'PPM', 'BMP', + 'PGM', 'TIF', 'TIFF', 'WEBP') +HDR_IMG_EXTENSIONS = ('hdr',) +VIDEO_EXTENSIONS = 'mp4' + + +class Augmentor(object): + r"""Handles data augmentation using albumentations library.""" + + def __init__(self, aug_list, individual_video_frame_aug_list, image_data_types, is_mask, + keypoint_data_types, interpolator): + r"""Initializes augmentation pipeline. + + Args: + aug_list (list): List of augmentation operations in sequence. + individual_video_frame_aug_list (list): List of augmentation operations in sequence that will be applied + to individual frames of videos independently. + image_data_types (list): List of keys in expected inputs. + is_mask (dict): Whether this data type is discrete masks? + keypoint_data_types (list): List of keys which are keypoints. + """ + + self.aug_list = aug_list + self.individual_video_frame_aug_list = individual_video_frame_aug_list + self.image_data_types = image_data_types + self.is_mask = is_mask + self.crop_h, self.crop_w = None, None + self.resize_h, self.resize_w = None, None + self.resize_smallest_side = None + self.max_time_step = 1 + self.keypoint_data_types = keypoint_data_types + self.interpolator = interpolator + + self.augment_ops = self._build_augmentation_ops() + self.individual_video_frame_augmentation_ops = self._build_individual_video_frame_augmentation_ops() + # Both crop and resize can't be none at the same time. + if self.crop_h is None and self.resize_smallest_side is None and \ + self.resize_h is None: + raise ValueError('resize_smallest_side, resize_h_w, ' + 'and crop_h_w cannot all be missing.') + # If resize_smallest_side is given, resize_h_w should not be give. + if self.resize_smallest_side is not None: + assert self.resize_h is None, \ + 'Cannot have both `resize_smallest_side` and `resize_h_w` set.' + if self.resize_smallest_side is None and self.resize_h is None: + self.resize_h, self.resize_w = self.crop_h, self.crop_w + + def _build_individual_video_frame_augmentation_ops(self): + r"""Builds sequence of augmentation ops that will be applied to each frame in the video independently. + Returns: + (list of alb.ops): List of augmentation ops. + """ + augs = [] + for key, value in self.individual_video_frame_aug_list.items(): + if key == 'random_scale_limit': + if type(value) == float: + scale_limit_lb = scale_limit_ub = value + p = 1 + else: + scale_limit_lb = value['scale_limit_lb'] + scale_limit_ub = value['scale_limit_ub'] + p = value['p'] + augs.append(alb.RandomScale(scale_limit=(-scale_limit_lb, scale_limit_ub), p=p)) + elif key == 'random_crop_h_w': + h, w = value.split(',') + h, w = int(h), int(w) + self.crop_h, self.crop_w = h, w + augs.append(alb.PadIfNeeded(min_height=h, min_width=w)) + augs.append(alb.RandomCrop(h, w, always_apply=True, p=1)) + return augs + + def _build_augmentation_ops(self): + r"""Builds sequence of augmentation ops. + Returns: + (list of alb.ops): List of augmentation ops. + """ + augs = [] + for key, value in self.aug_list.items(): + if key == 'resize_smallest_side': + if isinstance(value, int): + self.resize_smallest_side = value + else: + h, w = value.split(',') + h, w = int(h), int(w) + self.resize_smallest_side = (h, w) + elif key == 'resize_h_w': + h, w = value.split(',') + h, w = int(h), int(w) + self.resize_h, self.resize_w = h, w + elif key == 'random_resize_h_w_aspect': + aspect_start, aspect_end = value.find('('), value.find(')') + aspect = value[aspect_start+1:aspect_end] + aspect_min, aspect_max = aspect.split(',') + h, w = value[:aspect_start].split(',')[:2] + h, w = int(h), int(w) + aspect_min, aspect_max = float(aspect_min), float(aspect_max) + augs.append(alb.RandomResizedCrop( + h, w, scale=(1, 1), + ratio=(aspect_min, aspect_max), always_apply=True, p=1)) + self.resize_h, self.resize_w = h, w + elif key == 'rotate': + augs.append(alb.Rotate( + limit=value, always_apply=True, p=1)) + elif key == 'random_rotate_90': + augs.append(alb.RandomRotate90(always_apply=False, p=0.5)) + elif key == 'random_scale_limit': + augs.append(alb.RandomScale(scale_limit=(0, value), p=1)) + elif key == 'random_crop_h_w': + h, w = value.split(',') + h, w = int(h), int(w) + self.crop_h, self.crop_w = h, w + augs.append(alb.RandomCrop(h, w, always_apply=True, p=1)) + elif key == 'center_crop_h_w': + h, w = value.split(',') + h, w = int(h), int(w) + self.crop_h, self.crop_w = h, w + augs.append(alb.CenterCrop(h, w, always_apply=True, p=1)) + elif key == 'horizontal_flip': + # This is handled separately as we need to keep track if this + # was applied in order to correctly modify keypoint data. + if value: + augs.append(alb.HorizontalFlip(always_apply=False, p=0.5)) + # The options below including contrast, blur, motion_blur, compression, gamma + # were used during developing face-vid2vid. + elif key == 'contrast': + brightness_limit = value['brightness_limit'] + contrast_limit = value['contrast_limit'] + p = value['p'] + augs.append(alb.RandomBrightnessContrast( + brightness_limit=brightness_limit, contrast_limit=contrast_limit, p=p)) + elif key == 'blur': + blur_limit = value['blur_limit'] + p = value['p'] + augs.append(alb.Blur(blur_limit=blur_limit, p=p)) + elif key == 'motion_blur': + blur_limit = value['blur_limit'] + p = value['p'] + augs.append(alb.MotionBlur(blur_limit=blur_limit, p=p)) + elif key == 'compression': + quality_lower = value['quality_lower'] + p = value['p'] + augs.append(alb.ImageCompression(quality_lower=quality_lower, p=p)) + elif key == 'gamma': + gamma_limit_lb = value['gamma_limit_lb'] + gamma_limit_ub = value['gamma_limit_ub'] + p = value['p'] + augs.append(alb.RandomGamma(gamma_limit=(gamma_limit_lb, gamma_limit_ub), p=p)) + elif key == 'max_time_step': + self.max_time_step = value + assert self.max_time_step >= 1, \ + 'max_time_step has to be at least 1' + else: + raise ValueError('Unknown augmentation %s' % (key)) + return augs + + def _choose_image_key(self, inputs): + r"""Choose key to replace with 'image' for input to albumentations. + + Returns: + key (str): Chosen key to be replace with 'image' + """ + if 'image' in inputs: + return 'image' + for data_type in inputs: + if data_type in self.image_data_types: + return data_type + + def _choose_keypoint_key(self, inputs): + r"""Choose key to replace with 'keypoints' for input to albumentations. + Returns: + key (str): Chosen key to be replace with 'keypoints' + """ + if not self.keypoint_data_types: + return None + if 'keypoints' in inputs: + return 'keypoints' + for data_type in inputs: + if data_type in self.keypoint_data_types: + return data_type + + def _create_augmentation_targets(self, inputs): + r"""Create additional targets as required by the albumentation library. + + Args: + inputs (dict): Keys are from self.augmentable_data_types. Values can + be numpy.ndarray or list of numpy.ndarray + (image or list of images). + Returns: + (dict): + - targets (dict): Dict containing mapping of keys to image/mask types. + - new_inputs (dict): Dict containing mapping of keys to data. + """ + # Get additional target list. + targets, new_inputs = {}, {} + for data_type in inputs: + if data_type in self.keypoint_data_types: + # Keypoint-type. + target_type = 'keypoints' + elif data_type in self.image_data_types: + # Image-type. + # Find the target type (image/mask) based on interpolation + # method. + if self.is_mask[data_type]: + target_type = 'mask' + else: + target_type = 'image' + else: + raise ValueError( + 'Data type: %s is not image or keypoint' % (data_type)) + + current_data_type_inputs = inputs[data_type] + if not isinstance(current_data_type_inputs, list): + current_data_type_inputs = [current_data_type_inputs] + + # Create additional_targets and inputs when there are multiples. + for idx, new_input in enumerate(current_data_type_inputs): + key = data_type + if idx > 0: + key = '%s::%05d' % (key, idx) + targets[key] = target_type + new_inputs[key] = new_input + + return targets, new_inputs + + def _collate_augmented(self, augmented): + r"""Collate separated images back into sequence, grouped by keys. + + Args: + augmented (dict): Dict containing frames with keys of the form + 'key', 'key::00001', 'key::00002', ..., 'key::N'. + Returns: + (dict): + - outputs (dict): Dict with list of collated inputs, i.e. frames of + - same key are arranged in order ['key', 'key::00001', ..., 'key::N']. + """ + full_keys = sorted(augmented.keys()) + outputs = {} + for full_key in full_keys: + if '::' not in full_key: + # First occurrence of this key. + key = full_key + outputs[key] = [] + else: + key = full_key.split('::')[0] + outputs[key].append(augmented[full_key]) + return outputs + + def _get_resize_h_w(self, height, width): + r"""Get height and width to resize to, given smallest side. + + Args: + height (int): Input image height. + width (int): Input image width. + Returns: + (dict): + - height (int): Height to resize image to. + - width (int): Width to resize image to. + """ + if self.resize_smallest_side is None: + return self.resize_h, self.resize_w + + if isinstance(self.resize_smallest_side, int): + resize_smallest_height, resize_smallest_width = self.resize_smallest_side, self.resize_smallest_side + else: + resize_smallest_height, resize_smallest_width = self.resize_smallest_side + + if height * resize_smallest_width <= width * resize_smallest_height: + new_height = resize_smallest_height + new_width = int(np.round(new_height * width / float(height))) + else: + new_width = resize_smallest_width + new_height = int(np.round(new_width * height / float(width))) + return new_height, new_width + + def _perform_unpaired_augmentation(self, inputs, augment_ops): + r"""Perform different data augmentation on different image inputs. Note that this operation only works + + Args: + inputs (dict): Keys are from self.image_data_types. Values are list + of numpy.ndarray (list of images). + augment_ops (list): The augmentation operations. + Returns: + (dict): + - augmented (dict): Augmented inputs, with same keys as inputs. + - is_flipped (dict): Flag which tells if images have been LR flipped. + """ + # Process each data type separately as this is unpaired augmentation. + is_flipped = {} + for data_type in inputs: + assert data_type in self.image_data_types + augmented, flipped_flag = self._perform_paired_augmentation( + {data_type: inputs[data_type]}, augment_ops) + inputs[data_type] = augmented[data_type] + is_flipped[data_type] = flipped_flag + return inputs, is_flipped + + def _perform_paired_augmentation(self, inputs, augment_ops): + r"""Perform same data augmentation on all inputs. + + Args: + inputs (dict): Keys are from self.augmentable_data_types. Values are + list of numpy.ndarray (list of images). + augment_ops (list): The augmentation operations. + + Returns: + (dict): + - augmented (dict): Augmented inputs, with same keys as inputs. + - is_flipped (bool): Flag which tells if images have been LR flipped. + """ + # Different data types may have different sizes and we use the largest one as the original size. + # Convert PIL images to numpy array. + self.original_h, self.original_w = 0, 0 + for data_type in inputs: + if data_type in self.keypoint_data_types or \ + data_type not in self.image_data_types: + continue + for idx in range(len(inputs[data_type])): + value = inputs[data_type][idx] + # Get resize h, w. + w, h = get_image_size(value) + self.original_h, self.original_w = max(self.original_h, h), max(self.original_w, w) + # self.original_h, self.original_w = h, w + # self.resize_h, self.resize_w = self._get_resize_h_w(h, w) + # Convert to numpy array with 3 dims (H, W, C). + value = np.array(value) + if value.ndim == 2: + value = value[..., np.newaxis] + inputs[data_type][idx] = value + self.resize_h, self.resize_w = self._get_resize_h_w(self.original_h, self.original_w) + + # Add resize op to augmentation ops. + aug_ops_with_resize = [alb.Resize( + self.resize_h, self.resize_w, interpolation=getattr(cv2, self.interpolator), always_apply=1, p=1 + )] + augment_ops + + # Create targets. + targets, new_inputs = self._create_augmentation_targets(inputs) + extra_params = {} + + # Albumentation requires a key called 'image' and + # a key called 'keypoints', if any keypoints are being passed in. + # Arbitrarily choose one key of image type to be 'image'. + chosen_image_key = self._choose_image_key(inputs) + new_inputs['image'] = new_inputs.pop(chosen_image_key) + targets['image'] = targets.pop(chosen_image_key) + # Arbitrarily choose one key of keypoint type to be 'keypoints'. + chosen_keypoint_key = self._choose_keypoint_key(inputs) + if chosen_keypoint_key is not None: + new_inputs['keypoints'] = new_inputs.pop(chosen_keypoint_key) + targets['keypoints'] = targets.pop(chosen_keypoint_key) + extra_params['keypoint_params'] = alb.KeypointParams( + format='xy', remove_invisible=False) + + # Do augmentation. + augmented = alb.ReplayCompose( + aug_ops_with_resize, additional_targets=targets, + **extra_params)(**new_inputs) + augmentation_params = augmented.pop('replay') + + # Check if flipping has occurred. + is_flipped = False + for augmentation_param in augmentation_params['transforms']: + if 'HorizontalFlip' in augmentation_param['__class_fullname__']: + is_flipped = augmentation_param['applied'] + self.is_flipped = is_flipped + + # Replace the key 'image' with chosen_image_key, same for 'keypoints'. + augmented[chosen_image_key] = augmented.pop('image') + if chosen_keypoint_key is not None: + augmented[chosen_keypoint_key] = augmented.pop('keypoints') + + # Pack images back into a sequence. + augmented = self._collate_augmented(augmented) + + # Convert keypoint types to np.array from list. + for data_type in self.keypoint_data_types: + augmented[data_type] = np.array(augmented[data_type]) + + return augmented, is_flipped + + def perform_augmentation(self, inputs, paired, augment_ops): + r"""Entry point for augmentation. + + Args: + inputs (dict): Keys are from self.augmentable_data_types. Values are + list of numpy.ndarray (list of images). + paired (bool): Apply same augmentation to all input keys? + augment_ops (list): The augmentation operations. + """ + # Make sure that all inputs are of same size, else trouble will + # ensue. This is because different images might have different + # aspect ratios. + # Check within data type. + for data_type in inputs: + if data_type in self.keypoint_data_types or \ + data_type not in self.image_data_types: + continue + for idx in range(len(inputs[data_type])): + if idx == 0: + w, h = get_image_size(inputs[data_type][idx]) + else: + this_w, this_h = get_image_size(inputs[data_type][idx]) + # assert this_w == w and this_h == h + # assert this_w / (1.0 * this_h) == w / (1.0 * h) + # Check across data types. + if paired and self.resize_smallest_side is not None: + for idx, data_type in enumerate(inputs): + if data_type in self.keypoint_data_types or \ + data_type not in self.image_data_types: + continue + if paired: + return self._perform_paired_augmentation(inputs, augment_ops) + else: + return self._perform_unpaired_augmentation(inputs, augment_ops) + + +def load_from_lmdb(keys, lmdbs): + r"""Load keys from lmdb handles. + + Args: + keys (dict): This has data_type as key, and a list of paths into LMDB as + values. + lmdbs (dict): This has data_type as key, and LMDB handle as value. + Returns: + data (dict): This has data_type as key, and a list of decoded items from + LMDBs as value. + """ + data = {} + for data_type in keys: + if data_type not in data: + data[data_type] = [] + data_type_keys = keys[data_type] + if not isinstance(data_type_keys, list): + data_type_keys = [data_type_keys] + for key in data_type_keys: + data[data_type].append(lmdbs[data_type].getitem_by_path( + key.encode(), data_type)) + return data + + +def load_from_folder(keys, handles): + r"""Load keys from lmdb handles. + + Args: + keys (dict): This has data_type as key, and a list of paths as + values. + handles (dict): This has data_type as key, and Folder handle as value. + Returns: + data (dict): This has data_type as key, and a list of decoded items from + folders as value. + """ + data = {} + for data_type in keys: + if data_type not in data: + data[data_type] = [] + data_type_keys = keys[data_type] + if not isinstance(data_type_keys, list): + data_type_keys = [data_type_keys] + for key in data_type_keys: + data[data_type].append(handles[data_type].getitem_by_path( + key.encode(), data_type)) + return data + + +def load_from_object_store(keys, handles): + r"""Load keys from AWS S3 handles. + + Args: + keys (dict): This has data_type as key, and a list of paths as + values. + handles (dict): This has data_type as key, and Folder handle as value. + Returns: + data (dict): This has data_type as key, and a list of decoded items from + folders as value. + """ + data = {} + for data_type in keys: + if data_type not in data: + data[data_type] = [] + data_type_keys = keys[data_type] + if not isinstance(data_type_keys, list): + data_type_keys = [data_type_keys] + for key in data_type_keys: + while True: + try: + data[data_type].append(handles[data_type].getitem_by_path(key, data_type)) + except Exception as e: + print(e) + print(key, data_type) + print('Retrying in 30 seconds') + time.sleep(30) + continue + break + return data + + +def get_paired_input_image_channel_number(data_cfg): + r"""Get number of channels for the input image. + + Args: + data_cfg (obj): Data configuration structure. + Returns: + num_channels (int): Number of input image channels. + """ + num_channels = 0 + for ix, data_type in enumerate(data_cfg.input_types): + for k in data_type: + if k in data_cfg.input_image: + num_channels += data_type[k].num_channels + print('Concatenate %s for input.' % data_type) + print('\tNum. of channels in the input image: %d' % num_channels) + return num_channels + + +def get_paired_input_label_channel_number(data_cfg, video=False): + r"""Get number of channels for the input label map. + + Args: + data_cfg (obj): Data configuration structure. + video (bool): Whether we are dealing with video data. + Returns: + num_channels (int): Number of input label map channels. + """ + num_labels = 0 + if not hasattr(data_cfg, 'input_labels'): + return num_labels + for ix, data_type in enumerate(data_cfg.input_types): + for k in data_type: + if k in data_cfg.input_labels: + if hasattr(data_cfg, 'one_hot_num_classes') and k in data_cfg.one_hot_num_classes: + num_labels += data_cfg.one_hot_num_classes[k] + if getattr(data_cfg, 'use_dont_care', False): + num_labels += 1 + else: + num_labels += data_type[k].num_channels + print('Concatenate %s for input.' % data_type) + + if video: + num_time_steps = getattr(data_cfg.train, 'initial_sequence_length', + None) + num_labels *= num_time_steps + num_labels += get_paired_input_image_channel_number(data_cfg) * ( + num_time_steps - 1) + + print('\tNum. of channels in the input label: %d' % num_labels) + return num_labels + + +def get_class_number(data_cfg): + r"""Get number of classes for class-conditional GAN model + + Args: + data_cfg (obj): Data configuration structure. + + Returns: + (int): Number of classes. + """ + return data_cfg.num_classes + + +def get_crop_h_w(augmentation): + r"""Get height and width of crop. + + Args: + augmentation (dict): Dict of applied augmentations. + + Returns: + (dict): + - crop_h (int): Height of the image crop. + - crop_w (int): Width of the image crop. + """ + print(augmentation.__dict__.keys()) + for k in augmentation.__dict__.keys(): + if 'crop_h_w' in k: + filed = augmentation[k] + crop_h, crop_w = filed.split(',') + crop_h = int(crop_h) + crop_w = int(crop_w) + # assert crop_w == crop_h, 'This implementation only ' \ + # 'supports square-shaped images.' + print('\tCrop size: (%d, %d)' % (crop_h, crop_w)) + return crop_h, crop_w + raise AttributeError + + +def get_image_size(x): + try: + w, h = x.size + except Exception: + h, w, _ = x.shape + return w, h diff --git a/imaginaire/utils/dataset.py b/imaginaire/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..605bb2e20d2a22a70e38254bea46bd73177c8c5a --- /dev/null +++ b/imaginaire/utils/dataset.py @@ -0,0 +1,120 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import importlib + +import torch +import torch.distributed as dist + +from imaginaire.utils.distributed import master_only_print as print + + +def _get_train_and_val_dataset_objects(cfg): + r"""Return dataset objects for the training and validation sets. + + Args: + cfg (obj): Global configuration file. + + Returns: + (dict): + - train_dataset (obj): PyTorch training dataset object. + - val_dataset (obj): PyTorch validation dataset object. + """ + dataset_module = importlib.import_module(cfg.data.type) + train_dataset = dataset_module.Dataset(cfg, is_inference=False) + if hasattr(cfg.data.val, 'type'): + for key in ['type', 'input_types', 'input_image']: + setattr(cfg.data, key, getattr(cfg.data.val, key)) + dataset_module = importlib.import_module(cfg.data.type) + val_dataset = dataset_module.Dataset(cfg, is_inference=True) + print('Train dataset length:', len(train_dataset)) + print('Val dataset length:', len(val_dataset)) + return train_dataset, val_dataset + + +def _get_data_loader(cfg, dataset, batch_size, not_distributed=False, + shuffle=True, drop_last=True, seed=0): + r"""Return data loader . + + Args: + cfg (obj): Global configuration file. + dataset (obj): PyTorch dataset object. + batch_size (int): Batch size. + not_distributed (bool): Do not use distributed samplers. + + Return: + (obj): Data loader. + """ + not_distributed = not_distributed or not dist.is_initialized() + if not_distributed: + sampler = None + else: + sampler = torch.utils.data.distributed.DistributedSampler(dataset, seed=seed) + num_workers = getattr(cfg.data, 'num_workers', 8) + persistent_workers = getattr(cfg.data, 'persistent_workers', False) + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle and (sampler is None), + sampler=sampler, + pin_memory=True, + num_workers=num_workers, + drop_last=drop_last, + persistent_workers=persistent_workers if num_workers > 0 else False + ) + return data_loader + + +def get_train_and_val_dataloader(cfg, seed=0): + r"""Return dataset objects for the training and validation sets. + + Args: + cfg (obj): Global configuration file. + + Returns: + (dict): + - train_data_loader (obj): Train data loader. + - val_data_loader (obj): Val data loader. + """ + train_dataset, val_dataset = _get_train_and_val_dataset_objects(cfg) + train_data_loader = _get_data_loader(cfg, train_dataset, cfg.data.train.batch_size, drop_last=True, seed=seed) + not_distributed = getattr(cfg.data, 'val_data_loader_not_distributed', False) + not_distributed = 'video' in cfg.data.type or not_distributed + val_data_loader = _get_data_loader( + cfg, val_dataset, cfg.data.val.batch_size, not_distributed, + shuffle=False, drop_last=getattr(cfg.data.val, 'drop_last', False), seed=seed) + return train_data_loader, val_data_loader + + +def _get_test_dataset_object(cfg): + r"""Return dataset object for the test set + + Args: + cfg (obj): Global configuration file. + + Returns: + (obj): PyTorch dataset object. + """ + dataset_module = importlib.import_module(cfg.test_data.type) + test_dataset = dataset_module.Dataset(cfg, is_inference=True, is_test=True) + return test_dataset + + +def get_test_dataloader(cfg): + r"""Return dataset objects for testing + + Args: + cfg (obj): Global configuration file. + + Returns: + (obj): Val data loader. It may not contain the ground truth. + """ + test_dataset = _get_test_dataset_object(cfg) + not_distributed = getattr( + cfg.test_data, 'val_data_loader_not_distributed', False) + not_distributed = 'video' in cfg.test_data.type or not_distributed + test_data_loader = _get_data_loader( + cfg, test_dataset, cfg.test_data.test.batch_size, not_distributed, + shuffle=False) + return test_data_loader diff --git a/imaginaire/utils/diff_aug.py b/imaginaire/utils/diff_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..3410004fb68aef074750359cbb420e5dd340bd45 --- /dev/null +++ b/imaginaire/utils/diff_aug.py @@ -0,0 +1,142 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md + +# Differentiable Augmentation for Data-Efficient GAN Training +# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han +# https://arxiv.org/pdf/2006.10738 +# Modified from https://github.com/mit-han-lab/data-efficient-gans +import torch +import torch.nn.functional as F + + +def apply_diff_aug(data, keys, aug_policy, inplace=False, **kwargs): + r"""Applies differentiable augmentation. + Args: + data (dict): Input data. + keys (list of str): Keys to the data values that we want to apply + differentiable augmentation to. + aug_policy (str): Type of augmentation(s), ``'color'``, + ``'translation'``, or ``'cutout'`` separated by ``','``. + """ + if aug_policy == '': + return data + data_aug = data if inplace else {} + for key, value in data.items(): + if key in keys: + data_aug[key] = diff_aug(data[key], aug_policy, **kwargs) + else: + data_aug[key] = data[key] + return data_aug + + +def diff_aug(x, policy='', channels_first=True, **kwargs): + if policy: + if not channels_first: + x = x.permute(0, 3, 1, 2) + for p in policy.split(','): + for f in AUGMENT_FNS[p]: + x = f(x, **kwargs) + if not channels_first: + x = x.permute(0, 2, 3, 1) + x = x.contiguous() + return x + + +def rand_brightness(x, **kwargs): + x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, + device=x.device) - 0.5) + return x + + +def rand_saturation(x, **kwargs): + x_mean = x.mean(dim=1, keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, + device=x.device) * 2) + x_mean + return x + + +def rand_contrast(x, **kwargs): + x_mean = x.mean(dim=[1, 2, 3], keepdim=True) + x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, + device=x.device) + 0.5) + x_mean + return x + + +def rand_translation(x, ratio=0.125, **kwargs): + shift_x, shift_y = int(x.size(2) * ratio + 0.5), int( + x.size(3) * ratio + 0.5) + translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], + device=x.device) + translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], + device=x.device) + # noinspection PyTypeChecker + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(x.size(2), dtype=torch.long, device=x.device), + torch.arange(x.size(3), dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) + grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) + x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) + x = x_pad.permute(0, 2, 3, 1).contiguous()[ + grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) + return x + + +def rand_cutout(x, ratio=0.5, **kwargs): + cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) + offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), + size=[x.size(0), 1, 1], device=x.device) + offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), + size=[x.size(0), 1, 1], device=x.device) + # noinspection PyTypeChecker + grid_batch, grid_x, grid_y = torch.meshgrid( + torch.arange(x.size(0), dtype=torch.long, device=x.device), + torch.arange(cutout_size[0], dtype=torch.long, device=x.device), + torch.arange(cutout_size[1], dtype=torch.long, device=x.device), + ) + grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, + max=x.size(2) - 1) + grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, + max=x.size(3) - 1) + mask = torch.ones(x.size(0), x.size(2), x.size(3), + dtype=x.dtype, device=x.device) + mask[grid_batch, grid_x, grid_y] = 0 + x = x * mask.unsqueeze(1) + return x + + +def rand_translation_scale(x, trans_r=0.125, scale_r=0.125, + mode='bilinear', padding_mode='reflection', + **kwargs): + assert x.dim() == 4, "Input must be a 4D tensor." + batch_size = x.size(0) + + # Identity transformation. + theta = torch.eye(2, 3, device=x.device).unsqueeze(0).repeat( + batch_size, 1, 1) + + # Translation, uniformly sampled from (-trans_r, trans_r). + translate = \ + 2 * trans_r * torch.rand(batch_size, 2, device=x.device) - trans_r + theta[:, :, 2] += translate + + # Scaling, uniformly sampled from (1-scale_r, 1+scale_r). + scale = \ + 2 * scale_r * torch.rand(batch_size, 2, device=x.device) - scale_r + theta[:, :, :2] += torch.diag_embed(scale) + + grid = F.affine_grid(theta, x.size()) + x = F.grid_sample( + x.float(), grid.float(), mode=mode, padding_mode=padding_mode) + return x + + +AUGMENT_FNS = { + 'color': [rand_brightness, rand_saturation, rand_contrast], + 'translation': [rand_translation], + 'translation_scale': [rand_translation_scale], + 'cutout': [rand_cutout], +} diff --git a/imaginaire/utils/distributed.py b/imaginaire/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ec9d1099684e58a80a61107fe828e292352002 --- /dev/null +++ b/imaginaire/utils/distributed.py @@ -0,0 +1,117 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import functools +import ctypes + +import torch +import torch.distributed as dist + + +def init_dist(local_rank, backend='nccl', **kwargs): + r"""Initialize distributed training""" + if dist.is_available(): + if dist.is_initialized(): + return torch.cuda.current_device() + torch.cuda.set_device(local_rank) + dist.init_process_group(backend=backend, init_method='env://', **kwargs) + + # Increase the L2 fetch granularity for faster speed. + _libcudart = ctypes.CDLL('libcudart.so') + # Set device limit on the current device + # cudaLimitMaxL2FetchGranularity = 0x05 + pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) + _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) + _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) + # assert pValue.contents.value == 128 + + +def get_rank(): + r"""Get rank of the thread.""" + rank = 0 + if dist.is_available(): + if dist.is_initialized(): + rank = dist.get_rank() + return rank + + +def get_world_size(): + r"""Get world size. How many GPUs are available in this job.""" + world_size = 1 + if dist.is_available(): + if dist.is_initialized(): + world_size = dist.get_world_size() + return world_size + + +def master_only(func): + r"""Apply this function only to the master GPU.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + r"""Simple function wrapper for the master function""" + if get_rank() == 0: + return func(*args, **kwargs) + else: + return None + return wrapper + + +def is_master(): + r"""check if current process is the master""" + return get_rank() == 0 + + +def is_local_master(): + return torch.cuda.current_device() == 0 + + +@master_only +def master_only_print(*args): + r"""master-only print""" + print(*args) + + +def dist_reduce_tensor(tensor, rank=0, reduce='mean'): + r""" Reduce to rank 0 """ + world_size = get_world_size() + if world_size < 2: + return tensor + with torch.no_grad(): + dist.reduce(tensor, dst=rank) + if get_rank() == rank: + if reduce == 'mean': + tensor /= world_size + elif reduce == 'sum': + pass + else: + raise NotImplementedError + return tensor + + +def dist_all_reduce_tensor(tensor, reduce='mean'): + r""" Reduce to all ranks """ + world_size = get_world_size() + if world_size < 2: + return tensor + with torch.no_grad(): + dist.all_reduce(tensor) + if reduce == 'mean': + tensor /= world_size + elif reduce == 'sum': + pass + else: + raise NotImplementedError + return tensor + + +def dist_all_gather_tensor(tensor): + r""" gather to all ranks """ + world_size = get_world_size() + if world_size < 2: + return [tensor] + tensor_list = [ + torch.ones_like(tensor) for _ in range(dist.get_world_size())] + with torch.no_grad(): + dist.all_gather(tensor_list, tensor) + return tensor_list diff --git a/imaginaire/utils/gpu_affinity.py b/imaginaire/utils/gpu_affinity.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4e9cb40a5a5f9185e903af55694b5952cfe0ff --- /dev/null +++ b/imaginaire/utils/gpu_affinity.py @@ -0,0 +1,61 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import math +import os +import pynvml + +pynvml.nvmlInit() + + +def systemGetDriverVersion(): + r"""Get Driver Version""" + return pynvml.nvmlSystemGetDriverVersion() + + +def deviceGetCount(): + r"""Get number of devices""" + return pynvml.nvmlDeviceGetCount() + + +class device(object): + r"""Device used for nvml.""" + _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) + + def __init__(self, device_idx): + super().__init__() + self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) + + def getName(self): + r"""Get obect name""" + return pynvml.nvmlDeviceGetName(self.handle) + + def getCpuAffinity(self): + r"""Get CPU affinity""" + affinity_string = '' + for j in pynvml.nvmlDeviceGetCpuAffinity( + self.handle, device._nvml_affinity_elements): + # assume nvml returns list of 64 bit ints + affinity_string = '{:064b}'.format(j) + affinity_string + affinity_list = [int(x) for x in affinity_string] + affinity_list.reverse() # so core 0 is in 0th element of list + + return [i for i, e in enumerate(affinity_list) if e != 0] + + +def set_affinity(gpu_id=None): + r"""Set GPU affinity + + Args: + gpu_id (int): Which gpu device. + """ + if gpu_id is None: + gpu_id = int(os.getenv('LOCAL_RANK', 0)) + + dev = device(gpu_id) + os.sched_setaffinity(0, dev.getCpuAffinity()) + + # list of ints + # representing the logical cores this process is now affinitied with + return os.sched_getaffinity(0) diff --git a/imaginaire/utils/init_weight.py b/imaginaire/utils/init_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..80d826c27d7fe1ab75bfe565b40531acd02abd2b --- /dev/null +++ b/imaginaire/utils/init_weight.py @@ -0,0 +1,84 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import torch +from torch.nn import init + + +def weights_init(init_type='normal', gain=0.02, bias=None): + r"""Initialize weights in the network. + + Args: + init_type (str): The name of the initialization scheme. + gain (float): The parameter that is required for the initialization + scheme. + bias (object): If not ``None``, specifies the initialization parameter + for bias. + + Returns: + (obj): init function to be applied. + """ + + def init_func(m): + r"""Init function + + Args: + m: module to be weight initialized. + """ + class_name = m.__class__.__name__ + if hasattr(m, 'weight') and ( + class_name.find('Conv') != -1 or + class_name.find('Linear') != -1 or + class_name.find('Embedding') != -1): + lr_mul = getattr(m, 'lr_mul', 1.) + gain_final = gain / lr_mul + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain_final) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain_final) + elif init_type == 'xavier_uniform': + init.xavier_uniform_(m.weight.data, gain=gain_final) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + with torch.no_grad(): + m.weight.data *= gain_final + elif init_type == 'kaiming_linear': + init.kaiming_normal_( + m.weight.data, a=0, mode='fan_in', nonlinearity='linear' + ) + with torch.no_grad(): + m.weight.data *= gain_final + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain_final) + elif init_type == 'none': + pass + # m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is ' + 'not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + if init_type == 'none': + pass + elif bias is not None: + bias_type = getattr(bias, 'type', 'normal') + if bias_type == 'normal': + bias_gain = getattr(bias, 'gain', 0.5) + init.normal_(m.bias.data, 0.0, bias_gain) + else: + raise NotImplementedError( + 'initialization method [%s] is ' + 'not implemented' % bias_type) + else: + init.constant_(m.bias.data, 0.0) + return init_func + + +def weights_rescale(): + def init_func(m): + if hasattr(m, 'init_gain'): + for name, p in m.named_parameters(): + if 'output_scale' not in name: + p.data.mul_(m.init_gain) + return init_func diff --git a/imaginaire/utils/io.py b/imaginaire/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..2f3aa1d737d7a2b0922082c43d4f0c573a482063 --- /dev/null +++ b/imaginaire/utils/io.py @@ -0,0 +1,136 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import os + +import requests +import torch.distributed as dist +import torchvision.utils + +from imaginaire.utils.distributed import is_master + + +def save_pilimage_in_jpeg(fullname, output_img): + r"""Save PIL Image to JPEG. + + Args: + fullname (str): Full save path. + output_img (PIL Image): Image to be saved. + """ + dirname = os.path.dirname(fullname) + os.makedirs(dirname, exist_ok=True) + output_img.save(fullname, 'JPEG', quality=99) + + +def save_intermediate_training_results( + visualization_images, logdir, current_epoch, current_iteration): + r"""Save intermediate training results for debugging purpose. + + Args: + visualization_images (tensor): Image where pixel values are in [-1, 1]. + logdir (str): Where to save the image. + current_epoch (int): Current training epoch. + current_iteration (int): Current training iteration. + """ + visualization_images = (visualization_images + 1) / 2 + output_filename = os.path.join( + logdir, 'images', + 'epoch_{:05}iteration{:09}.jpg'.format( + current_epoch, current_iteration)) + print('Save output images to {}'.format(output_filename)) + os.makedirs(os.path.dirname(output_filename), exist_ok=True) + image_grid = torchvision.utils.make_grid( + visualization_images.data, nrow=1, padding=0, normalize=False) + torchvision.utils.save_image(image_grid, output_filename, nrow=1) + + +def download_file_from_google_drive(URL, destination): + r"""Download a file from google drive. + + Args: + URL: GDrive file ID. + destination: Path to save the file. + + Returns: + + """ + download_file(f"https://docs.google.com/uc?export=download&id={URL}", destination) + + +def download_file(URL, destination): + r"""Download a file from google drive or pbss by using the url. + + Args: + URL: GDrive URL or PBSS pre-signed URL for the checkpoint. + destination: Path to save the file. + + Returns: + + """ + session = requests.Session() + response = session.get(URL, stream=True) + token = get_confirm_token(response) + if token: + params = {'confirm': token} + response = session.get(URL, params=params, stream=True) + save_response_content(response, destination) + + +def get_confirm_token(response): + r"""Get confirm token + + Args: + response: Check if the file exists. + + Returns: + + """ + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination): + r"""Save response content + + Args: + response: + destination: Path to save the file. + + Returns: + + """ + chunk_size = 32768 + with open(destination, "wb") as f: + for chunk in response.iter_content(chunk_size): + if chunk: + f.write(chunk) + + +def get_checkpoint(checkpoint_path, url=''): + r"""Get the checkpoint path. If it does not exist yet, download it from + the url. + + Args: + checkpoint_path (str): Checkpoint path. + url (str): URL to download checkpoint. + Returns: + (str): Full checkpoint path. + """ + if 'TORCH_HOME' not in os.environ: + os.environ['TORCH_HOME'] = os.getcwd() + save_dir = os.path.join(os.environ['TORCH_HOME'], 'checkpoints') + os.makedirs(save_dir, exist_ok=True) + full_checkpoint_path = os.path.join(save_dir, checkpoint_path) + if not os.path.exists(full_checkpoint_path): + os.makedirs(os.path.dirname(full_checkpoint_path), exist_ok=True) + if is_master(): + print('Downloading {}'.format(url)) + if 'pbss.s8k.io' not in url: + url = f"https://docs.google.com/uc?export=download&id={url}" + download_file(url, full_checkpoint_path) + if dist.is_available() and dist.is_initialized(): + dist.barrier() + return full_checkpoint_path diff --git a/imaginaire/utils/lmdb.py b/imaginaire/utils/lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..df40c146b73295598cde04fd94a6869c6a5e69d2 --- /dev/null +++ b/imaginaire/utils/lmdb.py @@ -0,0 +1,216 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import glob +import os + +import lmdb +from tqdm import tqdm + +from imaginaire.utils import path + + +def construct_file_path(root, data_type, sequence, filename, ext): + """Get file path for our dataset structure.""" + return '%s/%s/%s/%s.%s' % (root, data_type, sequence, filename, ext) + + +def check_and_add(filepath, key, filepaths, keys, remove_missing=False): + r"""Add filepath and key to list of filepaths and keys. + + Args: + filepath (str): Filepath to add. + key (str): LMDB key for this filepath. + filepaths (list): List of filepaths added so far. + keys (list): List of keys added so far. + remove_missing (bool): If ``True``, removes missing files, otherwise + raises an error. + Returns: + (int): Size of file at filepath. + """ + if not os.path.exists(filepath): + print(filepath + ' does not exist.') + if remove_missing: + return -1 + else: + raise FileNotFoundError(filepath + ' does not exist.') + filepaths.append(filepath) + keys.append(key) + return os.path.getsize(filepath) + + +def write_entry(txn, key, filepath): + r"""Dump binary contents of file associated with key to LMDB. + + Args: + txn: handle to LMDB. + key (str): LMDB key for this filepath. + filepath (str): Filepath to add. + """ + with open(filepath, 'rb') as f: + data = f.read() + txn.put(key.encode('ascii'), data) + + +def build_lmdb(filepaths, keys, output_filepath, map_size, large): + r"""Write out lmdb containing (key, contents of filepath) to file. + + Args: + filepaths (list): List of filepath strings. + keys (list): List of key strings associated with filepaths. + output_filepath (str): Location to write LMDB to. + map_size (int): Size of LMDB. + large (bool): Is the dataset large? + """ + if large: + db = lmdb.open(output_filepath, map_size=map_size, writemap=True) + else: + db = lmdb.open(output_filepath, map_size=map_size) + txn = db.begin(write=True) + print('Writing LMDB to:', output_filepath) + for filepath, key in tqdm(zip(filepaths, keys), total=len(keys)): + write_entry(txn, key, filepath) + txn.commit() + + +def get_all_filenames_from_list(list_name): + r"""Get all filenames from list. + + Args: + list_name (str): Path to filename list. + Returns: + all_filenames (dict): Folder name for key, and filename for values. + """ + with open(list_name, 'rt') as f: + lines = f.readlines() + lines = [line.strip() for line in lines] + all_filenames = dict() + for line in lines: + if '/' in line: + file_str = line.split('/')[0:-1] + folder_name = os.path.join(*file_str) + image_name = line.split('/')[-1].replace('.jpg', '') + else: + folder_name = '.' + image_name = line.replace('.jpg', '') + if folder_name in all_filenames: + all_filenames[folder_name].append(image_name) + else: + all_filenames[folder_name] = [image_name] + return all_filenames + + +def get_lmdb_data_types(cfg): + r"""Get the data types which should be put in LMDB. + + Args: + cfg: Configuration object. + """ + data_types, extensions = [], [] + for data_type in cfg.data.input_types: + name = list(data_type.keys()) + assert len(name) == 1 + name = name[0] + info = data_type[name] + + if 'computed_on_the_fly' not in info: + info['computed_on_the_fly'] = False + is_lmdb = not info['computed_on_the_fly'] + if not is_lmdb: + continue + + ext = info['ext'] + data_types.append(name) + extensions.append(ext) + + cfg.data.data_types = data_types + cfg.data.extensions = extensions + return cfg + + +def create_metadata(data_root=None, cfg=None, paired=None, input_list=''): + r"""Main function. + + Args: + data_root (str): Location of dataset root. + cfg (object): Loaded config object. + paired (bool): Paired or unpaired data. + input_list (str): Path to filename containing list of inputs. + Returns: + (tuple): + - all_filenames (dict): Key of data type, values with sequences. + - extensions (dict): Extension of each data type. + """ + cfg = get_lmdb_data_types(cfg) + + # Get list of all data_types in the dataset. + available_data_types = path.get_immediate_subdirectories(data_root) + print(available_data_types) + required_data_types = cfg.data.data_types + data_exts = cfg.data.extensions + + # Find filenames. + assert set(required_data_types).issubset(set(available_data_types)), \ + print(set(required_data_types) - set(available_data_types), 'missing') + + # Find extensions for each data type. + extensions = {} + for data_type, data_ext in zip(required_data_types, data_exts): + extensions[data_type] = data_ext + print('Data file extensions:', extensions) + + if paired: + if input_list != '': + all_filenames = get_all_filenames_from_list(input_list) + else: + # Get list of all sequences in the dataset. + if 'data_keypoint' in required_data_types: + search_dir = 'data_keypoint' + elif 'data_segmaps' in required_data_types: + search_dir = 'data_segmaps' + else: + search_dir = required_data_types[0] + print('Searching in dir: %s' % search_dir) + sequences = path.get_recursive_subdirectories( + os.path.join(data_root, search_dir), + extensions[search_dir]) + print('Found %d sequences' % (len(sequences))) + + # Get filenames in each sequence. + all_filenames = {} + for sequence in sequences: + folder = '%s/%s/%s/*.%s' % ( + data_root, search_dir, sequence, + extensions[search_dir]) + filenames = sorted(glob.glob(folder)) + filenames = [ + os.path.splitext(os.path.basename(filename))[0] for + filename in filenames] + all_filenames[sequence] = filenames + total_filenames = [len(filenames) + for _, filenames in all_filenames.items()] + print('Found %d files' % (sum(total_filenames))) + else: + # Get sequences in each data type. + all_filenames = {} + for data_type in required_data_types: + all_filenames[data_type] = {} + sequences = path.get_recursive_subdirectories( + os.path.join(data_root, data_type), extensions[data_type]) + + # Get filenames in each sequence. + total_filenames = 0 + for sequence in sequences: + folder = '%s/%s/%s/*.%s' % ( + data_root, data_type, sequence, extensions[data_type]) + filenames = sorted(glob.glob(folder)) + filenames = [ + os.path.splitext(os.path.basename(filename))[0] for + filename in filenames] + all_filenames[data_type][sequence] = filenames + total_filenames += len(filenames) + print('Data type: %s, Found %d sequences, Found %d files' % + (data_type, len(sequences), total_filenames)) + + return all_filenames, extensions diff --git a/imaginaire/utils/logging.py b/imaginaire/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..449e0b7b892d0e11baccc5c2c2333afec8501422 --- /dev/null +++ b/imaginaire/utils/logging.py @@ -0,0 +1,51 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import datetime +import os + +from imaginaire.utils.distributed import master_only +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.meters import set_summary_writer + + +def get_date_uid(): + """Generate a unique id based on date. + Returns: + str: Return uid string, e.g. '20171122171307111552'. + """ + return str(datetime.datetime.now().strftime("%Y_%m%d_%H%M_%S")) + + +def init_logging(config_path, logdir): + r"""Create log directory for storing checkpoints and output images. + + Args: + config_path (str): Path to the configuration file. + logdir (str): Log directory name + Returns: + str: Return log dir + """ + config_file = os.path.basename(config_path) + root_dir = 'logs' + date_uid = get_date_uid() + # example: logs/2019_0125_1047_58_spade_cocostuff + log_file = '_'.join([date_uid, os.path.splitext(config_file)[0]]) + if logdir is None: + logdir = os.path.join(root_dir, log_file) + return date_uid, logdir + + +@master_only +def make_logging_dir(logdir): + r"""Create the logging directory + + Args: + logdir (str): Log directory name + """ + print('Make folder {}'.format(logdir)) + os.makedirs(logdir, exist_ok=True) + tensorboard_dir = os.path.join(logdir, 'tensorboard') + os.makedirs(tensorboard_dir, exist_ok=True) + set_summary_writer(tensorboard_dir) diff --git a/imaginaire/utils/meters.py b/imaginaire/utils/meters.py new file mode 100644 index 0000000000000000000000000000000000000000..3befb7b1e5fc44c00d3fe29092e75777afa64caa --- /dev/null +++ b/imaginaire/utils/meters.py @@ -0,0 +1,149 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import math +from datetime import timedelta + +import torch +import wandb +from wandb import AlertLevel +from torch.utils.tensorboard import SummaryWriter + +from imaginaire.utils.distributed import master_only, dist_all_reduce_tensor, \ + is_master, get_rank + +from imaginaire.utils.distributed import master_only_print as print + +LOG_WRITER = None +LOG_DIR = None + + +@torch.no_grad() +def sn_reshape_weight_to_matrix(weight): + r"""Reshape weight to obtain the matrix form. + + Args: + weight (Parameters): pytorch layer parameter tensor. + """ + weight_mat = weight + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + +@torch.no_grad() +def get_weight_stats(mod): + r"""Get weight state + + Args: + mod: Pytorch module + """ + if mod.weight_orig.grad is not None: + grad_norm = mod.weight_orig.grad.data.norm().item() + else: + grad_norm = 0. + weight_norm = mod.weight_orig.data.norm().item() + weight_mat = sn_reshape_weight_to_matrix(mod.weight_orig) + sigma = torch.sum(mod.weight_u * torch.mv(weight_mat, mod.weight_v)) + return grad_norm, weight_norm, sigma + + +@master_only +def set_summary_writer(log_dir): + r"""Set summary writer + + Args: + log_dir (str): Log directory. + """ + global LOG_DIR, LOG_WRITER + LOG_DIR = log_dir + LOG_WRITER = SummaryWriter(log_dir=log_dir) + + +def write_summary(name, summary, step, hist=False): + """Utility function for write summary to log_writer. + """ + global LOG_WRITER + lw = LOG_WRITER + if lw is None: + raise Exception("Log writer not set.") + if hist: + lw.add_histogram(name, summary, step) + else: + lw.add_scalar(name, summary, step) + + +class Meter(object): + """Meter is to keep track of statistics along steps. + Meters write values for purpose like printing average values. + Meters can be flushed to log files (i.e. TensorBoard for now) + regularly. + + Args: + name (str): the name of meter + reduce (bool): If ``True``, perform a distributed reduce for the log + values across all GPUs. + """ + + def __init__(self, name, reduce=True): + self.name = name + self.reduce = reduce + self.values = [] + + def reset(self): + r"""Reset the meter values""" + if not self.reduce and get_rank() != 0: + return + self.values = [] + + def write(self, value): + r"""Record the value""" + if not self.reduce and get_rank() != 0: + return + if value is not None: + self.values.append(value) + + def flush(self, step): + r"""Write the value in the tensorboard. + + Args: + step (int): Epoch or iteration number. + """ + if not self.reduce and get_rank() != 0: + return + values = torch.tensor(self.values, device="cuda") + if self.reduce: + values = dist_all_reduce_tensor(values) + + if not all(math.isfinite(x) for x in values): + print("meter {} contained a nan or inf.".format(self.name)) + if is_master(): + wandb.alert( + title='NaN', + text=f'Meter {self.name} contained a nan or inf.', + level=AlertLevel.WARN, + wait_duration=timedelta(minutes=120) + ) + filtered_values = list(filter(lambda x: math.isfinite(x), self.values)) + if float(len(filtered_values)) != 0: + value = float(sum(filtered_values)) / float(len(filtered_values)) + if is_master(): + write_summary(self.name, value, step) + wandb.log({self.name: value}, step=step) + self.reset() + + @master_only + def write_image(self, img_grid, step): + r"""Write the value in the tensorboard. + + Args: + img_grid: + step (int): Epoch or iteration number. + """ + if not self.reduce and get_rank() != 0: + return + global LOG_WRITER + lw = LOG_WRITER + if lw is None: + raise Exception("Log writer not set.") + lw.add_image("Visualizations", img_grid, step) diff --git a/imaginaire/utils/misc.py b/imaginaire/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..11ae68652975a2c7e2396c4d7eda2fa2f61fe5a5 --- /dev/null +++ b/imaginaire/utils/misc.py @@ -0,0 +1,269 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +"""Miscellaneous utils.""" +import collections +from collections import OrderedDict + +import torch +import torch.nn.functional as F +string_classes = (str, bytes) + + +def split_labels(labels, label_lengths): + r"""Split concatenated labels into their parts. + + Args: + labels (torch.Tensor): Labels obtained through concatenation. + label_lengths (OrderedDict): Containing order of labels & their lengths. + + Returns: + + """ + assert isinstance(label_lengths, OrderedDict) + start = 0 + outputs = {} + for data_type, length in label_lengths.items(): + end = start + length + if labels.dim() == 5: + outputs[data_type] = labels[:, :, start:end] + elif labels.dim() == 4: + outputs[data_type] = labels[:, start:end] + elif labels.dim() == 3: + outputs[data_type] = labels[start:end] + start = end + return outputs + + +def requires_grad(model, require=True): + r""" Set a model to require gradient or not. + + Args: + model (nn.Module): Neural network model. + require (bool): Whether the network requires gradient or not. + + Returns: + + """ + for p in model.parameters(): + p.requires_grad = require + + +def to_device(data, device): + r"""Move all tensors inside data to device. + + Args: + data (dict, list, or tensor): Input data. + device (str): 'cpu' or 'cuda'. + """ + assert device in ['cpu', 'cuda'] + if isinstance(data, torch.Tensor): + data = data.to(torch.device(device)) + return data + elif isinstance(data, collections.abc.Mapping): + return {key: to_device(data[key], device) for key in data} + elif isinstance(data, collections.abc.Sequence) and \ + not isinstance(data, string_classes): + return [to_device(d, device) for d in data] + else: + return data + + +def to_cuda(data): + r"""Move all tensors inside data to gpu. + + Args: + data (dict, list, or tensor): Input data. + """ + return to_device(data, 'cuda') + + +def to_cpu(data): + r"""Move all tensors inside data to cpu. + + Args: + data (dict, list, or tensor): Input data. + """ + return to_device(data, 'cpu') + + +def to_half(data): + r"""Move all floats to half. + + Args: + data (dict, list or tensor): Input data. + """ + if isinstance(data, torch.Tensor) and torch.is_floating_point(data): + data = data.half() + return data + elif isinstance(data, collections.abc.Mapping): + return {key: to_half(data[key]) for key in data} + elif isinstance(data, collections.abc.Sequence) and \ + not isinstance(data, string_classes): + return [to_half(d) for d in data] + else: + return data + + +def to_float(data): + r"""Move all halfs to float. + + Args: + data (dict, list or tensor): Input data. + """ + if isinstance(data, torch.Tensor) and torch.is_floating_point(data): + data = data.float() + return data + elif isinstance(data, collections.abc.Mapping): + return {key: to_float(data[key]) for key in data} + elif isinstance(data, collections.abc.Sequence) and \ + not isinstance(data, string_classes): + return [to_float(d) for d in data] + else: + return data + + +def to_channels_last(data): + r"""Move all data to ``channels_last`` format. + + Args: + data (dict, list or tensor): Input data. + """ + if isinstance(data, torch.Tensor): + if data.dim() == 4: + data = data.to(memory_format=torch.channels_last) + return data + elif isinstance(data, collections.abc.Mapping): + return {key: to_channels_last(data[key]) for key in data} + elif isinstance(data, collections.abc.Sequence) and \ + not isinstance(data, string_classes): + return [to_channels_last(d) for d in data] + else: + return data + + +def slice_tensor(data, start, end): + r"""Slice all tensors from start to end. + Args: + data (dict, list or tensor): Input data. + """ + if isinstance(data, torch.Tensor): + data = data[start:end] + return data + elif isinstance(data, collections.abc.Mapping): + return {key: slice_tensor(data[key], start, end) for key in data} + elif isinstance(data, collections.abc.Sequence) and \ + not isinstance(data, string_classes): + return [slice_tensor(d, start, end) for d in data] + else: + return data + + +def get_and_setattr(cfg, name, default): + r"""Get attribute with default choice. If attribute does not exist, set it + using the default value. + + Args: + cfg (obj) : Config options. + name (str) : Attribute name. + default (obj) : Default attribute. + + Returns: + (obj) : Desired attribute. + """ + if not hasattr(cfg, name) or name not in cfg.__dict__: + setattr(cfg, name, default) + return getattr(cfg, name) + + +def get_nested_attr(cfg, attr_name, default): + r"""Iteratively try to get the attribute from cfg. If not found, return + default. + + Args: + cfg (obj): Config file. + attr_name (str): Attribute name (e.g. XXX.YYY.ZZZ). + default (obj): Default return value for the attribute. + + Returns: + (obj): Attribute value. + """ + names = attr_name.split('.') + atr = cfg + for name in names: + if not hasattr(atr, name): + return default + atr = getattr(atr, name) + return atr + + +def gradient_norm(model): + r"""Return the gradient norm of model. + + Args: + model (PyTorch module): Your network. + + """ + total_norm = 0 + for p in model.parameters(): + if p.grad is not None: + param_norm = p.grad.norm(2) + total_norm += param_norm.item() ** 2 + return total_norm ** (1. / 2) + + +def random_shift(x, offset=0.05, mode='bilinear', padding_mode='reflection'): + r"""Randomly shift the input tensor. + + Args: + x (4D tensor): The input batch of images. + offset (int): The maximum offset ratio that is between [0, 1]. + The maximum shift is offset * image_size for each direction. + mode (str): The resample mode for 'F.grid_sample'. + padding_mode (str): The padding mode for 'F.grid_sample'. + + Returns: + x (4D tensor) : The randomly shifted image. + """ + assert x.dim() == 4, "Input must be a 4D tensor." + batch_size = x.size(0) + theta = torch.eye(2, 3, device=x.device).unsqueeze(0).repeat( + batch_size, 1, 1) + theta[:, :, 2] = 2 * offset * torch.rand(batch_size, 2) - offset + grid = F.affine_grid(theta, x.size()) + x = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode) + return x + + +# def truncated_gaussian(threshold, size, seed=None, device=None): +# r"""Apply the truncated gaussian trick to trade diversity for quality +# +# Args: +# threshold (float): Truncation threshold. +# size (list of integer): Tensor size. +# seed (int): Random seed. +# device: +# """ +# state = None if seed is None else np.random.RandomState(seed) +# values = truncnorm.rvs(-threshold, threshold, +# size=size, random_state=state) +# return torch.tensor(values, device=device).float() + + +def apply_imagenet_normalization(input): + r"""Normalize using ImageNet mean and std. + + Args: + input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1]. + + Returns: + Normalized inputs using the ImageNet normalization. + """ + # normalize the input back to [0, 1] + normalized_input = (input + 1) / 2 + # normalize the input using the ImageNet mean and std + mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + output = (normalized_input - mean) / std + return output diff --git a/imaginaire/utils/model_average.py b/imaginaire/utils/model_average.py new file mode 100644 index 0000000000000000000000000000000000000000..470428147c9e6cc55df74ad14c010a15cb874a29 --- /dev/null +++ b/imaginaire/utils/model_average.py @@ -0,0 +1,215 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import copy + +import torch +from torch import nn +from imaginaire.layers.weight_norm import remove_weight_norms +from imaginaire.utils.misc import requires_grad + + +def reset_batch_norm(m): + r"""Reset batch norm statistics + + Args: + m: Pytorch module + """ + if hasattr(m, 'reset_running_stats'): + m.reset_running_stats() + + +def calibrate_batch_norm_momentum(m): + r"""Calibrate batch norm momentum + + Args: + m: Pytorch module + """ + if hasattr(m, 'reset_running_stats'): + # if m._get_name() == 'SyncBatchNorm': + if 'BatchNorm' in m._get_name(): + m.momentum = 1.0 / float(m.num_batches_tracked + 1) + + +class ModelAverage(nn.Module): + r"""In this model average implementation, the spectral layers are + absorbed in the model parameter by default. If such options are + turned on, be careful with how you do the training. Remember to + re-estimate the batch norm parameters before using the model. + + Args: + module (torch nn module): Torch network. + beta (float): Moving average weights. How much we weight the past. + start_iteration (int): From which iteration, we start the update. + remove_sn (bool): Whether we remove the spectral norm when we it. + """ + def __init__( + self, module, beta=0.9999, start_iteration=1000, + remove_wn_wrapper=True + ): + super(ModelAverage, self).__init__() + self.module = module + # A shallow copy creates a new object which stores the reference of + # the original elements. + # A deep copy creates a new object and recursively adds the copies of + # nested objects present in the original elements. + self.averaged_model = copy.deepcopy(self.module).to('cuda') + self.beta = beta + self.remove_wn_wrapper = remove_wn_wrapper + self.start_iteration = start_iteration + # This buffer is to track how many iterations has the model been + # trained for. We will ignore the first $(start_iterations) and start + # the averaging after. + self.register_buffer('num_updates_tracked', + torch.tensor(0, dtype=torch.long)) + self.num_updates_tracked = self.num_updates_tracked.to('cuda') + # if self.remove_sn: + # # If we want to remove the spectral norm, we first copy the + # # weights to the moving average model. + # self.copy_s2t() + # + # def fn_remove_sn(m): + # r"""Remove spectral norm.""" + # if hasattr(m, 'weight_orig'): + # remove_spectral_norm(m) + # + # self.averaged_model.apply(fn_remove_sn) + # self.dim = 0 + if self.remove_wn_wrapper: + self.copy_s2t() + + self.averaged_model.apply(remove_weight_norms) + self.dim = 0 + else: + self.averaged_model.eval() + + # Averaged model does not require grad. + requires_grad(self.averaged_model, False) + + def forward(self, *inputs, **kwargs): + r"""PyTorch module forward function overload.""" + return self.module(*inputs, **kwargs) + + @torch.no_grad() + def update_average(self): + r"""Update the moving average.""" + self.num_updates_tracked += 1 + if self.num_updates_tracked <= self.start_iteration: + beta = 0. + else: + beta = self.beta + source_dict = self.module.state_dict() + target_dict = self.averaged_model.state_dict() + for key in target_dict: + if 'num_batches_tracked' in key: + continue + if self.remove_wn_wrapper: + if key.endswith('weight'): + # This is a weight parameter. + if key + '_ori' in source_dict: + # This parameter has scaled lr. + source_param = \ + source_dict[key + '_ori'] * \ + source_dict[key + '_scale'] + elif key + '_orig' in source_dict: + # This parameter has spectral norm + # but not scaled lr. + source_param = source_dict[key + '_orig'] + elif key in source_dict: + # This parameter does not have + # weight normalization wrappers. + source_param = source_dict[key] + else: + raise ValueError( + f"{key} required in the averaged model but not " + f"found in the regular model." + ) + source_param = source_param.detach() + + if key + '_orig' in source_dict: + # This parameter has spectral norm. + source_param = self.sn_compute_weight( + source_param, + source_dict[key + '_u'], + source_dict[key + '_v'], + ) + elif key.endswith('bias') and key + '_ori' in source_dict: + # This is a bias parameter and has scaled lr. + source_param = source_dict[key + '_ori'] * \ + source_dict[key + '_scale'] + else: + # This is a normal parameter. + source_param = source_dict[key] + target_dict[key].data.mul_(beta).add_( + source_param.data, alpha=1 - beta + ) + else: + target_dict[key].data.mul_(beta).add_( + source_dict[key].data, alpha=1 - beta + ) + + @torch.no_grad() + def copy_t2s(self): + r"""Copy the original weights to the moving average weights.""" + target_dict = self.module.state_dict() + source_dict = self.averaged_model.state_dict() + beta = 0. + for key in source_dict: + target_dict[key].data.copy_( + target_dict[key].data * beta + + source_dict[key].data * (1 - beta)) + + @torch.no_grad() + def copy_s2t(self): + r""" Copy state_dictionary from source to target. + Here source is the regular module and the target is the moving + average module. Basically, we will copy weights in the regular module + to the moving average module. + """ + source_dict = self.module.state_dict() + target_dict = self.averaged_model.state_dict() + beta = 0. + for key in source_dict: + target_dict[key].data.copy_( + target_dict[key].data * beta + + source_dict[key].data * (1 - beta)) + + def __repr__(self): + r"""Returns a string that holds a printable representation of an + object""" + return self.module.__repr__() + + def sn_reshape_weight_to_matrix(self, weight): + r"""Reshape weight to obtain the matrix form. + + Args: + weight (Parameters): pytorch layer parameter tensor. + + Returns: + (Parameters): Reshaped weight matrix + """ + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute( + self.dim, + *[d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def sn_compute_weight(self, weight, u, v): + r"""Compute the spectral norm normalized matrix. + + Args: + weight (Parameters): pytorch layer parameter tensor. + u (tensor): left singular vectors. + v (tensor) right singular vectors + + Returns: + (Parameters): weight parameter object. + """ + weight_mat = self.sn_reshape_weight_to_matrix(weight) + sigma = torch.sum(u * torch.mv(weight_mat, v)) + weight = weight / sigma + return weight diff --git a/imaginaire/utils/path.py b/imaginaire/utils/path.py new file mode 100644 index 0000000000000000000000000000000000000000..e576fc91e66d7c1931b0fb3f349363b49f62c8d5 --- /dev/null +++ b/imaginaire/utils/path.py @@ -0,0 +1,36 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +"""Utils to deal with directories and paths.""" + +import glob +import os + + +def get_immediate_subdirectories(input_dir): + """List dirs immediately under input_dir. + + Args: + input_dir (str): Directory to list children of. + Returns: + (list): List of directory paths relative to input_dir. + """ + return sorted([name for name in os.listdir(input_dir) + if os.path.isdir(os.path.join(input_dir, name))]) + + +def get_recursive_subdirectories(input_dir, ext): + """List dirs recursively under input_dir. + + Args: + input_dir (str): Directory to list children of. + ext (str): Extension of files expected in this directory. + Returns: + (list): List of directory paths relative to input_dir. + """ + lines = glob.glob('%s/**/*.%s' % (input_dir, ext), recursive=True) + dirpaths = [os.path.dirname(item) for item in lines] + dirpaths = [os.path.relpath(item, input_dir) for item in dirpaths] + dirpaths = sorted(list(set(dirpaths))) + return dirpaths diff --git a/imaginaire/utils/trainer.py b/imaginaire/utils/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb8593ac2904433ce408619d7f28c6fa80ababd --- /dev/null +++ b/imaginaire/utils/trainer.py @@ -0,0 +1,341 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import importlib +import random +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.optim import SGD, Adam, RMSprop, lr_scheduler + +from imaginaire.optimizers import Fromage, Madam +from imaginaire.utils.distributed import get_rank, get_world_size +from imaginaire.utils.distributed import master_only_print as print +from imaginaire.utils.init_weight import weights_init, weights_rescale +from imaginaire.utils.model_average import ModelAverage + + +def set_random_seed(seed, by_rank=False): + r"""Set random seeds for everything. + + Args: + seed (int): Random seed. + by_rank (bool): + """ + if by_rank: + seed += get_rank() + print(f"Using random seed {seed}") + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_trainer(cfg, net_G, net_D=None, + opt_G=None, opt_D=None, + sch_G=None, sch_D=None, + train_data_loader=None, + val_data_loader=None): + """Return the trainer object. + + Args: + cfg (Config): Loaded config object. + net_G (obj): Generator network object. + net_D (obj): Discriminator network object. + opt_G (obj): Generator optimizer object. + opt_D (obj): Discriminator optimizer object. + sch_G (obj): Generator optimizer scheduler object. + sch_D (obj): Discriminator optimizer scheduler object. + train_data_loader (obj): Train data loader. + val_data_loader (obj): Validation data loader. + + Returns: + (obj): Trainer object. + """ + trainer_lib = importlib.import_module(cfg.trainer.type) + trainer = trainer_lib.Trainer(cfg, net_G, net_D, + opt_G, opt_D, + sch_G, sch_D, + train_data_loader, val_data_loader) + return trainer + + +def get_model_optimizer_and_scheduler(cfg, seed=0): + r"""Return the networks, the optimizers, and the schedulers. We will + first set the random seed to a fixed value so that each GPU copy will be + initialized to have the same network weights. We will then use different + random seeds for different GPUs. After this we will wrap the generator + with a moving average model if applicable. It is followed by getting the + optimizers and data distributed data parallel wrapping. + + Args: + cfg (obj): Global configuration. + seed (int): Random seed. + + Returns: + (dict): + - net_G (obj): Generator network object. + - net_D (obj): Discriminator network object. + - opt_G (obj): Generator optimizer object. + - opt_D (obj): Discriminator optimizer object. + - sch_G (obj): Generator optimizer scheduler object. + - sch_D (obj): Discriminator optimizer scheduler object. + """ + # We first set the random seed to be the same so that we initialize each + # copy of the network in exactly the same way so that they have the same + # weights and other parameters. The true seed will be the seed. + set_random_seed(seed, by_rank=False) + # Construct networks + lib_G = importlib.import_module(cfg.gen.type) + lib_D = importlib.import_module(cfg.dis.type) + net_G = lib_G.Generator(cfg.gen, cfg.data) + net_D = lib_D.Discriminator(cfg.dis, cfg.data) + print('Initialize net_G and net_D weights using ' + 'type: {} gain: {}'.format(cfg.trainer.init.type, + cfg.trainer.init.gain)) + init_bias = getattr(cfg.trainer.init, 'bias', None) + net_G.apply(weights_init( + cfg.trainer.init.type, cfg.trainer.init.gain, init_bias)) + net_D.apply(weights_init( + cfg.trainer.init.type, cfg.trainer.init.gain, init_bias)) + net_G.apply(weights_rescale()) + net_D.apply(weights_rescale()) + # for name, p in net_G.named_parameters(): + # if 'modulation' in name and 'bias' in name: + # nn.init.constant_(p.data, 1.) + net_G = net_G.to('cuda') + net_D = net_D.to('cuda') + # Different GPU copies of the same model will receive noises + # initialized with different random seeds (if applicable) thanks to the + # set_random_seed command (GPU #K has random seed = args.seed + K). + set_random_seed(seed, by_rank=True) + print('net_G parameter count: {:,}'.format(_calculate_model_size(net_G))) + print('net_D parameter count: {:,}'.format(_calculate_model_size(net_D))) + + # Optimizer + opt_G = get_optimizer(cfg.gen_opt, net_G) + opt_D = get_optimizer(cfg.dis_opt, net_D) + + net_G, net_D, opt_G, opt_D = \ + wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D) + + # Scheduler + sch_G = get_scheduler(cfg.gen_opt, opt_G) + sch_D = get_scheduler(cfg.dis_opt, opt_D) + + return net_G, net_D, opt_G, opt_D, sch_G, sch_D + + +def wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D): + r"""Wrap the networks and the optimizers with AMP DDP and (optionally) + model average. + + Args: + cfg (obj): Global configuration. + net_G (obj): Generator network object. + net_D (obj): Discriminator network object. + opt_G (obj): Generator optimizer object. + opt_D (obj): Discriminator optimizer object. + + Returns: + (dict): + - net_G (obj): Generator network object. + - net_D (obj): Discriminator network object. + - opt_G (obj): Generator optimizer object. + - opt_D (obj): Discriminator optimizer object. + """ + # Apply model average wrapper. + if cfg.trainer.model_average_config.enabled: + if hasattr(cfg.trainer.model_average_config, 'g_smooth_img'): + # Specifies half-life of the running average of generator weights. + cfg.trainer.model_average_config.beta = \ + 0.5 ** (cfg.data.train.batch_size * + get_world_size() / cfg.trainer.model_average_config.g_smooth_img) + print(f"EMA Decay Factor: {cfg.trainer.model_average_config.beta}") + net_G = ModelAverage(net_G, cfg.trainer.model_average_config.beta, + cfg.trainer.model_average_config.start_iteration, + cfg.trainer.model_average_config.remove_sn) + if cfg.trainer.model_average_config.enabled: + net_G_module = net_G.module + else: + net_G_module = net_G + if hasattr(net_G_module, 'custom_init'): + net_G_module.custom_init() + + net_G = _wrap_model(cfg, net_G) + net_D = _wrap_model(cfg, net_D) + return net_G, net_D, opt_G, opt_D + + +def _calculate_model_size(model): + r"""Calculate number of parameters in a PyTorch network. + + Args: + model (obj): PyTorch network. + + Returns: + (int): Number of parameters. + """ + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +class WrappedModel(nn.Module): + r"""Dummy wrapping the module. + """ + + def __init__(self, module): + super(WrappedModel, self).__init__() + self.module = module + + def forward(self, *args, **kwargs): + r"""PyTorch module forward function overload.""" + return self.module(*args, **kwargs) + + +def _wrap_model(cfg, model): + r"""Wrap a model for distributed data parallel training. + + Args: + model (obj): PyTorch network model. + + Returns: + (obj): Wrapped PyTorch network model. + """ + if torch.distributed.is_available() and dist.is_initialized(): + # ddp = cfg.trainer.distributed_data_parallel + find_unused_parameters = cfg.trainer.distributed_data_parallel_params.find_unused_parameters + return torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[cfg.local_rank], + output_device=cfg.local_rank, + find_unused_parameters=find_unused_parameters, + broadcast_buffers=False + ) + # if ddp == 'pytorch': + # return torch.nn.parallel.DistributedDataParallel( + # model, + # device_ids=[cfg.local_rank], + # output_device=cfg.local_rank, + # find_unused_parameters=find_unused_parameters, + # broadcast_buffers=False) + # else: + # delay_allreduce = cfg.trainer.delay_allreduce + # return apex.parallel.DistributedDataParallel( + # model, delay_allreduce=delay_allreduce) + else: + return WrappedModel(model) + + +def get_scheduler(cfg_opt, opt): + """Return the scheduler object. + + Args: + cfg_opt (obj): Config for the specific optimization module (gen/dis). + opt (obj): PyTorch optimizer object. + + Returns: + (obj): Scheduler + """ + if cfg_opt.lr_policy.type == 'step': + scheduler = lr_scheduler.StepLR( + opt, + step_size=cfg_opt.lr_policy.step_size, + gamma=cfg_opt.lr_policy.gamma) + elif cfg_opt.lr_policy.type == 'constant': + scheduler = lr_scheduler.LambdaLR(opt, lambda x: 1) + elif cfg_opt.lr_policy.type == 'linear': + # Start linear decay from here. + decay_start = cfg_opt.lr_policy.decay_start + # End linear decay here. + # Continue to train using the lowest learning rate till the end. + decay_end = cfg_opt.lr_policy.decay_end + # Lowest learning rate multiplier. + decay_target = cfg_opt.lr_policy.decay_target + + def sch(x): + return min( + max(((x - decay_start) * decay_target + decay_end - x) / ( + decay_end - decay_start + ), decay_target), 1. + ) + scheduler = lr_scheduler.LambdaLR(opt, lambda x: sch(x)) + else: + return NotImplementedError('Learning rate policy {} not implemented.'. + format(cfg_opt.lr_policy.type)) + return scheduler + + +def get_optimizer(cfg_opt, net): + r"""Return the scheduler object. + + Args: + cfg_opt (obj): Config for the specific optimization module (gen/dis). + net (obj): PyTorch network object. + + Returns: + (obj): Pytorch optimizer + """ + if hasattr(net, 'get_param_groups'): + # Allow the network to use different hyper-parameters (e.g., learning + # rate) for different parameters. + params = net.get_param_groups(cfg_opt) + else: + params = net.parameters() + return get_optimizer_for_params(cfg_opt, params) + + +def get_optimizer_for_params(cfg_opt, params): + r"""Return the scheduler object. + + Args: + cfg_opt (obj): Config for the specific optimization module (gen/dis). + params (obj): Parameters to be trained by the parameters. + + Returns: + (obj): Optimizer + """ + # We will use fuse optimizers by default. + fused_opt = cfg_opt.fused_opt + try: + from apex.optimizers import FusedAdam + except: # noqa + fused_opt = False + + if cfg_opt.type == 'adam': + if fused_opt: + opt = FusedAdam(params, + lr=cfg_opt.lr, eps=cfg_opt.eps, + betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2)) + else: + opt = Adam(params, + lr=cfg_opt.lr, eps=cfg_opt.eps, + betas=(cfg_opt.adam_beta1, cfg_opt.adam_beta2)) + + elif cfg_opt.type == 'madam': + g_bound = getattr(cfg_opt, 'g_bound', None) + opt = Madam(params, lr=cfg_opt.lr, + scale=cfg_opt.scale, g_bound=g_bound) + elif cfg_opt.type == 'fromage': + opt = Fromage(params, lr=cfg_opt.lr) + elif cfg_opt.type == 'rmsprop': + opt = RMSprop(params, lr=cfg_opt.lr, + eps=cfg_opt.eps, weight_decay=cfg_opt.weight_decay) + elif cfg_opt.type == 'sgd': + if fused_opt: + from apex.optimizers import FusedSGD + opt = FusedSGD(params, + lr=cfg_opt.lr, + momentum=cfg_opt.momentum, + weight_decay=cfg_opt.weight_decay) + else: + opt = SGD(params, + lr=cfg_opt.lr, + momentum=cfg_opt.momentum, + weight_decay=cfg_opt.weight_decay) + else: + raise NotImplementedError( + 'Optimizer {} is not yet implemented.'.format(cfg_opt.type)) + return opt diff --git a/imaginaire/utils/visualization/__init__.py b/imaginaire/utils/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..27e3e7c0383a7e5f032593a50930e9d48bd0292b --- /dev/null +++ b/imaginaire/utils/visualization/__init__.py @@ -0,0 +1,9 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +from .common import tensor2im, tensor2flow, tensor2label, tensor2pilimage +from .common import save_tensor_image + +__all__ = ['tensor2im', 'tensor2flow', 'tensor2label', 'tensor2pilimage', + 'save_tensor_image'] diff --git a/imaginaire/utils/visualization/common.py b/imaginaire/utils/visualization/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b68c5b670c386fe9bef916db13c92682b81bd2 --- /dev/null +++ b/imaginaire/utils/visualization/common.py @@ -0,0 +1,314 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import cv2 +import numpy as np +import PIL +from PIL import Image +import torch +import torchvision +import os + + +def save_tensor_image( + filename, image, minus1to1_normalized=False): + r"""Convert a 3 dimensional torch tensor to a PIL image with the desired + width and height. + + Args: + filename (str): Image filename to be saved to. + image (3 x W1 x H1 tensor): Image tensor + minus1to1_normalized (bool): True if the tensor values are in [-1, + 1]. Otherwise, we assume the values are in [0, 1]. + + Returns: + (PIL image): The resulting PIL image. + """ + if len(image.size()) != 3: + raise ValueError('Image tensor dimension does not equal = 3.') + if image.size(0) != 3: + raise ValueError('Image has more than 3 channels.') + if minus1to1_normalized: + # Normalize back to [0, 1] + image = (image + 1) * 0.5 + dirname = os.path.dirname(filename) + os.makedirs(dirname, exist_ok=True) + image_grid = torchvision.utils.make_grid( + image, nrow=1, padding=0, normalize=False) + torchvision.utils.save_image(image_grid, filename, nrow=1) + return + + +def tensor2pilimage(image, width=None, height=None, minus1to1_normalized=False): + r"""Convert a 3 dimensional torch tensor to a PIL image with the desired + width and height. + + Args: + image (3 x W1 x H1 tensor): Image tensor + width (int): Desired width for the result PIL image. + height (int): Desired height for the result PIL image. + minus1to1_normalized (bool): True if the tensor values are in [-1, + 1]. Otherwise, we assume the values are in [0, 1]. + + Returns: + (PIL image): The resulting PIL image. + """ + if len(image.size()) != 3: + raise ValueError('Image tensor dimension does not equal = 3.') + if image.size(0) != 3: + raise ValueError('Image has more than 3 channels.') + if minus1to1_normalized: + # Normalize back to [0, 1] + image = (image + 1) * 0.5 + image = image.detach().cpu().squeeze().numpy() + image = np.transpose(image, (1, 2, 0)) * 255 + output_img = Image.fromarray(np.uint8(image)) + if width is not None and height is not None: + output_img = output_img.resize((width, height), Image.BICUBIC) + return output_img + + +def tensor2im(image_tensor, imtype=np.uint8, normalize=True, + three_channel_output=True): + r"""Convert tensor to image. + + Args: + image_tensor (torch.tensor or list of torch.tensor): If tensor then + (NxCxHxW) or (NxTxCxHxW) or (CxHxW). + imtype (np.dtype): Type of output image. + normalize (bool): Is the input image normalized or not? + three_channel_output (bool): Should single channel images be made 3 + channel in output? + + Returns: + (numpy.ndarray, list if case 1, 2 above). + """ + if image_tensor is None: + return None + if isinstance(image_tensor, list): + return [tensor2im(x, imtype, normalize) for x in image_tensor] + if image_tensor.dim() == 5 or image_tensor.dim() == 4: + return [tensor2im(image_tensor[idx], imtype, normalize) + for idx in range(image_tensor.size(0))] + + if image_tensor.dim() == 3: + image_numpy = image_tensor.cpu().float().numpy() + if normalize: + image_numpy = (np.transpose( + image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1 and three_channel_output: + image_numpy = np.repeat(image_numpy, 3, axis=2) + elif image_numpy.shape[2] > 3: + image_numpy = image_numpy[:, :, :3] + return image_numpy.astype(imtype) + + +def tensor2label(segmap, n_label=None, imtype=np.uint8, + colorize=True, output_normalized_tensor=False): + r"""Convert segmentation mask tensor to color image. + Args: + segmap (tensor) of + If tensor then (NxCxHxW) or (NxTxCxHxW) or (CxHxW). + n_label (int): If None, then segmap.size(0). + imtype (np.dtype): Type of output image. + colorize (bool): Put colors in. + + Returns: + (numpy.ndarray or normalized torch image). + """ + if segmap is None: + return None + if isinstance(segmap, list): + return [tensor2label(x, n_label, + imtype, colorize, + output_normalized_tensor) for x in segmap] + if segmap.dim() == 5 or segmap.dim() == 4: + return [tensor2label(segmap[idx], n_label, + imtype, colorize, + output_normalized_tensor) + for idx in range(segmap.size(0))] + + segmap = segmap.float() + if not output_normalized_tensor: + segmap = segmap.cpu() + if n_label is None: + n_label = segmap.size(0) + if n_label > 1: + segmap = segmap.max(0, keepdim=True)[1] + + if output_normalized_tensor: + if n_label == 0: + segmap = Colorize(256)(segmap).to('cuda') + else: + segmap = Colorize(n_label)(segmap).to('cuda') + return 2 * (segmap.float() / 255) - 1 + else: + if colorize: + segmap = Colorize(n_label)(segmap) + segmap = np.transpose(segmap.numpy(), (1, 2, 0)) + else: + segmap = segmap.cpu().numpy() + return segmap.astype(imtype) + + +def tensor2flow(tensor, imtype=np.uint8): + r"""Convert flow tensor to color image. + + Args: + tensor (tensor) of + If tensor then (NxCxHxW) or (NxTxCxHxW) or (CxHxW). + imtype (np.dtype): Type of output image. + + Returns: + (numpy.ndarray or normalized torch image). + """ + if tensor is None: + return None + if isinstance(tensor, list): + tensor = [t for t in tensor if t is not None] + if not tensor: + return None + return [tensor2flow(t, imtype) for t in tensor] + if tensor.dim() == 5 or tensor.dim() == 4: + return [tensor2flow(tensor[b]) for b in range(tensor.size(0))] + + tensor = tensor.detach().cpu().float().numpy() + tensor = np.transpose(tensor, (1, 2, 0)) + + hsv = np.zeros((tensor.shape[0], tensor.shape[1], 3), dtype=imtype) + hsv[:, :, 0] = 255 + hsv[:, :, 1] = 255 + mag, ang = cv2.cartToPolar(tensor[..., 0], tensor[..., 1]) + hsv[..., 0] = ang * 180 / np.pi / 2 + hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + return rgb + + +def plot_keypoints(image, keypoints, normalize=True): + r"""Plot keypoints on image. + + Args: + image (PIL.Image, or numpy.ndarray, or torch.Tensor): Input image. + keypoints (np.ndarray or torch.Tensor, Nx2): Keypoint locations. + normalize (bool): Whether to normalize the image or not. + """ + if isinstance(image, PIL.Image.Image): + image = np.array(image) + if isinstance(image, torch.Tensor): + image = tensor2im(image, normalize=normalize) + if isinstance(image, np.ndarray): + assert image.ndim == 3 + assert image.shape[-1] == 1 or image.shape[-1] == 3 + if isinstance(keypoints, torch.Tensor): + keypoints = keypoints.cpu().numpy() + assert keypoints.ndim == 2 and keypoints.shape[1] == 2 + + cv2_image = np.ascontiguousarray(image[:, :, ::-1]) # RGB to BGR. + for idx in range(keypoints.shape[0]): + keypoint = np.round(keypoints[idx]).astype(np.int) + cv2_image = cv2.circle(cv2_image, tuple(keypoint), + 5, (0, 255, 0), -1) + image = np.ascontiguousarray(cv2_image[:, :, ::-1]) + return image + + +def labelcolormap(N): + r"""Create colors for segmentation label ids. + + Args: + N (int): Number of labels. + """ + if N == 35: # GTA/cityscape train + cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), + (111, 74, 0), (81, 0, 81), (128, 64, 128), + (244, 35, 232), (250, 170, 160), (230, 150, 140), + (70, 70, 70), (102, 102, 156), (190, 153, 153), + (180, 165, 180), (150, 100, 100), (150, 120, 90), + (153, 153, 153), (153, 153, 153), (250, 170, 30), + (220, 220, 0), (107, 142, 35), (152, 251, 152), + (70, 130, 180), (220, 20, 60), (255, 0, 0), + (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 0, 90), + (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), + (0, 0, 142)], + dtype=np.uint8) + elif N == 20: # GTA/cityscape eval + cmap = np.array([(128, 64, 128), (244, 35, 232), (70, 70, 70), + (102, 102, 156), (190, 153, 153), (153, 153, 153), + (250, 170, 30), (220, 220, 0), (107, 142, 35), + (152, 251, 152), (220, 20, 60), (255, 0, 0), + (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), + (0, 0, 230), (119, 11, 32), (70, 130, 180), (0, 0, 0)], + dtype=np.uint8) + else: + cmap = np.zeros([N, 3]).astype(np.uint8) + for i in range(N): + r, g, b = np.zeros(3) + for j in range(8): + r = r + (1 << (7 - j)) * ((i & (1 << (3 * j))) >> (3 * j)) + g = g + (1 << (7 - j)) * \ + ((i & (1 << (3 * j + 1))) >> (3 * j + 1)) + b = b + (1 << (7 - j)) * \ + ((i & (1 << (3 * j + 2))) >> (3 * j + 2)) + cmap[i, :] = np.array([r, g, b]) + return cmap + + +class Colorize(object): + """Class to colorize segmentation maps.""" + + def __init__(self, n=35): + self.cmap = labelcolormap(n) + self.cmap = torch.from_numpy(self.cmap[:n]) + + def __call__(self, seg_map): + r""" + + Args: + seg_map (tensor): Input Segmentation maps to be colorized. + """ + size = seg_map.size() + color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) + for label in range(0, len(self.cmap)): + mask = (label == seg_map[0]).cpu() + color_image[0][mask] = self.cmap[label][0] + color_image[1][mask] = self.cmap[label][1] + color_image[2][mask] = self.cmap[label][2] + return color_image + + +def plot_keypoints_on_black(resize_h, resize_w, crop_h, crop_w, is_flipped, + cfgdata, keypoints): + r"""Plot keypoints on black image. + + Args: + resize_h (int): Height to be resized to. + resize_w (int): Width to be resized to. + crop_h (int): Height of the cropping. + crop_w (int): Width of the cropping. + is_flipped (bool): If image is a flipped version. + cfgdata (obj): Data configuration object. + keypoints (np.ndarray): Keypoint locations. Shape of + (Nx2) or (TxNx2). + + Returns: + (list of np.ndarray): List of images (output_h, output_w, 3). + """ + if keypoints.ndim == 2 and keypoints.shape[1] == 2: + keypoints = keypoints[np.newaxis, ...] + + outputs = [] + for t_idx in range(keypoints.shape[0]): + cv2_image = np.zeros((crop_h, crop_w, 3)).astype(np.uint8) + for idx in range(keypoints[t_idx].shape[0]): + keypoint = np.round(keypoints[t_idx][idx]).astype(np.int) + cv2_image = cv2.circle(cv2_image, tuple(keypoint), + 5, (0, 255, 0), -1) + image = np.ascontiguousarray(cv2_image[:, :, ::-1]) # BGR to RGB. + outputs.append(image) + + return outputs diff --git a/imaginaire/utils/visualization/face.py b/imaginaire/utils/visualization/face.py new file mode 100644 index 0000000000000000000000000000000000000000..19728cda34ddf75552fab9c421a6d76af7983542 --- /dev/null +++ b/imaginaire/utils/visualization/face.py @@ -0,0 +1,491 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import numpy as np +import cv2 +import torch +from scipy.optimize import curve_fit +from scipy.signal import medfilt +import warnings +from imaginaire.utils.io import get_checkpoint + + +def connect_face_keypoints(resize_h, resize_w, crop_h, crop_w, original_h, + original_w, is_flipped, cfgdata, keypoints): + r"""Connect the face keypoints to edges and draw the sketch. + + Args: + resize_h (int): Height the input image was resized to. + resize_w (int): Width the input image was resized to. + crop_h (int): Height the input image was cropped. + crop_w (int): Width the input image was cropped. + original_h (int): Original height of the input image. + original_w (int): Original width of the input image. + is_flipped (bool): Is the input image flipped. + cfgdata (obj): Data configuration. + keypoints (NxKx2 numpy array): Facial landmarks (with K keypoints). + + Returns: + (list of HxWxC numpy array): Drawn label map. + """ + if hasattr(cfgdata, 'for_face_dataset'): + face_cfg = cfgdata.for_face_dataset + # Whether to add the upper part of face to label map. + add_upper_face = getattr(face_cfg, 'add_upper_face', False) + # Whether to add distance transform output to label map. + add_dist_map = getattr(face_cfg, 'add_distance_transform', False) + # Whether to add positional encoding to label map. + add_pos_encode = add_dist_map and getattr( + face_cfg, 'add_positional_encode', False) + else: + add_upper_face = add_dist_map = add_pos_encode = False + + # Mapping from keypoint index to facial part. + part_list = [[list(range(0, 17)) + ( + (list(range(68, 83)) + [0]) if add_upper_face else [])], # ai_emoji + [range(17, 22)], # right eyebrow + [range(22, 27)], # left eyebrow + [[28, 31], range(31, 36), [35, 28]], # nose + [[36, 37, 38, 39], [39, 40, 41, 36]], # right eye + [[42, 43, 44, 45], [45, 46, 47, 42]], # left eye + [range(48, 55), [54, 55, 56, 57, 58, 59, 48], + range(60, 65), [64, 65, 66, 67, 60]], # mouth and tongue + ] + if add_upper_face: + pts = keypoints[:, :17, :].astype(np.int32) + baseline_y = (pts[:, 0:1, 1] + pts[:, -1:, 1]) / 2 + upper_pts = pts[:, 1:-1, :].copy() + upper_pts[:, :, 1] = baseline_y + ( + baseline_y - upper_pts[:, :, 1]) * 2 // 3 + keypoints = np.hstack((keypoints, upper_pts[:, ::-1, :])) + + edge_len = 3 # Interpolate 3 keypoints to form a curve when drawing edges. + bw = max(1, resize_h // 256) # Width of the stroke. + + outputs = [] + for t_idx in range(keypoints.shape[0]): + # Edge map for the face region from keypoints. + im_edges = np.zeros((resize_h, resize_w, 1), np.uint8) + im_dists = np.zeros((resize_h, resize_w, 0), np.uint8) + for edge_list in part_list: + for e, edge in enumerate(edge_list): + # Edge map for the current edge. + im_edge = np.zeros((resize_h, resize_w, 1), np.uint8) + # Divide a long edge into multiple small edges when drawing. + for i in range(0, max(1, len(edge) - 1), edge_len - 1): + sub_edge = edge[i:i + edge_len] + x = keypoints[t_idx, sub_edge, 0] + y = keypoints[t_idx, sub_edge, 1] + + # Interp keypoints to get the curve shape. + curve_x, curve_y = interp_points(x, y) + draw_edge(im_edges, curve_x, curve_y, bw=bw) + if add_dist_map: + draw_edge(im_edge, curve_x, curve_y, bw=bw) + + if add_dist_map: + # Add distance transform map on each facial part. + im_dist = cv2.distanceTransform(255 - im_edge, + cv2.DIST_L1, 3) + im_dist = np.clip((im_dist / 3), 0, 255) + im_dists = np.dstack((im_dists, im_dist)) + + if add_pos_encode and e == 0: + # Add positional encoding for the first edge. + from math import pi + im_pos = np.zeros((resize_h, resize_w, 0), np.float32) + for l in range(10): # noqa: E741 + dist = (im_dist.astype(np.float32) - 127.5) / 127.5 + sin = np.sin(pi * (2 ** l) * dist) + cos = np.cos(pi * (2 ** l) * dist) + im_pos = np.dstack((im_pos, sin, cos)) + + # Combine all components to form the final label map. + if add_dist_map: + im_edges = np.dstack((im_edges, im_dists)) + im_edges = im_edges.astype(np.float32) / 255.0 + if add_pos_encode: + im_edges = np.dstack((im_edges, im_pos)) + outputs.append(im_edges) + return outputs + + +def normalize_and_connect_face_keypoints(cfg, is_inference, data): + r"""Normalize face keypoints w.r.t. reference face keypoints and connect + keypoints to form 2D images. + + Args: + cfg (obj): Data configuration. + is_inference (bool): Is doing inference or not. + data (dict): Input data. + + Returns: + (dict): Output data. + """ + assert is_inference + resize_h, resize_w = data['images'][0].shape[-2:] + + keypoints = data['label'].numpy()[0] + ref_keypoints = data['few_shot_label'].numpy()[0] + + # Get the normalization params and prev data if it's been computed before. + dist_scales = prev_keypoints = None + if 'common_attr' in data and 'prev_data' in data['common_attr']: + dist_scales = data['common_attr']['dist_scales'] + prev_keypoints = data['common_attr']['prev_data'] + + def concat(prev, now, t): + r"""Concat prev and now frames in first dimension, up to t frames.""" + if prev is None: + return now + return np.vstack([prev, now])[-t:] + + # Normalize face keypoints w.r.t. reference face keypoints. + keypoints, dist_scales = \ + normalize_face_keypoints(keypoints[0], ref_keypoints[0], dist_scales, + momentum=getattr(cfg.for_face_dataset, + 'normalize_momentum', 0.9)) + keypoints = keypoints[np.newaxis, :] + + # Temporally smooth the face keypoints by median filtering. + ks = getattr(cfg.for_face_dataset, 'smooth_kernel_size', 5) + concat_keypoints = concat(prev_keypoints, keypoints, ks) + if ks > 1 and concat_keypoints.shape[0] == ks: + keypoints = smooth_face_keypoints(concat_keypoints, ks) + + # Store the computed params. + if 'common_attr' not in data: + data['common_attr'] = dict() + data['common_attr']['dist_scales'] = dist_scales + data['common_attr']['prev_data'] = concat_keypoints + + # Draw the keypoints to turn them into images. + labels = [] + for kpt in [keypoints, ref_keypoints]: + label = connect_face_keypoints(resize_h, resize_w, None, None, None, + None, False, cfg, kpt) + labels += [torch.from_numpy(label[0]).permute(2, 0, 1).unsqueeze(0)] + data['label'], data['few_shot_label'] = labels + return data + + +def smooth_face_keypoints(concat_keypoints, ks): + r""" Temporally smooth the face keypoints by median filtering. + + Args: + concat_keypoints (TxKx2 numpy array): Face keypoints to be filtered. + ks (int): Filter kernel size. + + Returns: + (1xKx2 numpy array): Output face keypoints. + """ + # Median filtering. + filtered_keypoints = medfilt(concat_keypoints, kernel_size=[ks, 1, 1]) + # Fill in any zero keypoints with the value from previous frame. + if (filtered_keypoints == 0).any(): + for t in range(1, filtered_keypoints.shape[0]): + kpt_prev = filtered_keypoints[t - 1] + kpt_cur = filtered_keypoints[t] + kpt_max = np.maximum(kpt_cur, kpt_prev) + kpt_cur[kpt_cur == 0] = kpt_max[kpt_cur == 0] + filtered_keypoints[t] = kpt_cur + keypoints = filtered_keypoints[ks // 2: ks // 2 + 1] + return keypoints + + +def normalize_face_keypoints(keypoints, ref_keypoints, dist_scales=None, + momentum=0.9): + r"""Normalize face keypoints w.r.t. the reference face keypoints. + + Args: + keypoints (Kx2 numpy array): Target facial keypoints to be normalized. + ref_keypoints (Kx2 numpy array): Reference facial keypoints. + dist_scales (list of list of floats): Normalization params. + momentum (float): Temporal momentum for the normalization params. + + Returns: + (Kx2 numpy array): Normalized facial keypoints. + """ + if keypoints.shape[0] == 68: + central_keypoints = [8] + part_list = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], + [5, 11], [6, 10], [7, 9, 8], + [17, 26], [18, 25], [19, 24], [20, 23], [21, 22], + [27], [28], [29], [30], [31, 35], [32, 34], [33], + [36, 45], [37, 44], [38, 43], [39, 42], [40, 47], [41, 46], + [48, 54], [49, 53], [50, 52], [51], [55, 59], [56, 58], + [57], + [60, 64], [61, 63], [62], [65, 67], [66] + ] + else: + raise ValueError('Input keypoints type not supported.') + + face_cen = np.mean(keypoints[central_keypoints, :], axis=0) + ref_face_cen = np.mean(ref_keypoints[central_keypoints, :], axis=0) + + def get_mean_dists(pts, face_cen): + r"""Get mean distances of the points from face center.""" + mean_dists_x, mean_dists_y = [], [] + pts_cen = np.mean(pts, axis=0) + for p, pt in enumerate(pts): + mean_dists_x.append(np.linalg.norm(pt - pts_cen)) + mean_dists_y.append(np.linalg.norm(pts_cen - face_cen)) + mean_dist_x = sum(mean_dists_x) / len(mean_dists_x) + 1e-3 + mean_dist_y = sum(mean_dists_y) / len(mean_dists_y) + 1e-3 + return mean_dist_x, mean_dist_y + + dist_scale_x, dist_scale_y = [None] * len(part_list), \ + [None] * len(part_list) + if dist_scales is None: + dist_scale_x_prev = dist_scale_y_prev = img_scale = None + else: + dist_scale_x_prev, dist_scale_y_prev, img_scale = dist_scales + if img_scale is None: + img_scale = (keypoints[:, 0].max() - keypoints[:, 0].min()) \ + / (ref_keypoints[:, 0].max() - ref_keypoints[:, 0].min()) + + for i, pts_idx in enumerate(part_list): + pts = keypoints[pts_idx] + pts = pts[pts[:, 0] != 0] + if pts.shape[0]: + ref_pts = ref_keypoints[pts_idx] + mean_dist_x, mean_dist_y = get_mean_dists(pts, face_cen) + ref_dist_x, ref_dist_y = get_mean_dists(ref_pts, ref_face_cen) + dist_scale_x[i] = ref_dist_x / mean_dist_x * img_scale + dist_scale_y[i] = ref_dist_y / mean_dist_y * img_scale + if dist_scale_x_prev is not None: + dist_scale_x[i] = dist_scale_x_prev[i] * momentum + \ + dist_scale_x[i] * (1 - momentum) + dist_scale_y[i] = dist_scale_y_prev[i] * momentum + \ + dist_scale_y[i] * (1 - momentum) + + pts_cen = np.mean(pts, axis=0) + pts = (pts - pts_cen) * dist_scale_x[i] + \ + (pts_cen - face_cen) * dist_scale_y[i] + face_cen + keypoints[pts_idx] = pts + + return keypoints, [dist_scale_x, dist_scale_y, img_scale] + + +def npy_to_tensor(keypoints): + r"""Convert numpy array to pytorch tensor.""" + return torch.from_numpy(keypoints).unsqueeze(0) + + +def get_dlib_landmarks_from_image( + imgs, predictor_path='shape_predictor_68_face_landmarks.dat'): + r"""Get face keypoints from an image. + + Args: + imgs (N x 3 x H x W tensor or N x H x W x 3 numpy array): Input images. + predictor_path (str): Path to the predictor model. + """ + import dlib + predictor_path = get_checkpoint(predictor_path, + url='1l9zT-AI1yKlfyAb_wl_RjLBSaiWQr8dr') + if type(imgs) == torch.Tensor: + imgs = ((imgs + 1) / 2 * 255).byte() + imgs = np.transpose(imgs.cpu().numpy(), (0, 2, 3, 1)) + detector = dlib.get_frontal_face_detector() + predictor = dlib.shape_predictor(predictor_path) + points = np.zeros([imgs.shape[0], 68, 2], dtype=int) + for i in range(imgs.shape[0]): + img = imgs[i] + dets = detector(img, 1) + if len(dets) > 0: + # Only returns the first face. + shape = predictor(img, dets[0]) + for b in range(68): + points[i, b, 0] = shape.part(b).x + points[i, b, 1] = shape.part(b).y + return points + + +def get_126_landmarks_from_image(imgs, landmarks_network): + r"""Get face keypoints from an image. + + Args: + imgs (Nx3xHxW tensor or NxHxWx3 numpy array): + Input images. + landmarks_network (obj): The landmark detection network. + + Return: + (Nx126x2 numpy array): Predicted landmarks. + """ + if type(imgs) == torch.Tensor: + imgs = ((imgs + 1) / 2 * 255).byte() + imgs = np.transpose(imgs.cpu().numpy(), (0, 2, 3, 1)) + + landmarks = [] + for i in range(imgs.shape[0]): + img = imgs[i] + out_boxes, landmark = \ + landmarks_network.get_face_boxes_and_landmarks(img) + if len(landmark) > 1: + # Pick the largest face in the image. + face_size_max = face_index = 0 + for i, out_box in enumerate(out_boxes): + face_size = max(out_box[2] - out_box[0], + out_box[1] - out_box[1]) + if face_size > face_size_max: + face_size_max = face_size + face_index = i + landmark = landmark[face_index] + elif len(landmark) == 1: + landmark = landmark[0] + else: + landmark = np.zeros((126, 2), dtype=np.float32) + landmarks += [landmark[np.newaxis]] + landmarks = np.vstack(landmarks).astype(np.float32) + return landmarks + + +def convert_face_landmarks_to_image(cfgdata, landmarks, output_size, + output_tensor=True, cpu_only=False): + r"""Convert the facial landmarks to a label map. + + Args: + cfgdata (obj): Data configuration. + landmarks + output_size (tuple of int): H, W of output label map. + output_tensor (bool): Output tensors instead of numpy arrays. + cpu_only (bool): Output CPU tensor only. + + Returns: + (NxCxHxW tensor or list of HxWxC numpy arrays): Label maps. + """ + h, w = output_size + labels = connect_face_keypoints(h, w, None, None, None, None, False, + cfgdata, landmarks) + if not output_tensor: + return labels + labels = [torch.from_numpy(label).permute(2, 0, 1).unsqueeze(0) + for label in labels] + labels = torch.cat(labels) + if cpu_only: + return labels + return labels.cuda() + + +def add_face_keypoints(label_map, image, keypoints): + r"""Add additional keypoints to label map. + + Args: + label_map (Nx1xHxW tensor or None) + image (Nx3xHxW tensor) + keypoints (NxKx2 tensor) + """ + if label_map is None: + label_map = torch.zeros_like(image)[:, :1] + x, y = keypoints[:, :, 0], keypoints[:, :, 1] + h, w = image.shape[-2:] + x = ((x + 1) / 2 * w).long() + y = ((y + 1) / 2 * h).long() + bs = torch.arange(label_map.shape[0]).cuda().view(-1, 1).expand_as(x) + label_map[bs, :, y, x] = 1 + return label_map + + +def draw_edge(im, x, y, bw=1, color=(255, 255, 255), draw_end_points=False): + r"""Set colors given a list of x and y coordinates for the edge. + + Args: + im (HxWxC numpy array): Canvas to draw. + x (1D numpy array): x coordinates of the edge. + y (1D numpy array): y coordinates of the edge. + bw (int): Width of the stroke. + color (list or tuple of int): Color to draw. + draw_end_points (bool): Whether to draw end points of the edge. + """ + if x is not None and x.size: + h, w = im.shape[0], im.shape[1] + # Draw edge. + for i in range(-bw, bw): + for j in range(-bw, bw): + yy = np.maximum(0, np.minimum(h - 1, y + i)) + xx = np.maximum(0, np.minimum(w - 1, x + j)) + set_color(im, yy, xx, color) + + # Draw endpoints. + if draw_end_points: + for i in range(-bw * 2, bw * 2): + for j in range(-bw * 2, bw * 2): + if (i ** 2) + (j ** 2) < (4 * bw ** 2): + yy = np.maximum(0, np.minimum(h - 1, np.array( + [y[0], y[-1]]) + i)) + xx = np.maximum(0, np.minimum(w - 1, np.array( + [x[0], x[-1]]) + j)) + set_color(im, yy, xx, color) + + +def set_color(im, yy, xx, color): + r"""Set pixels of the image to the given color. + + Args: + im (HxWxC numpy array): Canvas to draw. + xx (1D numpy array): x coordinates of the pixels. + yy (1D numpy array): y coordinates of the pixels. + color (list or tuple of int): Color to draw. + """ + if type(color) != list and type(color) != tuple: + color = [color] * 3 + if len(im.shape) == 3 and im.shape[2] == 3: + if (im[yy, xx] == 0).all(): + im[yy, xx, 0], im[yy, xx, 1], im[yy, xx, 2] = \ + color[0], color[1], color[2] + else: + for c in range(3): + im[yy, xx, c] = ((im[yy, xx, c].astype(float) + + color[c]) / 2).astype(np.uint8) + else: + im[yy, xx] = color[0] + + +def interp_points(x, y): + r"""Given the start and end points, interpolate to get a curve/line. + + Args: + x (1D array): x coordinates of the points to interpolate. + y (1D array): y coordinates of the points to interpolate. + + Returns: + (dict): + - curve_x (1D array): x coordinates of the interpolated points. + - curve_y (1D array): y coordinates of the interpolated points. + """ + if abs(x[:-1] - x[1:]).max() < abs(y[:-1] - y[1:]).max(): + curve_y, curve_x = interp_points(y, x) + if curve_y is None: + return None, None + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + if len(x) < 3: + popt, _ = curve_fit(linear, x, y) + else: + popt, _ = curve_fit(func, x, y) + if abs(popt[0]) > 1: + return None, None + except Exception: + return None, None + if x[0] > x[-1]: + x = list(reversed(x)) + y = list(reversed(y)) + curve_x = np.linspace(x[0], x[-1], int(np.round(x[-1]-x[0]))) + if len(x) < 3: + curve_y = linear(curve_x, *popt) + else: + curve_y = func(curve_x, *popt) + return curve_x.astype(int), curve_y.astype(int) + + +def func(x, a, b, c): + r"""Quadratic fitting function.""" + return a * x**2 + b * x + c + + +def linear(x, a, b): + r"""Linear fitting function.""" + return a * x + b diff --git a/imaginaire/utils/visualization/pose.py b/imaginaire/utils/visualization/pose.py new file mode 100644 index 0000000000000000000000000000000000000000..bca83f971e5f7c7d0206013bb42de278814244ba --- /dev/null +++ b/imaginaire/utils/visualization/pose.py @@ -0,0 +1,409 @@ +# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is made available under the Nvidia Source Code License-NC. +# To view a copy of this license, check out LICENSE.md +import numpy as np +import random +import importlib +from .common import tensor2im, tensor2label +from .face import draw_edge, interp_points +from imaginaire.model_utils.fs_vid2vid import extract_valid_pose_labels + + +def draw_openpose_npy(resize_h, resize_w, crop_h, crop_w, original_h, + original_w, is_flipped, cfgdata, keypoints_npy): + r"""Connect the OpenPose keypoints to edges and draw the pose map. + + Args: + resize_h (int): Height the input image was resized to. + resize_w (int): Width the input image was resized to. + crop_h (int): Height the input image was cropped. + crop_w (int): Width the input image was cropped. + original_h (int): Original height of the input image. + original_w (int): Original width of the input image. + is_flipped (bool): Is the input image flipped. + cfgdata (obj): Data configuration. + keypoints_npy (dict): OpenPose keypoint dict. + + Returns: + (list of HxWxC numpy array): Drawn label map. + """ + pose_cfg = cfgdata.for_pose_dataset + # Whether to draw only the basic keypoints. + basic_points_only = getattr(pose_cfg, 'basic_points_only', False) + # Whether to remove the face labels to avoid overfitting. + remove_face_labels = getattr(pose_cfg, 'remove_face_labels', False) + # Whether to randomly drop some keypoints to avoid overfitting. + random_drop_prob = getattr(pose_cfg, 'random_drop_prob', 0) + + # Get the list of edges to draw. + edge_lists = define_edge_lists(basic_points_only) + op_key = cfgdata.keypoint_data_types[0] + for input_type in cfgdata.input_types: + if op_key in input_type: + nc = input_type[op_key].num_channels + if crop_h is not None: + h, w = crop_h, crop_w + else: + h, w = resize_h, resize_w + + outputs = [] + for keypoint_npy in keypoints_npy: + person_keypoints = np.asarray(keypoint_npy).reshape(-1, 137, 3)[0] + # Separate out the keypoint array to different parts. + pose_pts = person_keypoints[:25] + face_pts = person_keypoints[25: (25 + 70)] + hand_pts_l = person_keypoints[(25 + 70): (25 + 70 + 21)] + hand_pts_r = person_keypoints[-21:] + all_pts = [pose_pts, face_pts, hand_pts_l, hand_pts_r] + # Remove the keypoints with low confidence. + all_pts = [extract_valid_keypoints(pts, edge_lists) + for pts in all_pts] + + # Connect the keypoints to form the label map. + pose_img = connect_pose_keypoints(all_pts, edge_lists, + (h, w, nc), + basic_points_only, + remove_face_labels, + random_drop_prob) + pose_img = pose_img.astype(np.float32) / 255.0 + outputs.append(pose_img) + return outputs + + +def openpose_to_npy_largest_only(inputs): + r"""Convert OpenPose dicts to numpy arrays of keypoints. Only return the + largest/tallest person in each dict. + + Args: + inputs (list of dicts): List of OpenPose dicts. + + Returns: + (list of numpy arrays): Keypoints. + """ + return base_openpose_to_npy(inputs, return_largest_only=True) + + +def openpose_to_npy(inputs): + r"""Conver OpenPose dicts to numpy arrays of keypoints. + + Args: + inputs (list of dicts): List of OpenPose dicts. + + Returns: + (list of numpy arrays): Keypoints. + """ + return base_openpose_to_npy(inputs, return_largest_only=False) + + +def base_openpose_to_npy(inputs, return_largest_only=False): + r"""Convert OpenPose dicts to numpy arrays of keypoints. + + Args: + inputs (list of dicts): List of OpenPose dicts. + return_largest_only (bool): Whether to return only the largest person. + + Returns: + (list of numpy arrays): Keypoints. + """ + outputs_npy = [] + for input in inputs: + people_dict = input['people'] + n_ppl = max(1, len(people_dict)) + output_npy = np.zeros((n_ppl, 25 + 70 + 21 + 21, 3), dtype=np.float32) + y_len_max = 0 + for i, person_dict in enumerate(people_dict): + # Extract corresponding keypoints from the dict. + pose_pts = np.array(person_dict["pose_keypoints_2d"]).reshape(25, 3) + face_pts = np.array(person_dict["face_keypoints_2d"]).reshape(70, 3) + hand_pts_l = np.array(person_dict["hand_left_keypoints_2d"] + ).reshape(21, 3) + hand_pts_r = np.array(person_dict["hand_right_keypoints_2d"] + ).reshape(21, 3) + + if return_largest_only: + # Get the body length. + y = pose_pts[pose_pts[:, 2] > 0.01, 1] + y_len = y.max() - y.min() + if y_len > y_len_max: + y_len_max = y_len + max_ind = i + + # Concatenate all keypoint together. + output_npy[i] = np.vstack([pose_pts, face_pts, + hand_pts_l, hand_pts_r]) + if return_largest_only: + # Only return the largest person in the dict. + output_npy = output_npy[max_ind: max_ind + 1] + + outputs_npy += [output_npy.astype(np.float32)] + return outputs_npy + + +def extract_valid_keypoints(pts, edge_lists): + r"""Use only the valid keypoints by looking at the detection confidences. + If the confidences for all keypoints in an edge are above threshold, + keep the keypoints. Otherwise, their coordinates will be set to zero. + + Args: + pts (Px3 numpy array): Keypoint xy coordinates + confidence. + edge_lists (nested list of ints): List of keypoint indices for edges. + + Returns: + (Px2 numpy array): Output keypoints. + """ + pose_edge_list, _, hand_edge_list, _, face_list = edge_lists + p = pts.shape[0] + thre = 0.1 if p == 70 else 0.01 + output = np.zeros((p, 2)) + + if p == 70: # ai_emoji + for edge_list in face_list: + for edge in edge_list: + if (pts[edge, 2] > thre).all(): + output[edge, :] = pts[edge, :2] + elif p == 21: # hand + for edge in hand_edge_list: + if (pts[edge, 2] > thre).all(): + output[edge, :] = pts[edge, :2] + else: # pose + valid = (pts[:, 2] > thre) + output[valid, :] = pts[valid, :2] + + return output + + +def connect_pose_keypoints(pts, edge_lists, size, basic_points_only, + remove_face_labels, random_drop_prob): + r"""Draw edges by connecting the keypoints onto the label map. + + Args: + pts (Px3 numpy array): Keypoint xy coordinates + confidence. + edge_lists (nested list of ints): List of keypoint indices for edges. + size (tuple of int): Output size. + basic_points_only (bool): Whether to use only the basic keypoints. + remove_face_labels (bool): Whether to remove face labels. + random_drop_prob (float): Probability to randomly drop keypoints. + + Returns: + (HxWxC numpy array): Output label map. + """ + pose_pts, face_pts, hand_pts_l, hand_pts_r = pts + h, w, c = size + body_edges = np.zeros((h, w, c), np.uint8) + # If using one-hot, different parts of the body will be drawn to + # different channels. + use_one_hot = c > 3 + if use_one_hot: + assert c == 27 + pose_edge_list, pose_color_list, hand_edge_list, hand_color_list, \ + face_list = edge_lists + + # Draw pose edges. + h = int(pose_pts[:, 1].max() - pose_pts[:, 1].min()) + bw = max(1, h // 150) # Stroke width. + body_edges = draw_edges(body_edges, pose_pts, [pose_edge_list], bw, + use_one_hot, random_drop_prob, + colors=pose_color_list, draw_end_points=True) + + if not basic_points_only: + # Draw hand edges. + bw = max(1, h // 450) + for i, hand_pts in enumerate([hand_pts_l, hand_pts_r]): + if use_one_hot: + k = 24 + i + body_edges[:, :, k] = draw_edges(body_edges[:, :, k], hand_pts, + [hand_edge_list], + bw, False, random_drop_prob, + colors=[255] * len(hand_pts)) + else: + body_edges = draw_edges(body_edges, hand_pts, [hand_edge_list], + bw, False, random_drop_prob, + colors=hand_color_list) + # Draw face edges. + if not remove_face_labels: + if use_one_hot: + k = 26 + body_edges[:, :, k] = draw_edges(body_edges[:, :, k], face_pts, + face_list, bw, False, + random_drop_prob) + else: + body_edges = draw_edges(body_edges, face_pts, face_list, bw, + False, random_drop_prob) + return body_edges + + +def draw_edges(canvas, keypoints, edges_list, bw, use_one_hot, + random_drop_prob=0, edge_len=2, colors=None, + draw_end_points=False): + r"""Draw all the edges in the edge list on the canvas. + + Args: + canvas (HxWxK numpy array): Canvas to draw. + keypoints (Px2 numpy array): Keypoints. + edge_list (nested list of ints): List of keypoint indices for edges. + bw (int): Stroke width. + use_one_hot (bool): Use one-hot encoding or not. + random_drop_prob (float): Probability to randomly drop keypoints. + edge_len (int): Number of keypoints in an edge. + colors (tuple of int): Color to draw. + draw_end_points (bool): Whether to draw end points for edges. + + Returns: + (HxWxK numpy array): Output. + """ + k = 0 + for edge_list in edges_list: + for i, edge in enumerate(edge_list): + for j in range(0, max(1, len(edge) - 1), edge_len - 1): + if random.random() > random_drop_prob: + sub_edge = edge[j:j + edge_len] + x, y = keypoints[sub_edge, 0], keypoints[sub_edge, 1] + if 0 not in x: # Get rid of invalid keypoints. + curve_x, curve_y = interp_points(x, y) + if use_one_hot: + # If using one-hot, draw to different channels of + # the canvas. + draw_edge(canvas[:, :, k], curve_x, curve_y, + bw=bw, color=255, + draw_end_points=draw_end_points) + else: + color = colors[i] if colors is not None \ + else (255, 255, 255) + draw_edge(canvas, curve_x, curve_y, + bw=bw, color=color, + draw_end_points=draw_end_points) + k += 1 + return canvas + + +def define_edge_lists(basic_points_only): + r"""Define the list of keypoints that should be connected to form the edges. + + Args: + basic_points_only (bool): Whether to use only the basic keypoints. + """ + # Pose edges and corresponding colors. + pose_edge_list = [ + [17, 15], [15, 0], [0, 16], [16, 18], # head + [0, 1], [1, 8], # body + [1, 2], [2, 3], [3, 4], # right arm + [1, 5], [5, 6], [6, 7], # left arm + [8, 9], [9, 10], [10, 11], # right leg + [8, 12], [12, 13], [13, 14] # left leg + ] + pose_color_list = [ + [153, 0, 153], [153, 0, 102], [102, 0, 153], [51, 0, 153], + [153, 0, 51], [153, 0, 0], + [153, 51, 0], [153, 102, 0], [153, 153, 0], + [102, 153, 0], [51, 153, 0], [0, 153, 0], + [0, 153, 51], [0, 153, 102], [0, 153, 153], + [0, 102, 153], [0, 51, 153], [0, 0, 153], + ] + + if not basic_points_only: + pose_edge_list += [ + [11, 24], [11, 22], [22, 23], # right foot + [14, 21], [14, 19], [19, 20] # left foot + ] + pose_color_list += [ + [0, 153, 153], [0, 153, 153], [0, 153, 153], + [0, 0, 153], [0, 0, 153], [0, 0, 153] + ] + + # Hand edges and corresponding colors. + hand_edge_list = [ + [0, 1, 2, 3, 4], + [0, 5, 6, 7, 8], + [0, 9, 10, 11, 12], + [0, 13, 14, 15, 16], + [0, 17, 18, 19, 20] + ] + hand_color_list = [ + [204, 0, 0], [163, 204, 0], [0, 204, 82], [0, 82, 204], [163, 0, 204] + ] + + # Face edges. + face_list = [ + [range(0, 17)], # face contour + [range(17, 22)], # left eyebrow + [range(22, 27)], # right eyebrow + [[28, 31], range(31, 36), [35, 28]], # nose + [[36, 37, 38, 39], [39, 40, 41, 36]], # left eye + [[42, 43, 44, 45], [45, 46, 47, 42]], # right eye + [range(48, 55), [54, 55, 56, 57, 58, 59, 48]], # mouth + ] + + return pose_edge_list, pose_color_list, hand_edge_list, hand_color_list, \ + face_list + + +def tensor2pose(cfg, label_tensor): + r"""Convert output tensor to a numpy pose map. + + Args: + label_tensor (3D/4D/5D tensor): Label tensor. + + Returns: + (HxWx3 numpy array or list of numpy arrays): Pose map. + """ + if label_tensor.dim() == 5 or label_tensor.dim() == 4: + return [tensor2pose(cfg, label_tensor[idx]) + for idx in range(label_tensor.size(0))] + + # If adding additional discriminators, draw the bbox for the regions + # (e.g. faces) too. + add_dis_cfg = getattr(cfg.dis, 'additional_discriminators', None) + if add_dis_cfg is not None: + crop_coords = [] + for name in add_dis_cfg: + v = add_dis_cfg[name].vis + file, crop_func = v.split('::') + file = importlib.import_module(file) + crop_func = getattr(file, crop_func) + crop_coord = crop_func(cfg.data, label_tensor) + if len(crop_coord) > 0: + if type(crop_coord[0]) == list: + crop_coords.extend(crop_coord) + else: + crop_coords.append(crop_coord) + + pose_cfg = cfg.data.for_pose_dataset + pose_type = getattr(pose_cfg, 'pose_type', 'both') + remove_face_labels = getattr(pose_cfg, 'remove_face_labels', False) + label_tensor = extract_valid_pose_labels(label_tensor, pose_type, + remove_face_labels) + + # If using both DensePose and OpenPose, overlay one image onto the other + # to get the visualization map. + dp_key = 'pose_maps-densepose' + op_key = 'poses-openpose' + use_densepose = use_openpose = False + for input_type in cfg.data.input_types: + if dp_key in input_type: + dp_ch = input_type[dp_key].num_channels + use_densepose = True + elif op_key in input_type: + op_ch = input_type[op_key].num_channels + use_openpose = True + if use_densepose: + label_img = tensor2im(label_tensor[:dp_ch]) + if use_openpose: + openpose = label_tensor[-op_ch:] + openpose = tensor2im(openpose) if op_ch == 3 else \ + tensor2label(openpose, op_ch) + if use_densepose: + label_img[openpose != 0] = openpose[openpose != 0] + else: + label_img = openpose + + # Draw the bbox for the regions for the additional discriminator. + if add_dis_cfg is not None: + for crop_coord in crop_coords: + ys, ye, xs, xe = crop_coord + label_img[ys, xs:xe, :] = label_img[ye - 1, xs:xe, :] \ + = label_img[ys:ye, xs, :] = label_img[ys:ye, xe - 1, :] = 255 + + if len(label_img.shape) == 2: + label_img = np.repeat(label_img[:, :, np.newaxis], 3, axis=2) + return label_img diff --git a/inference/draw_points.py b/inference/draw_points.py new file mode 100644 index 0000000000000000000000000000000000000000..97364c42d15be50ca87b8e66fbda448371b0e032 --- /dev/null +++ b/inference/draw_points.py @@ -0,0 +1,38 @@ +import numpy as np +from scipy.interpolate import interp1d +import matplotlib.pyplot as plt +from matplotlib.image import imread +import csv,os,sys + +data = sys.argv[1] +assert data.endswith('satView_polish.png') +img_path = os.path.join('dataset/CVACT/satview_correct',data) + +# img_path = './dataset/CVACT/satview_correct/__-DFIFxvZBCn1873qkqXA_satView_polish.png' +csv_path = 'vis_video/pixels.csv' +select_points = [28, 44, 53] + +x_list,y_list = [],[] +x_whole,y_whole = [],[] +with open(csv_path, 'r') as csvfile: + reader = csv.DictReader(csvfile) + for i,row in enumerate(reader): + x,y = float(row['x']),float(row['y']) + if i in select_points: + x_list.append(x) + y_list.append(y) + print(i,x,y) + x_whole.append(x) + y_whole.append(y) +fig, ax = plt.subplots() + + +img = imread(img_path) +plt.imshow(img) +plt.plot(x_whole, y_whole, 'r-',label='Smooth curve', linewidth=4) +plt.scatter(x_list,y_list,marker='o', s=0, color='red') +plt.axis('off') +plt.xlim([0, 256]) +plt.ylim([256, 0]) +plt.subplots_adjust(left=0, right=1, top=1, bottom=0) +plt.savefig('point_curve.png', bbox_inches='tight', pad_inches=0) diff --git a/inference/get_score_from_dir.py b/inference/get_score_from_dir.py new file mode 100644 index 0000000000000000000000000000000000000000..691362706ca5339dce2779917c18580ddb854bbb --- /dev/null +++ b/inference/get_score_from_dir.py @@ -0,0 +1,68 @@ +from unittest import result +from matplotlib.pyplot import hist +from torch.utils import data +from torch.utils.data.dataset import Dataset + +import os,torch +from PIL import Image +import torchvision.transforms as T +from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM +import torch.nn.functional as F +from imaginaire.evaluation.segmentation import get_segmentation_hist_model,get_miou,compute_hist +import lpips +from easydict import EasyDict as edict +from tqdm import tqdm +import piq +from torch.utils.data import DataLoader +from piq import FID,KID +import numpy as np + +result_path = 'result/Ours-pers-sin-sty' +gt_path = 'dataset/CVACT/streetview_test' + + +class Dataset_img(Dataset): + def __init__(self, dir): + self.dir = dir + self.datalist = sorted(os.listdir(dir)) + + def __len__(self): + return len(self.datalist) + + def __getitem__(self, index): + img = os.path.join(self.dir,self.datalist[index]) + img = Image.open(img).convert('RGB') + img = T.ToTensor()(img) + return {'images':img} + + + +data_gt = Dataset_img(gt_path) +data_pred = Dataset_img(result_path) + + +loss_fn_alex = lpips.LPIPS(net='alex',eval_mode=True).cuda() +loss_fn_squeeze = lpips.LPIPS(net='squeeze',eval_mode=True).cuda() + + +data_list = os.listdir(result_path) +results = edict() +results.psnr = [] +results.ssim = [] +results.alex = [] +results.squeeze = [] +results.RMSE = [] + +dataloader_pred = DataLoader(data_pred,batch_size=1,shuffle=False,num_workers=10) +dataloader_gt = DataLoader(data_gt,batch_size=1,shuffle=False,num_workers=10) +for i in tqdm(zip(dataloader_pred,dataloader_gt),ncols=100): + pred = i[0]['images'].cuda() + gt = i[1]['images'].cuda() + results.psnr.append(-10*F.mse_loss(pred,gt).log10().item()) + results.ssim.append(ssim(pred, gt,data_range=1.).item()) + results.alex.append(torch.mean(loss_fn_alex((pred*2.)-1, (2.*gt)-1)).cpu().item()) + results.squeeze.append(torch.mean(loss_fn_squeeze((pred*2.)-1, (2.*gt)-1)).cpu().item()) + results.RMSE.append(torch.sqrt(F.mse_loss(pred,gt)).item()*255) + +for i in results: + print("%-10s"%i, ':',np.mean(results[i])) diff --git a/inference/img2vid.py b/inference/img2vid.py new file mode 100644 index 0000000000000000000000000000000000000000..68d33d2b8276935856875c5687af19b768adb245 --- /dev/null +++ b/inference/img2vid.py @@ -0,0 +1,73 @@ +import os +import cv2 +from PIL import Image + +def image_to_video(img_dir,image_names, media_path): + fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') + fps = 20 + image = Image.open(os.path.join( img_dir , image_names[0])) + media_writer = cv2.VideoWriter(media_path, fourcc, fps, image.size) + for image_name in image_names: + im = cv2.imread(os.path.join(img_dir, image_name)) + media_writer.write(im) + print(image_name, 'combined') + media_writer.release() + print('end') + +def img_pair2vid(sat_list,grd_list,angle_list=None,media_path= 'output.mp4'): + fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') + out = cv2.VideoWriter(media_path, fourcc, 12.0, (512, 256)) + out_sat = cv2.VideoWriter(media_path.replace('.mp4','_sat.mp4'), fourcc, 12.0, (389, 389)) + assert len(sat_list) == len(grd_list) + for i in range(len(sat_list)): + + img1 = cv2.imread(os.path.join( img_dir , sat_list[i])) + img2 = cv2.imread(os.path.join( img_dir , grd_list[i])) + img3 = cv2.imread(os.path.join( img_dir , grd_list[i].replace('.png','_depth.png'))) + + + if angle_list!=None: + angle = angle_list[i] + left_pixel = int((angle/180)*256) + if angle<0: + img2 = cv2.hconcat([img2[:,left_pixel:,:],img2[:,:left_pixel,:]]) + img3= cv2.hconcat([img3[:,left_pixel:,:],img3[:,:left_pixel,:]]) + else: + img2 = cv2.hconcat([img2[:,left_pixel:,:],img2[:,:left_pixel,:]]) + img3 = cv2.hconcat([img3[:,left_pixel:,:],img3[:,:left_pixel,:]]) + merged_image = cv2.vconcat([img2,img3]) + out.write(merged_image) + out_sat.write(img1) + out.release() + out_sat.release() + +if __name__=='__main__': + import csv + img_dir = 'vis_video' + img_list = sorted(os.listdir(img_dir)) + sat_list = [] + grd_list = [] + for img in img_list: + if '.png' in img: + if 'satdepth' in img: + continue + if 'grdView_pano.png' in img: + continue + if 'grdView' in img: + if '_depth.png' not in img: + grd_list.append(img) + elif 'satView' in img: + sat_list.append(img) + sat_list = sat_list[:-1] + grd_list = grd_list[:-1] + media_path = os.path.join(img_dir,'output_cat.mp4') + angle_list = [] + with open(os.path.join(img_dir,'pixels.csv') , 'r') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + angle = float(row['angle']) + angle_list.append(angle) + print(angle_list) + + img_pair2vid(sat_list,grd_list,angle_list,media_path= media_path) + print('save 2 ',media_path) \ No newline at end of file diff --git a/inference/img2vid_interpolation.py b/inference/img2vid_interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..462dd2829e472652bd066d71b2b39bb87de5ba77 --- /dev/null +++ b/inference/img2vid_interpolation.py @@ -0,0 +1,26 @@ +import os +import cv2 +from PIL import Image + + +def img_pair2vid(sat_list,media_path= 'interpolation.mp4'): + fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') + out = cv2.VideoWriter(media_path, fourcc, 12.0, (512, 128)) + for i in range(len(sat_list)): + + img1 = cv2.imread(os.path.join( img_dir , sat_list[i])) + + out.write(img1) + out.release() + +if __name__=='__main__': + import csv + img_dir = 'vis_interpolation' + img_list = sorted(os.listdir(img_dir)) + sat_list = [] + for img in img_list: + sat_list.append(img) + media_path = os.path.join(img_dir,'interpolation.mp4') + + img_pair2vid(sat_list,media_path= media_path) + print('save 2 ',media_path) \ No newline at end of file diff --git a/inference/quick_demo_interpolation.sh b/inference/quick_demo_interpolation.sh new file mode 100644 index 0000000000000000000000000000000000000000..ad760675784c53cd3a3b2c4df38c6efb3c2e619a --- /dev/null +++ b/inference/quick_demo_interpolation.sh @@ -0,0 +1,10 @@ +CUDA_VISIBLE_DEVICES=0 python offline_train_test.py --task=test_interpolation \ +--yaml=sat2density_cvact \ +--test_ckpt_path=2u87bj8w \ +--sty_img1=YL81FiK9PucIvAkr1FHkpA_grdView.png \ +--sty_img2=pdZmLHYEhe2PHj_8-WHMhw_grdView.png \ +--demo_img=VAMM6sIEbYAY5E6ZD_RMKg_satView_polish.png \ +--data.root=demo_img + + +python inference/img2vid_interpolation.py \ No newline at end of file diff --git a/inference/quick_demo_video.sh b/inference/quick_demo_video.sh new file mode 100644 index 0000000000000000000000000000000000000000..cc8d9732df05c5056582bfe4e344b07a43fb3d73 --- /dev/null +++ b/inference/quick_demo_video.sh @@ -0,0 +1,32 @@ +sat_name=__-DFIFxvZBCn1873qkqXA_satView_polish.png + +sty_name=VAMM6sIEbYAY5E6ZD_RMKg_grdView.png + +# select point +# First find a starting point, then press and hold the left mouse button, draw any shape, +# then release the left mouse button, and press 'q' on the keyboard to end the point selection process + +# better select regions near the center of the satellite image. 'q' to end select point. +# python inference/select_points.py ${sat_name} + +# inference +# if you want use illumination from another image , you could add --sty_img=WsKPDHEgLwrhrJXcUU34xA_grdView.png +CUDA_VISIBLE_DEVICES=0 python offline_train_test.py --yaml=sat2density_cvact \ +--test_ckpt_path=2u87bj8w \ +--task=test_vid \ +--demo_img=${sat_name} --sty_img=${sty_name} \ +--data.root=demo_img + + +# make video +python inference/img2vid.py + +# visualize vis_video/volume_data.vtk with ParaView + + +# python test.py --yaml=sat2density_cvact \ +# --test_ckpt_path=2u87bj8w \ +# --task=test_vid \ +# --demo_img=__-DFIFxvZBCn1873qkqXA_satView_polish.png \ +# --sty_img=VAMM6sIEbYAY5E6ZD_RMKg_grdView.png \ +# --data.root=demo_img diff --git a/inference/select_points.py b/inference/select_points.py new file mode 100644 index 0000000000000000000000000000000000000000..36cae5d190000a4ab8b4693faa26ee0cd142005d --- /dev/null +++ b/inference/select_points.py @@ -0,0 +1,110 @@ +import matplotlib.pyplot as plt +from matplotlib.widgets import Cursor +from matplotlib.image import imread +import numpy as np +import csv,os +from scipy.interpolate import interp1d +import sys +data = sys.argv[1] +assert data.endswith('satView_polish.png') +dirs = os.path.join('dataset/CVACT/satview_correct',data) +if not os.path.exists(dirs): + dirs = dirs.replace('dataset/CVACT','demo_img') +sav_pth = 'vis_video' +if not os.path.exists(sav_pth): + os.mkdir(sav_pth) + +img = imread(dirs) + +fig = plt.figure() +fig.set_size_inches(1,1,forward=False) +ax = plt.Axes(fig, [0., 0., 1., 1.]) +ax.set_axis_off() +ax.imshow(img) + +coords = [] + +def ondrag(event): + if event.button != 1: + return + x, y = int(event.xdata), int(event.ydata) + coords.append((x, y)) + ax.plot([x], [y], 'o', color='red') + fig.canvas.draw_idle() +fig.add_axes(ax) +cursor = Cursor(ax, useblit=True, color='red', linewidth=1) +fig.canvas.mpl_connect('motion_notify_event', ondrag) +plt.show() +plt.close() + + +unique_lst = list(dict.fromkeys(coords)) +pixels = [] +for x in coords: + if x in unique_lst: + if x not in pixels: + pixels.append(x) +print(pixels) + +########################################### + +from scipy.interpolate import splprep, splev + +points = pixels +points = np.array(points) +tck, u = splprep(points.T, s=25, per=0) +u_new = np.linspace(u.min(), u.max(), 80) +x_new, y_new = splev(u_new, tck) + +plt.plot(points[:,0], points[:,1], 'ro', label='Original curve') +plt.plot(x_new, y_new, 'b-', label='Smooth curve') +plt.legend() +plt.show() +plt.close() + +fig, ax = plt.subplots() + + +pixels = [tuple(sublist[:2]) for sublist in zip(x_new,y_new)] +########################################### +img = imread(dirs) +fig, ax = plt.subplots() +ax.set_xticks([]) +ax.set_yticks([]) +ax.imshow(img) +plt.plot(x_new, y_new, 'r-', label='Smooth curve') +fig.subplots_adjust(left=0, right=1, top=1, bottom=0) +plt.savefig(os.path.join(sav_pth,os.path.basename(dirs)).replace('.png','_sat_track.png'),bbox_inches='tight', pad_inches=0) +plt.close() + +########################################### +angle_list = [] +for i,pixel in enumerate(pixels[:-1]): + img = imread(dirs) + + x1, y1 = pixel + x2, y2 = pixels[i+1] + dx, dy = x2 - x1, y2 - y1 + angle_save = np.degrees(np.arctan2(dy, dx))+90 + if angle_save>180: + angle_save = angle_save-360 + angle_list.append(angle_save) + length = np.sqrt(dx ** 2 + dy ** 2) + angle = np.arctan2(dy, dx) * 180 / np.pi + fig, ax = plt.subplots() + ax.set_xticks([]) + ax.set_yticks([]) + ax.imshow(img) + ax.arrow(x1, y1, dx*10, dy*10, color='red', width=length, head_width=4*length, head_length=5*length) + + name = '_sat'+'%05d' % int(i) + ".png" + plt.savefig(os.path.join(sav_pth,os.path.basename(dirs)).replace('.png',name),bbox_inches='tight') + plt.close() + + +with open( os.path.join(sav_pth,'pixels.csv'), 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['x', 'y','angle']) + for i, (x, y) in enumerate(pixels[:-1]): + writer.writerow([x, y,angle_list[i]]) +print('save to pixels.csv',len(pixels[:-1])) \ No newline at end of file diff --git a/inference/single_style_test_cvact.sh b/inference/single_style_test_cvact.sh new file mode 100644 index 0000000000000000000000000000000000000000..8def276ac04fc4e689e00cc288b737b9d4f7d74e --- /dev/null +++ b/inference/single_style_test_cvact.sh @@ -0,0 +1,9 @@ +python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=t1DOCdyniuWDC5JPqm4MWA_grdView.png +python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=XefdeYLN_XZEaG2VLPFVtA_grdView.png +python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=H2y6n9mCg53Ip1-0_UigRQ_grdView.png +python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=YOBJgPIILw9PbSFvnYZFZg_grdView.png +python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=lqgXjFM3zR8EWbiWWfgjNA_grdView.png +python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=fOX6greOFJMH8IlA8Gm5hg_grdView.png +python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=uZoS7QAxnEGlw22PtslB_Q_grdView.png +python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=FVJZ86vbU43hYf4-uM4lFg_grdView.png +python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --sty_img=uHD7qJude23nqRPLVrdKIA_grdView.png \ No newline at end of file diff --git a/inference/single_style_test_cvusa.sh b/inference/single_style_test_cvusa.sh new file mode 100644 index 0000000000000000000000000000000000000000..fc529e24d3c259aed28bb0e3b2073a7be85b952a --- /dev/null +++ b/inference/single_style_test_cvusa.sh @@ -0,0 +1,9 @@ +CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4 --sty_img=0001227.jpg +CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4 --sty_img=0044093.jpg +CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4 --sty_img=0015421.jpg +CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4 --sty_img=0040767.jpg +CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4 --sty_img=0014628.jpg +CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4 --sty_img=0027413.jpg +CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4 --sty_img=0021324.jpg +CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4 --sty_img=0013073.jpg +CUDA_VISIBLE_DEVICES=1 python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4 --sty_img=0028546.jpg diff --git a/inference/synthesis_video.sh b/inference/synthesis_video.sh new file mode 100644 index 0000000000000000000000000000000000000000..51424a8517ff8ded75821a724f0aa9fe1af9667a --- /dev/null +++ b/inference/synthesis_video.sh @@ -0,0 +1,16 @@ +### a demo for synthesis ground video +name = __-DFIFxvZBCn1873qkqXA_satView_polish.png + +# select point +# First find a starting point, then press and hold the left mouse button, draw any shape, +# then release the left mouse button, and press 'q' on the keyboard to end the point selection process + +# better select regions near the center of the satellite image. 'q' to end select point. +python inference/select_points.py ${name} + +# inference +# if you want use illumination from another image , you could add --sty_img=WsKPDHEgLwrhrJXcUU34xA_grdView.png +CUDA_VISIBLE_DEVICES=0 python offline_train_test.py --task=test_vid --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w --demo_img=${name} + +# make video +python img2vid.py \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/__pycache__/__init__.cpython-38.pyc b/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e731b482fc158f51e8a1b1b6b2621e4c0449ea8e Binary files /dev/null and b/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/model/__pycache__/base_model.cpython-38.pyc b/model/__pycache__/base_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9524dcddb0cc6363017d0e966e0117927837f5f Binary files /dev/null and b/model/__pycache__/base_model.cpython-38.pyc differ diff --git a/model/__pycache__/craft_feature.cpython-38.pyc b/model/__pycache__/craft_feature.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62b9f2cf7b5051cb8c45316bfd596a1067d65a44 Binary files /dev/null and b/model/__pycache__/craft_feature.cpython-38.pyc differ diff --git a/model/__pycache__/geometry_transform.cpython-38.pyc b/model/__pycache__/geometry_transform.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..921aee90ab4d8c80090f4ba55b13989d9849dbc4 Binary files /dev/null and b/model/__pycache__/geometry_transform.cpython-38.pyc differ diff --git a/model/base_model.py b/model/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bdab0f131e1881ed51578cd2027062dad910ab99 --- /dev/null +++ b/model/base_model.py @@ -0,0 +1,572 @@ +import os +import torch +from abc import ABC, abstractmethod +import wandb +import options +import utils +from pytorch_msssim import ssim, SSIM +import numpy as np +import torchvision +from tqdm import tqdm +import lpips +from imaginaire.losses import FeatureMatchingLoss, GaussianKLLoss, PerceptualLoss,GANLoss +import cv2 +from imaginaire.utils.trainer import get_scheduler +from .geometry_transform import render_sat +from model import geometry_transform +import csv + + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt,wandb=None): + """Initialize the BaseModel class. + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + When creating your custom class, you need to implement your own initialization. + In this function, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + """ + self.wandb = wandb + if opt.isTrain: + opt.save_dir =wandb.dir + options.save_options_file(opt,opt.save_dir) + self.opt = opt + self.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu) + # torch.backends.cudnn.benchmark = True + self.model_names = [] + self.train_loader = None + self.val_loader = None + self.sty_loader = None + self.loss_fn_alex = lpips.LPIPS(net='alex',eval_mode=True).cuda() + if opt.task=='test': + self.loss_fn_sque = lpips.LPIPS(net='squeeze',eval_mode=True).cuda() + self.mseloss = torch.nn.MSELoss(True,True) + self.criteria = {} + self.weights = {} + if hasattr(opt.optim.loss_weight, 'GaussianKL'): + if opt.optim.loss_weight.GaussianKL: + self.criteria['GaussianKL'] = GaussianKLLoss() + self.weights['GaussianKL'] = opt.optim.loss_weight.GaussianKL + if hasattr(opt.optim.loss_weight, 'L1'): + if opt.optim.loss_weight.L1: + self.criteria['L1'] = torch.nn.L1Loss(True,True) + self.weights['L1'] = opt.optim.loss_weight.L1 + if hasattr(opt.optim.loss_weight, 'L2'): + if opt.optim.loss_weight.L2: + self.criteria['L2'] = torch.nn.MSELoss(True,True) + self.weights['L2'] = opt.optim.loss_weight.L2 + if hasattr(opt.optim.loss_weight, 'SSIM'): + if opt.optim.loss_weight.SSIM: + self.criteria['SSIM'] = SSIM(data_range =1., size_average=True, channel=3) + self.weights['SSIM'] = opt.optim.loss_weight.SSIM + if hasattr(opt.optim.loss_weight, 'Perceptual'): + if opt.optim.loss_weight.Perceptual: + self.criteria['Perceptual'] = \ + PerceptualLoss( + network=opt.optim.perceptual_loss.mode, + layers=opt.optim.perceptual_loss.layers, + weights=opt.optim.perceptual_loss.weights).to(self.device) + self.weights['Perceptual'] = opt.optim.loss_weight.Perceptual + if hasattr(opt.optim.loss_weight, 'sky_inner'): + if opt.optim.loss_weight.sky_inner: + self.criteria['sky_inner'] = torch.nn.L1Loss(True,True) + self.weights['sky_inner'] = opt.optim.loss_weight.sky_inner + if hasattr(opt.optim.loss_weight, 'feature_matching'): + if opt.optim.loss_weight.feature_matching: + self.criteria['feature_matching'] = FeatureMatchingLoss() + self.weights['feature_matching'] = opt.optim.loss_weight.feature_matching + self.weights['GAN'] = opt.optim.loss_weight.GAN + self.criteria['GAN'] = GANLoss(gan_mode=opt.optim.gan_mode) + + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + def save_checkpoint(self,ep=0,latest=False): + """ + save trained models. + Args: + ep (int, optional): model epochs. Defaults to 0. + latest (bool, optional): qhether it is the latest model. Defaults to False. + """ + ckpt_save_path = os.path.join(self.wandb.dir,'checkpoint') + if not os.path.exists(ckpt_save_path): + os.mkdir(ckpt_save_path) + utils.save_checkpoint(self,ep=ep,latest=latest,output_path=ckpt_save_path) + if not latest: + print("checkpoint saved: {0}, epoch {1} ".format(self.opt.name,ep)) + + + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + + def setup_optimizer(self,opt): + # initialize optimizers; schedulers will be automatically created by function . + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.optim.lr_gen, betas=(opt.optim.beta1, 0.999),eps=1.e-7) + if opt.isTrain: + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.optim.lr_dis, betas=(opt.optim.beta1, 0.999)) + if opt.optim.lr_policy: + self.sch_G = get_scheduler(opt.optim, self.optimizer_G) + self.sch_D = get_scheduler(opt.optim, self.optimizer_D) + + def optimize_parameters(self,opt): + self.netG.train() + # update Discriminators + self.backward_D(opt) # calculate gradients for D + + # update Generator + self.backward_G(opt) # calculate graidents for G + + psnr1 = -10*self.mseloss(self.fake_B.detach(),self.real_B.detach()).log10().item() + ssim_ = ssim(self.real_B.detach().float(), self.fake_B.detach().float(),data_range=1.) + + out_dict = { + "train_ssim": ssim_, + "train_psnr1": psnr1, + } + # adjust learning rates according to schedule + if opt.optim.lr_policy: + out_dict["lr_D"]=self.sch_D.get_lr()[0] + out_dict["lr_G"]=self.sch_G.get_lr()[0] + out_dict.update(self.loss) + out_dict.update(self.dis_losses) + self.wandb.log(out_dict) + + def validation(self,opt): + """Used for validation and test in Center Ground-View Synthesis setting + + Args: + opt (_type_): option dict + """ + print(10*"*","validate",10*"*") + self.netG.eval() + # six image reconstruction metrics + psnr_val = [] + ssim_val = [] + lpips_ale_val = [] + lpips_squ_val = [] + rmse_val = [] + sd_val = [] + with torch.no_grad(): + # set the sky of all images with predefined sky histogram. + if opt.sty_img: + for _,data in enumerate(self.sty_loader): + self.set_input(data) + self.style_temp=self.sky_histc + break + + for _,data in enumerate(tqdm(self.val_loader,ncols=100)): + self.set_input(data) + # if true: use the sky of predefined image + # if false: use the sky of corresponding GT + if opt.sty_img: + self.sky_histc = self.style_temp + + self.forward(opt) + rmse = torch.sqrt(self.mseloss(self.fake_B*255.,self.real_B*255.)).item() + sd = sd_func(self.real_B,self.fake_B) + rmse_val.append(rmse) + sd_val.append(sd) + + psnr1 = -10*self.mseloss(self.fake_B,self.real_B).log10().item() + ssim_ = ssim(self.real_B, self.fake_B,data_range=1.).item() + lpips_ale = torch.mean(self.loss_fn_alex((self.real_B*2.)-1, (2.*self.fake_B)-1)).cpu() + if opt.task=='test': + lpips_sque = torch.mean(self.loss_fn_sque((self.real_B*2.)-1, (2.*self.fake_B)-1)).cpu() + lpips_squ_val.append(lpips_sque) + psnr_val.append(psnr1) + ssim_val.append(ssim_) + lpips_ale_val.append(lpips_ale) + + if opt.task in ['vis_test']: + if not os.path.exists(opt.vis_dir): + os.mkdir(opt.vis_dir) + + sat_opacity,sat_depth = render_sat(opt,self.out_put['voxel']) + + self.out_put['depth'] = (self.out_put['depth']/self.out_put['depth'].max())*255. + sat_depth = (sat_depth/sat_depth.max())*255. + for i in range(len(self.fake_B)): + depth_save = cv2.applyColorMap(self.out_put['depth'][i].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO) + depth_sat_save = cv2.applyColorMap(sat_depth[i].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO) + # cat generated ground images, GT ground images, predicted ground depth + torchvision.utils.save_image([self.fake_B[i].cpu(),self.real_B[i].cpu(),torch.flip(torch.from_numpy(depth_save).permute(2,0,1)/255.,[0])],os.path.join(opt.vis_dir,os.path.basename(self.image_paths[i]))) + # cat GT satellite images, predicted satellite depth + torchvision.utils.save_image( [self.real_A[i].cpu() ,torch.flip(torch.from_numpy(depth_sat_save).permute(2,0,1)/255.,[0])],os.path.join(opt.vis_dir,os.path.basename(self.image_paths[i]).rsplit('.', 1)[0]+'_sat.jpg')) + # ground opacity + torchvision.utils.save_image([self.out_put['opacity'][i]] ,os.path.join(opt.vis_dir,os.path.basename(self.image_paths[i]).rsplit('.', 1)[0]+'_sat.jpg')) + psnr_avg = np.average(psnr_val) + ssim_avg = np.average(ssim_val) + + lpips_ale_avg = np.average(lpips_ale_val) + if 'test' in opt.task: + lpips_squ_avg = np.average(lpips_squ_val) + + rmse_avg = np.average(rmse_val) + sd_avg = np.average(sd_val) + if opt.task in ["train" , "Train"]: + out_dict = { + 'val_psnr': psnr_avg, + 'val_ssim': ssim_avg, + 'val_lpips_ale':lpips_ale_avg, + 'val_rmse':rmse_avg, + 'val_sd':sd_avg + } + if opt.task=='test': + out_dict['val_lpips_squ'] = lpips_squ_avg + self.wandb.log(out_dict,commit=False) + else: + print( + { + 'val_rmse':rmse_avg, + 'val_ssim': ssim_avg, + 'val_psnr': psnr_avg, + 'val_sd':sd_avg, + 'val_lpips_ale':lpips_ale_avg, + 'val_lpips_squ':lpips_squ_avg, + } + ) + with open('test_output.csv', mode='a', newline='') as csv_file: + writer = csv.writer(csv_file) + writer.writerow([rmse_avg, ssim_avg, psnr_avg, sd_avg, lpips_ale_avg, lpips_squ_avg]) + + def test_vid(self,opt): + """Used for synthesis ground video + + Args: + opt (_type_): option dict + """ + ckpt_list = os.listdir('wandb/') + for i in ckpt_list: + if opt.test_ckpt_path in i: + ckpt_path = i + + ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG'] + print('load success!') + self.netG.load_state_dict(ckpt,strict=True) + self.netG.eval() + print(10*"*","test_video",10*"*") + + + pixels = [] + if os.path.exists('vis_video/pixels.csv'): + + with open('vis_video/pixels.csv', 'r') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + x = float(row['x']) #x is + y = float(row['y']) + pixels.append((x, y)) + else: + print('only render center point without vis_video/pixels.csv') + pixels = [(128,128)] + + if opt.sty_img: + # inference with illumination from other images + for idx,data in enumerate(self.sty_loader): + self.set_input(data) + self.style_temp=self.sky_histc + break + with torch.no_grad(): + for idx,data in enumerate(self.val_loader): + self.set_input(data) + if opt.sty_img: + self.sky_histc = self.style_temp + for i,(x,y) in enumerate(pixels): + opt.origin_H_W = [(y-128)/128 , (x-128)/128] + print(opt.origin_H_W) + self.forward(opt) + + + + if not os.path.exists('vis_video'): + os.mkdir('vis_video') + + # save voxel to visalize & satellite depth, works well on cvact + if i==0: + # pre-process for better visualize + volume_data = self.out_put.voxel.squeeze().cpu().numpy().transpose((1,2,0)) + volume_data = np.clip(volume_data, None, 10) + + import pyvista as pv + + grid = pv.UniformGrid() + grid.dimensions = volume_data.shape + grid.spacing = (1, 1, 1) + grid.origin = (0, 0, 0) + grid.point_data['values'] = volume_data.flatten(order='F') + grid.save(os.path.join('vis_video',"volume_data.vtk") ) # vtk file could be visualized by ParaView app + + sat_opacity,sat_depth = render_sat(opt,self.out_put['voxel']) + sat_depth = (2 - sat_depth)/(opt.data.max_height/15)*255. + depth_sat_save = cv2.applyColorMap(sat_depth[0].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO) + torchvision.utils.save_image(torch.flip(torch.from_numpy(depth_sat_save).permute(2,0,1)/255.,[0]) ,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png','_satdepth.png')) + torchvision.utils.save_image( [self.real_A[0].cpu() ] ,os.path.join('vis_video',os.path.basename(self.image_paths[0]).replace('.png','_sat.png'))) + torchvision.utils.save_image( [self.real_B[0].cpu() ] ,os.path.join('vis_video',os.path.basename(self.image_paths[0]).replace('.png','_pano.png'))) + + self.out_put['depth'] = (self.out_put['depth']/self.out_put['depth'].max())*255. + depth_save = cv2.applyColorMap(self.out_put['depth'][0].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO) + depth_save = torch.flip(torch.from_numpy(depth_save).permute(2,0,1)/255.,[0]) + + + save_img = self.out_put.pred[0].cpu() + name = '%05d' % int(i) + ".png" + torchvision.utils.save_image(save_img,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png',name)) + + save_img = depth_save + name = '%05d' % int(i) + "_depth.png" + torchvision.utils.save_image(save_img,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png',name)) + + # save_img = self.out_put.generator_inputs[0][:3,:,:] + # name = '%05d' % int(i) + "_color_project.png" + # torchvision.utils.save_image(save_img,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png',name)) + + def test_interpolation(self,opt): + """Used for test interpolation + + Args: + opt (_type_): option dict + """ + ckpt_list = os.listdir('wandb/') + for i in ckpt_list: + if opt.test_ckpt_path in i: + ckpt_path = i + + ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG'] + print('load success!') + self.netG.load_state_dict(ckpt,strict=True) + self.netG.eval() + + pixels = [(128,128)] + if opt.sty_img1: + for idx,data in enumerate(self.sty_loader1): + self.set_input(data) + self.style_temp1=self.sky_histc + break + if opt.sty_img2: + for idx,data in enumerate(self.sty_loader2): + self.set_input(data) + self.style_temp2=self.sky_histc + break + + with torch.no_grad(): + for idx,data in enumerate(self.val_loader): + self.set_input(data) + self.sky_histc1 = self.style_temp1 + self.sky_histc2 = self.style_temp2 + x,y = pixels[0] + opt.origin_H_W = [(y-128)/128 , (x-128)/128] + print(opt.origin_H_W) + + + estimated_height = self.netG.depth_model(self.real_A) + geo_outputs = geometry_transform.render(opt,self.real_A,estimated_height,self.netG.pano_direction,PE=self.netG.PE) + generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth'] + if self.netG.gen_cfg.cat_opa: + generator_inputs = torch.cat((generator_inputs,opacity),dim=1) + if self.netG.gen_cfg.cat_depth: + generator_inputs = torch.cat((generator_inputs,depth),dim=1) + _, _, z1 = self.netG.style_encode(self.sky_histc1) + _, _, z2 = self.netG.style_encode(self.sky_histc2) + num_inter = 60 + for i in range(num_inter): + z = z1 * (1-i/(num_inter-1)) + z2* (i/(num_inter-1)) + z = self.netG.style_model(z) + output_RGB = self.netG.denoise_model(generator_inputs,z) + + save_img = output_RGB.cpu() + name = 'img{:03d}.png'.format(i) + if not os.path.exists('vis_interpolation'): + os.mkdir('vis_interpolation') + torchvision.utils.save_image(save_img,os.path.join('vis_interpolation',name)) + + + + + + + def test_speed(self,opt): + self.netG.eval() + random_input = torch.randn(1, 3, 256, 256).to(opt.device) + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + iterations = 300 + + times = torch.zeros(iterations) + with torch.no_grad(): + for _ in range(50): + _ = self.netG(random_input,None,opt) + for iter in range(iterations): + starter.record() + _ = self.netG(random_input,None,opt) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) # 计算时间 + times[iter] = curr_time + # print(curr_time) + + mean_time = times.mean().item() + print("Inference time: {:.6f}, FPS: {} ".format(mean_time, 1000/mean_time)) + + + def test_sty(self,opt): + ckpt_list = os.listdir('wandb/') + for i in ckpt_list: + if opt.test_ckpt_path in i: + ckpt_path = i + + ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG'] + print('load success!') + self.netG.load_state_dict(ckpt,strict=True) + self.netG.eval() + print(10*"*","test_sty",10*"*") + self.netG.eval() + self.style_temp_list = [] + with torch.no_grad(): + num_val_loader = len(self.val_loader) + for i in range(num_val_loader): + for idx,data in enumerate(tqdm(self.val_loader,ncols=100)): + self.set_input(data) + + if i==0: + self.style_temp_list.append(self.sky_histc) + name = '%05d' % int(idx) + torchvision.utils.save_image( [self.real_A[0].cpu() ] ,os.path.join(opt.vis_dir,os.path.basename(self.image_paths[0]).replace('.png',name+'_sat.png'))) + self.sky_histc = self.style_temp_list[i] + self.forward(opt) + if not os.path.exists(opt.vis_dir): + os.mkdir(opt.vis_dir) + name = '%05d' % int(idx)+'_'+'%05d' % int(i) + name= name+ '.png' + torchvision.utils.save_image(self.fake_B[0].cpu(),os.path.join(opt.vis_dir, name)) + + def train(self,opt): + self.validation(opt) + for current_epoch in range(opt.max_epochs): + print(10*'-','current epoch is ',current_epoch,10*'-') + for idx,data in enumerate(tqdm(self.train_loader,ncols=100)): + self.set_input(data) + self.optimize_parameters(opt) + if idx%500==0 : + out_ing_dict = { + 'train_input': wandb.Image(self.real_A[0].float()), + 'train_pred_and_gt': wandb.Image(torch.cat([self.fake_B,self.real_B],2)[0].float()), + } + if hasattr(self.out_put, 'inter_RGB'): + out_ing_dict["train_inner_pred"] = wandb.Image(self.out_put.inter_RGB[0].float()) + if opt.arch.gen.transform_mode in ['volum_rendering']: + out_ing_dict['train_inner_opacity'] = wandb.Image(self.out_put.opacity[0].float()) + self.wandb.log(out_ing_dict,commit=False) + if opt.optim.lr_policy.iteration_mode: + self.sch_G.step() + self.sch_D.step() + if not opt.optim.lr_policy.iteration_mode: + self.sch_G.step() + self.sch_D.step() + self.validation(opt) + if current_epoch%5==0: + self.save_checkpoint(ep=current_epoch) + self.save_checkpoint(ep=current_epoch) + + def test(self,opt): + ckpt_list = os.listdir('wandb/') + for i in ckpt_list: + if '.zip' not in i: + if opt.test_ckpt_path in i: + ckpt_path = i + + ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG'] + print('load success!') + self.netG.load_state_dict(ckpt,strict=True) + # print(10*"*","validate",10*"*") + self.validation(opt) + print('if --task=vis_test,visible results will be saved,you can add "--vis_dir=xxx" to save in other dictionary',opt.vis_dir) + + + def _get_outputs(self, net_D_output, real=True): + r"""Return output values. Note that when the gan mode is relativistic. + It will do the difference before returning. + + Args: + net_D_output (dict): + real_outputs (tensor): Real output values. + fake_outputs (tensor): Fake output values. + real (bool): Return real or fake. + """ + + def _get_difference(a, b): + r"""Get difference between two lists of tensors or two tensors. + + Args: + a: list of tensors or tensor + b: list of tensors or tensor + """ + out = list() + for x, y in zip(a, b): + if isinstance(x, list): + res = _get_difference(x, y) + else: + res = x - y + out.append(res) + return out + + if real: + return net_D_output['real_outputs'] + else: + return net_D_output['fake_outputs'] + + +def sd_func(real, fake): + ''' + ref: page 6 in https://arxiv.org/abs/1511.05440 + ''' + dgt1 = torch.abs(torch.diff(real,dim=-2))[:, :, 1:, 1:-1] + dgt2 = torch.abs(torch.diff(real, dim=-1))[:, :, 1:-1, 1:] + dpred1 = torch.abs(torch.diff(fake, dim=-2))[:, :, 1:, 1:-1] + dpred2 = torch.abs(torch.diff(fake, dim=-1))[:, :, 1:-1, 1:] + return 10*torch.log10(1.**2/torch.mean(torch.abs(dgt1+dgt2-dpred1-dpred2))).cpu().item() \ No newline at end of file diff --git a/model/craft_feature.py b/model/craft_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..7e073bc3c05dd0e308336ef6b41fa8c621727b11 --- /dev/null +++ b/model/craft_feature.py @@ -0,0 +1,146 @@ +# from this import d +import torch +from .base_model import BaseModel +import importlib +from torch.utils.data import DataLoader +from easydict import EasyDict as edict + +class Model(BaseModel): + def __init__(self, opt, wandb=None): + + """Initialize the Generator. + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt,wandb) + self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] + + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + Parameters: + input (dict): include the data itself and its metadata information. + self.real_A: aerial images + self.real_B: ground images + self.image_paths: images paths of ground images + self.sky_mask: the sky mask of ground images + self.sky_histc: the histogram of selected sky + """ + self.real_A = input['sat' ].to(self.device) + self.real_B = input['pano'].to(self.device) if 'pano' in input else None # for testing + self.image_paths = input['paths'] + if self.opt.data.sky_mask: + self.sky_mask = input['sky_mask'].to(self.device) if 'sky_mask' in input else None # for testing + if self.opt.data.histo_mode and self.opt.data.sky_mask: + self.sky_histc = input['sky_histc'].to(self.device) if 'sky_histc' in input else None # for testing + else: self.sky_histc = None + + def forward(self,opt): + """Run forward pass; called by both functions and .""" + # origin_H_W is the inital localization of camera + if opt.task != 'test_vid': + opt.origin_H_W=None + if hasattr(opt.arch.gen,'style_inject'): + # replace the predicted sky with selected sky histogram + if opt.arch.gen.style_inject == 'histo': + self.out_put = self.netG(self.real_A,self.sky_histc.detach(),opt) + else: + raise Exception('Unknown style inject mode') + else: + self.out_put = self.netG(self.real_A,None,opt) + self.out_put = edict(self.out_put) + self.fake_B = self.out_put.pred + # perceptive image + + def backward_D(self,opt): + """Calculate GAN loss for the discriminator""" + self.optimizer_D.zero_grad() + self.netG.eval() + with torch.no_grad(): + self.forward(opt) + self.out_put.pred = self.out_put.pred.detach() + net_D_output = self.netD(self.real_B, self.out_put) + + output_fake = self._get_outputs(net_D_output, real=False) + output_real = self._get_outputs(net_D_output, real=True) + fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True) + true_loss = self.criteria['GAN'](output_real, True, dis_update=True) + self.dis_losses = dict() + self.dis_losses['GAN/fake'] = fake_loss + self.dis_losses['GAN/true'] = true_loss + self.dis_losses['DIS'] = fake_loss + true_loss + self.dis_losses['DIS'].backward() + self.optimizer_D.step() + + + def backward_G(self,opt): + self.optimizer_G.zero_grad() + self.loss = {} + self.netG.train() + self.forward(opt) + net_D_output = self.netD(self.real_B, self.out_put) + pred_fake = self._get_outputs(net_D_output, real=False) + self.loss['GAN'] = self.criteria['GAN'](pred_fake, True, dis_update=False) + if 'GaussianKL' in self.criteria: + self.loss['GaussianKL'] = self.criteria['GaussianKL'](self.out_put['mu'], self.out_put['logvar']) + if 'L1' in self.criteria: + self.loss['L1'] = self.criteria['L1'](self.real_B,self.fake_B) + if 'L2' in self.criteria: + self.loss['L2'] = self.criteria['L2'](self.real_B,self.fake_B) + if 'SSIM' in self.criteria: + self.loss['SSIM'] = 1-self.criteria['SSIM'](self.real_B, self.fake_B) + if 'GaussianKL' in self.criteria: + self.loss['GaussianKL'] = self.criteria['GaussianKL'](self.out_put['mu'], self.out_put['logvar']) + if 'sky_inner' in self.criteria: + self.loss['sky_inner'] = self.criteria['sky_inner'](self.out_put.opacity, 1-self.sky_mask) + if 'Perceptual' in self.criteria: + self.loss['Perceptual'] = self.criteria['Perceptual'](self.fake_B,self.real_B) + if 'feature_matching' in self.criteria: + self.loss['feature_matching'] = self.criteria['feature_matching'](net_D_output['fake_features'], net_D_output['real_features']) + self.loss_G = 0 + for key in self.loss: + self.loss_G += self.loss[key] * self.weights[key] + self.loss['total'] = self.loss_G + self.loss_G.backward() + self.optimizer_G.step() # udpate G's weights + + + def load_dataset(self,opt): + data = importlib.import_module("data.{}".format(opt.data.dataset)) + if opt.task in ["train", "Train"]: + train_data = data.Dataset(opt,"train",opt.data.train_sub) + + self.train_loader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,num_workers=opt.data.num_workers,drop_last=True) + self.len_train_loader = len(self.train_loader) + + val_data = data.Dataset(opt,"val") + opt.batch_size = 1 if opt.task in ["test" , "val","vis_test",'test_vid','test_sty'] else opt.batch_size + opt.batch_size = 1 if opt.task=='test_speed' else opt.batch_size + self.val_loader = DataLoader(val_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.data.num_workers) + self.len_val_loader = len(self.val_loader) + # you can select one random image as a style of all predicted skys + # if None, we use the corresponding style of GT + if opt.sty_img: + sty_data = data.Dataset(opt,sty_img = opt.sty_img) + self.sty_loader = DataLoader(sty_data,batch_size=1,num_workers=1,shuffle=False) + # The followings are only used for test the illumination interpolation. + if opt.sty_img1: + sty1_data = data.Dataset(opt,sty_img = opt.sty_img1) + self.sty_loader1 = DataLoader(sty1_data,batch_size=1,num_workers=1,shuffle=False) + if opt.sty_img2: + sty2_data = data.Dataset(opt,sty_img = opt.sty_img2) + self.sty_loader2 = DataLoader(sty2_data,batch_size=1,num_workers=1,shuffle=False) + + def build_networks(self, opt): + if 'imaginaire' in opt.arch.gen.netG: + lib_G = importlib.import_module(opt.arch.gen.netG) + self.netG = lib_G.Generator(opt).to(self.device) + else: + raise Exception('Unknown discriminator function') + + if opt.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc + if opt.arch.dis.netD == 'imaginaire.discriminators.multires_patch_pano': + lib_D = importlib.import_module(opt.arch.dis.netD) + self.netD = lib_D.Discriminator(opt.arch.dis).to(self.device) + else: + raise Exception('Unknown discriminator function') diff --git a/model/geometry_transform.py b/model/geometry_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a349a1cf9ee2ea858c3f608365a701583d3977 --- /dev/null +++ b/model/geometry_transform.py @@ -0,0 +1,272 @@ +import numpy as np +import torch,math +from PIL import Image +import torchvision +from easydict import EasyDict as edict + +def position_produce(opt): + depth_channel = opt.arch.gen.depth_arch.output_nc + if opt.optim.ground_prior: + depth_channel = depth_channel+1 + z_ = torch.arange(depth_channel)/depth_channel + x_ = torch.arange(opt.data.sat_size[1])/opt.data.sat_size[1] + y_ = torch.arange(opt.data.sat_size[0])/opt.data.sat_size[0] + Z,X,Y = torch.meshgrid(z_,x_,y_) + input = torch.cat((Z[...,None],X[...,None],Y[...,None]),dim=-1).to(opt.device) + pos = positional_encoding(opt,input) + pos = pos.permute(3,0,1,2) + return pos + +def positional_encoding(opt,input): # [B,...,N] + shape = input.shape + freq = 2**torch.arange(opt.arch.gen.PE_channel,dtype=torch.float32,device=opt.device)*np.pi # [L] + spectrum = input[...,None]*freq # [B,...,N,L] + sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L] + input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L] + input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL] + return input_enc + + + +def get_original_coord(opt): + ''' + pano_direction [X,Y,Z] x right,y up,z out + ''' + W,H = opt.data.pano_size + _y = np.repeat(np.array(range(W)).reshape(1,W), H, axis=0) + _x = np.repeat(np.array(range(H)).reshape(1,H), W, axis=0).T + + if opt.data.dataset in ['CVACT_Shi', 'CVACT', 'CVACThalf']: + _theta = (1 - 2 * (_x) / H) * np.pi/2 # latitude + elif opt.data.dataset in ['CVUSA']: + _theta = (1 - 2 * (_x) / H) * np.pi/4 + # _phi = math.pi* ( 1 -2* (_y)/W ) # longtitude + _phi = math.pi*( - 0.5 - 2* (_y)/W ) + axis0 = (np.cos(_theta)*np.cos(_phi)).reshape(H, W, 1) + axis1 = np.sin(_theta).reshape(H, W, 1) + axis2 = (-np.cos(_theta)*np.sin(_phi)).reshape(H, W, 1) + pano_direction = np.concatenate((axis0, axis1, axis2), axis=2) + return pano_direction + + +def render(opt,feature,voxel,pano_direction,PE=None): + ''' + render ground images from ssatellite images + + feature: B,C,H_sat,W_sat feature or a input RGB + voxel: B,N,H_sat,W_sat density of each grid + PE: whether add position encoding , default is None + pano_direction: pano ray direction by their definition + ''' + # pano_W,pano_H = opt.data.pano_size + sat_W,sat_H = opt.data.sat_size + BS = feature.size(0) + ##### get origin, sample point ,depth + + if opt.data.dataset =='CVACT_Shi': + origin_height=2 ## the height of photo taken in real world scale + realworld_scale = 30 ## the real world scale corresponding to [-1,1] regular cooridinate + elif opt.data.dataset == 'CVUSA': + origin_height=2 + realworld_scale = 55 + else: + assert Exception('Not implement yet') + + assert sat_W==sat_H + pixel_resolution = realworld_scale/sat_W #### pixel resolution of satellite image in realworld + + if opt.data.sample_total_length: + sample_total_length = opt.data.sample_total_length + else: sample_total_length = (int(max(np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(2)**2), \ + np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(opt.data.max_height-origin_height)**2))/pixel_resolution))/(sat_W/2) + + origin_z = torch.ones([BS,1])*(-1+(origin_height/(realworld_scale/2))) ### -1 is the loweast position in regular cooridinate + ##### origin_z: which can be definition by origin height + if opt.origin_H_W is None: ### origin_H_W is the photo taken space in regular coordinate + origin_H,origin_w = torch.zeros([BS,1]),torch.zeros([BS,1]) + else: + origin_H,origin_w = torch.ones([BS,1])*opt.origin_H_W[0],torch.ones([BS,1])*opt.origin_H_W[1] + origin = torch.cat([origin_w,origin_z,origin_H],dim=1).to(opt.device)[:,None,None,:] ## w,z,h, samiliar to NERF coordinate definition + sample_len = ((torch.arange(opt.data.sample_number)+1)*(sample_total_length/opt.data.sample_number)).to(opt.device) + ### sample_len: For sample distance is fixed, so we can easily calculate sample len along a way by max length and sample number + origin = origin[...,None] + pano_direction = pano_direction[...,None] ### the direction has been normalized + depth = sample_len[None,None,None,None,:] + sample_point = origin + pano_direction * depth #0.0000],-0.8667],0.0000 w,z,h + # x points right, y points up, z points backwards scene nerf + # ray_depth = sample_point-origin + + if opt.optim.ground_prior: + voxel = torch.cat([torch.ones(voxel.size(0),1,voxel.size(2),voxel.size(3),device=opt.device)*1000,voxel],1) + + # voxel[:,0,:,:] = 100 + N = voxel.size(1) + voxel_low = -1 + voxel_max = -1 + opt.data.max_height/(realworld_scale/2) ### voxel highest space in normal space + grid = sample_point.permute(0,4,1,2,3)[...,[0,2,1]] ### BS,NUM_point,W,H,3 + grid[...,2] = ((grid[...,2]-voxel_low)/(voxel_max-voxel_low))*2-1 ### grid_space change to sample space by scale the z space + grid = grid.float() ## [1, 300, 256, 512, 3] + + color_input = feature.unsqueeze(2).repeat(1, 1, N, 1, 1) + alpha_grid = torch.nn.functional.grid_sample(voxel.unsqueeze(1), grid) + + color_grid = torch.nn.functional.grid_sample(color_input, grid) + if PE is not None: + PE_grid = torch.nn.functional.grid_sample(PE[None,...], grid[:1,...]) + color_grid = torch.cat([color_grid,PE_grid.repeat(BS, 1, 1, 1, 1)],dim=1) + + depth_sample = depth.permute(0,1,2,4,3).view(1,-1,opt.data.sample_number,1) + feature_size = color_grid.size(1) + color_grid = color_grid.permute(0,3,4,2,1).view(BS,-1,opt.data.sample_number,feature_size) + alpha_grid = alpha_grid.permute(0,3,4,2,1).view(BS,-1,opt.data.sample_number) + intv = sample_total_length/opt.data.sample_number + output = composite(opt, rgb_samples=color_grid,density_samples=alpha_grid,depth_samples=depth_sample,intv = intv) + output['voxel'] = voxel + return output + +def composite(opt,rgb_samples,density_samples,depth_samples,intv): + """generate 2d ground images according to ray + + Args: + opt (_type_): option dict + rgb_samples (_type_): rgb (sampled from satellite image) belongs to the ray which start from the ground camera to world + density_samples (_type_): density (sampled from the predicted voxel of satellite image) belongs to the ray which start from the ground camera to world + depth_samples (_type_): depth of the ray which start from the ground camera to world + intv (_type_): interval of the ray's depth which start from the ground camera to world + + Returns: + 2d ground images (rgd, opacity, and depth) + """ + + sigma_delta = density_samples*intv # [B,HW,N] + alpha = 1-(-sigma_delta).exp_() # [B,HW,N] + T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)) .exp_() # [B,HW,N] + prob = (T*alpha)[...,None] # [B,HW,N,1] + # integrate RGB and depth weighted by probability + depth = (depth_samples*prob).sum(dim=2) # [B,HW,1] + rgb = (rgb_samples*prob).sum(dim=2) # [B,HW,3] + opacity = prob.sum(dim=2) # [B,HW,1] + depth = depth.permute(0,2,1).view(depth.size(0),-1,opt.data.pano_size[1],opt.data.pano_size[0]) + rgb = rgb.permute(0,2,1).view(rgb.size(0),-1,opt.data.pano_size[1],opt.data.pano_size[0]) + opacity = opacity.view(opacity.size(0),1,opt.data.pano_size[1],opt.data.pano_size[0]) + return {'rgb':rgb,'opacity':opacity,'depth':depth} + + +def get_sat_ori(opt): + W,H = opt.data.sat_size + y_range = (torch.arange(H,dtype=torch.float32,)+0.5)/(0.5*H)-1 + x_range = (torch.arange(W,dtype=torch.float32,)+0.5)/(0.5*H)-1 + Y,X = torch.meshgrid(y_range,x_range) + Z = torch.ones_like(Y) + xy_grid = torch.stack([X,Z,Y],dim=-1)[None,:,:] + return xy_grid + +def render_sat(opt,voxel): + ''' + voxel: voxel has been processed + ''' + # pano_W,pano_H = opt.data.pano_size + sat_W,sat_H = opt.data.sat_size + sat_ori = get_sat_ori(opt) + sat_dir = torch.tensor([0,-1,0])[None,None,None,:] + + ##### get origin, sample point ,depth + if opt.data.dataset =='CVACT_Shi': + origin_height=2 + realworld_scale = 30 + elif opt.data.dataset == 'CVUSA': + origin_height=2 + realworld_scale = 55 + + else: + assert Exception('Not implement yet') + + pixel_resolution = realworld_scale/sat_W #### pixel resolution of satellite image in realworld + # if opt.data.sample_total_length: + # sample_total_length = opt.data.sample_total_length + # else: sample_total_length = (int(max(np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(2)**2), \ + # np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(opt.data.max_height-origin_height)**2))/pixel_resolution))/(sat_W/2) + sample_total_length = 2 + # #### sample_total_length: it can be definition in future, which is the farest length between sample point and original ponit + # assert sat_W==sat_H + + origin = sat_ori.to(opt.device) ## w,z,h, samiliar to NERF coordinate definition + sample_len = ((torch.arange(opt.data.sample_number)+1)*(sample_total_length/opt.data.sample_number)).to(opt.device) + ### sample_len: For sample distance is fixed, so we can easily calculate sample len along a way by max length and sample number + origin = origin[...,None].to(opt.device) + direction = sat_dir[...,None].to(opt.device) ### the direction has been normalized + depth = sample_len[None,None,None,None,:] + sample_point = origin + direction * depth #0.0000],-0.8667],0.0000 w,z,h + + + N = voxel.size(1) + voxel_low = -1 + voxel_max = -1 + opt.data.max_height/(realworld_scale/2) ### voxel highest space in normal space + # axis_voxel = (torch.arange(N)/N) * (voxel_max-voxel_low) +voxel_low + grid = sample_point.permute(0,4,1,2,3)[...,[0,2,1]] ### BS,NUM_point,W,H,3 + grid[...,2] = ((grid[...,2]-voxel_low)/(voxel_max-voxel_low))*2-1 ### grid_space change to sample space by scale the z space + grid = grid.float() ## [1, 300, 256, 512, 3] + alpha_grid = torch.nn.functional.grid_sample(voxel.unsqueeze(1), grid) + + depth_sample = depth.permute(0,1,2,4,3).view(1,-1,opt.data.sample_number,1) + alpha_grid = alpha_grid.permute(0,3,4,2,1).view(opt.batch_size,-1,opt.data.sample_number) + # color_grid = torch.flip(color_grid,[2]) + # alpha_grid = torch.flip(alpha_grid,[2]) + intv = sample_total_length/opt.data.sample_number + output = composite_sat(opt,density_samples=alpha_grid,depth_samples=depth_sample,intv = intv) + return output['opacity'],output['depth'] + +def composite_sat(opt,density_samples,depth_samples,intv): + sigma_delta = density_samples*intv # [B,HW,N] + alpha = 1-(-sigma_delta).exp_() # [B,HW,N] + T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)) .exp_() # [B,HW,N] + prob = (T*alpha)[...,None] # [B,HW,N,1] + depth = (depth_samples*prob).sum(dim=2) # [B,HW,1] + opacity = prob.sum(dim=2) # [B,HW,1] + depth = depth.permute(0,2,1).view(depth.size(0),-1,opt.data.sat_size[1],opt.data.sat_size[0]) + opacity = opacity.view(opacity.size(0),1,opt.data.sat_size[1],opt.data.sat_size[0]) + # return rgb,depth,opacity,prob # [B,HW,K] + return {'opacity':opacity,'depth':depth} + +if __name__ == '__main__': + # test_demo + opt=edict() + opt.device = 'cuda' + opt.data = edict() + opt.data.pano_size = [512,256] + opt.data.sat_size = [256,256] + opt.data.dataset = 'CVACT_Shi' + opt.data.max_height = 20 + opt.data.sample_number = 300 + opt.arch = edict() + opt.optim = edict() + opt.optim.ground_prior = False + opt.arch.gen.transform_mode = 'volum_rendering' + # opt.arch.gen.transform_mode = 'proj_like_radus' + BS = 1 + opt.data.sample_total_length = 1 + sat_name = './CVACT/satview_correct/__-DFIFxvZBCn1873qkqXA_satView_polish.png' + a = Image.open(sat_name) + a = np.array(a).astype(np.float32) + a = torch.from_numpy(a) + a = a.permute(2, 0, 1).unsqueeze(0).to(opt.device).repeat(BS,1,1,1)/255. + + + pano = sat_name.replace('satview_correct','streetview').replace('_satView_polish','_grdView') + pano = np.array(Image.open(pano)).astype(np.float32) + pano = torch.from_numpy(pano) + pano = pano.permute(2, 0, 1).unsqueeze(0).to(opt.device).repeat(BS,1,1,1)/255. + voxel=torch.zeros([BS, 65, 256, 256]).to(opt.device) + pano_direction = torch.from_numpy(get_original_coord(opt)).unsqueeze(0).to(opt.device) + + import time + star = time.time() + with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=False) as prof: + rgb,opacity =render(opt,a,voxel,pano_direction) + print(prof.table()) + + print(time.time()-star) + + torchvision.utils.save_image(torch.cat([rgb,pano],2), opt.arch.gen.transform_mode + '.png') + print( opt.arch.gen.transform_mode + '.png') + torchvision.utils.save_image(opacity, 'opa.png') \ No newline at end of file diff --git a/model/sample.py b/model/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..2ddfa80dbee7d1f54a76a9b39fa17c428bac4f81 --- /dev/null +++ b/model/sample.py @@ -0,0 +1,110 @@ +import cv2 +import numpy as np +import torch +import torchvision + +class Equirectangular(): + """ + Random sample a panorama image into a perspective view + take https://github.com/fuenwang/Equirec2Perspec/blob/master/Equirec2Perspec.py as a reference + """ + def __init__(self, width = 256, height = 256, FovX = 100, theta = [0, 0]): + """ + width: output image's width + height: output image's height + FovX: perspective camera FOV on x-axis (degree) + theta: theta field where img's theta degree from + """ + self.theta = theta + self.width = width + self.height = height + self.type = type + + #create x-axis coordinates and corresponding y-axis coordinates + x = np.arange(width) + y = np.arange(height) + x, y = np.meshgrid(x, y) + + #create homogenerous coordinates + z = np.ones_like(x) + xyz = np.concatenate([x[..., None], y[..., None], z[..., None]], axis=-1) + + #translation matrix + f = 0.5 * width * 1 / np.tan(np.radians(FovX/2)) + # cx = (width - 1) / 2.0 + # cy = (height - 1) / 2.0 + cx = (width) / 2.0 + cy = (height) / 2.0 + K = np.array([ + [f, 0, cx], + [0, f, cy], + [0, 0, 1], + ], np.float32) + K_inv = np.linalg.inv(K) + xyz = xyz @ K_inv.T + self.xyz = xyz ### self.xyz is the direction of the each ray in the camera space when camera is fixed + + + + def __call__(self, img1): + batch = img1.shape[0] + PHI, THETA = self.getRandomRotation(batch) + y_axis = np.array([0.0, 1.0, 0.0], np.float32) + x_axis = np.array([1.0, 0.0, 0.0], np.float32) + #rotation matrix + xy_grid = [] + for i in range(batch): + R1, _ = cv2.Rodrigues(y_axis * np.radians(PHI[i])) + R2, _ = cv2.Rodrigues(np.dot(R1, x_axis) * np.radians(THETA[i])) + R = R2 @ R1 + #rotate + xyz = self.xyz @ R.T ### ### xyz is the direction of the each ray in the camera space when camera is rotate + norm = np.linalg.norm(xyz, axis=-1, keepdims=True) + xyz_norm = xyz / norm + + #transfer to image coordinates + xy = self.xyz2xy(xyz_norm) + device = img1.device + xy = torch.from_numpy(xy).to(device).unsqueeze(0) + xy_grid.append(xy) + xy = torch.cat(xy_grid,dim=0) + + #resample + return xy + + def xyz2xy(self, xyz_norm): + #normlize + x = xyz_norm[..., 0] + y = xyz_norm[..., 1] + z = xyz_norm[..., 2] + + lon = np.arctan2(x, z) + lat = np.arcsin(y) + ### transfer to the lon and lat + + X = lon / (np.pi) + Y = lat / (np.pi) * 2 + xy = np.stack([X, Y], axis=-1) + xy = xy.astype(np.float32) + + return xy + + def getRandomRotation(self,batch_size): + # phi = np.random.rand(batch_size) * 360 -180 + phi = np.random.randint(-180,180,batch_size) + assert(self.theta[0] value + --key1.key2.key3= --> None + --key1.key2.key3 --> True + --key1.key2.key3! --> False + """ + opt_cmd = {} + for arg in args: + assert(arg.startswith("--")) + if "=" not in arg[2:]: + key_str,value = (arg[2:-1],"false") if arg[-1]=="!" else (arg[2:],"true") + else: + key_str,value = arg[2:].split("=") + keys_sub = key_str.split(".") + opt_sub = opt_cmd + for k in keys_sub[:-1]: + if k not in opt_sub: opt_sub[k] = {} + opt_sub = opt_sub[k] + assert keys_sub[-1] not in opt_sub,keys_sub[-1] + opt_sub[keys_sub[-1]] = yaml.safe_load(value) + opt_cmd = edict(opt_cmd) + return opt_cmd + +def set(opt_cmd={}): + log.info("setting configurations...") + # load config from yaml file + assert("yaml" in opt_cmd) + fname = "options/{}.yaml".format(opt_cmd.yaml) + opt_base = load_options(fname) + # override with command line arguments + opt = override_options(opt_base,opt_cmd,key_stack=[],safe_check=True) + process_options(opt) + log.options(opt) + return opt + +def load_options(fname): + with open(fname) as file: + opt = edict(yaml.safe_load(file)) + if "_parent_" in opt: + # load parent yaml file(s) as base options + parent_fnames = opt.pop("_parent_") + if type(parent_fnames) is str: + parent_fnames = [parent_fnames] + for parent_fname in parent_fnames: + opt_parent = load_options(parent_fname) + opt_parent = override_options(opt_parent,opt,key_stack=[]) + opt = opt_parent + print("loading {}...".format(fname)) + return opt + +def override_options(opt,opt_over,key_stack=None,safe_check=False): + for key,value in opt_over.items(): + print(key,value) + if isinstance(value,dict): + # parse child options (until leaf nodes are reached) + opt[key] = override_options(opt.get(key,dict()),value,key_stack=key_stack+[key],safe_check=safe_check) + else: + # ensure command line argument to override is also in yaml file + if safe_check and key not in opt: + add_new = None + while add_new not in ["y","n"]: + key_str = ".".join(key_stack+[key]) + add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str)) + if add_new=="n": + print("safe exiting...") + exit() + opt[key] = value + return opt + +def process_options(opt): + # set seed + if opt.seed is not None: + random.seed(opt.seed) + np.random.seed(opt.seed) + torch.manual_seed(opt.seed) + torch.cuda.manual_seed_all(opt.seed) + else: + # create random string as run ID + randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4)) + opt.name = str(opt.name)+"_{}".format(randkey) + assert(isinstance(opt.gpu,int)) # disable multi-GPU support for now, single is enough + opt.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu) + +def save_options_file(opt,output_path): + opt_fname = "{}/options.yaml".format(output_path) + if os.path.isfile(opt_fname): + with open(opt_fname) as file: + opt_old = yaml.safe_load(file) + if opt!=opt_old: + # prompt if options are not identical + opt_new_fname = "{}/options_temp.yaml".format(output_path) + with open(opt_new_fname,"w") as file: + yaml.safe_dump(utils.to_dict(opt),file,default_flow_style=False,indent=4) + print("existing options file found (different from current one)...") + os.system("diff {} {}".format(opt_fname,opt_new_fname)) + os.system("rm {}".format(opt_new_fname)) + override = None + while override not in ["y","n"]: + override = input("override? (y/n) ") + if override=="n": + print("safe exiting...") + exit() + else: print("existing options file found (identical)") + else: print("(creating new options file...)") + with open(opt_fname,"w") as file: + yaml.safe_dump(utils.to_dict(opt),file,default_flow_style=False,indent=4) diff --git a/options/base.yaml b/options/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a7a95869c8e09d9afb0990f3cf8df5c362fb68e0 --- /dev/null +++ b/options/base.yaml @@ -0,0 +1,57 @@ +project: 0_test +name: +model: shi +yaml: +seed: 6004 +task: Train +Group: DEBUG +gpu: 0 +cpu: false +load: +arch: {} +test_ckpt_path: +demo_img: __-DFIFxvZBCn1873qkqXA_grdView.png + +# for testing +sty_img: + +sky_img: + +# only for illumination interpolation visualization +sty_img1: +sty_img2: +sky_img1: +sky_img2: + +data: + sky_mask: + root: + dataset: + num_workers: 24 + histo_mode: + sample_total_length: + train_sub: + +optim: + lr_gen: 0.0001 # learning rate (main) + lr_dis: 0.0004 + gan_mode: hinge + beta1: 0 + ground_prior: + loss_weight: + GAN: 1 + L1: 1 + perceptual_loss: + mode: 'vgg19' + layers: ['relu_3_1', 'relu_4_1', 'relu_5_1'] + weights: [0.125, 0.25, 1.0] + lr_policy: + iteration_mode: False + type: step + step_size: 13 + gamma: 0.1 +batch_size: 16 # batch size +resume: false # not test +fp16: # not test +vis_dir: 'vis' +max_epochs: 30 diff --git a/options/sat2density_cvact.yaml b/options/sat2density_cvact.yaml new file mode 100644 index 0000000000000000000000000000000000000000..87c7de7ca0e77a0b21e758fd48bf87e04f9adb4d --- /dev/null +++ b/options/sat2density_cvact.yaml @@ -0,0 +1,91 @@ +_parent_: options/base.yaml + +gpu_ids: '0' + +## config for wandb +project: 'sat2pano' +Group: 'craft_feature' + + +model: craft_feature # model/craft_feature +arch: + gen: ## config for generator + netG: imaginaire.generators.craft_2stage_add_style + weight_norm_type: spectral + activation_norm_type: instance + padding_mode: reflect + transform_mode: volum_rendering + feature_model: + style_inject: histo # use histogram to inject illumination, chose list [histo, perspective] + cat_PE: + cat_opa: true + cat_depth: true + depth_arch: # Density Net + name: depth + num_filters: 32 + num_downsamples: 4 + num_res_blocks: 6 + output_nc: 64 + render_arch: # Render Net + name: render + num_filters: 64 + num_downsamples: 4 + num_res_blocks: 9 + output_nc: 3 + style_enc_cfg: # style injection + input_image_channels: 3 + num_filters: 256 + kernel_size: 3 + style_dims: 128 + interm_style_dims: 256 + hidden_channel: 256 + weight_norm_type: spectral + dis: # discriminator + netD: imaginaire.discriminators.multires_patch_pano + num_filters: 64 + max_num_filters: 512 + num_discriminators: 3 + num_layers: 3 + weight_norm_type: spectral + activation_norm_type: instance + + + +data: # data options + dataset: CVACT_Shi # dataset name + root: ./dataset/CVACT/ + sat_size: [256,256] + pano_size: [512, 128] + sample_number: 100 # points per ray + max_height: 8 # pre-defined density space in height axis + sky_mask: true + histo_mode: rgb + # val: + # sub: 500 + + +optim: + lr_gen: 0.00005 + lr_dis: 0.00005 + gan_mode: non_saturated #'hinge', 'least_square', 'non_saturated', 'wasserstein' + loss_weight: + L1: 1 + L2: 10 + GaussianKL: 0.1 + feature_matching: 10.0 + Perceptual: 10 + sky_inner: 1 + GAN: 1 + + lr_policy: + iteration_mode: False # iteration or epoch + type: step + step_size: 45 + gamma: 0.1 + + ground_prior: true + +######## for test, if only style, will random choice one style for save dir +only_style: +only_img: +save_dir: \ No newline at end of file diff --git a/options/sat2density_cvusa.yaml b/options/sat2density_cvusa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..398aecc4a2d73763933cc7eb2c36ff3d84b396cb --- /dev/null +++ b/options/sat2density_cvusa.yaml @@ -0,0 +1,10 @@ +_parent_: options/sat2density_cvact.yaml + + +arch: + dis: + num_discriminators: 2 + +data: + dataset: CVUSA + root: ./dataset/CVUSA/ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3db762bf06fb5e9eb3016f8649fe0306bb84e599 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +cmake +numpy==1.22.1 +Pillow>=8.3.2 +scipy +scikit-image +tqdm==4.35.0 +cython +qimage2ndarray +requests==2.25.1 +tensorboard +lpips +easydict +termcolor +wandb==0.13.3 +pytorch_msssim +opencv-contrib-python==4.6.0.66 +albumentations +pyvista \ No newline at end of file diff --git a/scripts/INSTALL.md b/scripts/INSTALL.md new file mode 100644 index 0000000000000000000000000000000000000000..7f2f095ebf5e18fb6d369c7e7b2a8285dd854cfa --- /dev/null +++ b/scripts/INSTALL.md @@ -0,0 +1,11 @@ +1. Andconda +2. install Cuda11.1 and cudnn (requested by [imaginaire](https://github.com/NVlabs/imaginaire)) + If you are not root user, you can install cuda in you '~' path +3. make sure the 'nvcc -V' \& 'cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2' + have a right output +4. apt-get update \& apt-get install cmake (if you are not in docker content, you should add 'sudo before every apt-get') +5. conda activate your-env-name (make sure python>3.8) +6. pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu111 + or https://mirror.sjtu.edu.cn/pytorch-wheels/cu111/torch_stable.html +7. test 'torch.cuda.is_available()' +8. bash scripts/install \ No newline at end of file diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh new file mode 100644 index 0000000000000000000000000000000000000000..a6adb8ed2a92efb3931f8e57cc49f5b5974823a2 --- /dev/null +++ b/scripts/build_docker.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +key=${1} + +rm -rf Dockerfile +echo "FROM nvcr.io/nvidia/pytorch:${key}-py3" > Dockerfile +input="Dockerfile.base" + +while IFS= read -r line +do + echo "$line" >> Dockerfile +done < "$input" + +input="scripts/requirements.txt" +while IFS= read -r line +do + echo "RUN pip install $line" >> Dockerfile +done < "$input" + + +for p in correlation channelnorm resample2d bias_act upfirdn2d; do + echo "COPY imaginaire/third_party/$p $p" >> Dockerfile + echo "RUN cd $p && rm -rf build dist *-info && python setup.py install" >> Dockerfile +done + +# Compile GANcraft libraries. +echo "COPY imaginaire/model_utils/gancraft/voxlib gancraft/voxlib" >> Dockerfile +echo "RUN cd gancraft/voxlib && make" >> Dockerfile + +docker build -t nvcr.io/nvidian/lpr-imagine/imaginaire:${key}-py3 . diff --git a/scripts/build_index.py b/scripts/build_index.py new file mode 100644 index 0000000000000000000000000000000000000000..427701dc963c4d6afbd5e5a1301de54da1687889 --- /dev/null +++ b/scripts/build_index.py @@ -0,0 +1,67 @@ +import argparse +import json +import os +import sys + +sys.path.append('.') +from imaginaire.utils.lmdb import create_metadata # noqa: E402 +from imaginaire.config import Config # noqa: E402 + + +def parse_args(): + r"""Parse user input arguments""" + parser = argparse.ArgumentParser(description='Folder -> LMDB conversion') + parser.add_argument('--data_root', type=str, required=True, + help='Input data location.') + parser.add_argument('--output_root', type=str, default='', + help='Input data location.') + parser.add_argument('--config', type=str, required=True, + help='Config with label info.') + parser.add_argument('--paired', default=False, action='store_true', + help='Is the input data paired?') + parser.add_argument('--input_list', type=str, default='', + help='list of images that will be used.') + args = parser.parse_args() + return args + + +def main(): + r""" Build lmdb for training/testing. + Usage: + python scripts/build_index.py \ + --data_root /mnt/bigdata01/datasets/test_image \ + --output_root /mnt/bigdata01/datasets/test_image/lmdb_0/ \ + --overwrite + """ + args = parse_args() + if args.output_root == '': + args.output_root = args.data_root + cfg = Config(args.config) + + all_filenames, extensions = \ + create_metadata( + data_root=args.data_root, + cfg=cfg, + paired=args.paired, + input_list=args.input_list) + + os.makedirs(args.output_root, exist_ok=True) + + if args.paired: + base = args.data_root.split('/')[-1] + new_all_filenames = dict() + for key in all_filenames.keys(): + new_all_filenames['{}/{}'.format(base, key)] = all_filenames[key] + all_filenames = new_all_filenames.copy() + + # Output list of all filenames. + with open(args.output_root + '/all_filenames.json', 'w') as fout: + json.dump(all_filenames, fout, indent=4) + + # Output metadata. + with open(args.output_root + '/metadata.json', 'w') as fout: + json.dump(extensions, fout, indent=4) + + +if __name__ == "__main__": + main() diff --git a/scripts/build_lmdb.py b/scripts/build_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..544afaa7b2b77f78f64c9e82ad60278437b00d21 --- /dev/null +++ b/scripts/build_lmdb.py @@ -0,0 +1,125 @@ +import copy +import shutil +import argparse +import json +import sys +import os +from tqdm import tqdm + +sys.path.append('.') +from imaginaire.utils.lmdb import create_metadata, \ + construct_file_path, check_and_add, build_lmdb # noqa: E402 +from imaginaire.config import Config # noqa: E402 + + +def parse_args(): + r"""Parse user input arguments""" + parser = argparse.ArgumentParser(description='Folder -> LMDB conversion') + parser.add_argument('--data_root', type=str, required=True, + help='Input data location.') + parser.add_argument('--config', type=str, required=True, + help='Config with label info.') + parser.add_argument('--output_root', type=str, required=True, + help='Output LMDB location') + parser.add_argument('--input_list', type=str, default='', + help='list of images that will be used.') + parser.add_argument('--metadata_factor', type=float, default=0.75, + help='Factor of filesize to allocate for metadata?') + parser.add_argument('--overwrite', default=False, action='store_true', + help='Overwrite output file if exists') + parser.add_argument('--paired', default=False, action='store_true', + help='Is the input data paired?') + parser.add_argument('--large', default=False, action='store_true', + help='Is the dataset large?') + parser.add_argument('--remove_missing', default=False, action='store_true', + help='Remove missing files from paired datasets?') + args = parser.parse_args() + return args + + +def main(): + r""" Build lmdb for training/testing. + Usage: + python scripts/build_lmdb.py \ + --config configs/data_image.yaml \ + --data_root /mnt/bigdata01/datasets/test_image \ + --output_root /mnt/bigdata01/datasets/test_image/lmdb_0/ \ + --overwrite + """ + args = parse_args() + cfg = Config(args.config) + + # Check if output file already exists. + if os.path.exists(args.output_root): + if args.overwrite: + print('Deleting existing output LMDB.') + shutil.rmtree(args.output_root) + else: + print('Output root LMDB already exists. Use --overwrite. ' + + 'Exiting...') + return + + all_filenames, extensions = \ + create_metadata(data_root=args.data_root, + cfg=cfg, + paired=args.paired, + input_list=args.input_list) + required_data_types = cfg.data.data_types + + # Build LMDB. + os.makedirs(args.output_root) + for data_type in required_data_types: + data_size = 0 + print('Data type:', data_type) + filepaths, keys = [], [] + print('>> Building file list.') + + # Get appropriate list of files. + if args.paired: + filenames = all_filenames + else: + filenames = all_filenames[data_type] + + for sequence in tqdm(filenames): + for filename in copy.deepcopy(filenames[sequence]): + filepath = construct_file_path( + args.data_root, data_type, sequence, filename, + extensions[data_type]) + key = '%s/%s' % (sequence, filename) + filesize = check_and_add(filepath, key, filepaths, keys, + remove_missing=args.remove_missing) + + # Remove file from list, if missing. + if filesize == -1 and args.paired and args.remove_missing: + print('Removing %s from list' % (filename)) + filenames[sequence].remove(filename) + data_size += filesize + + # Remove empty sequences. + if args.paired and args.remove_missing: + for sequence in copy.deepcopy(all_filenames): + if not all_filenames[sequence]: + all_filenames.pop(sequence) + + # Allocate size. + data_size = max(int((1 + args.metadata_factor) * data_size), 1e9) + print('Reserved size: %s, %dGB' % (data_type, data_size // 1e9)) + + # Write LMDB to file. + output_filepath = os.path.join(args.output_root, data_type) + build_lmdb(filepaths, keys, output_filepath, data_size, args.large) + + # Output list of all filenames. + if args.output_root: + with open(args.output_root + '/all_filenames.json', 'w') as fout: + json.dump(all_filenames, fout, indent=4) + + # Output metadata. + with open(args.output_root + '/metadata.json', 'w') as fout: + json.dump(extensions, fout, indent=4) + else: + return all_filenames, extensions + + +if __name__ == "__main__": + main() diff --git a/scripts/build_lmdb.sh b/scripts/build_lmdb.sh new file mode 100644 index 0000000000000000000000000000000000000000..a749cd18124634cb2de57433f1b9d16ca291919c --- /dev/null +++ b/scripts/build_lmdb.sh @@ -0,0 +1,9 @@ +MODEL=$1 +DATASET=$2 + +for SPLIT in test train; do + RAW=dataset/${DATASET}_raw/${SPLIT} + LMDB=dataset/${DATASET}/${SPLIT} + echo ${LMDB} + python scripts/build_lmdb.py --config configs/projects/${MODEL}/${DATASET}/ampO1.yaml --data_root ${RAW} --output_root ${LMDB} --overwrite +done \ No newline at end of file diff --git a/scripts/download_dataset.py b/scripts/download_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..156910fca512b2e789236e2e1fe94bcf5ff90473 --- /dev/null +++ b/scripts/download_dataset.py @@ -0,0 +1,49 @@ +import argparse +import os +import tarfile +import sys + +sys.path.append('.') +from imaginaire.utils.io import download_file_from_google_drive # noqa: E402 + + +def parse_args(): + parser = argparse.ArgumentParser(description='Download and process dataset') + parser.add_argument('--dataset', help='Name of the dataset.', required=True, + choices=['afhq_dog2cat', + 'animal_faces']) + parser.add_argument('--data_dir', default='./dataset', + help='Directory to save all datasets.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + if args.dataset == 'afhq_dog2cat': + url = '1XaiwS0eRctqm-JEDezOBy4TXriAQgc4_' + elif args.dataset == 'animal_faces': + url = '1ftr1xWm0VakGlLUWi7-hdAt9W37luQOA' + else: + raise ValueError('Invalid dataset {}.'.format(args.dataset)) + + # Create the dataset directory. + if not os.path.exists(args.data_dir): + os.makedirs(args.data_dir) + + # Download the compressed dataset. + folder_path = os.path.join(args.data_dir, args.dataset + '_raw') + compressed_path = folder_path + '.tar.gz' + if not os.path.exists(compressed_path) and not os.path.exists(folder_path): + print("Downloading the dataset {}.".format(args.dataset)) + download_file_from_google_drive(url, compressed_path) + + # Extract the dataset. + if not os.path.exists(folder_path): + print("Extracting the dataset {}.".format(args.dataset)) + with tarfile.open(compressed_path) as tar: + tar.extractall(folder_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/download_test_data.py b/scripts/download_test_data.py new file mode 100644 index 0000000000000000000000000000000000000000..c8cede6105c60509ed38fc7d7b473a5c782044a3 --- /dev/null +++ b/scripts/download_test_data.py @@ -0,0 +1,60 @@ +import argparse +import os +import sys +import tarfile +sys.path.append('.') +from imaginaire.utils.io import download_file_from_google_drive # noqa: E402 + +URLS = { + 'pix2pixhd': '1Xg9m184zkuG8H0LHdBtSzt2VbMi3SWwR', + 'spade': '1ESm-gHWu_aMHnKF42qkGc8qf1SBECsgf', + 'funit': '1a-EE_6RsYPUoKxEl5oXrpRmKYUltqaD-', + 'coco_funit': '1JYVYB0Q1VStDLOb0SBJbN1vkaf6KrGDh', + 'unit': '17BbwnCG7qF7FI-t9VkORv2XCKqlrY1CO', + 'munit': '1VPgHGuQfmm1N1Vh56wr34wtAwaXzjXtH', + 'vid2vid': '1SHvGPMq-55GDUQ0Ac2Ng0eyG5xCPeKhc', + 'fs_vid2vid': '1fTj0HHjzcitgsSeG5O_aWMF8yvCQUQkN', + 'wc_vid2vid/cityscapes': '1KKzrTHfbpBY9xtLqK8e3QvX8psSdrFcD', + 'wc_vid2vid/mannequin': '1mafZf9KJrwUGGI1kBTvwgehHSqP5iaA0', + 'gancraft': '1m6q7ZtYJjxFL0SQ_WzMbvoLZxXmI5_vJ', +} + + +def parse_args(): + parser = argparse.ArgumentParser(description='Download test data.') + parser.add_argument('--model_name', required=True, + help='Name of the model.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + test_data_dir = 'projects/' + args.model_name + '/test_data' + print(test_data_dir) + assert args.model_name in URLS, 'No sample test data available' + url = URLS[args.model_name] + + if os.path.exists(test_data_dir): + print('Test data exists at', test_data_dir) + compressed_path = test_data_dir + '.tar.gz' + # Extract the dataset. + print('Extracting test data to', test_data_dir) + with tarfile.open(compressed_path) as tar: + tar.extractall(path=test_data_dir) + else: + os.makedirs(test_data_dir, exist_ok=True) + # Download the compressed dataset. + compressed_path = test_data_dir + '.tar.gz' + if not os.path.exists(compressed_path): + print('Downloading test data to', compressed_path) + download_file_from_google_drive(url, compressed_path) + + # Extract the dataset. + print('Extracting test data to', test_data_dir) + with tarfile.open(compressed_path) as tar: + tar.extractall(path=test_data_dir) + + +if __name__ == "__main__": + main() diff --git a/scripts/download_weights.sh b/scripts/download_weights.sh new file mode 100644 index 0000000000000000000000000000000000000000..7cac24c86c499d0aede03adf46c3d246ce9eb008 --- /dev/null +++ b/scripts/download_weights.sh @@ -0,0 +1,16 @@ +CHECKPOINTS="run-20230219_141512-2u87bj8w.zip" + +if [ ! -d "wandb" ]; then + mkdir wandb +fi + +for checkpoint in $CHECKPOINTS ; do + echo "Downloading $checkpoint"; + if [ ! -f "wandb/$checkpoint" ]; then + wget https://github.com/sat2density/checkpoints/releases/download/cvusa/$checkpoint -P wandb + fi + echo "Unzipping $checkpoint"; + if [ ! -d "wandb/${checkpoint%.*}" ]; then + unzip wandb/$checkpoint -d wandb + fi +done \ No newline at end of file diff --git a/scripts/install.bat b/scripts/install.bat new file mode 100644 index 0000000000000000000000000000000000000000..c5e40f6fdf6ea7ddf8beb84574b06846b7308bea --- /dev/null +++ b/scripts/install.bat @@ -0,0 +1,25 @@ +@ECHO OFF +FOR /F "tokens=*" %%g IN ('nvcc --version') do (set ver=%%g) + +echo %ver% +set CUDA_VERSION=%ver:~11,4% +echo %CUDA_VERSION% + +pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio===0.9.0 -f https://download.pytorch.org/whl/torch_stable.html + +pip install --upgrade -r scripts/requirements.txt + +echo %cd% +set curr_directory=%cd% +echo %curr_directory% + +for %%p in (correlation channelnorm resample2d bias_act upfirdn2d) do ( + cd %curr_directory% + cd imaginaire\third_party\%%p + rmdir /s /q build dist *info + python setup.py install + cd %curr_directory% +) + + + diff --git a/scripts/install.sh b/scripts/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..b6fe8e261838c64f60d6911ecaa9a76821a7448a --- /dev/null +++ b/scripts/install.sh @@ -0,0 +1,37 @@ +#!/bin/sh +CURRENT=$(pwd) + +# Check CUDA_VERSION +export CUDA_VERSION=$(nvcc --version| grep -Po "(\d+\.)+\d+" | head -1) + +apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \ + build-essential \ + git \ + curl \ + vim \ + tmux \ + wget \ + bzip2 \ + unzip \ + g++ \ + ca-certificates \ + ffmpeg \ + libx264-dev \ + imagemagick + +pip install --upgrade -r scripts/requirements.txt +# pip install -U git+https://github.com/pyvista/pyvista.git@main + + +for p in correlation channelnorm resample2d bias_act upfirdn2d; do + cd imaginaire/third_party/${p}; + rm -rf build dist *info; + python setup.py install; + cd ${CURRENT}; +done + +# for p in gancraft/voxlib; do +# cd imaginaire/model_utils/${p}; +# make all +# cd ${CURRENT}; +# done diff --git a/scripts/start_local_docker.sh b/scripts/start_local_docker.sh new file mode 100644 index 0000000000000000000000000000000000000000..78dca4657ef90b26726aea5b9e20e79213128e2c --- /dev/null +++ b/scripts/start_local_docker.sh @@ -0,0 +1,10 @@ +docker run \ + --gpus all \ + --shm-size 32g \ + --ipc=host \ + -it \ + -v /mnt:/mnt \ + -v ~/:/home \ + nvcr.io/nvidian/lpr-imagine/imaginaire:${1}-py3 \ + /bin/bash + diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..965a1c5ff5f3abd9809f164e4afff43eac64f110 --- /dev/null +++ b/test.py @@ -0,0 +1,350 @@ +import importlib +import os +import os.path as osp +import sys +import warnings + +import torch + +import options +from utils import log + +warnings.filterwarnings("ignore") + +import matplotlib.pyplot as plt +import numpy as np +import torchvision.transforms as transforms +from matplotlib.widgets import Cursor +from PIL import Image +from scipy.interpolate import interp1d, splev, splprep +from torch.utils.data import default_convert,default_collate +import torchvision + +from model.geometry_transform import render_sat,render +import cv2 +import imageio + +def get_checkpoint(opt): + if opt.test_ckpt_path == '2u87bj8w': + opt.test_ckpt_path = osp.join('wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth') + elif opt.test_ckpt_path == '2cqv8uh4': + opt.test_ckpt_path = osp.join('wandb/run-20230303_142752-2cqv8uh4/files/checkpoint/model.pth') + else: + pass + + +def img_read(img,size=None,datatype='RGB'): + img = Image.open(img).convert('RGB' if datatype=='RGB' else "L") + if size: + if type(size) is int: + size = (size,size) + img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST) + img = transforms.ToTensor()(img) + return img + +def select_points(sat_image): + fig = plt.figure() + fig.set_size_inches(1,1,forward=False) + ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax.set_axis_off() + ax.imshow(sat_image) + + coords = [] + + def ondrag(event): + if event.button != 1: + return + x, y = int(event.xdata), int(event.ydata) + coords.append((x, y)) + ax.plot([x], [y], 'o', color='red') + fig.canvas.draw_idle() + + fig.add_axes(ax) + cursor = Cursor(ax, useblit=True, color='red', linewidth=1) + fig.canvas.mpl_connect('motion_notify_event', ondrag) + plt.show() + plt.close() + + unique_lst = list(dict.fromkeys(coords)) + pixels = [] + for x in coords: + if x in unique_lst: + if x not in pixels: + pixels.append(x) + print(pixels) + pixels = np.array(pixels) + tck, u = splprep(pixels.T, s=25, per=0) + u_new = np.linspace(u.min(), u.max(), 80) + x_new, y_new = splev(u_new, tck) + + smooth_path = np.array([x_new,y_new]).T + + angles = np.arctan2(y_new[1:]-y_new[:-1],x_new[1:]-x_new[:-1]) + + return pixels, angles, smooth_path + +def volume2pyvista(volume_data): + import pyvista as pv + grid = pv.UniformGrid() + grid.dimensions = volume_data.shape + grid.spacing = (1, 1, 1) + grid.origin = (0, 0, 0) + grid.point_data['values'] = volume_data.flatten(order='F') + return grid + + +def img_pair2vid(sat_list,save_dir,media_path= 'interpolation.mp4'): + fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') + out = cv2.VideoWriter(media_path, fourcc, 12.0, (512, 128)) + for i in range(len(sat_list)): + + img1 = cv2.imread(os.path.join( save_dir , sat_list[i])) + + out.write(img1) + out.release() + +@torch.no_grad() +def test_vid(model, opt): + ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') + model.netG.load_state_dict(ckpt['netG']) + model.netG.eval() + + # for idx, data in enumerate(model.val_loader): + # import pdb; pdb.set_trace() + demo_imgpath = opt.demo_img + sty_imgpath = opt.sty_img + if opt.sky_img is None: + sky_imgpath = opt.sty_img.replace('image','sky') + else: + sky_imgpath = opt.sky_img + + sat = img_read(demo_imgpath, size=opt.data.sat_size) + pano = img_read(sty_imgpath, size=opt.data.pano_size) + + input_dict = {} + input_dict['sat'] = sat + input_dict['pano'] = pano + input_dict['paths'] = demo_imgpath + + + if opt.data.sky_mask: + sky = img_read(sky_imgpath, size=opt.data.pano_size, datatype='L') + input_a = pano*sky + sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) + input_dict['sky_histc'] = sky_histc + input_dict['sky_mask'] = sky + else: + sky_histc = None + + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].unsqueeze(0) + + model.set_input(input_dict) + + model.style_temp = model.sky_histc + + pixels, angles, smooth_path = select_points(sat_image=sat.permute(1,2,0).numpy()) + + rendered_image_list = [] + rendered_depth_list = [] + + + volume_data = None + + for i, (x,y) in enumerate(pixels): + opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future + print('Rendering at ({}, {})'.format(x,y)) + model.forward(opt) + + rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().numpy().transpose((1,2,0)) + rgb = np.array(rgb*255, dtype=np.uint8) + rendered_image_list.append(rgb) + + rendered_depth_list.append( + model.out_put.depth[0,0].cpu().numpy() + ) + + + sat_opacity, sat_depth = render_sat(opt,model.out_put.voxel) + + volume_data = model.out_put.voxel[0].cpu().numpy().transpose((1,2,0)) + volume_data = np.clip(volume_data, None, 10) + + volume_export = volume2pyvista(volume_data) + + os.makedirs(opt.save_dir, exist_ok=True) + volume_export.save(os.path.join(opt.save_dir, 'volume.vtk')) + + # save rendered images + os.makedirs(osp.join(opt.save_dir,'rendered_images'), exist_ok=True) + + for i, img in enumerate(rendered_image_list): + plt.imsave(osp.join(opt.save_dir,'rendered_images','{:05d}.png'.format(i)), img) + + os.makedirs(osp.join(opt.save_dir,'rendered_depth'), exist_ok=True) + + os.makedirs(osp.join(opt.save_dir, + 'rendered_images+depths'), exist_ok=True) + + for i, img in enumerate(rendered_depth_list): + depth = np.array(img/img.max()*255,dtype=np.uint8) + depth = cv2.applyColorMap(depth, cv2.COLORMAP_TURBO) + plt.imsave(osp.join(opt.save_dir,'rendered_depth','{:05d}.png'.format(i)), depth) + image_and_depth = np.concatenate((rendered_image_list[i], depth), axis=0) + + plt.imsave(osp.join(opt.save_dir,'rendered_images+depths','{:05d}.png'.format(i)), image_and_depth) + + os.makedirs(osp.join(opt.save_dir,'sat_images'), exist_ok=True) + + for i, (x,y) in enumerate(pixels): + + + # plt.plot(x, y, 'o', color='red') + + sat_rgb = sat.permute(1,2,0).numpy() + sat_rgb = np.array(sat_rgb*255, dtype=np.uint8) + fig = plt.figure() + fig.set_size_inches(1,1,forward=False) + ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax.set_axis_off() + ax.imshow(sat_rgb) + ax.plot(pixels[:i+1,0], pixels[:i+1,1], 'r-', color='red') + ax.plot(x, y, 'o', color='red', markersize=2) + # if i < len(pixels)-1: + # # ax.plot([x,pixels[0,0]],[y,pixels[0,1]],'r-') + # # else: + # ax.plot([x,pixels[i+1,0]],[y,pixels[i+1,1]],'r-') + fig.add_axes(ax) + plt.savefig(osp.join(opt.save_dir,'sat_images','{:05d}.png'.format(i)),bbox_inches='tight', pad_inches=0, dpi=256) + + print('Done') + + +@torch.no_grad() +def test_interpolation(model,opt): + ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') + model.netG.load_state_dict(ckpt['netG']) + model.netG.eval() + + + + + sat = img_read(opt.demo_img , size=opt.data.sat_size) + pano1 = img_read(opt.sty_img1 , size=opt.data.pano_size) + pano2 = img_read(opt.sty_img2 , size=opt.data.pano_size) + + + input_dict = {} + input_dict['sat'] = sat + input_dict['paths'] = opt.demo_img + + # black_ground = torch.zeros_like(pano1) + sky_imgpath1 = opt.sty_img1.replace('image','sky') + sky_imgpath2 = opt.sty_img2.replace('image','sky') + + sky = img_read(sky_imgpath1, size=opt.data.pano_size, datatype='L') + input_a = pano1*sky + sky_histc1 = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) + + # for idx in range(len(input_a)): + # if idx == 0: + # sky_histc1 = input_a[idx].histc()[10:] + # else: + # sky_histc1 = torch.cat([input_a[idx].histc()[10:],sky_histc1],dim=0) + + sky = img_read(sky_imgpath2, size=opt.data.pano_size, datatype='L') + input_b = pano2*sky + sky_histc2 = torch.cat([input_b[i].histc()[10:] for i in reversed(range(3))]) + # for idx in range(len(input_b)): + # if idx == 0: + # sky_histc2 = input_b[idx].histc()[10:] + # else: + # sky_histc2 = torch.cat([input_b[idx].histc()[10:],sky_histc2],dim=0) + + for key in input_dict.keys(): + if isinstance(input_dict[key], torch.Tensor): + input_dict[key] = input_dict[key].unsqueeze(0) + + model.set_input(input_dict) + pixels = [(128,128)] + + x,y = pixels[0] + opt.origin_H_W = [(y-128)/128 , (x-128)/128] + print(opt.origin_H_W) + + estimated_height = model.netG.depth_model(model.real_A) + geo_outputs = render(opt,model.real_A,estimated_height,model.netG.pano_direction,PE=model.netG.PE) + generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth'] + if model.netG.gen_cfg.cat_opa: + generator_inputs = torch.cat((generator_inputs,opacity),dim=1) + if model.netG.gen_cfg.cat_depth: + generator_inputs = torch.cat((generator_inputs,depth),dim=1) + _, _, z1 = model.netG.style_encode(sky_histc1.unsqueeze(0).to(model.device)) + _, _, z2 = model.netG.style_encode(sky_histc2.unsqueeze(0).to(model.device)) + num_inter = 60 + for i in range(num_inter): + z = z1 * (1-i/(num_inter-1)) + z2* (i/(num_inter-1)) + z = model.netG.style_model(z) + output_RGB = model.netG.denoise_model(generator_inputs,z) + + save_img = output_RGB.cpu() + name = 'img{:03d}.png'.format(i) + torchvision.utils.save_image(save_img,os.path.join(opt.save_dir,name)) + + img_list = sorted(os.listdir(opt.save_dir)) + sat_list = [] + for img in img_list: + sat_list.append(img) + media_path = os.path.join(opt.save_dir,'interpolation.mp4') + + img_pair2vid(sat_list,opt.save_dir,media_path) + print('Done, save 2 ',media_path) + +def main(): + log.process(os.getpid()) + log.title("[{}] (PyTorch code for testing Sat2Density and debug".format(sys.argv[0])) + + opt_cmd = options.parse_arguments(sys.argv[1:]) + opt = options.set(opt_cmd=opt_cmd) + opt.isTrain = False + opt.name = opt.yaml if opt.name is None else opt.name + opt.batch_size = 1 + + if opt.save_dir is None: + raise Exception("Please specify the save dir") + + get_checkpoint(opt) + + mode = importlib.import_module("model.{}".format(opt.model)) + m = mode.Model(opt) + + # m.load_dataset(opt) + m.build_networks(opt) + + if os.path.exists(opt.save_dir): + import shutil + shutil.rmtree(opt.save_dir) + if opt.task == 'test_vid': + test_vid(m, opt) + if opt.task == 'test_interpolation': + assert opt.sty_img1 + assert opt.sty_img2 + os.makedirs(opt.save_dir, exist_ok=True) + test_interpolation(m,opt) + + # import pdb; pdb.set_trace() + + # print(m) + # # test or visualization + # if opt.task == 'test_vid': + # m.test_vid(opt) + # elif opt.task == 'test_sty': + # m.test_sty(opt) + # elif opt.task == 'test_interpolation': + # m.test_interpolation(opt) + # else: + # raise RuntimeError("Unknow task") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e9dfac88fde48243d964d294b548989c999836 --- /dev/null +++ b/train.py @@ -0,0 +1,30 @@ +import os,sys +import importlib +import options +import warnings +import wandb +warnings.filterwarnings("ignore") +os.environ['WANDB_IGNORE_GLOBS'] = '*.pth' # not save checkpoint in cloud + +def main(): + opt_cmd = options.parse_arguments(sys.argv[1:]) + opt = options.set(opt_cmd=opt_cmd) + assert opt.task in ["train","Train"] + opt.isTrain = True + opt.name = opt.yaml if opt.name is None else opt.name + wandb_log = wandb.init( + project=opt.project, + name=opt.name, + group=opt.Group, + config=opt, + ) + mode = importlib.import_module("model.{}".format(opt.model)) + m = mode.Model(opt,wandb_log) + + m.load_dataset(opt) + m.build_networks(opt) + m.setup_optimizer(opt) + m.train(opt) + +if __name__=="__main__": + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4967e14d5cd14425c34a32392f4c939e90db3fea --- /dev/null +++ b/utils.py @@ -0,0 +1,166 @@ +import termcolor,os,shutil,torch +from easydict import EasyDict as edict +from collections import OrderedDict +import math +import numpy as np +from torch.nn import init + +def get_time(sec): + """ + Convert seconds to days, hours, minutes, and seconds + """ + d = int(sec//(24*60*60)) + h = int(sec//(60*60)%24) + m = int((sec//60)%60) + s = int(sec%60) + return d,h,m,s + +# convert to colored strings +def red(message,**kwargs): return termcolor.colored(str(message),color="red",attrs=[k for k,v in kwargs.items() if v is True]) +def green(message,**kwargs): return termcolor.colored(str(message),color="green",attrs=[k for k,v in kwargs.items() if v is True]) +def blue(message,**kwargs): return termcolor.colored(str(message),color="blue",attrs=[k for k,v in kwargs.items() if v is True]) +def cyan(message,**kwargs): return termcolor.colored(str(message),color="cyan",attrs=[k for k,v in kwargs.items() if v is True]) +def yellow(message,**kwargs): return termcolor.colored(str(message),color="yellow",attrs=[k for k,v in kwargs.items() if v is True]) +def magenta(message,**kwargs): return termcolor.colored(str(message),color="magenta",attrs=[k for k,v in kwargs.items() if v is True]) +def grey(message,**kwargs): return termcolor.colored(str(message),color="grey",attrs=[k for k,v in kwargs.items() if v is True]) + + + +def openreadtxt(file_name): + + file = open(file_name,'r') + file_data = file.read().splitlines() + return file_data + +def to_dict(D,dict_type=dict): + D = dict_type(D) + for k,v in D.items(): + if isinstance(v,dict): + D[k] = to_dict(v,dict_type) + return D + +class Log: + def __init__(self): pass + def process(self,pid): + print(grey("Process ID: {}".format(pid),bold=True)) + def title(self,message): + print(yellow(message,bold=True,underline=True)) + def info(self,message): + print(magenta(message,bold=True)) + def options(self,opt,level=0): + for key,value in sorted(opt.items()): + if isinstance(value,(dict,edict)): + print(" "*level+cyan("* ")+green(key)+":") + self.options(value,level+1) + else: + print(" "*level+cyan("* ")+green(key)+":",yellow(value)) + def loss_train(self,opt,ep,lr,loss,timer): + if not opt.max_epoch: return + message = grey("[train] ",bold=True) + message += "epoch {}/{}".format(cyan(ep,bold=True),opt.max_epoch) + message += ", lr:{}".format(yellow("{:.2e}".format(lr),bold=True)) + message += ", loss:{}".format(red("{:.3e}".format(loss),bold=True)) + message += ", time:{}".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.elapsed)),bold=True)) + message += " (ETA:{})".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.arrival)))) + print(message) + def loss_val(self,opt,loss): + message = grey("[val] ",bold=True) + message += "loss:{}".format(red("{:.3e}".format(loss),bold=True)) + print(message) +log = Log() + +def save_checkpoint(model,ep,latest=False,children=None,output_path=None): + + os.makedirs("{0}/model".format(output_path),exist_ok=True) + checkpoint = dict( + epoch=ep, + netG=model.netG.state_dict(), + netD=model.netD.state_dict() + ) + + torch.save(checkpoint,"{0}/model.pth".format(output_path)) + if not latest: + shutil.copy("{0}/model.pth".format(output_path), + "{0}/model/{1}.pth".format(output_path,ep)) # if ep is None, track it instead + +def filt_ckpt_keys(ckpt, item_name, model_name): + # if item_name in ckpt: + assert item_name in ckpt, "Cannot find [%s] in the checkpoints." % item_name + d = ckpt[item_name] + d_filt = OrderedDict() + for k, v in d.items(): + k_list = k.split('.') + if k_list[0] == model_name: + if k_list[1] == 'module': + d_filt['.'.join(k_list[2:])] = v + else: + d_filt['.'.join(k_list[1:])] = v + return d_filt + +def requires_grad(model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + +def filt_ckpt_keys(ckpt, item_name, model_name): + # if item_name in ckpt: + assert item_name in ckpt, "Cannot find [%s] in the checkpoints." % item_name + d = ckpt[item_name] + d_filt = OrderedDict() + for k, v in d.items(): + k_list = k.split('.') + if k_list[0] == model_name: + if k_list[1] == 'module': + d_filt['.'.join(k_list[2:])] = v + else: + d_filt['.'.join(k_list[1:])] = v + return d_filt + +def get_ray_pano(batch_img): + _,_,H,W = batch_img.size() + _y = np.repeat(np.array(range(W)).reshape(1,W), H, axis=0) + _x = np.repeat(np.array(range(H)).reshape(1,H), W, axis=0).T + + _theta = (1 - 2 * (_x) / H) * np.pi/2 # latitude + _phi = 2*math.pi*(0.5 - (_y)/W ) # longtitude + axis0 = (np.cos(_theta)*np.cos(_phi)).reshape(1,H, W) + axis1 = np.sin(_theta).reshape(1,H, W) + axis2 = (-np.cos(_theta)*np.sin(_phi)).reshape(1, H, W) + original_coord = np.concatenate((axis0, axis1, axis2), axis=0) + + return original_coord + +def init_weights(net, init_type='kaiming', init_gain=0.02): + """Initialize network weights. + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + + +if __name__=='__main__': + a = torch.zeros([2,3,200,100]) + cood = get_ray_pano(a)