|
import os |
|
import logging |
|
import argparse |
|
import numpy as np |
|
from shutil import copyfile |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from torch.utils.tensorboard import SummaryWriter |
|
from rich import print |
|
from tqdm import tqdm |
|
from pyhocon import ConfigFactory |
|
|
|
import sys |
|
sys.path.append(os.path.dirname(__file__)) |
|
|
|
from models.fields import SingleVarianceNetwork |
|
from models.featurenet import FeatureNet |
|
from models.trainer_generic import GenericTrainer |
|
from models.sparse_sdf_network import SparseSdfNetwork |
|
from models.rendering_network import GeneralRenderingNetwork |
|
from data.blender_general_narrow_all_eval_new_data import BlenderPerView |
|
|
|
|
|
from datetime import datetime |
|
|
|
class Runner: |
|
def __init__(self, conf_path, mode='train', is_continue=False, |
|
is_restore=False, restore_lod0=False, local_rank=0): |
|
|
|
|
|
self.device = torch.device('cuda:%d' % local_rank) |
|
|
|
self.num_devices = torch.cuda.device_count() |
|
self.is_continue = is_continue or (mode == "export_mesh") |
|
self.is_restore = is_restore |
|
self.restore_lod0 = restore_lod0 |
|
self.mode = mode |
|
self.model_list = [] |
|
self.logger = logging.getLogger('exp_logger') |
|
|
|
print("detected %d GPUs" % self.num_devices) |
|
|
|
self.conf_path = conf_path |
|
self.conf = ConfigFactory.parse_file(conf_path) |
|
self.timestamp = None |
|
if not self.is_continue: |
|
self.timestamp = '_{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) |
|
self.base_exp_dir = self.conf['general.base_exp_dir'] + self.timestamp |
|
else: |
|
self.base_exp_dir = self.conf['general.base_exp_dir'] |
|
self.conf['general.base_exp_dir'] = self.base_exp_dir |
|
print("base_exp_dir: " + self.base_exp_dir) |
|
os.makedirs(self.base_exp_dir, exist_ok=True) |
|
self.iter_step = 0 |
|
self.val_step = 0 |
|
|
|
|
|
self.end_iter = self.conf.get_int('train.end_iter') |
|
self.save_freq = self.conf.get_int('train.save_freq') |
|
self.report_freq = self.conf.get_int('train.report_freq') |
|
self.val_freq = self.conf.get_int('train.val_freq') |
|
self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') |
|
self.batch_size = self.num_devices |
|
self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level') |
|
self.learning_rate = self.conf.get_float('train.learning_rate') |
|
self.learning_rate_milestone = self.conf.get_list('train.learning_rate_milestone') |
|
self.learning_rate_factor = self.conf.get_float('train.learning_rate_factor') |
|
self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd') |
|
self.N_rays = self.conf.get_int('train.N_rays') |
|
|
|
|
|
self.anneal_start_lod0 = self.conf.get_float('train.anneal_start', default=0) |
|
self.anneal_end_lod0 = self.conf.get_float('train.anneal_end', default=0) |
|
self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0) |
|
self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0) |
|
|
|
self.writer = None |
|
|
|
|
|
self.num_lods = self.conf.get_int('model.num_lods') |
|
|
|
self.rendering_network_outside = None |
|
self.sdf_network_lod0 = None |
|
self.sdf_network_lod1 = None |
|
self.variance_network_lod0 = None |
|
self.variance_network_lod1 = None |
|
self.rendering_network_lod0 = None |
|
self.rendering_network_lod1 = None |
|
self.pyramid_feature_network = None |
|
self.pyramid_feature_network_lod1 = None |
|
|
|
|
|
self.pyramid_feature_network = FeatureNet().to(self.device) |
|
self.sdf_network_lod0 = SparseSdfNetwork(**self.conf['model.sdf_network_lod0']).to(self.device) |
|
self.variance_network_lod0 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) |
|
|
|
if self.num_lods > 1: |
|
self.sdf_network_lod1 = SparseSdfNetwork(**self.conf['model.sdf_network_lod1']).to(self.device) |
|
self.variance_network_lod1 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) |
|
|
|
self.rendering_network_lod0 = GeneralRenderingNetwork(**self.conf['model.rendering_network']).to( |
|
self.device) |
|
|
|
if self.num_lods > 1: |
|
self.pyramid_feature_network_lod1 = FeatureNet().to(self.device) |
|
self.rendering_network_lod1 = GeneralRenderingNetwork( |
|
**self.conf['model.rendering_network_lod1']).to(self.device) |
|
if self.mode == 'export_mesh' or self.mode == 'val': |
|
|
|
base_exp_dir_to_store = os.path.join("../", args.specific_dataset_name) |
|
else: |
|
base_exp_dir_to_store = self.base_exp_dir |
|
|
|
print(f"Store in: {base_exp_dir_to_store}") |
|
|
|
self.trainer = GenericTrainer( |
|
self.rendering_network_outside, |
|
self.pyramid_feature_network, |
|
self.pyramid_feature_network_lod1, |
|
self.sdf_network_lod0, |
|
self.sdf_network_lod1, |
|
self.variance_network_lod0, |
|
self.variance_network_lod1, |
|
self.rendering_network_lod0, |
|
self.rendering_network_lod1, |
|
**self.conf['model.trainer'], |
|
timestamp=self.timestamp, |
|
base_exp_dir=base_exp_dir_to_store, |
|
conf=self.conf) |
|
|
|
self.data_setup() |
|
|
|
self.optimizer_setup() |
|
|
|
|
|
latest_model_name = None |
|
if self.is_continue: |
|
model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints')) |
|
model_list = [] |
|
for model_name in model_list_raw: |
|
if model_name.startswith('ckpt'): |
|
if model_name[-3:] == 'pth': |
|
model_list.append(model_name) |
|
model_list.sort() |
|
latest_model_name = model_list[-1] |
|
|
|
if latest_model_name is not None: |
|
self.logger.info('Find checkpoint: {}'.format(latest_model_name)) |
|
self.load_checkpoint(latest_model_name) |
|
|
|
self.trainer = torch.nn.DataParallel(self.trainer).to(self.device) |
|
|
|
if self.mode[:5] == 'train': |
|
self.file_backup() |
|
|
|
def optimizer_setup(self): |
|
self.params_to_train = self.trainer.get_trainable_params() |
|
self.optimizer = torch.optim.Adam(self.params_to_train, lr=self.learning_rate) |
|
|
|
def data_setup(self): |
|
""" |
|
if use ddp, use setup() not prepare_data(), |
|
prepare_data() only called on 1 GPU/TPU in distributed |
|
:return: |
|
""" |
|
|
|
self.train_dataset = BlenderPerView( |
|
root_dir=self.conf['dataset.trainpath'], |
|
split=self.conf.get_string('dataset.train_split', default='train'), |
|
split_filepath=self.conf.get_string('dataset.train_split_filepath', default=None), |
|
n_views=self.conf['dataset.nviews'], |
|
downSample=self.conf['dataset.imgScale_train'], |
|
N_rays=self.N_rays, |
|
batch_size=self.batch_size, |
|
clean_image=True, |
|
importance_sample=self.conf.get_bool('dataset.importance_sample', default=False), |
|
specific_dataset_name = args.specific_dataset_name |
|
) |
|
|
|
self.val_dataset = BlenderPerView( |
|
root_dir=self.conf['dataset.valpath'], |
|
split=self.conf.get_string('dataset.test_split', default='test'), |
|
split_filepath=self.conf.get_string('dataset.val_split_filepath', default=None), |
|
n_views=3, |
|
downSample=self.conf['dataset.imgScale_test'], |
|
N_rays=self.N_rays, |
|
batch_size=self.batch_size, |
|
clean_image=self.conf.get_bool('dataset.mask_out_image', |
|
default=False) if self.mode != 'train' else False, |
|
importance_sample=self.conf.get_bool('dataset.importance_sample', default=False), |
|
test_ref_views=self.conf.get_list('dataset.test_ref_views', default=[]), |
|
specific_dataset_name = args.specific_dataset_name |
|
) |
|
|
|
|
|
self.train_dataloader = DataLoader(self.train_dataset, |
|
shuffle=True, |
|
num_workers=4 * self.batch_size, |
|
|
|
batch_size=self.batch_size, |
|
pin_memory=True, |
|
drop_last=True |
|
) |
|
|
|
self.val_dataloader = DataLoader(self.val_dataset, |
|
|
|
shuffle=False, |
|
num_workers=4 * self.batch_size, |
|
|
|
batch_size=self.batch_size, |
|
pin_memory=True, |
|
drop_last=False |
|
) |
|
|
|
self.val_dataloader_iterator = iter(self.val_dataloader) |
|
|
|
def train(self): |
|
self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs')) |
|
res_step = self.end_iter - self.iter_step |
|
|
|
dataloader = self.train_dataloader |
|
|
|
epochs = int(1 + res_step // len(dataloader)) |
|
|
|
self.adjust_learning_rate() |
|
print("starting training learning rate: {:.5f}".format(self.optimizer.param_groups[0]['lr'])) |
|
|
|
background_rgb = None |
|
if self.use_white_bkgd: |
|
|
|
background_rgb = 1.0 |
|
|
|
for epoch_i in range(epochs): |
|
|
|
print("current epoch %d" % epoch_i) |
|
dataloader = tqdm(dataloader) |
|
|
|
for batch in dataloader: |
|
|
|
batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) |
|
|
|
|
|
if self.num_lods == 1: |
|
alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) |
|
else: |
|
alpha_inter_ratio_lod0 = 1. |
|
alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) |
|
|
|
losses = self.trainer( |
|
batch, |
|
background_rgb=background_rgb, |
|
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, |
|
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, |
|
iter_step=self.iter_step, |
|
mode='train', |
|
) |
|
|
|
loss_types = ['loss_lod0', 'loss_lod1'] |
|
|
|
|
|
losses_lod0 = losses['losses_lod0'] |
|
losses_lod1 = losses['losses_lod1'] |
|
|
|
loss = 0 |
|
for loss_type in loss_types: |
|
if losses[loss_type] is not None: |
|
loss = loss + losses[loss_type].mean() |
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(self.params_to_train, 1.0) |
|
self.optimizer.step() |
|
|
|
self.iter_step += 1 |
|
|
|
if self.iter_step % self.report_freq == 0: |
|
self.writer.add_scalar('Loss/loss', loss, self.iter_step) |
|
|
|
if losses_lod0 is not None: |
|
self.writer.add_scalar('Loss/d_loss_lod0', |
|
losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0, |
|
self.iter_step) |
|
self.writer.add_scalar('Loss/sparse_loss_lod0', |
|
losses_lod0[ |
|
'sparse_loss'].mean() if losses_lod0 is not None else 0, |
|
self.iter_step) |
|
self.writer.add_scalar('Loss/color_loss_lod0', |
|
losses_lod0['color_fine_loss'].mean() |
|
if losses_lod0['color_fine_loss'] is not None else 0, |
|
self.iter_step) |
|
|
|
self.writer.add_scalar('statis/psnr_lod0', |
|
losses_lod0['psnr'].mean() |
|
if losses_lod0['psnr'] is not None else 0, |
|
self.iter_step) |
|
|
|
self.writer.add_scalar('param/variance_lod0', |
|
1. / torch.exp(self.variance_network_lod0.variance * 10), |
|
self.iter_step) |
|
self.writer.add_scalar('param/eikonal_loss', losses_lod0['gradient_error_loss'].mean() if losses_lod0 is not None else 0, |
|
self.iter_step) |
|
|
|
|
|
if self.num_lods > 1: |
|
self.writer.add_scalar('Loss/d_loss_lod1', |
|
losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0, |
|
self.iter_step) |
|
self.writer.add_scalar('Loss/sparse_loss_lod1', |
|
losses_lod1[ |
|
'sparse_loss'].mean() if losses_lod1 is not None else 0, |
|
self.iter_step) |
|
self.writer.add_scalar('Loss/color_loss_lod1', |
|
losses_lod1['color_fine_loss'].mean() |
|
if losses_lod1['color_fine_loss'] is not None else 0, |
|
self.iter_step) |
|
self.writer.add_scalar('statis/sdf_mean_lod1', |
|
losses_lod1['sdf_mean'].mean() if losses_lod1 is not None else 0, |
|
self.iter_step) |
|
self.writer.add_scalar('statis/psnr_lod1', |
|
losses_lod1['psnr'].mean() |
|
if losses_lod1['psnr'] is not None else 0, |
|
self.iter_step) |
|
self.writer.add_scalar('statis/sparseness_0.01_lod1', |
|
losses_lod1['sparseness_1'].mean() |
|
if losses_lod1['sparseness_1'] is not None else 0, |
|
self.iter_step) |
|
self.writer.add_scalar('statis/sparseness_0.02_lod1', |
|
losses_lod1['sparseness_2'].mean() |
|
if losses_lod1['sparseness_2'] is not None else 0, |
|
self.iter_step) |
|
self.writer.add_scalar('param/variance_lod1', |
|
1. / torch.exp(self.variance_network_lod1.variance * 10), |
|
self.iter_step) |
|
|
|
print(self.base_exp_dir) |
|
print( |
|
'iter:{:8>d} ' |
|
'loss = {:.4f} ' |
|
'd_loss_lod0 = {:.4f} ' |
|
'color_loss_lod0 = {:.4f} ' |
|
'sparse_loss_lod0= {:.4f} ' |
|
'd_loss_lod1 = {:.4f} ' |
|
'color_loss_lod1 = {:.4f} ' |
|
' lr = {:.5f}'.format( |
|
self.iter_step, loss, |
|
losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0, |
|
losses_lod0['color_fine_loss'].mean() if losses_lod0 is not None else 0, |
|
losses_lod0['sparse_loss'].mean() if losses_lod0 is not None else 0, |
|
losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0, |
|
losses_lod1['color_fine_loss'].mean() if losses_lod1 is not None else 0, |
|
self.optimizer.param_groups[0]['lr'])) |
|
|
|
print('alpha_inter_ratio_lod0 = {:.4f} alpha_inter_ratio_lod1 = {:.4f}\n'.format( |
|
alpha_inter_ratio_lod0, alpha_inter_ratio_lod1)) |
|
|
|
if losses_lod0 is not None: |
|
|
|
|
|
print( |
|
'iter:{:8>d} ' |
|
'variance = {:.5f} ' |
|
'weights_sum = {:.4f} ' |
|
'weights_sum_fg = {:.4f} ' |
|
'alpha_sum = {:.4f} ' |
|
'sparse_weight= {:.4f} ' |
|
'background_loss = {:.4f} ' |
|
'background_weight = {:.4f} ' |
|
.format( |
|
self.iter_step, |
|
losses_lod0['variance'].mean(), |
|
losses_lod0['weights_sum'].mean(), |
|
losses_lod0['weights_sum_fg'].mean(), |
|
losses_lod0['alpha_sum'].mean(), |
|
losses_lod0['sparse_weight'].mean(), |
|
losses_lod0['fg_bg_loss'].mean(), |
|
losses_lod0['fg_bg_weight'].mean(), |
|
)) |
|
|
|
if losses_lod1 is not None: |
|
print( |
|
'iter:{:8>d} ' |
|
'variance = {:.5f} ' |
|
' weights_sum = {:.4f} ' |
|
'alpha_sum = {:.4f} ' |
|
'fg_bg_loss = {:.4f} ' |
|
'fg_bg_weight = {:.4f} ' |
|
'sparse_weight= {:.4f} ' |
|
'fg_bg_loss = {:.4f} ' |
|
'fg_bg_weight = {:.4f} ' |
|
.format( |
|
self.iter_step, |
|
losses_lod1['variance'].mean(), |
|
losses_lod1['weights_sum'].mean(), |
|
losses_lod1['alpha_sum'].mean(), |
|
losses_lod1['fg_bg_loss'].mean(), |
|
losses_lod1['fg_bg_weight'].mean(), |
|
losses_lod1['sparse_weight'].mean(), |
|
losses_lod1['fg_bg_loss'].mean(), |
|
losses_lod1['fg_bg_weight'].mean(), |
|
)) |
|
|
|
if self.iter_step % self.save_freq == 0: |
|
self.save_checkpoint() |
|
|
|
if self.iter_step % self.val_freq == 0: |
|
self.validate() |
|
|
|
|
|
self.adjust_learning_rate() |
|
|
|
def adjust_learning_rate(self): |
|
|
|
learning_rate = (np.cos(np.pi * self.iter_step / self.end_iter) + 1.0) * 0.5 * 0.9 + 0.1 |
|
learning_rate = self.learning_rate * learning_rate |
|
for g in self.optimizer.param_groups: |
|
g['lr'] = learning_rate |
|
|
|
def get_alpha_inter_ratio(self, start, end): |
|
if end == 0.0: |
|
return 1.0 |
|
elif self.iter_step < start: |
|
return 0.0 |
|
else: |
|
return np.min([1.0, (self.iter_step - start) / (end - start)]) |
|
|
|
def file_backup(self): |
|
|
|
dir_lis = self.conf['general.recording'] |
|
os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) |
|
for dir_name in dir_lis: |
|
cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) |
|
os.makedirs(cur_dir, exist_ok=True) |
|
files = os.listdir(dir_name) |
|
for f_name in files: |
|
if f_name[-3:] == '.py': |
|
copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) |
|
|
|
|
|
copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) |
|
|
|
def load_checkpoint(self, checkpoint_name): |
|
|
|
def load_state_dict(network, checkpoint, comment): |
|
if network is not None: |
|
try: |
|
pretrained_dict = checkpoint[comment] |
|
|
|
model_dict = network.state_dict() |
|
|
|
|
|
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} |
|
|
|
model_dict.update(pretrained_dict) |
|
|
|
network.load_state_dict(pretrained_dict) |
|
except: |
|
print(comment + " load fails") |
|
|
|
checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), |
|
map_location=self.device) |
|
|
|
load_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside') |
|
|
|
load_state_dict(self.sdf_network_lod0, checkpoint, 'sdf_network_lod0') |
|
load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod1') |
|
|
|
load_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network') |
|
load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1') |
|
|
|
load_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0') |
|
load_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1') |
|
|
|
load_state_dict(self.rendering_network_lod0, checkpoint, 'rendering_network_lod0') |
|
load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod1') |
|
|
|
if self.restore_lod0: |
|
load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod0') |
|
load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network') |
|
load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod0') |
|
|
|
if self.is_continue and (not self.restore_lod0): |
|
try: |
|
self.optimizer.load_state_dict(checkpoint['optimizer']) |
|
except: |
|
print("load optimizer fails") |
|
self.iter_step = checkpoint['iter_step'] |
|
self.val_step = checkpoint['val_step'] if 'val_step' in checkpoint.keys() else 0 |
|
|
|
self.logger.info('End') |
|
|
|
def save_checkpoint(self): |
|
|
|
def save_state_dict(network, checkpoint, comment): |
|
if network is not None: |
|
checkpoint[comment] = network.state_dict() |
|
|
|
checkpoint = { |
|
'optimizer': self.optimizer.state_dict(), |
|
'iter_step': self.iter_step, |
|
'val_step': self.val_step, |
|
} |
|
|
|
save_state_dict(self.sdf_network_lod0, checkpoint, "sdf_network_lod0") |
|
save_state_dict(self.sdf_network_lod1, checkpoint, "sdf_network_lod1") |
|
|
|
save_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside') |
|
save_state_dict(self.rendering_network_lod0, checkpoint, "rendering_network_lod0") |
|
save_state_dict(self.rendering_network_lod1, checkpoint, "rendering_network_lod1") |
|
|
|
save_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0') |
|
save_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1') |
|
|
|
save_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network') |
|
save_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1') |
|
|
|
os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) |
|
torch.save(checkpoint, |
|
os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) |
|
|
|
def validate(self, resolution_level=-1): |
|
|
|
print("iter_step: ", self.iter_step) |
|
self.logger.info('Validate begin') |
|
self.val_step += 1 |
|
|
|
try: |
|
batch = next(self.val_dataloader_iterator) |
|
except: |
|
self.val_dataloader_iterator = iter(self.val_dataloader) |
|
|
|
batch = next(self.val_dataloader_iterator) |
|
|
|
|
|
background_rgb = None |
|
if self.use_white_bkgd: |
|
|
|
background_rgb = 1.0 |
|
|
|
batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) |
|
|
|
|
|
if self.num_lods == 1: |
|
alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) |
|
else: |
|
alpha_inter_ratio_lod0 = 1. |
|
alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) |
|
|
|
self.trainer( |
|
batch, |
|
background_rgb=background_rgb, |
|
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, |
|
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, |
|
iter_step=self.iter_step, |
|
save_vis=True, |
|
mode='val', |
|
) |
|
|
|
|
|
def export_mesh(self, resolution_level=-1): |
|
print("iter_step: ", self.iter_step) |
|
self.logger.info('Validate begin') |
|
self.val_step += 1 |
|
|
|
try: |
|
batch = next(self.val_dataloader_iterator) |
|
except: |
|
self.val_dataloader_iterator = iter(self.val_dataloader) |
|
|
|
batch = next(self.val_dataloader_iterator) |
|
|
|
|
|
background_rgb = None |
|
if self.use_white_bkgd: |
|
background_rgb = 1.0 |
|
|
|
batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) |
|
|
|
|
|
if self.num_lods == 1: |
|
alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) |
|
else: |
|
alpha_inter_ratio_lod0 = 1. |
|
alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) |
|
self.trainer( |
|
batch, |
|
background_rgb=background_rgb, |
|
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, |
|
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, |
|
iter_step=self.iter_step, |
|
save_vis=True, |
|
mode='export_mesh', |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
torch.set_default_dtype(torch.float32) |
|
FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" |
|
logging.basicConfig(level=logging.INFO, format=FORMAT) |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--conf', type=str, default='./confs/base.conf') |
|
parser.add_argument('--mode', type=str, default='train') |
|
parser.add_argument('--threshold', type=float, default=0.0) |
|
parser.add_argument('--is_continue', default=False, action="store_true") |
|
parser.add_argument('--is_restore', default=False, action="store_true") |
|
parser.add_argument('--is_finetune', default=False, action="store_true") |
|
parser.add_argument('--train_from_scratch', default=False, action="store_true") |
|
parser.add_argument('--restore_lod0', default=False, action="store_true") |
|
parser.add_argument('--local_rank', type=int, default=0) |
|
parser.add_argument('--specific_dataset_name', type=str, default='GSO') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
torch.cuda.set_device(args.local_rank) |
|
torch.backends.cudnn.benchmark = True |
|
|
|
runner = Runner(args.conf, args.mode, args.is_continue, args.is_restore, args.restore_lod0, |
|
args.local_rank) |
|
|
|
if args.mode == 'train': |
|
runner.train() |
|
elif args.mode == 'val': |
|
for i in range(len(runner.val_dataset)): |
|
runner.validate() |
|
elif args.mode == 'export_mesh': |
|
for i in range(len(runner.val_dataset)): |
|
runner.export_mesh() |
|
|