wyysf's picture
i
0f079b2
raw
history blame
25.5 kB
import os
import time
import logging
import argparse
import numpy as np
import cv2 as cv
import trimesh
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from shutil import copyfile
from icecream import ic
from tqdm import tqdm
from pyhocon import ConfigFactory
from models.dataset_mvdiff import Dataset
from models.fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork, NeRF
from models.renderer import NeuSRenderer
import pdb
import math
def ranking_loss(error, penalize_ratio=0.7, type='mean'):
error, indices = torch.sort(error)
# only sum relatively small errors
s_error = torch.index_select(error, 0, index=indices[: int(penalize_ratio * indices.shape[0])])
if type == 'mean':
return torch.mean(s_error)
elif type == 'sum':
return torch.sum(s_error)
class Runner:
def __init__(self, conf_path, mode='train', case='CASE_NAME', is_continue=False, data_dir=None):
self.device = torch.device('cuda')
# Configuration
self.conf_path = conf_path
f = open(self.conf_path)
conf_text = f.read()
conf_text = conf_text.replace('CASE_NAME', case)
f.close()
self.conf = ConfigFactory.parse_string(conf_text)
self.conf['dataset']['data_dir'] = data_dir
self.conf['dataset.data_dir'] = self.conf['dataset.data_dir'].replace('CASE_NAME', case)
self.base_exp_dir = self.conf['general.base_exp_dir']
os.makedirs(self.base_exp_dir, exist_ok=True)
self.dataset = Dataset(self.conf['dataset'])
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=self.conf['train']['batch_size'],
shuffle=True,
num_workers=64,
)
self.iter_step = 1
# Training parameters
self.end_iter = self.conf.get_int('train.end_iter')
self.save_freq = self.conf.get_int('train.save_freq')
self.report_freq = self.conf.get_int('train.report_freq')
self.val_freq = self.conf.get_int('train.val_freq')
self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
self.batch_size = self.conf.get_int('train.batch_size')
self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level')
self.learning_rate = self.conf.get_float('train.learning_rate')
self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha')
self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd')
self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0)
self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0)
# Weights
self.color_weight = self.conf.get_float('train.color_weight')
self.igr_weight = self.conf.get_float('train.igr_weight')
self.mask_weight = self.conf.get_float('train.mask_weight')
self.normal_weight = self.conf.get_float('train.normal_weight')
self.sparse_weight = self.conf.get_float('train.sparse_weight')
self.is_continue = is_continue
self.mode = mode
self.model_list = []
self.writer = None
# Networks
params_to_train_slow = []
self.nerf_outside = NeRF(**self.conf['model.nerf']).to(self.device)
self.sdf_network = SDFNetwork(**self.conf['model.sdf_network']).to(self.device)
self.deviation_network = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device)
self.color_network = RenderingNetwork(**self.conf['model.rendering_network']).to(self.device)
# params_to_train += list(self.nerf_outside.parameters())
params_to_train_slow += list(self.sdf_network.parameters())
params_to_train_slow += list(self.deviation_network.parameters())
# params_to_train += list(self.color_network.parameters())
self.optimizer = torch.optim.Adam(
[{'params': params_to_train_slow}, {'params': self.color_network.parameters(), 'lr': self.learning_rate * 2}], lr=self.learning_rate
)
self.renderer = NeuSRenderer(
self.nerf_outside, self.sdf_network, self.deviation_network, self.color_network, **self.conf['model.neus_renderer']
)
# Load checkpoint
latest_model_name = None
if is_continue:
model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints'))
model_list = []
for model_name in model_list_raw:
if model_name[-3:] == 'pth' and int(model_name[5:-4]) <= self.end_iter:
model_list.append(model_name)
model_list.sort()
latest_model_name = model_list[-1]
if latest_model_name is not None:
logging.info('Find checkpoint: {}'.format(latest_model_name))
self.load_checkpoint(latest_model_name)
# Backup codes and configs for debug
if self.mode[:5] == 'train':
self.file_backup()
def train(self):
self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs'))
self.update_learning_rate()
res_step = self.end_iter - self.iter_step
image_perm = self.get_image_perm()
num_train_epochs = math.ceil(res_step / len(self.dataloader))
print("training ", num_train_epochs, " epoches")
for epoch in range(num_train_epochs):
# for iter_i in tqdm(range(res_step)):
print("epoch ", epoch)
for iter_i, data in enumerate(self.dataloader):
# img_idx = image_perm[self.iter_step % len(image_perm)]
# data = self.dataset.gen_random_rays_at(img_idx, self.batch_size)
data = data.cuda()
rays_o, rays_d, true_rgb, mask, true_normal, cosines = (
data[:, :3],
data[:, 3:6],
data[:, 6:9],
data[:, 9:10],
data[:, 10:13],
data[:, 13:],
)
# near, far = self.dataset.near_far_from_sphere(rays_o, rays_d)
near, far = self.dataset.get_near_far()
background_rgb = None
if self.use_white_bkgd:
background_rgb = torch.ones([1, 3])
if self.mask_weight > 0.0:
mask = (mask > 0.5).float()
else:
mask = torch.ones_like(mask)
cosines[cosines > -0.1] = 0
mask = ((mask > 0) & (cosines < -0.1)).to(torch.float32)
mask_sum = mask.sum() + 1e-5
render_out = self.renderer.render(
rays_o, rays_d, near, far, background_rgb=background_rgb, cos_anneal_ratio=self.get_cos_anneal_ratio()
)
color_fine = render_out['color_fine']
s_val = render_out['s_val']
cdf_fine = render_out['cdf_fine']
gradient_error = render_out['gradient_error']
weight_max = render_out['weight_max']
weight_sum = render_out['weight_sum']
# Loss
# color_error = (color_fine - true_rgb) * mask
# color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error), reduction='sum') / mask_sum
color_errors = (color_fine - true_rgb).abs().sum(dim=1)
color_fine_loss = ranking_loss(color_errors[mask[:, 0] > 0])
psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb) ** 2 * mask).sum() / (mask_sum * 3.0)).sqrt())
eikonal_loss = gradient_error
# pdb.set_trace()
mask_errors = F.binary_cross_entropy(weight_sum.clip(1e-3, 1.0 - 1e-3), mask, reduction='none')
mask_loss = ranking_loss(mask_errors[:, 0], penalize_ratio=0.8)
def feasible(key):
return (key in render_out) and (render_out[key] is not None)
# calculate normal loss
n_samples = self.renderer.n_samples + self.renderer.n_importance
normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None]
if feasible('inside_sphere'):
normals = normals * render_out['inside_sphere'][..., None]
normals = normals.sum(dim=1)
# pdb.set_trace()
normal_errors = 1 - F.cosine_similarity(normals, true_normal, dim=1)
# normal_error = normal_error * mask[:, 0]
# normal_loss = F.l1_loss(normal_error, torch.zeros_like(normal_error), reduction='sum') / mask_sum
normal_errors = normal_errors * torch.exp(cosines.abs()[:, 0]) / torch.exp(cosines.abs()).sum()
normal_loss = ranking_loss(normal_errors[mask[:, 0] > 0], penalize_ratio=0.9, type='sum')
sparse_loss = render_out['sparse_loss']
loss = (
color_fine_loss * self.color_weight
+ eikonal_loss * self.igr_weight
+ sparse_loss * self.sparse_weight
+ mask_loss * self.mask_weight
+ normal_loss * self.normal_weight
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.writer.add_scalar('Loss/loss', loss, self.iter_step)
self.writer.add_scalar('Loss/color_loss', color_fine_loss, self.iter_step)
self.writer.add_scalar('Loss/eikonal_loss', eikonal_loss, self.iter_step)
self.writer.add_scalar('Statistics/s_val', s_val.mean(), self.iter_step)
self.writer.add_scalar('Statistics/cdf', (cdf_fine[:, :1] * mask).sum() / mask_sum, self.iter_step)
self.writer.add_scalar('Statistics/weight_max', (weight_max * mask).sum() / mask_sum, self.iter_step)
self.writer.add_scalar('Statistics/psnr', psnr, self.iter_step)
if self.iter_step % self.report_freq == 0:
print(self.base_exp_dir)
print(
'iter:{:8>d} loss = {:4>f} color_ls = {:4>f} eik_ls = {:4>f} normal_ls = {:4>f} mask_ls = {:4>f} sparse_ls = {:4>f} lr={:5>f}'.format(
self.iter_step,
loss,
color_fine_loss,
eikonal_loss,
normal_loss,
mask_loss,
sparse_loss,
self.optimizer.param_groups[0]['lr'],
)
)
print('iter:{:8>d} s_val = {:4>f}'.format(self.iter_step, s_val.mean()))
if self.iter_step % self.val_mesh_freq == 0:
self.validate_mesh(resolution=256)
self.update_learning_rate()
self.iter_step += 1
if self.iter_step % self.val_freq == 0:
self.validate_image(idx=0)
self.validate_image(idx=1)
self.validate_image(idx=2)
self.validate_image(idx=3)
if self.iter_step % self.save_freq == 0:
self.save_checkpoint()
if self.iter_step % len(image_perm) == 0:
image_perm = self.get_image_perm()
def get_image_perm(self):
return torch.randperm(self.dataset.n_images)
def get_cos_anneal_ratio(self):
if self.anneal_end == 0.0:
return 1.0
else:
return np.min([1.0, self.iter_step / self.anneal_end])
def update_learning_rate(self):
if self.iter_step < self.warm_up_end:
learning_factor = self.iter_step / self.warm_up_end
else:
alpha = self.learning_rate_alpha
progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end)
learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha
for g in self.optimizer.param_groups:
g['lr'] = self.learning_rate * learning_factor
def file_backup(self):
dir_lis = self.conf['general.recording']
os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True)
for dir_name in dir_lis:
cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name)
os.makedirs(cur_dir, exist_ok=True)
files = os.listdir(dir_name)
for f_name in files:
if f_name[-3:] == '.py':
copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name))
copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf'))
def load_checkpoint(self, checkpoint_name):
checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device)
self.nerf_outside.load_state_dict(checkpoint['nerf'])
self.sdf_network.load_state_dict(checkpoint['sdf_network_fine'])
self.deviation_network.load_state_dict(checkpoint['variance_network_fine'])
self.color_network.load_state_dict(checkpoint['color_network_fine'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.iter_step = checkpoint['iter_step']
logging.info('End')
def save_checkpoint(self):
checkpoint = {
'nerf': self.nerf_outside.state_dict(),
'sdf_network_fine': self.sdf_network.state_dict(),
'variance_network_fine': self.deviation_network.state_dict(),
'color_network_fine': self.color_network.state_dict(),
'optimizer': self.optimizer.state_dict(),
'iter_step': self.iter_step,
}
os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True)
torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step)))
def validate_image(self, idx=-1, resolution_level=-1):
if idx < 0:
idx = np.random.randint(self.dataset.n_images)
print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx))
if resolution_level < 0:
resolution_level = self.validate_resolution_level
rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level)
H, W, _ = rays_o.shape
rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
rays_d = rays_d.reshape(-1, 3).split(self.batch_size)
out_rgb_fine = []
out_normal_fine = []
out_mask = []
for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
# near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
near, far = self.dataset.get_near_far()
background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
render_out = self.renderer.render(
rays_o_batch, rays_d_batch, near, far, cos_anneal_ratio=self.get_cos_anneal_ratio(), background_rgb=background_rgb
)
def feasible(key):
return (key in render_out) and (render_out[key] is not None)
if feasible('color_fine'):
out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
if feasible('gradients') and feasible('weights'):
n_samples = self.renderer.n_samples + self.renderer.n_importance
normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None]
if feasible('inside_sphere'):
normals = normals * render_out['inside_sphere'][..., None]
normals = normals.sum(dim=1).detach().cpu().numpy()
out_normal_fine.append(normals)
if feasible('weight_sum'):
out_mask.append(render_out['weight_sum'].detach().clip(0, 1).cpu().numpy())
del render_out
img_fine = None
if len(out_rgb_fine) > 0:
img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255)
mask_map = None
if len(out_mask) > 0:
mask_map = (np.concatenate(out_mask, axis=0).reshape([H, W, -1]) * 256).clip(0, 255)
normal_img = None
if len(out_normal_fine) > 0:
normal_img = np.concatenate(out_normal_fine, axis=0)
rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy())
normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None]).reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255)
os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True)
os.makedirs(os.path.join(self.base_exp_dir, 'normals'), exist_ok=True)
for i in range(img_fine.shape[-1]):
if len(out_rgb_fine) > 0:
cv.imwrite(
os.path.join(self.base_exp_dir, 'validations_fine', '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)),
np.concatenate(
[
img_fine[..., i],
self.dataset.image_at(idx, resolution_level=resolution_level),
self.dataset.mask_at(idx, resolution_level=resolution_level),
]
),
)
if len(out_normal_fine) > 0:
cv.imwrite(
os.path.join(self.base_exp_dir, 'normals', '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)),
np.concatenate([normal_img[..., i], self.dataset.normal_cam_at(idx, resolution_level=resolution_level)])[:, :, ::-1],
)
if len(out_mask) > 0:
cv.imwrite(os.path.join(self.base_exp_dir, 'normals', '{:0>8d}_{}_{}_mask.png'.format(self.iter_step, i, idx)), mask_map[..., i])
def save_maps(self, idx, img_idx, resolution_level=1):
view_types = ['front', 'back', 'left', 'right']
print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx))
rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level)
H, W, _ = rays_o.shape
rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
rays_d = rays_d.reshape(-1, 3).split(self.batch_size)
out_rgb_fine = []
out_normal_fine = []
out_mask = []
for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
# near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
near, far = self.dataset.get_near_far()
background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
render_out = self.renderer.render(
rays_o_batch, rays_d_batch, near, far, cos_anneal_ratio=self.get_cos_anneal_ratio(), background_rgb=background_rgb
)
def feasible(key):
return (key in render_out) and (render_out[key] is not None)
if feasible('color_fine'):
out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
if feasible('gradients') and feasible('weights'):
n_samples = self.renderer.n_samples + self.renderer.n_importance
normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None]
if feasible('inside_sphere'):
normals = normals * render_out['inside_sphere'][..., None]
normals = normals.sum(dim=1).detach().cpu().numpy()
out_normal_fine.append(normals)
if feasible('weight_sum'):
out_mask.append(render_out['weight_sum'].detach().clip(0, 1).cpu().numpy())
del render_out
img_fine = None
if len(out_rgb_fine) > 0:
img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
mask_map = None
if len(out_mask) > 0:
mask_map = (np.concatenate(out_mask, axis=0).reshape([H, W, 1]) * 256).clip(0, 255)
normal_img = None
if len(out_normal_fine) > 0:
normal_img = np.concatenate(out_normal_fine, axis=0)
# rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy())
world_normal_img = (normal_img.reshape([H, W, 3]) * 128 + 128).clip(0, 255)
os.makedirs(os.path.join(self.base_exp_dir, 'coarse_maps'), exist_ok=True)
img_rgba = np.concatenate([img_fine[:, :, ::-1], mask_map], axis=-1)
normal_rgba = np.concatenate([world_normal_img[:, :, ::-1], mask_map], axis=-1)
cv.imwrite(os.path.join(self.base_exp_dir, 'coarse_maps', "normals_mlp_%03d_%s.png" % (img_idx, view_types[idx])), img_rgba)
cv.imwrite(os.path.join(self.base_exp_dir, 'coarse_maps', "normals_grad_%03d_%s.png" % (img_idx, view_types[idx])), normal_rgba)
def render_novel_image(self, idx_0, idx_1, ratio, resolution_level):
"""
Interpolate view between two cameras.
"""
rays_o, rays_d = self.dataset.gen_rays_between(idx_0, idx_1, ratio, resolution_level=resolution_level)
H, W, _ = rays_o.shape
rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
rays_d = rays_d.reshape(-1, 3).split(self.batch_size)
out_rgb_fine = []
for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
# near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
near, far = self.dataset.get_near_far()
background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
render_out = self.renderer.render(
rays_o_batch, rays_d_batch, near, far, cos_anneal_ratio=self.get_cos_anneal_ratio(), background_rgb=background_rgb
)
out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
del render_out
img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255).astype(np.uint8)
return img_fine
def validate_mesh(self, world_space=False, resolution=64, threshold=0.0):
bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32)
bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32)
vertices, triangles, vertex_colors = self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold)
os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True)
if world_space:
vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None]
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertex_colors)
# mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}.ply'.format(self.iter_step)))
# export as glb
mesh.export(os.path.join(self.base_exp_dir, 'meshes', 'tmp.glb'))
logging.info('End')
def interpolate_view(self, img_idx_0, img_idx_1):
images = []
n_frames = 60
for i in range(n_frames):
print(i)
images.append(self.render_novel_image(img_idx_0, img_idx_1, np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5, resolution_level=4))
for i in range(n_frames):
images.append(images[n_frames - i - 1])
fourcc = cv.VideoWriter_fourcc(*'mp4v')
video_dir = os.path.join(self.base_exp_dir, 'render')
os.makedirs(video_dir, exist_ok=True)
h, w, _ = images[0].shape
writer = cv.VideoWriter(os.path.join(video_dir, '{:0>8d}_{}_{}.mp4'.format(self.iter_step, img_idx_0, img_idx_1)), fourcc, 30, (w, h))
for image in images:
writer.write(image)
writer.release()
if __name__ == '__main__':
print('Hello Wooden')
torch.set_default_tensor_type('torch.FloatTensor')
FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
logging.basicConfig(level=logging.DEBUG, format=FORMAT)
parser = argparse.ArgumentParser()
parser.add_argument('--conf', type=str, default='./confs/base.conf')
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--mcube_threshold', type=float, default=0.0)
parser.add_argument('--is_continue', default=False, action="store_true")
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--case', type=str, default='')
parser.add_argument('--data_dir', type=str, default='')
args = parser.parse_args()
torch.cuda.set_device(args.gpu)
runner = Runner(args.conf, args.mode, args.case, args.is_continue, args.data_dir)
if args.mode == 'train':
runner.train()
runner.validate_mesh(world_space=False, resolution=256, threshold=args.mcube_threshold)
elif args.mode == 'save_maps':
for i in range(4):
runner.save_maps(idx=i, img_idx=runner.dataset.object_viewidx)
elif args.mode == 'validate_mesh':
runner.validate_mesh(world_space=False, resolution=512, threshold=args.mcube_threshold)
elif args.mode.startswith('interpolate'): # Interpolate views given two image indices
_, img_idx_0, img_idx_1 = args.mode.split('_')
img_idx_0 = int(img_idx_0)
img_idx_1 = int(img_idx_1)
runner.interpolate_view(img_idx_0, img_idx_1)