Spaces:
Runtime error
Runtime error
# | |
# Copyright (C) 2023, Inria | |
# GRAPHDECO research group, https://team.inria.fr/graphdeco | |
# All rights reserved. | |
# | |
# This software is free for non-commercial, research and evaluation use | |
# under the terms of the LICENSE.md file. | |
# | |
# For inquiries contact george.drettakis@inria.fr | |
# | |
import os | |
import torch | |
import numpy as np | |
from torchvision.transforms.functional import pil_to_tensor, to_tensor | |
from torchvision.utils import make_grid, save_image | |
from random import randint | |
from utils.loss_utils import l1_loss, ssim, lpips | |
from gaussian_renderer import render, network_gui | |
import sys | |
from scene import Scene, GaussianModel | |
from utils.general_utils import safe_state | |
import uuid | |
from tqdm import tqdm | |
from utils.image_utils import psnr | |
from argparse import ArgumentParser, Namespace | |
from arguments import ModelParams, PipelineParams, OptimizationParams | |
from scripts.sampling.simple_mv_sample import sample_one | |
try: | |
from torch.utils.tensorboard import SummaryWriter | |
TENSORBOARD_FOUND = True | |
except ImportError: | |
TENSORBOARD_FOUND = False | |
def training( | |
dataset, | |
opt, | |
pipe, | |
testing_iterations, | |
saving_iterations, | |
checkpoint_iterations, | |
checkpoint, | |
debug_from, | |
resample_period=500, | |
resample_sigma=0.1, | |
resample_start=1000, | |
model=None, | |
): | |
first_iter = 0 | |
tb_writer = prepare_output_and_logger(dataset) | |
gaussians = GaussianModel(dataset.sh_degree) | |
scene = Scene(dataset, gaussians, shuffle=False) | |
gaussians.training_setup(opt) | |
if checkpoint: | |
(model_params, first_iter) = torch.load(checkpoint) | |
gaussians.restore(model_params, opt) | |
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] | |
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") | |
iter_start = torch.cuda.Event(enable_timing=True) | |
iter_end = torch.cuda.Event(enable_timing=True) | |
viewpoint_stack = None | |
ema_loss_for_log = 0.0 | |
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") | |
first_iter += 1 | |
for iteration in range(first_iter, opt.iterations + 1): | |
iter_start.record() | |
gaussians.update_learning_rate(iteration) | |
# Every 1000 its we increase the levels of SH up to a maximum degree | |
if iteration % 1000 == 0: | |
gaussians.oneupSHdegree() | |
with torch.no_grad(): | |
if iteration % resample_period == 0 and iteration > resample_start: | |
# if iteration % resample_period: | |
views = [] | |
viewpoint_stack = scene.getTrainCameras().copy() | |
for view_cam in viewpoint_stack: | |
bg = ( | |
torch.rand((3), device="cuda") | |
if opt.random_background | |
else background | |
) | |
render_pkg = render(view_cam, gaussians, pipe, bg) | |
image, viewspace_point_tensor, visibility_filter, radii = ( | |
render_pkg["render"], | |
render_pkg["viewspace_points"], | |
render_pkg["visibility_filter"], | |
render_pkg["radii"], | |
) | |
views.append(image) | |
views = torch.stack(views) | |
save_image(views, f"tmp/views_{iteration}.png") | |
views = views * 2.0 - 1.0 | |
views = model.encode_first_stage(views) | |
noisy_views = views + torch.randn_like(views) * resample_sigma | |
noisy_views = ( | |
np.sqrt(1 - resample_sigma**2) * views | |
+ torch.randn_like(views) * resample_sigma | |
) | |
resampled_images = sample_one( | |
args.image, | |
args.ckpt_path, | |
noise=noisy_views, | |
cached_model=model, | |
)[0] | |
dataset.images = resampled_images | |
scene = Scene( | |
dataset, | |
gaussians, | |
shuffle=False, | |
skip_gaussians=True, | |
) | |
resampled_images_grid = [] | |
for img in resampled_images: | |
resampled_images_grid.append(to_tensor(img)) | |
resampled_images_grid = torch.stack(resampled_images_grid) | |
resampled_images_grid = make_grid(resampled_images_grid, nrow=3) | |
save_image( | |
resampled_images_grid, f"tmp/resampled_images_{iteration}.png" | |
) | |
# Pick a random Camera | |
if not viewpoint_stack: | |
viewpoint_stack = scene.getTrainCameras().copy() | |
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) | |
# Render | |
if (iteration - 1) == debug_from: | |
pipe.debug = True | |
bg = torch.rand((3), device="cuda") if opt.random_background else background | |
render_pkg = render(viewpoint_cam, gaussians, pipe, bg) | |
image, viewspace_point_tensor, visibility_filter, radii = ( | |
render_pkg["render"], | |
render_pkg["viewspace_points"], | |
render_pkg["visibility_filter"], | |
render_pkg["radii"], | |
) | |
# Loss | |
gt_image = viewpoint_cam.original_image.cuda() | |
Ll1 = l1_loss(image, gt_image) | |
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ( | |
1.0 - ssim(image, gt_image) | |
) | |
if opt.lambda_lpips > 0: | |
loss += opt.lambda_lpips * lpips(image, gt_image) | |
loss.backward() | |
iter_end.record() | |
with torch.no_grad(): | |
# Progress bar | |
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log | |
if iteration % 10 == 0: | |
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) | |
progress_bar.update(10) | |
if iteration == opt.iterations: | |
progress_bar.close() | |
# Log and save | |
training_report( | |
tb_writer, | |
iteration, | |
Ll1, | |
loss, | |
l1_loss, | |
iter_start.elapsed_time(iter_end), | |
testing_iterations, | |
scene, | |
render, | |
(pipe, background), | |
) | |
if iteration in saving_iterations: | |
print("\n[ITER {}] Saving Gaussians".format(iteration)) | |
scene.save(iteration) | |
# Densification | |
if iteration < opt.densify_until_iter: | |
# Keep track of max radii in image-space for pruning | |
gaussians.max_radii2D[visibility_filter] = torch.max( | |
gaussians.max_radii2D[visibility_filter], radii[visibility_filter] | |
) | |
gaussians.add_densification_stats( | |
viewspace_point_tensor, visibility_filter | |
) | |
if ( | |
iteration > opt.densify_from_iter | |
and iteration % opt.densification_interval == 0 | |
): | |
size_threshold = ( | |
20 if iteration > opt.opacity_reset_interval else None | |
) | |
gaussians.densify_and_prune( | |
opt.densify_grad_threshold, | |
0.005, | |
scene.cameras_extent, | |
size_threshold, | |
) | |
if iteration % opt.opacity_reset_interval == 0 or ( | |
dataset.white_background and iteration == opt.densify_from_iter | |
): | |
gaussians.reset_opacity() | |
# Optimizer step | |
if iteration < opt.iterations: | |
gaussians.optimizer.step() | |
gaussians.optimizer.zero_grad(set_to_none=True) | |
if iteration in checkpoint_iterations: | |
print("\n[ITER {}] Saving Checkpoint".format(iteration)) | |
torch.save( | |
(gaussians.capture(), iteration), | |
scene.model_path + "/chkpnt" + str(iteration) + ".pth", | |
) | |
def prepare_output_and_logger(args): | |
if not args.model_path: | |
if os.getenv("OAR_JOB_ID"): | |
unique_str = os.getenv("OAR_JOB_ID") | |
else: | |
unique_str = str(uuid.uuid4()) | |
args.model_path = os.path.join("./output/", unique_str[0:10]) | |
# Set up output folder | |
print("Output folder: {}".format(args.model_path)) | |
os.makedirs(args.model_path, exist_ok=True) | |
with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f: | |
cfg_log_f.write(str(Namespace(**vars(args)))) | |
# Create Tensorboard writer | |
tb_writer = None | |
if TENSORBOARD_FOUND: | |
tb_writer = SummaryWriter(args.model_path) | |
else: | |
print("Tensorboard not available: not logging progress") | |
return tb_writer | |
def training_report( | |
tb_writer, | |
iteration, | |
Ll1, | |
loss, | |
l1_loss, | |
elapsed, | |
testing_iterations, | |
scene: Scene, | |
renderFunc, | |
renderArgs, | |
): | |
if tb_writer: | |
tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration) | |
tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration) | |
tb_writer.add_scalar("iter_time", elapsed, iteration) | |
# Report test and samples of training set | |
if iteration in testing_iterations: | |
torch.cuda.empty_cache() | |
validation_configs = ( | |
{"name": "test", "cameras": scene.getTestCameras()}, | |
{ | |
"name": "train", | |
"cameras": [ | |
scene.getTrainCameras()[idx % len(scene.getTrainCameras())] | |
for idx in range(5, 30, 5) | |
], | |
}, | |
) | |
for config in validation_configs: | |
if config["cameras"] and len(config["cameras"]) > 0: | |
l1_test = 0.0 | |
psnr_test = 0.0 | |
for idx, viewpoint in enumerate(config["cameras"]): | |
image = torch.clamp( | |
renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], | |
0.0, | |
1.0, | |
) | |
gt_image = torch.clamp( | |
viewpoint.original_image.to("cuda"), 0.0, 1.0 | |
) | |
if tb_writer and (idx < 5): | |
tb_writer.add_images( | |
config["name"] | |
+ "_view_{}/render".format(viewpoint.image_name), | |
image[None], | |
global_step=iteration, | |
) | |
if iteration == testing_iterations[0]: | |
tb_writer.add_images( | |
config["name"] | |
+ "_view_{}/ground_truth".format(viewpoint.image_name), | |
gt_image[None], | |
global_step=iteration, | |
) | |
l1_test += l1_loss(image, gt_image).mean().double() | |
psnr_test += psnr(image, gt_image).mean().double() | |
psnr_test /= len(config["cameras"]) | |
l1_test /= len(config["cameras"]) | |
print( | |
"\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format( | |
iteration, config["name"], l1_test, psnr_test | |
) | |
) | |
if tb_writer: | |
tb_writer.add_scalar( | |
config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration | |
) | |
tb_writer.add_scalar( | |
config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration | |
) | |
if tb_writer: | |
tb_writer.add_histogram( | |
"scene/opacity_histogram", scene.gaussians.get_opacity, iteration | |
) | |
tb_writer.add_scalar( | |
"total_points", scene.gaussians.get_xyz.shape[0], iteration | |
) | |
torch.cuda.empty_cache() | |
if __name__ == "__main__": | |
# Set up command line argument parser | |
parser = ArgumentParser(description="Training script parameters") | |
lp = ModelParams(parser) | |
op = OptimizationParams(parser) | |
pp = PipelineParams(parser) | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("--image", type=str, default="assets/images/ceramic.png") | |
parser.add_argument("--ckpt_path", type=str, required=True) | |
parser.add_argument("--ip", type=str, default="127.0.0.1") | |
parser.add_argument("--port", type=int, default=6009) | |
parser.add_argument("--debug_from", type=int, default=-1) | |
parser.add_argument("--detect_anomaly", action="store_true", default=False) | |
parser.add_argument( | |
"--test_iterations", nargs="+", type=int, default=[7_000, 30_000] | |
) | |
parser.add_argument( | |
"--save_iterations", nargs="+", type=int, default=[7_000, 30_000] | |
) | |
parser.add_argument("--quiet", action="store_true") | |
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) | |
parser.add_argument("--start_checkpoint", type=str, default=None) | |
parser.add_argument("--resample_period", type=int, default=500) | |
parser.add_argument("--resample_sigma", type=float, default=0.1) | |
parser.add_argument("--resample_start", type=int, default=500) | |
args = parser.parse_args(sys.argv[1:]) | |
args.save_iterations.append(args.iterations) | |
print("Optimizing " + args.model_path) | |
# Initialize system state (RNG) | |
safe_state(args.quiet) | |
# Start GUI server, configure and run training | |
network_gui.init(args.ip, args.port) | |
torch.autograd.set_detect_anomaly(args.detect_anomaly) | |
print("=====Start generating MV Images=====") | |
images, model = sample_one(args.image, args.ckpt_path, seed=args.seed) | |
print("=====Finish generating MV Images=====") | |
lp = lp.extract(args) | |
lp.images = images | |
training( | |
lp, | |
op.extract(args), | |
pp.extract(args), | |
args.test_iterations, | |
args.save_iterations, | |
args.checkpoint_iterations, | |
args.start_checkpoint, | |
args.debug_from, | |
args.resample_period, | |
args.resample_sigma, | |
args.resample_start, | |
model, | |
) | |
# All done | |
print("\nTraining complete.") | |