import os import logging import argparse import numpy as np from shutil import copyfile import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from rich import print from tqdm import tqdm from pyhocon import ConfigFactory import sys sys.path.append(os.path.dirname(__file__)) from models.fields import SingleVarianceNetwork from models.featurenet import FeatureNet from models.trainer_generic import GenericTrainer from models.sparse_sdf_network import SparseSdfNetwork from models.rendering_network import GeneralRenderingNetwork from data.blender_general_narrow_all_eval_new_data import BlenderPerView from datetime import datetime class Runner: def __init__(self, conf_path, mode='train', is_continue=False, is_restore=False, restore_lod0=False, local_rank=0): # Initial setting self.device = torch.device('cuda:%d' % local_rank) # self.device = torch.device('cuda') self.num_devices = torch.cuda.device_count() self.is_continue = is_continue or (mode == "export_mesh") self.is_restore = is_restore self.restore_lod0 = restore_lod0 self.mode = mode self.model_list = [] self.logger = logging.getLogger('exp_logger') print("detected %d GPUs" % self.num_devices) self.conf_path = conf_path self.conf = ConfigFactory.parse_file(conf_path) self.timestamp = None if not self.is_continue: self.timestamp = '_{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) self.base_exp_dir = self.conf['general.base_exp_dir'] + self.timestamp # jha comment this when testing and use this when training else: self.base_exp_dir = self.conf['general.base_exp_dir'] self.conf['general.base_exp_dir'] = self.base_exp_dir # jha use this when testing print("base_exp_dir: " + self.base_exp_dir) os.makedirs(self.base_exp_dir, exist_ok=True) self.iter_step = 0 self.val_step = 0 # trainning parameters self.end_iter = self.conf.get_int('train.end_iter') self.save_freq = self.conf.get_int('train.save_freq') self.report_freq = self.conf.get_int('train.report_freq') self.val_freq = self.conf.get_int('train.val_freq') self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') self.batch_size = self.num_devices # use DataParallel to warp self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level') self.learning_rate = self.conf.get_float('train.learning_rate') self.learning_rate_milestone = self.conf.get_list('train.learning_rate_milestone') self.learning_rate_factor = self.conf.get_float('train.learning_rate_factor') self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd') self.N_rays = self.conf.get_int('train.N_rays') # warmup params for sdf gradient self.anneal_start_lod0 = self.conf.get_float('train.anneal_start', default=0) self.anneal_end_lod0 = self.conf.get_float('train.anneal_end', default=0) self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0) self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0) self.writer = None # Networks self.num_lods = self.conf.get_int('model.num_lods') self.rendering_network_outside = None self.sdf_network_lod0 = None self.sdf_network_lod1 = None self.variance_network_lod0 = None self.variance_network_lod1 = None self.rendering_network_lod0 = None self.rendering_network_lod1 = None self.pyramid_feature_network = None # extract 2d pyramid feature maps from images, used for geometry self.pyramid_feature_network_lod1 = None # may use different feature network for different lod # * pyramid_feature_network self.pyramid_feature_network = FeatureNet().to(self.device) self.sdf_network_lod0 = SparseSdfNetwork(**self.conf['model.sdf_network_lod0']).to(self.device) self.variance_network_lod0 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) if self.num_lods > 1: self.sdf_network_lod1 = SparseSdfNetwork(**self.conf['model.sdf_network_lod1']).to(self.device) self.variance_network_lod1 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) self.rendering_network_lod0 = GeneralRenderingNetwork(**self.conf['model.rendering_network']).to( self.device) if self.num_lods > 1: self.pyramid_feature_network_lod1 = FeatureNet().to(self.device) self.rendering_network_lod1 = GeneralRenderingNetwork( **self.conf['model.rendering_network_lod1']).to(self.device) if self.mode == 'export_mesh' or self.mode == 'val': # base_exp_dir_to_store = os.path.join(self.base_exp_dir, '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())) base_exp_dir_to_store = os.path.join("../", args.specific_dataset_name) #"../gradio_tmp" # MODIFIED else: base_exp_dir_to_store = self.base_exp_dir print(f"Store in: {base_exp_dir_to_store}") # Renderer model self.trainer = GenericTrainer( self.rendering_network_outside, self.pyramid_feature_network, self.pyramid_feature_network_lod1, self.sdf_network_lod0, self.sdf_network_lod1, self.variance_network_lod0, self.variance_network_lod1, self.rendering_network_lod0, self.rendering_network_lod1, **self.conf['model.trainer'], timestamp=self.timestamp, base_exp_dir=base_exp_dir_to_store, conf=self.conf) self.data_setup() # * data setup self.optimizer_setup() # Load checkpoint latest_model_name = None if self.is_continue: model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints')) model_list = [] for model_name in model_list_raw: if model_name.startswith('ckpt'): if model_name[-3:] == 'pth': # and int(model_name[5:-4]) <= self.end_iter: model_list.append(model_name) model_list.sort() latest_model_name = model_list[-1] if latest_model_name is not None: self.logger.info('Find checkpoint: {}'.format(latest_model_name)) self.load_checkpoint(latest_model_name) self.trainer = torch.nn.DataParallel(self.trainer).to(self.device) if self.mode[:5] == 'train': self.file_backup() def optimizer_setup(self): self.params_to_train = self.trainer.get_trainable_params() self.optimizer = torch.optim.Adam(self.params_to_train, lr=self.learning_rate) def data_setup(self): """ if use ddp, use setup() not prepare_data(), prepare_data() only called on 1 GPU/TPU in distributed :return: """ self.train_dataset = BlenderPerView( root_dir=self.conf['dataset.trainpath'], split=self.conf.get_string('dataset.train_split', default='train'), split_filepath=self.conf.get_string('dataset.train_split_filepath', default=None), n_views=self.conf['dataset.nviews'], downSample=self.conf['dataset.imgScale_train'], N_rays=self.N_rays, batch_size=self.batch_size, clean_image=True, # True for training importance_sample=self.conf.get_bool('dataset.importance_sample', default=False), specific_dataset_name = args.specific_dataset_name ) self.val_dataset = BlenderPerView( root_dir=self.conf['dataset.valpath'], split=self.conf.get_string('dataset.test_split', default='test'), split_filepath=self.conf.get_string('dataset.val_split_filepath', default=None), n_views=3, downSample=self.conf['dataset.imgScale_test'], N_rays=self.N_rays, batch_size=self.batch_size, clean_image=self.conf.get_bool('dataset.mask_out_image', default=False) if self.mode != 'train' else False, importance_sample=self.conf.get_bool('dataset.importance_sample', default=False), test_ref_views=self.conf.get_list('dataset.test_ref_views', default=[]), specific_dataset_name = args.specific_dataset_name ) # item = self.train_dataset.__getitem__(0) self.train_dataloader = DataLoader(self.train_dataset, shuffle=True, num_workers=4 * self.batch_size, # num_workers=1, batch_size=self.batch_size, pin_memory=True, drop_last=True ) self.val_dataloader = DataLoader(self.val_dataset, # shuffle=False if self.mode == 'train' else True, shuffle=False, num_workers=4 * self.batch_size, # num_workers=1, batch_size=self.batch_size, pin_memory=True, drop_last=False ) self.val_dataloader_iterator = iter(self.val_dataloader) # - should be after "reconstruct_metas_for_gru_fusion" def train(self): self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs')) res_step = self.end_iter - self.iter_step dataloader = self.train_dataloader epochs = int(1 + res_step // len(dataloader)) self.adjust_learning_rate() print("starting training learning rate: {:.5f}".format(self.optimizer.param_groups[0]['lr'])) background_rgb = None if self.use_white_bkgd: # background_rgb = torch.ones([1, 3]).to(self.device) background_rgb = 1.0 for epoch_i in range(epochs): print("current epoch %d" % epoch_i) dataloader = tqdm(dataloader) for batch in dataloader: # print("Checker1:, fetch data") batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) # used to get meta # - warmup params if self.num_lods == 1: alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) else: alpha_inter_ratio_lod0 = 1. alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) losses = self.trainer( batch, background_rgb=background_rgb, alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, iter_step=self.iter_step, mode='train', ) loss_types = ['loss_lod0', 'loss_lod1'] # print("[TEST]: weights_sum in trainer return", losses['losses_lod0']['weights_sum'].mean()) losses_lod0 = losses['losses_lod0'] losses_lod1 = losses['losses_lod1'] # import ipdb; ipdb.set_trace() loss = 0 for loss_type in loss_types: if losses[loss_type] is not None: loss = loss + losses[loss_type].mean() # print("Checker4:, begin BP") self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.params_to_train, 1.0) self.optimizer.step() # print("Checker5:, end BP") self.iter_step += 1 if self.iter_step % self.report_freq == 0: self.writer.add_scalar('Loss/loss', loss, self.iter_step) if losses_lod0 is not None: self.writer.add_scalar('Loss/d_loss_lod0', losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0, self.iter_step) self.writer.add_scalar('Loss/sparse_loss_lod0', losses_lod0[ 'sparse_loss'].mean() if losses_lod0 is not None else 0, self.iter_step) self.writer.add_scalar('Loss/color_loss_lod0', losses_lod0['color_fine_loss'].mean() if losses_lod0['color_fine_loss'] is not None else 0, self.iter_step) self.writer.add_scalar('statis/psnr_lod0', losses_lod0['psnr'].mean() if losses_lod0['psnr'] is not None else 0, self.iter_step) self.writer.add_scalar('param/variance_lod0', 1. / torch.exp(self.variance_network_lod0.variance * 10), self.iter_step) self.writer.add_scalar('param/eikonal_loss', losses_lod0['gradient_error_loss'].mean() if losses_lod0 is not None else 0, self.iter_step) ######## - lod 1 if self.num_lods > 1: self.writer.add_scalar('Loss/d_loss_lod1', losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0, self.iter_step) self.writer.add_scalar('Loss/sparse_loss_lod1', losses_lod1[ 'sparse_loss'].mean() if losses_lod1 is not None else 0, self.iter_step) self.writer.add_scalar('Loss/color_loss_lod1', losses_lod1['color_fine_loss'].mean() if losses_lod1['color_fine_loss'] is not None else 0, self.iter_step) self.writer.add_scalar('statis/sdf_mean_lod1', losses_lod1['sdf_mean'].mean() if losses_lod1 is not None else 0, self.iter_step) self.writer.add_scalar('statis/psnr_lod1', losses_lod1['psnr'].mean() if losses_lod1['psnr'] is not None else 0, self.iter_step) self.writer.add_scalar('statis/sparseness_0.01_lod1', losses_lod1['sparseness_1'].mean() if losses_lod1['sparseness_1'] is not None else 0, self.iter_step) self.writer.add_scalar('statis/sparseness_0.02_lod1', losses_lod1['sparseness_2'].mean() if losses_lod1['sparseness_2'] is not None else 0, self.iter_step) self.writer.add_scalar('param/variance_lod1', 1. / torch.exp(self.variance_network_lod1.variance * 10), self.iter_step) print(self.base_exp_dir) print( 'iter:{:8>d} ' 'loss = {:.4f} ' 'd_loss_lod0 = {:.4f} ' 'color_loss_lod0 = {:.4f} ' 'sparse_loss_lod0= {:.4f} ' 'd_loss_lod1 = {:.4f} ' 'color_loss_lod1 = {:.4f} ' ' lr = {:.5f}'.format( self.iter_step, loss, losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0, losses_lod0['color_fine_loss'].mean() if losses_lod0 is not None else 0, losses_lod0['sparse_loss'].mean() if losses_lod0 is not None else 0, losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0, losses_lod1['color_fine_loss'].mean() if losses_lod1 is not None else 0, self.optimizer.param_groups[0]['lr'])) print('alpha_inter_ratio_lod0 = {:.4f} alpha_inter_ratio_lod1 = {:.4f}\n'.format( alpha_inter_ratio_lod0, alpha_inter_ratio_lod1)) if losses_lod0 is not None: # print("[TEST]: weights_sum in print", losses_lod0['weights_sum'].mean()) # import ipdb; ipdb.set_trace() print( 'iter:{:8>d} ' 'variance = {:.5f} ' 'weights_sum = {:.4f} ' 'weights_sum_fg = {:.4f} ' 'alpha_sum = {:.4f} ' 'sparse_weight= {:.4f} ' 'background_loss = {:.4f} ' 'background_weight = {:.4f} ' .format( self.iter_step, losses_lod0['variance'].mean(), losses_lod0['weights_sum'].mean(), losses_lod0['weights_sum_fg'].mean(), losses_lod0['alpha_sum'].mean(), losses_lod0['sparse_weight'].mean(), losses_lod0['fg_bg_loss'].mean(), losses_lod0['fg_bg_weight'].mean(), )) if losses_lod1 is not None: print( 'iter:{:8>d} ' 'variance = {:.5f} ' ' weights_sum = {:.4f} ' 'alpha_sum = {:.4f} ' 'fg_bg_loss = {:.4f} ' 'fg_bg_weight = {:.4f} ' 'sparse_weight= {:.4f} ' 'fg_bg_loss = {:.4f} ' 'fg_bg_weight = {:.4f} ' .format( self.iter_step, losses_lod1['variance'].mean(), losses_lod1['weights_sum'].mean(), losses_lod1['alpha_sum'].mean(), losses_lod1['fg_bg_loss'].mean(), losses_lod1['fg_bg_weight'].mean(), losses_lod1['sparse_weight'].mean(), losses_lod1['fg_bg_loss'].mean(), losses_lod1['fg_bg_weight'].mean(), )) if self.iter_step % self.save_freq == 0: self.save_checkpoint() if self.iter_step % self.val_freq == 0: self.validate() # - ajust learning rate self.adjust_learning_rate() def adjust_learning_rate(self): # - ajust learning rate, cosine learning schedule learning_rate = (np.cos(np.pi * self.iter_step / self.end_iter) + 1.0) * 0.5 * 0.9 + 0.1 learning_rate = self.learning_rate * learning_rate for g in self.optimizer.param_groups: g['lr'] = learning_rate def get_alpha_inter_ratio(self, start, end): if end == 0.0: return 1.0 elif self.iter_step < start: return 0.0 else: return np.min([1.0, (self.iter_step - start) / (end - start)]) def file_backup(self): # copy python file dir_lis = self.conf['general.recording'] os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) for dir_name in dir_lis: cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) os.makedirs(cur_dir, exist_ok=True) files = os.listdir(dir_name) for f_name in files: if f_name[-3:] == '.py': copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) # copy configs copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) def load_checkpoint(self, checkpoint_name): def load_state_dict(network, checkpoint, comment): if network is not None: try: pretrained_dict = checkpoint[comment] model_dict = network.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict network.load_state_dict(pretrained_dict) except: print(comment + " load fails") checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device) load_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside') load_state_dict(self.sdf_network_lod0, checkpoint, 'sdf_network_lod0') load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod1') load_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network') load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1') load_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0') load_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1') load_state_dict(self.rendering_network_lod0, checkpoint, 'rendering_network_lod0') load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod1') if self.restore_lod0: # use the trained lod0 networks to initialize lod1 networks load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod0') load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network') load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod0') if self.is_continue and (not self.restore_lod0): try: self.optimizer.load_state_dict(checkpoint['optimizer']) except: print("load optimizer fails") self.iter_step = checkpoint['iter_step'] self.val_step = checkpoint['val_step'] if 'val_step' in checkpoint.keys() else 0 self.logger.info('End') def save_checkpoint(self): def save_state_dict(network, checkpoint, comment): if network is not None: checkpoint[comment] = network.state_dict() checkpoint = { 'optimizer': self.optimizer.state_dict(), 'iter_step': self.iter_step, 'val_step': self.val_step, } save_state_dict(self.sdf_network_lod0, checkpoint, "sdf_network_lod0") save_state_dict(self.sdf_network_lod1, checkpoint, "sdf_network_lod1") save_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside') save_state_dict(self.rendering_network_lod0, checkpoint, "rendering_network_lod0") save_state_dict(self.rendering_network_lod1, checkpoint, "rendering_network_lod1") save_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0') save_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1') save_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network') save_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1') os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) def validate(self, resolution_level=-1): # validate image print("iter_step: ", self.iter_step) self.logger.info('Validate begin') self.val_step += 1 try: batch = next(self.val_dataloader_iterator) except: self.val_dataloader_iterator = iter(self.val_dataloader) # reset batch = next(self.val_dataloader_iterator) background_rgb = None if self.use_white_bkgd: # background_rgb = torch.ones([1, 3]).to(self.device) background_rgb = 1.0 batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) # - warmup params if self.num_lods == 1: alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) else: alpha_inter_ratio_lod0 = 1. alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) self.trainer( batch, background_rgb=background_rgb, alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, iter_step=self.iter_step, save_vis=True, mode='val', ) def export_mesh(self, resolution_level=-1): print("iter_step: ", self.iter_step) self.logger.info('Validate begin') self.val_step += 1 try: batch = next(self.val_dataloader_iterator) except: self.val_dataloader_iterator = iter(self.val_dataloader) # reset batch = next(self.val_dataloader_iterator) background_rgb = None if self.use_white_bkgd: background_rgb = 1.0 batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) # - warmup params if self.num_lods == 1: alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) else: alpha_inter_ratio_lod0 = 1. alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) self.trainer( batch, background_rgb=background_rgb, alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, iter_step=self.iter_step, save_vis=True, mode='export_mesh', ) if __name__ == '__main__': # torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.set_default_dtype(torch.float32) FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) parser = argparse.ArgumentParser() parser.add_argument('--conf', type=str, default='./confs/base.conf') parser.add_argument('--mode', type=str, default='train') parser.add_argument('--threshold', type=float, default=0.0) parser.add_argument('--is_continue', default=False, action="store_true") parser.add_argument('--is_restore', default=False, action="store_true") parser.add_argument('--is_finetune', default=False, action="store_true") parser.add_argument('--train_from_scratch', default=False, action="store_true") parser.add_argument('--restore_lod0', default=False, action="store_true") parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--specific_dataset_name', type=str, default='GSO') args = parser.parse_args() torch.cuda.set_device(args.local_rank) torch.backends.cudnn.benchmark = True # ! make training 2x faster runner = Runner(args.conf, args.mode, args.is_continue, args.is_restore, args.restore_lod0, args.local_rank) if args.mode == 'train': runner.train() elif args.mode == 'val': for i in range(len(runner.val_dataset)): runner.validate() elif args.mode == 'export_mesh': for i in range(len(runner.val_dataset)): runner.export_mesh()