Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import cv2 | |
import einops | |
from collections import defaultdict | |
import matplotlib.pyplot as plt | |
import random | |
# import emd | |
import pytorch3d.loss | |
# import imageio.v3 | |
import functools | |
import json | |
import os | |
from pathlib import Path | |
from pdb import set_trace as st | |
from einops import rearrange | |
import webdataset as wds | |
from nsr.camera_utils import generate_input_camera, uni_mesh_path | |
import point_cloud_utils as pcu | |
import traceback | |
import blobfile as bf | |
from datasets.g_buffer_objaverse import focal2fov, fov2focal | |
import math | |
import imageio | |
import numpy as np | |
# from sympy import O | |
import torch | |
from torch.autograd import Function | |
import torch.nn.functional as F | |
import torch as th | |
import open3d as o3d | |
import torch.distributed as dist | |
import torchvision | |
from PIL import Image | |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |
from torch.optim import AdamW | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm | |
import pytorch3d.ops | |
from torch.profiler import profile, record_function, ProfilerActivity | |
from guided_diffusion import dist_util, logger | |
from guided_diffusion.fp16_util import MixedPrecisionTrainer | |
from guided_diffusion.nn import update_ema | |
from guided_diffusion.resample import LossAwareSampler, UniformSampler | |
from guided_diffusion.train_util import (calc_average_loss, | |
find_ema_checkpoint, | |
find_resume_checkpoint, | |
get_blob_logdir, log_rec3d_loss_dict, | |
parse_resume_step_from_filename) | |
from datasets.g_buffer_objaverse import unity2blender, unity2blender_th, PostProcess | |
from nsr.volumetric_rendering.ray_sampler import RaySampler | |
from utils.mesh_util import post_process_mesh, to_cam_open3d_compat | |
from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics | |
from nsr.camera_utils import generate_input_camera, uni_mesh_path, sample_uniform_cameras_on_sphere | |
from utils.gs_utils.graphics_utils import getWorld2View2, getProjectionMatrix, getView2World | |
from utils.general_utils import matrix_to_quaternion | |
from utils.mesh_util import post_process_mesh, to_cam_open3d_compat, smooth_mesh | |
from datasets.g_buffer_objaverse import focal2fov, fov2focal | |
from .train_util import TrainLoop3DRec | |
import kornia | |
def psnr(input, target, max_val): | |
return kornia.metrics.psnr(input, target, max_val) | |
def calc_emd(output, gt, eps=0.005, iterations=50): | |
import utils.emd.emd_module as emd | |
emd_loss = emd.emdModule() | |
dist, _ = emd_loss(output, gt, eps, iterations) | |
emd_out = torch.sqrt(dist).mean(1) | |
return emd_out | |
class TrainLoop3DRecNV(TrainLoop3DRec): | |
# supervise the training of novel view | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
self.rec_cano = True | |
def forward_backward(self, batch, *args, **kwargs): | |
# return super().forward_backward(batch, *args, **kwargs) | |
self.mp_trainer_rec.zero_grad() | |
batch_size = batch['img_to_encoder'].shape[0] | |
for i in range(0, batch_size, self.microbatch): | |
# st() | |
micro = { | |
k: v[i:i + self.microbatch].to(dist_util.dev()) | |
for k, v in batch.items() | |
} | |
# ! concat novel-view? next version. also add self reconstruction, patch-based loss in the next version. verify novel-view prediction first. | |
# wrap forward within amp | |
with th.autocast(device_type='cuda', | |
dtype=th.float16, | |
enabled=self.mp_trainer_rec.use_amp): | |
target_nvs = {} | |
target_cano = {} | |
latent = self.rec_model(img=micro['img_to_encoder'], | |
behaviour='enc_dec_wo_triplane') | |
pred = self.rec_model( | |
latent=latent, | |
c=micro['nv_c'], # predict novel view here | |
behaviour='triplane_dec') | |
for k, v in micro.items(): | |
if k[:2] == 'nv': | |
orig_key = k.replace('nv_', '') | |
target_nvs[orig_key] = v | |
target_cano[orig_key] = micro[orig_key] | |
with self.rec_model.no_sync(): # type: ignore | |
loss, loss_dict, fg_mask = self.loss_class( | |
pred, | |
target_nvs, | |
step=self.step + self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None) | |
log_rec3d_loss_dict(loss_dict) | |
if self.rec_cano: | |
pred_cano = self.rec_model(latent=latent, | |
c=micro['c'], | |
behaviour='triplane_dec') | |
with self.rec_model.no_sync(): # type: ignore | |
fg_mask = target_cano['depth_mask'].unsqueeze( | |
1).repeat_interleave(3, 1).float() | |
loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( | |
pred_cano['image_raw'], | |
target_cano['img'], | |
fg_mask, | |
step=self.step + self.resume_step, | |
test_mode=False, | |
) | |
loss = loss + loss_cano | |
# remove redundant log | |
log_rec3d_loss_dict({ | |
f'cano_{k}': v | |
for k, v in loss_cano_dict.items() | |
# if "loss" in k | |
}) | |
self.mp_trainer_rec.backward(loss) | |
if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
if self.rec_cano: | |
self.log_img(micro, pred, pred_cano) | |
else: | |
self.log_img(micro, pred, None) | |
def log_img(self, micro, pred, pred_cano): | |
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
def norm_depth(pred_depth): # to [-1,1] | |
# pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
pred_depth.min()) | |
return -(pred_depth * 2 - 1) | |
pred_img = pred['image_raw'] | |
gt_img = micro['img'] | |
# infer novel view also | |
# if self.loss_class.opt.symmetry_loss: | |
# pred_nv_img = nvs_pred | |
# else: | |
# ! replace with novel view prediction | |
# ! log another novel-view prediction | |
# pred_nv_img = self.rec_model( | |
# img=micro['img_to_encoder'], | |
# c=self.novel_view_poses) # pred: (B, 3, 64, 64) | |
# if 'depth' in micro: | |
gt_depth = micro['depth'] | |
if gt_depth.ndim == 3: | |
gt_depth = gt_depth.unsqueeze(1) | |
gt_depth = norm_depth(gt_depth) | |
# gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
# gt_depth.min()) | |
# if True: | |
fg_mask = pred['image_mask'] * 2 - 1 # 0-1 | |
input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1 | |
if 'image_depth' in pred: | |
pred_depth = norm_depth(pred['image_depth']) | |
pred_nv_depth = norm_depth(pred_cano['image_depth']) | |
else: | |
pred_depth = th.zeros_like(gt_depth) | |
pred_nv_depth = th.zeros_like(gt_depth) | |
if 'image_sr' in pred: | |
if pred['image_sr'].shape[-1] == 512: | |
pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], | |
dim=-1) | |
gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], | |
dim=-1) | |
pred_depth = self.pool_512(pred_depth) | |
gt_depth = self.pool_512(gt_depth) | |
elif pred['image_sr'].shape[-1] == 256: | |
pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], | |
dim=-1) | |
gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], | |
dim=-1) | |
pred_depth = self.pool_256(pred_depth) | |
gt_depth = self.pool_256(gt_depth) | |
else: | |
pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], | |
dim=-1) | |
gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], | |
dim=-1) | |
gt_depth = self.pool_128(gt_depth) | |
pred_depth = self.pool_128(pred_depth) | |
else: | |
gt_img = self.pool_64(gt_img) | |
gt_depth = self.pool_64(gt_depth) | |
pred_vis = th.cat([ | |
pred_img, | |
pred_depth.repeat_interleave(3, dim=1), | |
fg_mask.repeat_interleave(3, dim=1), | |
], | |
dim=-1) # B, 3, H, W | |
pred_vis_nv = th.cat([ | |
pred_cano['image_raw'], | |
pred_nv_depth.repeat_interleave(3, dim=1), | |
input_fg_mask.repeat_interleave(3, dim=1), | |
], | |
dim=-1) # B, 3, H, W | |
pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim | |
gt_vis = th.cat([ | |
gt_img, | |
gt_depth.repeat_interleave(3, dim=1), | |
th.zeros_like(gt_img) | |
], | |
dim=-1) # TODO, fail to load depth. range [0, 1] | |
if 'conf_sigma' in pred: | |
gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder | |
# vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
vis = th.cat([gt_vis, pred_vis], dim=-2) | |
# .permute( | |
# 0, 2, 3, 1).cpu() | |
vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // | |
64) # HWC | |
torchvision.utils.save_image( | |
vis_tensor, | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
value_range=(-1, 1), | |
normalize=True) | |
# vis = vis.numpy() * 127.5 + 127.5 | |
# vis = vis.clip(0, 255).astype(np.uint8) | |
# Image.fromarray(vis).save( | |
# f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
logger.log('log vis to: ', | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
# self.writer.add_image(f'images', | |
# vis, | |
# self.step + self.resume_step, | |
# dataformats='HWC') | |
# return pred | |
class TrainLoop3DRecNVPatch(TrainLoop3DRecNV): | |
# add patch rendering | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
# the rendrer | |
self.eg3d_model = self.rec_model.module.decoder.triplane_decoder # type: ignore | |
# self.rec_cano = False | |
self.rec_cano = True | |
def forward_backward(self, batch, *args, **kwargs): | |
# add patch sampling | |
self.mp_trainer_rec.zero_grad() | |
batch_size = batch['img_to_encoder'].shape[0] | |
for i in range(0, batch_size, self.microbatch): | |
micro = { | |
k: v[i:i + self.microbatch].to(dist_util.dev()) | |
for k, v in batch.items() | |
} | |
# ! sample rendering patch | |
target = { | |
**self.eg3d_model( | |
c=micro['nv_c'], # type: ignore | |
ws=None, | |
planes=None, | |
sample_ray_only=True, | |
fg_bbox=micro['nv_bbox']), # rays o / dir | |
} | |
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ | |
'patch_rendering_resolution'] # type: ignore | |
cropped_target = { | |
k: | |
th.empty_like(v) | |
[..., :patch_rendering_resolution, :patch_rendering_resolution] | |
if k not in [ | |
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', | |
'nv_img_sr', 'c' | |
] else v | |
for k, v in micro.items() | |
} | |
# crop according to uv sampling | |
for j in range(micro['img'].shape[0]): | |
top, left, height, width = target['ray_bboxes'][ | |
j] # list of tuple | |
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
# target[key][i:i+1] = torchvision.transforms.functional.crop( | |
# cropped_target[key][ | |
# j:j + 1] = torchvision.transforms.functional.crop( | |
# micro[key][j:j + 1], top, left, height, width) | |
cropped_target[f'{key}'][ # ! no nv_ here | |
j:j + 1] = torchvision.transforms.functional.crop( | |
micro[f'nv_{key}'][j:j + 1], top, left, height, | |
width) | |
# target.update(cropped_target) | |
# wrap forward within amp | |
with th.autocast(device_type='cuda', | |
dtype=th.float16, | |
enabled=self.mp_trainer_rec.use_amp): | |
# target_nvs = {} | |
# target_cano = {} | |
latent = self.rec_model(img=micro['img_to_encoder'], | |
behaviour='enc_dec_wo_triplane') | |
pred_nv = self.rec_model( | |
latent=latent, | |
c=micro['nv_c'], # predict novel view here | |
behaviour='triplane_dec', | |
ray_origins=target['ray_origins'], | |
ray_directions=target['ray_directions'], | |
) | |
# ! directly retrieve from target | |
# for k, v in target.items(): | |
# if k[:2] == 'nv': | |
# orig_key = k.replace('nv_', '') | |
# target_nvs[orig_key] = v | |
# target_cano[orig_key] = target[orig_key] | |
with self.rec_model.no_sync(): # type: ignore | |
loss, loss_dict, _ = self.loss_class(pred_nv, | |
cropped_target, | |
step=self.step + | |
self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None) | |
log_rec3d_loss_dict(loss_dict) | |
if self.rec_cano: | |
cano_target = { | |
**self.eg3d_model( | |
c=micro['c'], # type: ignore | |
ws=None, | |
planes=None, | |
sample_ray_only=True, | |
fg_bbox=micro['bbox']), # rays o / dir | |
} | |
cano_cropped_target = { | |
k: th.empty_like(v) | |
for k, v in cropped_target.items() | |
} | |
for j in range(micro['img'].shape[0]): | |
top, left, height, width = cano_target['ray_bboxes'][ | |
j] # list of tuple | |
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
for key in ('img', 'depth_mask', | |
'depth'): # type: ignore | |
# target[key][i:i+1] = torchvision.transforms.functional.crop( | |
cano_cropped_target[key][ | |
j:j + | |
1] = torchvision.transforms.functional.crop( | |
micro[key][j:j + 1], top, left, height, | |
width) | |
# cano_target.update(cano_cropped_target) | |
pred_cano = self.rec_model( | |
latent=latent, | |
c=micro['c'], | |
behaviour='triplane_dec', | |
ray_origins=cano_target['ray_origins'], | |
ray_directions=cano_target['ray_directions'], | |
) | |
with self.rec_model.no_sync(): # type: ignore | |
fg_mask = cano_cropped_target['depth_mask'].unsqueeze( | |
1).repeat_interleave(3, 1).float() | |
loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( | |
pred_cano['image_raw'], | |
cano_cropped_target['img'], | |
fg_mask, | |
step=self.step + self.resume_step, | |
test_mode=False, | |
) | |
loss = loss + loss_cano | |
# remove redundant log | |
log_rec3d_loss_dict({ | |
f'cano_{k}': v | |
for k, v in loss_cano_dict.items() | |
# if "loss" in k | |
}) | |
self.mp_trainer_rec.backward(loss) | |
if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
self.log_patch_img(cropped_target, pred_nv, pred_cano) | |
def log_patch_img(self, micro, pred, pred_cano): | |
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
def norm_depth(pred_depth): # to [-1,1] | |
# pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
pred_depth.min()) | |
return -(pred_depth * 2 - 1) | |
pred_img = pred['image_raw'] | |
gt_img = micro['img'] | |
# infer novel view also | |
# if self.loss_class.opt.symmetry_loss: | |
# pred_nv_img = nvs_pred | |
# else: | |
# ! replace with novel view prediction | |
# ! log another novel-view prediction | |
# pred_nv_img = self.rec_model( | |
# img=micro['img_to_encoder'], | |
# c=self.novel_view_poses) # pred: (B, 3, 64, 64) | |
# if 'depth' in micro: | |
gt_depth = micro['depth'] | |
if gt_depth.ndim == 3: | |
gt_depth = gt_depth.unsqueeze(1) | |
gt_depth = norm_depth(gt_depth) | |
# gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
# gt_depth.min()) | |
# if True: | |
fg_mask = pred['image_mask'] * 2 - 1 # 0-1 | |
input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1 | |
if 'image_depth' in pred: | |
pred_depth = norm_depth(pred['image_depth']) | |
pred_cano_depth = norm_depth(pred_cano['image_depth']) | |
else: | |
pred_depth = th.zeros_like(gt_depth) | |
pred_cano_depth = th.zeros_like(gt_depth) | |
# if 'image_sr' in pred: | |
# if pred['image_sr'].shape[-1] == 512: | |
# pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], | |
# dim=-1) | |
# gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], | |
# dim=-1) | |
# pred_depth = self.pool_512(pred_depth) | |
# gt_depth = self.pool_512(gt_depth) | |
# elif pred['image_sr'].shape[-1] == 256: | |
# pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], | |
# dim=-1) | |
# gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], | |
# dim=-1) | |
# pred_depth = self.pool_256(pred_depth) | |
# gt_depth = self.pool_256(gt_depth) | |
# else: | |
# pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], | |
# dim=-1) | |
# gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], | |
# dim=-1) | |
# gt_depth = self.pool_128(gt_depth) | |
# pred_depth = self.pool_128(pred_depth) | |
# else: | |
# gt_img = self.pool_64(gt_img) | |
# gt_depth = self.pool_64(gt_depth) | |
pred_vis = th.cat([ | |
pred_img, | |
pred_depth.repeat_interleave(3, dim=1), | |
fg_mask.repeat_interleave(3, dim=1), | |
], | |
dim=-1) # B, 3, H, W | |
pred_vis_nv = th.cat([ | |
pred_cano['image_raw'], | |
pred_cano_depth.repeat_interleave(3, dim=1), | |
input_fg_mask.repeat_interleave(3, dim=1), | |
], | |
dim=-1) # B, 3, H, W | |
pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim | |
gt_vis = th.cat([ | |
gt_img, | |
gt_depth.repeat_interleave(3, dim=1), | |
th.zeros_like(gt_img) | |
], | |
dim=-1) # TODO, fail to load depth. range [0, 1] | |
# if 'conf_sigma' in pred: | |
# gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder | |
# vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
# st() | |
vis = th.cat([gt_vis, pred_vis], dim=-2) | |
# .permute( | |
# 0, 2, 3, 1).cpu() | |
vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // | |
64) # HWC | |
torchvision.utils.save_image( | |
vis_tensor, | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
value_range=(-1, 1), | |
normalize=True) | |
logger.log('log vis to: ', | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
# self.writer.add_image(f'images', | |
# vis, | |
# self.step + self.resume_step, | |
# dataformats='HWC') | |
class TrainLoop3DRecNVPatchSingleForward(TrainLoop3DRecNVPatch): | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
def forward_backward(self, batch, *args, **kwargs): | |
# add patch sampling | |
self.mp_trainer_rec.zero_grad() | |
batch_size = batch['img_to_encoder'].shape[0] | |
batch.pop('caption') # not required | |
batch.pop('ins') # not required | |
batch.pop('nv_caption') # not required | |
batch.pop('nv_ins') # not required | |
for i in range(0, batch_size, self.microbatch): | |
micro = { | |
k: | |
v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( | |
v, th.Tensor) else v[i:i + self.microbatch] | |
for k, v in batch.items() | |
} | |
# ! sample rendering patch | |
target = { | |
**self.eg3d_model( | |
c=micro['nv_c'], # type: ignore | |
ws=None, | |
planes=None, | |
sample_ray_only=True, | |
fg_bbox=micro['nv_bbox']), # rays o / dir | |
} | |
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ | |
'patch_rendering_resolution'] # type: ignore | |
cropped_target = { | |
k: | |
th.empty_like(v) | |
[..., :patch_rendering_resolution, :patch_rendering_resolution] | |
if k not in [ | |
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', | |
'nv_img_sr', 'c', 'caption', 'nv_caption' | |
] else v | |
for k, v in micro.items() | |
} | |
# crop according to uv sampling | |
for j in range(micro['img'].shape[0]): | |
top, left, height, width = target['ray_bboxes'][ | |
j] # list of tuple | |
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
# target[key][i:i+1] = torchvision.transforms.functional.crop( | |
# cropped_target[key][ | |
# j:j + 1] = torchvision.transforms.functional.crop( | |
# micro[key][j:j + 1], top, left, height, width) | |
cropped_target[f'{key}'][ # ! no nv_ here | |
j:j + 1] = torchvision.transforms.functional.crop( | |
micro[f'nv_{key}'][j:j + 1], top, left, height, | |
width) | |
# ! cano view loss | |
# cano_target = { | |
# **self.eg3d_model( | |
# c=micro['c'], # type: ignore | |
# ws=None, | |
# planes=None, | |
# sample_ray_only=True, | |
# fg_bbox=micro['bbox']), # rays o / dir | |
# } | |
# cano_cropped_target = { | |
# k: th.empty_like(v) | |
# for k, v in cropped_target.items() | |
# } | |
# for j in range(micro['img'].shape[0]): | |
# top, left, height, width = cano_target['ray_bboxes'][ | |
# j] # list of tuple | |
# # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
# for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
# # target[key][i:i+1] = torchvision.transforms.functional.crop( | |
# cano_cropped_target[key][ | |
# j:j + 1] = torchvision.transforms.functional.crop( | |
# micro[key][j:j + 1], top, left, height, width) | |
# ! vit no amp | |
latent = self.rec_model(img=micro['img_to_encoder'].to(self.dtype), | |
behaviour='enc_dec_wo_triplane') | |
# wrap forward within amp | |
with th.autocast(device_type='cuda', | |
# dtype=th.float16, | |
dtype=th.bfloat16, # avoid NAN | |
enabled=self.mp_trainer_rec.use_amp): | |
# c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here | |
# c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here | |
instance_mv_num = batch_size // 4 # 4 pairs by default | |
# instance_mv_num = 4 | |
# ! roll views for multi-view supervision | |
c = th.cat([ | |
micro['nv_c'].roll(instance_mv_num * i, dims=0) | |
for i in range(1, 4) | |
] | |
# + [micro['c']] | |
) # predict novel view here | |
ray_origins = th.cat( | |
[ | |
target['ray_origins'].roll(instance_mv_num * i, dims=0) | |
for i in range(1, 4) | |
] | |
# + [cano_target['ray_origins'] ] | |
, | |
0) | |
ray_directions = th.cat([ | |
target['ray_directions'].roll(instance_mv_num * i, dims=0) | |
for i in range(1, 4) | |
] | |
# + [cano_target['ray_directions'] ] | |
) | |
pred_nv_cano = self.rec_model( | |
# latent=latent.expand(2,), | |
latent={ | |
'latent_after_vit': # ! triplane for rendering | |
# latent['latent_after_vit'].repeat(2, 1, 1, 1) | |
latent['latent_after_vit'].repeat(3, 1, 1, 1) | |
}, | |
c=c, | |
behaviour='triplane_dec', | |
# ray_origins=target['ray_origins'], | |
# ray_directions=target['ray_directions'], | |
ray_origins=ray_origins, | |
ray_directions=ray_directions, | |
) | |
pred_nv_cano.update( | |
latent | |
) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True) | |
# gt = { | |
# k: th.cat([v, cano_cropped_target[k]], 0) | |
# for k, v in cropped_target.items() | |
# } | |
gt = { | |
k: | |
th.cat( | |
[ | |
v.roll(instance_mv_num * i, dims=0) | |
for i in range(1, 4) | |
] | |
# + [cano_cropped_target[k] ] | |
, | |
0) | |
for k, v in cropped_target.items() | |
} # torchvision.utils.save_image(gt['img'], 'gt.png', normalize=True) | |
with self.rec_model.no_sync(): # type: ignore | |
loss, loss_dict, _ = self.loss_class( | |
pred_nv_cano, | |
gt, # prepare merged data | |
step=self.step + self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None) | |
log_rec3d_loss_dict(loss_dict) | |
self.mp_trainer_rec.backward(loss) | |
# for name, p in self.rec_model.named_parameters(): | |
# if p.grad is None: | |
# logger.log(f"found rec unused param: {name}") | |
if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
micro_bs = micro['img_to_encoder'].shape[0] | |
self.log_patch_img( # record one cano view and one novel view | |
cropped_target, | |
{ | |
k: pred_nv_cano[k][-micro_bs:] | |
for k in ['image_raw', 'image_depth', 'image_mask'] | |
}, | |
{ | |
k: pred_nv_cano[k][:micro_bs] | |
for k in ['image_raw', 'image_depth', 'image_mask'] | |
}, | |
) | |
def eval_loop(self): | |
return super().eval_loop() | |
# def eval_loop(self, c_list:list): | |
def eval_novelview_loop_old(self, camera=None): | |
# novel view synthesis given evaluation camera trajectory | |
all_loss_dict = [] | |
novel_view_micro = {} | |
# ! randomly inference an instance | |
export_mesh = True | |
if export_mesh: | |
Path(f'{logger.get_dir()}/FID_Cals/').mkdir(parents=True, | |
exist_ok=True) | |
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
batch = {} | |
# if camera is not None: | |
# # batch['c'] = camera.to(batch['c'].device()) | |
# batch['c'] = camera.clone() | |
# else: | |
# batch = | |
for eval_idx, render_reference in enumerate(tqdm(self.eval_data)): | |
if eval_idx > 500: | |
break | |
video_out = imageio.get_writer( | |
f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_{eval_idx}.mp4', | |
mode='I', | |
fps=25, | |
codec='libx264') | |
with open( | |
f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{eval_idx}_caption.txt', | |
'w') as f: | |
f.write(render_reference['caption']) | |
for key in ['ins', 'bbox', 'caption']: | |
if key in render_reference: | |
render_reference.pop(key) | |
real_flag = False | |
mv_flag = False # TODO, use full-instance for evaluation? Calculate the metrics. | |
if render_reference['c'].shape[:2] == (1, 40): | |
real_flag = True | |
# real img monocular reconstruction | |
# compat lst for enumerate | |
render_reference = [{ | |
k: v[0][idx:idx + 1] | |
for k, v in render_reference.items() | |
} for idx in range(40)] | |
elif render_reference['c'].shape[0] == 8: | |
mv_flag = True | |
render_reference = { | |
k: v[:4] | |
for k, v in render_reference.items() | |
} | |
# save gt | |
torchvision.utils.save_image( | |
render_reference[0:4]['img'], | |
logger.get_dir() + '/FID_Cals/{}_inp.png'.format(eval_idx), | |
padding=0, | |
normalize=True, | |
value_range=(-1, 1), | |
) | |
# torchvision.utils.save_image(render_reference[4:8]['img'], | |
# logger.get_dir() + '/FID_Cals/{}_inp2.png'.format(eval_idx), | |
# padding=0, | |
# normalize=True, | |
# value_range=(-1,1), | |
# ) | |
else: | |
# compat lst for enumerate | |
st() | |
render_reference = [{ | |
k: v[idx:idx + 1] | |
for k, v in render_reference.items() | |
} for idx in range(40)] | |
# ! single-view version | |
render_reference[0]['img_to_encoder'] = render_reference[14][ | |
'img_to_encoder'] # encode side view | |
render_reference[0]['img'] = render_reference[14][ | |
'img'] # encode side view | |
# save gt | |
torchvision.utils.save_image( | |
render_reference[0]['img'], | |
logger.get_dir() + '/FID_Cals/{}_gt.png'.format(eval_idx), | |
padding=0, | |
normalize=True, | |
value_range=(-1, 1)) | |
# ! TODO, merge with render_video_given_triplane later | |
for i, batch in enumerate(render_reference): | |
# for i in range(0, 8, self.microbatch): | |
# c = c_list[i].to(dist_util.dev()).reshape(1, -1) | |
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} | |
st() | |
if i == 0: | |
if mv_flag: | |
novel_view_micro = None | |
else: | |
novel_view_micro = { | |
k: | |
v[0:1].to(dist_util.dev()).repeat_interleave( | |
# v[14:15].to(dist_util.dev()).repeat_interleave( | |
micro['img'].shape[0], | |
0) if isinstance(v, th.Tensor) else v[0:1] | |
for k, v in batch.items() | |
} | |
else: | |
if i == 1: | |
# ! output mesh | |
if export_mesh: | |
# ! get planes first | |
# self.latent_name = 'latent_normalized' # normalized triplane latent | |
# ddpm_latent = { | |
# self.latent_name: planes, | |
# } | |
# ddpm_latent.update(self.rec_model(latent=ddpm_latent, behaviour='decode_after_vae_no_render')) | |
# mesh_size = 512 | |
# mesh_size = 256 | |
mesh_size = 384 | |
# mesh_size = 320 | |
# mesh_thres = 3 # TODO, requires tuning | |
# mesh_thres = 5 # TODO, requires tuning | |
mesh_thres = 10 # TODO, requires tuning | |
import mcubes | |
import trimesh | |
dump_path = f'{logger.get_dir()}/mesh/' | |
os.makedirs(dump_path, exist_ok=True) | |
grid_out = self.rec_model( | |
latent=pred, | |
grid_size=mesh_size, | |
behaviour='triplane_decode_grid', | |
) | |
vtx, faces = mcubes.marching_cubes( | |
grid_out['sigma'].squeeze(0).squeeze( | |
-1).cpu().numpy(), mesh_thres) | |
vtx = vtx / (mesh_size - 1) * 2 - 1 | |
# vtx_tensor = th.tensor(vtx, dtype=th.float32, device=dist_util.dev()).unsqueeze(0) | |
# vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1) | |
# vtx_colors = (vtx_colors * 255).astype(np.uint8) | |
# mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) | |
mesh = trimesh.Trimesh( | |
vertices=vtx, | |
faces=faces, | |
) | |
mesh_dump_path = os.path.join( | |
dump_path, f'{eval_idx}.ply') | |
mesh.export(mesh_dump_path, 'ply') | |
print(f"Mesh dumped to {dump_path}") | |
del grid_out, mesh | |
th.cuda.empty_cache() | |
# return | |
# st() | |
# if novel_view_micro['c'].shape[0] < micro['img'].shape[0]: | |
novel_view_micro = { | |
k: | |
v[0:1].to(dist_util.dev()).repeat_interleave( | |
micro['img'].shape[0], 0) | |
for k, v in novel_view_micro.items() | |
} | |
pred = self.rec_model(img=novel_view_micro['img_to_encoder'], | |
c=micro['c']) # pred: (B, 3, 64, 64) | |
# target = { | |
# 'img': micro['img'], | |
# 'depth': micro['depth'], | |
# 'depth_mask': micro['depth_mask'] | |
# } | |
# targe | |
# if not export_mesh: | |
if not real_flag: | |
_, loss_dict = self.loss_class(pred, micro, test_mode=True) | |
all_loss_dict.append(loss_dict) | |
# ! move to other places, add tensorboard | |
# pred_vis = th.cat([ | |
# pred['image_raw'], | |
# -pred['image_depth'].repeat_interleave(3, dim=1) | |
# ], | |
# dim=-1) | |
# normalize depth | |
# if True: | |
pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / ( | |
pred_depth.max() - pred_depth.min()) | |
if 'image_sr' in pred: | |
if pred['image_sr'].shape[-1] == 512: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_512(pred['image_raw']), pred['image_sr'], | |
self.pool_512(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
elif pred['image_sr'].shape[-1] == 256: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_256(pred['image_raw']), pred['image_sr'], | |
self.pool_256(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
else: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred['image_sr']), | |
self.pool_128(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
else: | |
# pred_vis = th.cat([ | |
# self.pool_64(micro['img']), pred['image_raw'], | |
# pred_depth.repeat_interleave(3, dim=1) | |
# ], | |
# dim=-1) # B, 3, H, W | |
pooled_depth = self.pool_128(pred_depth).repeat_interleave( | |
3, dim=1) | |
pred_vis = th.cat( | |
[ | |
# self.pool_128(micro['img']), | |
self.pool_128(novel_view_micro['img'] | |
), # use the input here | |
self.pool_128(pred['image_raw']), | |
pooled_depth, | |
], | |
dim=-1) # B, 3, H, W | |
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
vis = vis * 127.5 + 127.5 | |
vis = vis.clip(0, 255).astype(np.uint8) | |
if export_mesh: | |
# save image | |
torchvision.utils.save_image( | |
pred['image_raw'], | |
logger.get_dir() + | |
'/FID_Cals/{}_{}.png'.format(eval_idx, i), | |
padding=0, | |
normalize=True, | |
value_range=(-1, 1)) | |
torchvision.utils.save_image( | |
pooled_depth, | |
logger.get_dir() + | |
'/FID_Cals/{}_{}_dpeth.png'.format(eval_idx, i), | |
padding=0, | |
normalize=True, | |
value_range=(0, 1)) | |
# st() | |
for j in range(vis.shape[0]): | |
video_out.append_data(vis[j]) | |
video_out.close() | |
# if not export_mesh: | |
if not real_flag or mv_flag: | |
val_scores_for_logging = calc_average_loss(all_loss_dict) | |
with open(os.path.join(logger.get_dir(), 'scores_novelview.json'), | |
'a') as f: | |
json.dump({'step': self.step, **val_scores_for_logging}, f) | |
# * log to tensorboard | |
for k, v in val_scores_for_logging.items(): | |
self.writer.add_scalar(f'Eval/NovelView/{k}', v, | |
self.step + self.resume_step) | |
del video_out | |
# del pred_vis | |
# del pred | |
th.cuda.empty_cache() | |
# def eval_loop(self, c_list:list): | |
def eval_novelview_loop(self, camera=None, save_latent=False): | |
# novel view synthesis given evaluation camera trajectory | |
if save_latent: # for diffusion learning | |
latent_dir = Path(f'{logger.get_dir()}/latent_dir') | |
latent_dir.mkdir(exist_ok=True, parents=True) | |
# wds_path = os.path.join(logger.get_dir(), 'latent_dir', | |
# f'wds-%06d.tar') | |
# sink = wds.ShardWriter(wds_path, start_shard=0) | |
# eval_batch_size = 20 | |
# eval_batch_size = 1 | |
eval_batch_size = 40 # ! for i23d | |
latent_rec_statistics = False | |
for eval_idx, micro in enumerate(tqdm(self.eval_data)): | |
# if eval_idx > 500: | |
# break | |
latent = self.rec_model( | |
img=micro['img_to_encoder'], | |
behaviour='encoder_vae') # pred: (B, 3, 64, 64) | |
# torchvision.utils.save_image(micro['img'], 'inp.jpg') | |
if micro['img'].shape[0] == 40: | |
assert eval_batch_size == 40 | |
if save_latent: | |
# np.save(f'{logger.get_dir()}/latent_dir/{eval_idx}.npy', latent[self.latent_name].cpu().numpy()) | |
latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}' | |
Path(latent_save_dir).mkdir(parents=True, exist_ok=True) | |
np.save(f'{latent_save_dir}/latent.npy', | |
latent[self.latent_name][0].cpu().numpy()) | |
assert all([ | |
micro['ins'][0] == micro['ins'][i] | |
for i in range(micro['c'].shape[0]) | |
]) # ! assert same instance | |
# for i in range(micro['img'].shape[0]): | |
# compressed_sample = { | |
# 'latent':latent[self.latent_name][0].cpu().numpy(), # 12 32 32 | |
# 'caption': micro['caption'][0].encode('utf-8'), | |
# 'ins': micro['ins'][0].encode('utf-8'), | |
# 'c': micro['c'][i].cpu().numpy(), | |
# 'img': micro['img'][i].cpu().numpy() # 128x128, for diffusion log | |
# } | |
# sink.write({ | |
# "__key__": f"sample_{eval_idx*eval_batch_size+i:07d}", | |
# 'sample.pyd': compressed_sample | |
# }) | |
if latent_rec_statistics: | |
gen_imgs = self.render_video_given_triplane( | |
latent[self.latent_name], | |
self.rec_model, # compatible with join_model | |
name_prefix=f'{self.step + self.resume_step}_{eval_idx}', | |
save_img=False, | |
render_reference={'c': micro['c']}, | |
save_mesh=False, | |
render_reference_length=4, | |
return_gen_imgs=True) | |
rec_psnr = psnr((micro['img'] / 2 + 0.5), | |
(gen_imgs.cpu() / 2 + 0.5), 1.0) | |
with open( | |
os.path.join(logger.get_dir(), | |
'four_view_rec_psnr.json'), 'a') as f: | |
json.dump( | |
{ | |
f'{eval_idx}': { | |
'ins': micro["ins"][0], | |
'psnr': rec_psnr.item(), | |
} | |
}, f) | |
# save to json | |
elif eval_idx < 30: | |
# if False: | |
self.render_video_given_triplane( | |
latent[self.latent_name], | |
self.rec_model, # compatible with join_model | |
name_prefix=f'{self.step + self.resume_step}_{micro["ins"][0].split("/")[0]}_{eval_idx}', | |
save_img=False, | |
render_reference={'c': camera}, | |
save_mesh=True) | |
class TrainLoop3DRecNVPatchSingleForwardMV(TrainLoop3DRecNVPatchSingleForward): | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
def forward_backward(self, batch, behaviour='g_step', *args, **kwargs): | |
# add patch sampling | |
if behaviour == 'g_step': | |
self.mp_trainer_rec.zero_grad() | |
else: | |
self.mp_trainer_disc.zero_grad() | |
batch_size = batch['img_to_encoder'].shape[0] | |
batch.pop('caption') # not required | |
batch.pop('ins') # not required | |
batch.pop('nv_caption') # not required | |
batch.pop('nv_ins') # not required | |
if '__key__' in batch.keys(): | |
batch.pop('__key__') | |
for i in range(0, batch_size, self.microbatch): | |
micro = { | |
k: | |
v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( | |
v, th.Tensor) else v[i:i + self.microbatch] | |
for k, v in batch.items() | |
} | |
# ! sample rendering patch | |
# nv_c = th.cat([micro['nv_c'], micro['c']]) | |
nv_c = th.cat([micro['nv_c'], micro['c']]) | |
# nv_c = micro['nv_c'] | |
target = { | |
**self.eg3d_model( | |
c=nv_c, # type: ignore | |
ws=None, | |
planes=None, | |
sample_ray_only=True, | |
fg_bbox=th.cat([micro['nv_bbox'], micro['bbox']])), # rays o / dir | |
} | |
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ | |
'patch_rendering_resolution'] # type: ignore | |
cropped_target = { | |
k: | |
th.empty_like(v).repeat_interleave(2, 0) | |
# th.empty_like(v).repeat_interleave(1, 0) | |
[..., :patch_rendering_resolution, :patch_rendering_resolution] | |
if k not in [ | |
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', | |
'nv_img_sr', 'c', 'caption', 'nv_caption' | |
] else v | |
for k, v in micro.items() | |
} | |
# crop according to uv sampling | |
for j in range(2 * self.microbatch): | |
top, left, height, width = target['ray_bboxes'][ | |
j] # list of tuple | |
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
if j < self.microbatch: | |
cropped_target[f'{key}'][ # ! no nv_ here | |
j:j + 1] = torchvision.transforms.functional.crop( | |
micro[f'nv_{key}'][j:j + 1], top, left, height, | |
width) | |
else: | |
cropped_target[f'{key}'][ # ! no nv_ here | |
j:j + 1] = torchvision.transforms.functional.crop( | |
micro[f'{key}'][j - self.microbatch:j - | |
self.microbatch + 1], top, | |
left, height, width) | |
# for j in range(batch_size, 2*batch_size, 1): | |
# top, left, height, width = target['ray_bboxes'][ | |
# j] # list of tuple | |
# # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
# for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
# cropped_target[f'{key}'][ # ! no nv_ here | |
# j:j + 1] = torchvision.transforms.functional.crop( | |
# micro[f'{key}'][j-batch_size:j-batch_size + 1], top, left, height, | |
# width) | |
# wrap forward within amp | |
with th.autocast(device_type='cuda', | |
dtype=self.dtype, | |
enabled=self.mp_trainer_rec.use_amp): | |
# c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here | |
# c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here | |
# instance_mv_num = batch_size // 4 # 4 pairs by default | |
# instance_mv_num = 4 | |
# ! roll views for multi-view supervision | |
# c = micro['nv_c'] | |
# ! vit no amp | |
latent = self.rec_model(img=micro['img_to_encoder'].to(self.dtype), | |
behaviour='enc_dec_wo_triplane') | |
# # ! disable amp in rendering and loss | |
# with th.autocast(device_type='cuda', | |
# dtype=th.float16, | |
# enabled=False): | |
ray_origins = target['ray_origins'] | |
ray_directions = target['ray_directions'] | |
pred_nv_cano = self.rec_model( | |
# latent=latent.expand(2,), | |
latent={ | |
'latent_after_vit': # ! triplane for rendering | |
# latent['latent_after_vit'].repeat_interleave(4, dim=0).repeat(2,1,1,1) # NV=4 | |
latent['latent_after_vit'].repeat_interleave(6, dim=0).repeat(2,1,1,1) # NV=6 | |
# latent['latent_after_vit'].repeat_interleave(10, dim=0).repeat(2,1,1,1) # NV=4 | |
# latent['latent_after_vit'].repeat_interleave(8, dim=0) # NV=4 | |
}, | |
c=nv_c, | |
behaviour='triplane_dec', | |
ray_origins=ray_origins, | |
ray_directions=ray_directions, | |
) | |
pred_nv_cano.update( | |
latent | |
) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True) | |
gt = cropped_target | |
with self.rec_model.no_sync(): # type: ignore | |
loss, loss_dict, _ = self.loss_class( | |
pred_nv_cano, | |
gt, # prepare merged data | |
step=self.step + self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
behaviour=behaviour, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None, | |
# dtype=self.dtype | |
) | |
log_rec3d_loss_dict(loss_dict) | |
if behaviour == 'g_step': | |
self.mp_trainer_rec.backward(loss) | |
else: | |
self.mp_trainer_disc.backward(loss) | |
# for name, p in self.rec_model.named_parameters(): | |
# if p.grad is None: | |
# logger.log(f"found rec unused param: {name}") | |
# torchvision.utils.save_image(cropped_target['img'], 'gt.png', normalize=True) | |
# torchvision.utils.save_image( pred_nv_cano['image_raw'], 'pred.png', normalize=True) | |
if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0 and behaviour == 'g_step': | |
try: | |
torchvision.utils.save_image( | |
th.cat( | |
[cropped_target['img'], pred_nv_cano['image_raw'] | |
], ), | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
normalize=True, nrow=6*2) | |
logger.log( | |
'log vis to: ', | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
except Exception as e: | |
logger.log(e) | |
# micro_bs = micro['img_to_encoder'].shape[0] | |
# self.log_patch_img( # record one cano view and one novel view | |
# cropped_target, | |
# { | |
# k: pred_nv_cano[k][0:1] | |
# for k in ['image_raw', 'image_depth', 'image_mask'] | |
# }, | |
# { | |
# k: pred_nv_cano[k][1:2] | |
# for k in ['image_raw', 'image_depth', 'image_mask'] | |
# }, | |
# ) | |
# def save(self): | |
# return super().save() | |
class TrainLoop3DRecNVPatchSingleForwardMVAdvLoss( | |
TrainLoop3DRecNVPatchSingleForwardMV): | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
# create discriminator | |
disc_params = self.loss_class.get_trainable_parameters() | |
self.mp_trainer_disc = MixedPrecisionTrainer( | |
model=self.loss_class.discriminator, | |
use_fp16=self.use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
model_name='disc', | |
use_amp=use_amp, | |
model_params=disc_params) | |
# st() # check self.lr | |
self.opt_disc = AdamW( | |
self.mp_trainer_disc.master_params, | |
lr=self.lr, # follow sd code base | |
betas=(0, 0.999), | |
eps=1e-8) | |
# TODO, is loss cls already in the DDP? | |
if self.use_ddp: | |
self.ddp_disc = DDP( | |
self.loss_class.discriminator, | |
device_ids=[dist_util.dev()], | |
output_device=dist_util.dev(), | |
broadcast_buffers=False, | |
bucket_cap_mb=128, | |
find_unused_parameters=False, | |
) | |
else: | |
self.ddp_disc = self.loss_class.discriminator | |
# def run_st | |
# def run_step(self, batch, *args): | |
# self.forward_backward(batch) | |
# took_step = self.mp_trainer_rec.optimize(self.opt) | |
# if took_step: | |
# self._update_ema() | |
# self._anneal_lr() | |
# self.log_step() | |
def save(self, mp_trainer=None, model_name='rec'): | |
if mp_trainer is None: | |
mp_trainer = self.mp_trainer_rec | |
def save_checkpoint(rate, params): | |
state_dict = mp_trainer.master_params_to_state_dict(params) | |
if dist_util.get_rank() == 0: | |
logger.log(f"saving model {model_name} {rate}...") | |
if not rate: | |
filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt" | |
else: | |
filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt" | |
with bf.BlobFile(bf.join(get_blob_logdir(), filename), | |
"wb") as f: | |
th.save(state_dict, f) | |
save_checkpoint(0, mp_trainer.master_params) | |
dist.barrier() | |
# ! load disc | |
def _load_and_sync_parameters(self, submodule_name=''): | |
super()._load_and_sync_parameters(submodule_name) | |
# load disc | |
resume_checkpoint = self.resume_checkpoint.replace( | |
'rec', 'disc') # * default behaviour | |
if os.path.exists(resume_checkpoint): | |
if dist_util.get_rank() == 0: | |
logger.log( | |
f"loading disc model from checkpoint: {resume_checkpoint}..." | |
) | |
map_location = { | |
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() | |
} # configure map_location properly | |
resume_state_dict = dist_util.load_state_dict( | |
resume_checkpoint, map_location=map_location) | |
model_state_dict = self.loss_class.discriminator.state_dict() | |
for k, v in resume_state_dict.items(): | |
if k in model_state_dict.keys(): | |
if v.size() == model_state_dict[k].size(): | |
model_state_dict[k] = v | |
# model_state_dict[k].copy_(v) | |
else: | |
logger.log('!!!! partially load: ', k, ": ", | |
v.size(), "state_dict: ", | |
model_state_dict[k].size()) | |
if dist_util.get_world_size() > 1: | |
# dist_util.sync_params(self.model.named_parameters()) | |
dist_util.sync_params( | |
self.loss_class.get_trainable_parameters()) | |
logger.log('synced disc params') | |
def run_step(self, batch, step='g_step'): | |
# self.forward_backward(batch) | |
if step == 'g_step': | |
self.forward_backward(batch, behaviour='g_step') | |
took_step_g_rec = self.mp_trainer_rec.optimize(self.opt) | |
if took_step_g_rec: | |
self._update_ema() # g_ema | |
elif step == 'd_step': | |
self.forward_backward(batch, behaviour='d_step') | |
_ = self.mp_trainer_disc.optimize(self.opt_disc) | |
self._anneal_lr() | |
self.log_step() | |
def run_loop(self, batch=None): | |
while (not self.lr_anneal_steps | |
or self.step + self.resume_step < self.lr_anneal_steps): | |
batch = next(self.data) | |
self.run_step(batch, 'g_step') | |
batch = next(self.data) | |
self.run_step(batch, 'd_step') | |
if self.step % 1000 == 0: | |
dist_util.synchronize() | |
if self.step % 5000 == 0: | |
th.cuda.empty_cache() # avoid memory leak | |
if self.step % self.log_interval == 0 and dist_util.get_rank( | |
) == 0: | |
out = logger.dumpkvs() | |
# * log to tensorboard | |
for k, v in out.items(): | |
self.writer.add_scalar(f'Loss/{k}', v, | |
self.step + self.resume_step) | |
if self.step % self.eval_interval == 0 and self.step != 0: | |
if dist_util.get_rank() == 0: | |
try: | |
self.eval_loop() | |
except Exception as e: | |
logger.log(e) | |
dist_util.synchronize() | |
# if self.step % self.save_interval == 0 and self.step != 0: | |
if self.step % self.save_interval == 0: | |
self.save() | |
self.save(self.mp_trainer_disc, | |
self.mp_trainer_disc.model_name) | |
dist_util.synchronize() | |
# Run for a finite amount of time in integration tests. | |
if os.environ.get("DIFFUSION_TRAINING_TEST", | |
"") and self.step > 0: | |
return | |
self.step += 1 | |
if self.step > self.iterations: | |
logger.log('reached maximum iterations, exiting') | |
# Save the last checkpoint if it wasn't already saved. | |
if (self.step - | |
1) % self.save_interval != 0 and self.step != 1: | |
self.save() | |
exit() | |
# Save the last checkpoint if it wasn't already saved. | |
# if (self.step - 1) % self.save_interval != 0 and self.step != 1: | |
if (self.step - 1) % self.save_interval != 0: | |
self.save() # save rec | |
self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name) | |
class TrainLoop3DRecNVPatchSingleForwardMV_NoCrop( | |
TrainLoop3DRecNVPatchSingleForwardMV): | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
num_frames=4, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
self.num_frames = num_frames | |
self.ray_sampler = RaySampler() | |
print(self.opt) | |
# ! requires tuning | |
N = 768 # hyp param, overfitting now | |
# self.scale_expected_threshold = (1 / (N/2)) ** 0.5 * 0.45 | |
self.scale_expected_threshold = 0.0075 | |
self.latent_name = 'latent_normalized' # normalized triplane latent | |
# to transform to 3dgs | |
self.gs_bg_color=th.tensor([1,1,1], dtype=th.float32, device=dist_util.dev()) | |
self.post_process = PostProcess( | |
384, | |
384, | |
imgnet_normalize=True, | |
plucker_embedding=True, | |
decode_encode_img_only=False, | |
mv_input=True, | |
split_chunk_input=16, | |
duplicate_sample=True, | |
append_depth=False, | |
append_xyz=False, | |
gs_cam_format=True, | |
orthog_duplicate=False, | |
frame_0_as_canonical=False, | |
pcd_path='pcd_path', | |
load_pcd=True, | |
split_chunk_size=16, | |
) | |
self.zfar = 100.0 | |
self.znear = 0.01 | |
# def _init_optim_groups(self, kwargs): | |
# return super()._init_optim_groups({**kwargs, 'ignore_encoder': True}) # freeze MVEncoder to accelerate training. | |
def forward_backward(self, batch, behaviour='g_step', *args, **kwargs): | |
# add patch sampling | |
self.mp_trainer_rec.zero_grad() | |
batch_size = batch['img_to_encoder'].shape[0] | |
batch.pop('caption') # not required | |
ins = batch.pop('ins') # not required | |
if '__key__' in batch.keys(): | |
batch.pop('__key__') | |
assert isinstance(batch['c'], dict) | |
for i in range(0, batch_size, self.microbatch): | |
micro = {} | |
for k, v in batch.items(): # grad acc | |
if isinstance(v, th.Tensor): | |
micro[k] = v[i:i + self.microbatch].to(dist_util.dev()) | |
elif isinstance(v, list): | |
micro[k] = v[i:i + self.microbatch] | |
elif isinstance(v, dict): # | |
assert k in ['c', 'nv_c'] | |
micro[k] = { | |
key: | |
value[i:i + self.microbatch].to(dist_util.dev()) if | |
isinstance(value, th.Tensor) else value # can be float | |
for key, value in v.items() | |
} | |
assert micro['img_to_encoder'].shape[1] == 15 | |
micro['normal'] = micro['img_to_encoder'][:, 3:6] | |
micro['nv_normal'] = micro['nv_img_to_encoder'][:, 3:6] | |
# ! concat nv_c to render N+N views | |
indices = np.random.permutation(self.num_frames) | |
indices, indices_nv = indices[:4], indices[-4:] # make sure thorough pose converage. | |
# indices, indices_nv = indices[:2], indices[-2:] # ! 2+2 views for supervision, as in gs-lrm. | |
# indices_nv = np.random.permutation(self.num_frames)[:6] # randomly pick 4+4 views for supervision. | |
# indices = np.arange(self.num_frames) | |
# indices_nv = np.arange(self.num_frames) | |
nv_c = {} | |
for key in micro['c'].keys(): | |
if isinstance(micro['c'][key], th.Tensor): | |
nv_c[key] = th.cat([micro['c'][key][:, indices], micro['nv_c'][key][:, indices_nv]], | |
1) # B 2V ... | |
else: | |
nv_c[key] = micro['c'][key] # float, will remove later | |
target = {} | |
for key in ('img', 'depth_mask', 'depth', 'normal',): # type: ignore | |
# st() | |
target[key] = th.cat([ | |
rearrange(micro[key], '(B V) ... -> B V ...', V=self.num_frames)[:, indices], | |
rearrange(micro[f'nv_{key}'], '(B V) ... -> B V ...', V=self.num_frames)[:, indices_nv] | |
# rearrange(micro[key][:, indices], '(B V) ... -> B V ...', V=4), | |
# rearrange(micro[f'nv_{key}'][:, indices], '(B V) ... -> B V ...', V=4) | |
], 1) # B 2*V H W | |
target[key] = rearrange(target[key], | |
'B V ... -> (B V) ...') # concat | |
# st() | |
# wrap forward within amp | |
with th.autocast(device_type='cuda', | |
dtype=self.dtype, | |
enabled=self.mp_trainer_rec.use_amp): | |
# ! vit no amp | |
# with profile(activities=[ | |
# ProfilerActivity.CUDA], record_shapes=True) as prof: | |
# with record_function("get_gs"): | |
latent = self.rec_model( | |
img=micro['img_to_encoder'].to(self.dtype), | |
behaviour='enc_dec_wo_triplane', | |
c=micro['c'], | |
pcd=micro['fps_pcd'], # send in pcd for surface reference. | |
) # send in input-view C since pixel-aligned gaussians required | |
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | |
gaussians, query_pcd_xyz = latent['gaussians'], latent['query_pcd_xyz'] | |
# query_pcd_xyz = latent['query_pcd_xyz'] | |
# if self.loss_class.opt.rand_aug_bg and random.random()>0.9: | |
if self.loss_class.opt.rand_aug_bg: | |
bg_color=torch.randint(0,255,(3,), device=dist_util.dev()) / 255.0 | |
else: | |
bg_color=torch.tensor([1,1,1], dtype=torch.float32, device=dist_util.dev()) | |
def visualize_latent_activations(latent, b_idx=0, write=False, ): | |
def normalize_latent_plane(latent_plane): | |
avg_p1 = latent_plane.detach().cpu().numpy().mean(0, keepdims=0) | |
avg_p1 = (avg_p1 - avg_p1.min()) / (avg_p1.max() - avg_p1.min()) | |
# return avg_p1 | |
return ((avg_p1).clip(0,1)*255.0).astype(np.uint8) | |
p1, p2, p3 = (normalize_latent_plane(latent_plane) for latent_plane in (latent[b_idx, 0:4], latent[b_idx,4:8], latent[b_idx,8:12])) | |
if write: | |
plt.imsave(os.path.join(logger.get_dir(), f'{self.step}_{b_idx}_1.jpg'), p1) | |
plt.imsave(os.path.join(logger.get_dir(), f'{self.step}_{b_idx}_2.jpg'), p2) | |
plt.imsave(os.path.join(logger.get_dir(), f'{self.step}_{b_idx}_3.jpg'), p3) | |
# imageio.imwrite(os.path.join(logger.get_dir(), f'{b_idx}_1.jpg'), p1) | |
# imageio.imwrite(os.path.join(logger.get_dir(), f'{b_idx}_2.jpg'), p2) | |
# imageio.imwrite(os.path.join(logger.get_dir(), f'{b_idx}_3.jpg'), p3) | |
return p1, p2, p3 | |
# with profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU,], record_shapes=True) as prof: | |
# # ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: | |
# with record_function("rendering"): | |
pred_nv_cano = self.rec_model( | |
latent=latent, | |
# latent={ | |
# 'gaussians': latent['gaussians'].repeat_interleave(2,0) | |
# }, | |
c=nv_c, | |
behaviour='triplane_dec', | |
bg_color=bg_color, | |
) | |
fine_scale_key = list(pred_nv_cano.keys())[-1] | |
# st() | |
fine_gaussians = latent[fine_scale_key] | |
fine_gaussians_opa = fine_gaussians[..., 3:4] | |
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) | |
# st() # torchvision.utils.save_image(pred_nv_cano['image_raw'][0], 'pred.jpg', normalize=True, value_range=(-1,1)) | |
if self.loss_class.opt.rand_aug_bg: | |
# bg_color | |
alpha_mask = target['depth_mask'].float().unsqueeze(1) # B 1 H W | |
target['img'] = target['img'] * alpha_mask + (bg_color.reshape(1,3,1,1) * 2 - 1) * (1-alpha_mask) | |
target['depth_mask'] = target['depth_mask'].unsqueeze(1) | |
target['depth'] = target['depth'].unsqueeze(1) | |
multiscale_target = defaultdict(dict) | |
multiscale_pred = defaultdict(dict) | |
for idx, (gaussian_wavelet_key, gaussian_wavelet) in enumerate(pred_nv_cano.items()): | |
gs_output_size = pred_nv_cano[gaussian_wavelet_key]['image_raw'].shape[-1] | |
for k in gaussian_wavelet.keys(): | |
pred_nv_cano[gaussian_wavelet_key][k] = rearrange( | |
gaussian_wavelet[k], 'B V ... -> (B V) ...') # match GT shape order | |
# if idx == 0: # only KL calculation in scale 0 | |
if gaussian_wavelet_key == fine_scale_key: | |
pred_nv_cano[gaussian_wavelet_key].update( | |
{ | |
k: latent[k] for k in ['posterior'] | |
} | |
) # ! for KL supervision | |
# ! prepare target according to the wavelet size | |
for k in target.keys(): | |
if target[key].shape[-1] == gs_output_size: | |
multiscale_target[gaussian_wavelet_key][k] = target[k] | |
else: | |
if k in ('depth', 'normal'): | |
mode = 'nearest' | |
else: | |
mode='bilinear' | |
multiscale_target[gaussian_wavelet_key][k] = F.interpolate(target[k], size=(gs_output_size, gs_output_size), mode=mode) | |
# st() | |
# st() | |
# torchvision.utils.save_image(target['img'], 'gt.jpg', normalize=True, value_range=(-1,1)) | |
# torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.jpg', normalize=True, value_range=(-1,1)) | |
# torchvision.utils.save_image(micro['img'], 'inp_gt.jpg', normalize=True, value_range=(-1,1)) | |
# torchvision.utils.save_image(micro['nv_img'], 'nv_gt.jpg', normalize=True, value_range=(-1,1)) | |
# if self.loss_class.opt.rand_aug_bg: | |
# # bg_color | |
# alpha_mask = target['depth_mask'].float().unsqueeze(1) # B 1 H W | |
# target['img'] = target['img'] * alpha_mask + (bg_color.reshape(1,3,1,1) * 2 - 1) * (1-alpha_mask) | |
lod_num = len(pred_nv_cano.keys()) | |
random_scale_for_lpips = random.choice(list(pred_nv_cano.keys())) | |
with self.rec_model.no_sync(): # type: ignore | |
# with profile(activities=[ | |
# ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: | |
# with record_function("loss"): | |
loss = th.tensor(0., device=dist_util.dev()) | |
loss_dict = {} | |
if behaviour == 'd_step': | |
loss_scale, loss_dict_scale, _ = self.loss_class( | |
pred_nv_cano[random_scale_for_lpips], | |
multiscale_target[random_scale_for_lpips], # prepare merged data | |
step=self.step + self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
behaviour=behaviour, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None, | |
ignore_kl=True, # only calculate once | |
ignore_lpips=True, # lpips on each lod | |
ignore_d_loss=False) | |
loss = loss + loss_scale | |
loss_dict.update( | |
{ | |
f"{gaussian_wavelet_key.replace('gaussians_', '')}/{loss_key}": loss_v for loss_key, loss_v in loss_dict_scale.items() | |
} | |
) | |
else: | |
for scale_idx, gaussian_wavelet_key in enumerate(pred_nv_cano.keys()): # ! multi-scale gs rendering supervision | |
loss_scale, loss_dict_scale, _ = self.loss_class( | |
pred_nv_cano[gaussian_wavelet_key], | |
multiscale_target[gaussian_wavelet_key], # prepare merged data | |
step=self.step + self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
behaviour=behaviour, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None, | |
ignore_kl=gaussian_wavelet_key!=fine_scale_key, # only calculate once | |
ignore_lpips=gaussian_wavelet_key!=random_scale_for_lpips, # lpips on each lod | |
ignore_d_loss=gaussian_wavelet_key!=fine_scale_key) | |
loss = loss + loss_scale | |
loss_dict.update( | |
{ | |
f"{gaussian_wavelet_key.replace('gaussians_', '')}/{loss_key}": loss_v for loss_key, loss_v in loss_dict_scale.items() | |
} | |
) | |
pos = latent['pos'] | |
opacity = gaussians[..., 3:4] | |
scaling = gaussians[..., 4:6] # 2dgs here | |
if self.step % self.log_interval == 0 and dist_util.get_rank( | |
) == 0: | |
with th.no_grad(): # save idx 0 here | |
try: | |
self.writer.add_histogram("scene/opacity_hist", | |
opacity[0][:], | |
self.step + self.resume_step) | |
self.writer.add_histogram("scene/scale_hist", | |
scaling[0][:], | |
self.step + self.resume_step) | |
except Exception as e: | |
logger.log(e) | |
if behaviour == 'g_step': | |
# ! 2dgs loss | |
# debugging now, open it from the beginning | |
# if (self.step + self.resume_step) >= 2000 and self.loss_class.opt.lambda_normal > 0: | |
surf_normal = multiscale_target[fine_scale_key]['normal'] * multiscale_target[fine_scale_key]['depth_mask'] # foreground supervision only. | |
# ! hard-coded | |
# rend_normal = pred_nv_cano['rend_normal'] # ! supervise disk normal with GT normal here instead; | |
# st() | |
rend_normal = pred_nv_cano[fine_scale_key]['rend_normal'] | |
rend_dist = pred_nv_cano[fine_scale_key]['dist'] | |
if self.loss_class.opt.lambda_scale_reg > 0: | |
scale_reg = (scaling-self.scale_expected_threshold).square().mean() * self.loss_class.opt.lambda_scale_reg | |
loss = loss + scale_reg | |
loss_dict.update({'loss_scale_reg': scale_reg}) | |
if self.loss_class.opt.lambda_opa_reg > 0: | |
# small_base_opa = latent['gaussians_base_opa'] | |
opa_reg = (-self.loss_class.beta_mvp_base_dist.log_prob(latent['gaussians_base_opa'].clamp(min=1/255, max=0.99)).mean()) * self.loss_class.opt.lambda_opa_reg | |
# ! also on the fine stage | |
opa_reg_fine = (-self.loss_class.beta_mvp_base_dist.log_prob(fine_gaussians_opa.clamp(min=1/255, max=0.99)).mean()) * self.loss_class.opt.lambda_opa_reg | |
# opa_reg = (1-latent['gaussians_base_opa'].mean() ) * self.loss_class.opt.lambda_opa_reg | |
loss = loss + opa_reg + opa_reg_fine | |
loss_dict.update({'loss_opa_reg': opa_reg, 'loss_opa_reg_fine': opa_reg_fine}) | |
if (self.step + self.resume_step) >= 35000 and self.loss_class.opt.lambda_normal > 0: | |
# if (self.step + self.resume_step) >= 2000 and self.loss_class.opt.lambda_normal > 0: | |
# surf_normal = unity2blender_th(surf_normal) # ! g-buffer normal system is different | |
normal_error = (1 - (rend_normal * surf_normal).sum(dim=1)) # B H W | |
# normal_loss = self.loss_class.opt.lambda_normal * (normal_error.sum() / target['depth_mask'].sum()) # average with fg area ratio | |
normal_loss = self.loss_class.opt.lambda_normal * normal_error.mean() | |
loss = loss + normal_loss | |
loss_dict.update({'loss_normal': normal_loss}) | |
# if (self.step + self.resume_step) >= 1500 and self.loss_class.opt.lambda_dist > 0: | |
if (self.step + self.resume_step) >= 15000 and self.loss_class.opt.lambda_dist > 0: | |
# if (self.step + self.resume_step) >= 300 and self.loss_class.opt.lambda_dist > 0: | |
dist_loss = self.loss_class.opt.lambda_dist * (rend_dist).mean() | |
loss = loss + dist_loss | |
loss_dict.update({'loss_dist': dist_loss}) | |
if self.loss_class.opt.pruning_ot_lambda > 0: | |
# for now, save and analyze first | |
# selected_pts_mask_scaling = th.where(th.max(scaling, dim=-1).values < 0.01 * 0.9, True, False) | |
selected_pts_mask_scaling = th.where( | |
th.max(scaling, dim=-1).values > 0.05 * 0.9, True, | |
False) | |
# selected_pts_mask_opacity = th.where(opacity[..., 0] < 0.1, True, False) # B N | |
selected_pts_mask_opacity = th.where( | |
opacity[..., 0] < 0.01, True, | |
False) # 0.005 in the original 3dgs setting | |
selected_scaling_pts = pos[0][selected_pts_mask_scaling[0]] | |
selected_opacity_pts = pos[0][selected_pts_mask_opacity[0]] | |
pcu.save_mesh_v( | |
f'tmp/voxel/cd/10/scaling_masked_pts_0.05.ply', | |
selected_scaling_pts.detach().cpu().numpy(), | |
) | |
pcu.save_mesh_v( | |
f'tmp/voxel/cd/10/opacity_masked_pts_0.01.ply', | |
selected_opacity_pts.detach().cpu().numpy(), | |
) | |
# st() | |
# pass | |
if self.loss_class.opt.cd_lambda > 0: | |
# fuse depth to 3D point cloud to supervise the gaussians | |
B = latent['pos'].shape[0] | |
# c = micro['c'] | |
# H = micro['depth'].shape[-1] | |
# V = 4 | |
# # ! prepare 3D xyz ground truth | |
# cam2world_matrix = c['orig_c2w'][:, :, :16].reshape( | |
# B * V, 4, 4) | |
# intrinsics = c['orig_pose'][:, :, | |
# 16:25].reshape(B * V, 3, 3) | |
# # ! already in the world space after ray_sampler() | |
# ray_origins, ray_directions = self.ray_sampler( # shape: | |
# cam2world_matrix, intrinsics, H // 2)[:2] | |
# # depth = th.nn.functional.interpolate(micro['depth'].unsqueeze(1), (128,128), mode='nearest')[:, 0] # since each view has 128x128 Gaussians | |
# # depth = th.nn.functional.interpolate(micro['depth'].unsqueeze(1), (128,128), mode='nearest')[:, 0] # since each view has 128x128 Gaussians | |
# depth_128 = th.nn.functional.interpolate( | |
# micro['depth'].unsqueeze(1), (128, 128), | |
# mode='nearest' | |
# )[:, 0] # since each view has 128x128 Gaussians | |
# depth = depth_128.reshape(B * V, -1).unsqueeze(-1) | |
# # depth = micro['depth'].reshape(B*V, -1).unsqueeze(-1) | |
# gt_pos = ray_origins + depth * ray_directions # BV HW 3, already in the world space | |
# gt_pos = rearrange(gt_pos, | |
# '(B V) N C -> B (V N) C', | |
# B=B, | |
# V=V) | |
# gt_pos = gt_pos.clip(-0.45, 0.45) | |
# TODO | |
gt_pos = micro[ | |
'fps_pcd'] # all the same, will update later. | |
# ! use online here | |
# gt_pos = query_pcd_xyz | |
cd_loss = pytorch3d.loss.chamfer_distance( | |
gt_pos, latent['pos'] | |
)[0] * self.loss_class.opt.cd_lambda # V=4 GT for now. Test with V=8 GT later. | |
# st() | |
# for vis | |
if False: | |
torchvision.utils.save_image(micro['img'], | |
'gt.jpg', | |
value_range=(-1, 1), | |
normalize=True) | |
with th.no_grad(): | |
for b in range(B): | |
pcu.save_mesh_v( | |
f'tmp/voxel/cd/10/again_pred-{b}.ply', | |
latent['pos'][b].detach().cpu().numpy(), | |
) | |
# pcu.save_mesh_v( | |
# f'tmp/voxel/cd/10/again-gt-{b}.ply', | |
# gt_pos[b].detach().cpu().numpy(), | |
# ) | |
# st() | |
loss = loss + cd_loss | |
loss_dict.update({'loss_cd': cd_loss}) | |
elif self.loss_class.opt.xyz_lambda > 0: | |
''' | |
B = latent['per_view_pos'].shape[0] // 4 | |
V = 4 | |
c = micro['c'] | |
H = micro['depth'].shape[-1] | |
# ! prepare 3D xyz ground truth | |
cam2world_matrix = c['orig_c2w'][:, :, :16].reshape( | |
B * V, 4, 4) | |
intrinsics = c['orig_pose'][:, :, | |
16:25].reshape(B * V, 3, 3) | |
# ! already in the world space after ray_sampler() | |
ray_origins, ray_directions = self.ray_sampler( # shape: | |
cam2world_matrix, intrinsics, H // 2)[:2] | |
# self.gs.output_size,)[:2] | |
# depth = rearrange(micro['depth'], '(B V) H W -> ') | |
depth_128 = th.nn.functional.interpolate( | |
micro['depth'].unsqueeze(1), (128, 128), | |
mode='nearest' | |
)[:, 0] # since each view has 128x128 Gaussians | |
depth = depth_128.reshape(B * V, -1).unsqueeze(-1) | |
fg_mask = th.nn.functional.interpolate( | |
micro['depth_mask'].unsqueeze(1).to(th.uint8), | |
(128, 128), | |
mode='nearest').squeeze(1) # B*V H W | |
fg_mask = fg_mask.reshape(B * V, -1).unsqueeze(-1) | |
gt_pos = ray_origins + depth * ray_directions # BV HW 3, already in the world space | |
# st() | |
gt_pos = fg_mask * gt_pos.clip( | |
-0.45, 0.45) # g-buffer objaverse range | |
pred = fg_mask * latent['per_view_pos'] | |
# for vis | |
if True: | |
torchvision.utils.save_image(micro['img'], | |
'gt.jpg', | |
value_range=(-1, 1), | |
normalize=True) | |
with th.no_grad(): | |
gt_pos_vis = rearrange(gt_pos, | |
'(B V) N C -> B V N C', | |
B=B, | |
V=V) | |
pred_pos_vis = rearrange(pred, | |
'(B V) N C -> B V N C', | |
B=B, | |
V=V) | |
# save | |
for b in range(B): | |
for v in range(V): | |
# pcu.save_mesh_v(f'tmp/dust3r/add3dsupp-pred-{b}-{v}.ply', | |
# pred_pos_vis[b][v].detach().cpu().numpy(),) | |
# pcu.save_mesh_v(f'tmp/dust3r/add3dsupp-gt-{b}-{v}.ply', | |
# gt_pos_vis[b][v].detach().cpu().numpy(),) | |
pcu.save_mesh_v( | |
f'tmp/lambda50/no3dsupp-pred-{b}-{v}.ply', | |
pred_pos_vis[b] | |
[v].detach().cpu().numpy(), | |
) | |
pcu.save_mesh_v( | |
f'tmp/lambda50/no3dsupp-gt-{b}-{v}.ply', | |
gt_pos_vis[b] | |
[v].detach().cpu().numpy(), | |
) | |
st() | |
xyz_loss = th.nn.functional.mse_loss( | |
gt_pos, pred | |
) * self.loss_class.opt.xyz_lambda # ! 15% nonzero points | |
loss = loss + xyz_loss | |
''' | |
# ! directly gs center supervision with l1 loss, follow LION VAE | |
# xyz_loss = th.nn.functional.l1_loss( | |
# query_pcd_xyz, pred | |
# ) * self.loss_class.opt.xyz_lambda # ! 15% nonzero points | |
xyz_loss = self.loss_class.criterion_xyz(query_pcd_xyz, latent['pos']) * self.loss_class.opt.xyz_lambda | |
loss = loss + xyz_loss | |
# only calculate foreground gt_pos here? | |
loss_dict.update({'loss_xyz': xyz_loss}) | |
elif self.loss_class.opt.emd_lambda > 0: | |
# rand_pt_size = 4096 # K value. Input Error! The size of the point clouds should be a multiple of 1024. | |
pred = latent['pos'] | |
rand_pt_size = min(2048, max(pred.shape[1], 1024)) # K value. Input Error! The size of the point clouds should be a multiple of 1024. | |
if micro['fps_pcd'].shape[0] == pred.shape[0]: | |
gt_point = micro['fps_pcd'] | |
else: # overfit memory dataset | |
gt_point = micro[ | |
'fps_pcd'][:: | |
4] # consecutive 4 views are from the same ID | |
B, gt_point_N = gt_point.shape[:2] | |
# random sample pred points | |
# sampled_pred = | |
# rand_pt_idx = torch.randint(high=pred.shape[1]-gt_point_N, size=(B,)) | |
# pcu.save_mesh_v( f'tmp/voxel/emd/gt-half.ply', gt_point[0, ::4].detach().cpu().numpy(),) | |
# for b in range(gt_point.shape[0]): | |
# pcu.save_mesh_v( f'{logger.get_dir()}/gt-{b}.ply', gt_point[b].detach().cpu().numpy(),) | |
# pcu.save_mesh_v( f'{logger.get_dir()}/pred-{b}.ply', pred[b].detach().cpu().numpy(),) | |
# pcu.save_mesh_v( f'0.ply', latent['pos'][0].detach().cpu().numpy()) | |
# st() | |
if self.loss_class.opt.fps_sampling: # O(N*K). reduce K later. | |
if self.loss_class.opt.subset_fps_sampling: | |
rand_pt_size = 1024 # for faster calculation | |
# ! uniform sampling with randomness | |
# sampled_gt_pts_for_emd_loss = gt_point[:, random.randint(0,9)::9][:, :1024] # direct uniform downsample to the K size | |
# sampled_gt_pts_for_emd_loss = gt_point[:, random.randint(0,9)::4][:, :1024] # direct uniform downsample to the K size | |
rand_perm = torch.randperm( | |
gt_point.shape[1] | |
)[:rand_pt_size] # shuffle the xyz before downsample - fps sampling | |
sampled_gt_pts_for_emd_loss = gt_point[:, rand_perm] | |
# sampled_gt_pts_for_emd_loss = gt_point[:, ::4] | |
# sampled_pred_pts_for_emd_loss = pred[:, ::32] | |
# sampled_gt_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points( | |
# gt_point[:, ::4], K=rand_pt_size)[0] # V4 | |
if self.loss_class.opt.subset_half_fps_sampling: | |
# sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points( | |
# pred, K=rand_pt_size)[0] # V5 | |
rand_perm = torch.randperm( | |
pred.shape[1] | |
)[:4096] # shuffle the xyz before downsample - fps sampling | |
sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points( | |
pred[:, rand_perm], | |
K=rand_pt_size)[0] # improve randomness | |
else: | |
sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points( | |
pred[:, ::4], K=rand_pt_size)[0] # V5 | |
# rand_perm = torch.randperm(pred.shape[1]) # shuffle the xyz before downsample - fps sampling | |
# sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points( | |
# pred[:, rand_perm][:, ::4], K=rand_pt_size)[0] # rand perm before downsampling, V6 | |
# sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points( | |
# pred[:, rand_perm][:, ::8], K=rand_pt_size)[0] # rand perm before downsampling, V7 | |
# sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points( | |
# pred[:, self.step%2::4], K=rand_pt_size)[0] # rand perm before downsampling, V8, based on V50 | |
else: | |
sampled_gt_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points( | |
gt_point, K=rand_pt_size)[0] | |
# if self.loss_class.opt.subset_half_fps_sampling: | |
# rand_pt_size = 4096 # K value. Input Error! The size of the point clouds should be a multiple of 1024. | |
sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points( | |
pred, K=rand_pt_size)[0] | |
# else: | |
# sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points( | |
# pred, K=rand_pt_size)[0] | |
else: # random sampling | |
rand_pt_idx_pred = torch.randint(high=pred.shape[1] - | |
rand_pt_size, | |
size=(1, ))[0] | |
rand_pt_idx_gt = torch.randint(high=gt_point.shape[1] - | |
rand_pt_size, | |
size=(1, ))[0] | |
sampled_pred_pts_for_emd_loss = pred[:, | |
rand_pt_idx_pred: | |
rand_pt_idx_pred + | |
rand_pt_size, ...] | |
sampled_gt_pts_for_emd_loss = gt_point[:, | |
rand_pt_idx_gt: | |
rand_pt_idx_gt + | |
rand_pt_size, | |
...] | |
# only calculate foreground gt_pos here? | |
emd_loss = calc_emd(sampled_gt_pts_for_emd_loss, | |
sampled_pred_pts_for_emd_loss).mean( | |
) * self.loss_class.opt.emd_lambda | |
loss = loss + emd_loss | |
loss_dict.update({'loss_emd': emd_loss}) | |
if self.loss_class.opt.commitment_loss_lambda > 0: | |
ellipsoid_vol = torch.prod(scaling, dim=-1, keepdim=True) / ((0.01 * 0.9)**3) # * (4/3*torch.pi). normalized vol | |
commitment = ellipsoid_vol * opacity | |
to_be_pruned_ellipsoid_idx = commitment < (3/4)**3 * 0.9 # those points shall have larger vol*opacity contribution | |
commitment_loss = -commitment[to_be_pruned_ellipsoid_idx].mean() * self.loss_class.opt.commitment_loss_lambda | |
loss = loss + commitment_loss | |
loss_dict.update({'loss_commitment': commitment_loss}) | |
loss_dict.update({'loss_commitment_opacity': opacity.mean()}) | |
loss_dict.update({'loss_commitment_vol': ellipsoid_vol.mean()}) | |
log_rec3d_loss_dict(loss_dict) | |
# self.mp_trainer_rec.backward(loss) | |
if behaviour == 'g_step': | |
self.mp_trainer_rec.backward(loss) | |
else: | |
self.mp_trainer_disc.backward(loss) | |
# for name, p in self.rec_model.named_parameters(): | |
# if p.grad is None: | |
# logger.log(f"found rec unused param: {name}") | |
# print(name, p.grad.mean(), p.grad.abs().max()) | |
if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0 and behaviour=='g_step': | |
# if dist_util.get_rank() == 0 and self.step % 1 == 0 and i == 0: | |
try: | |
torchvision.utils.save_image( | |
th.cat([target['img'][::1], pred_nv_cano[fine_scale_key]['image_raw'][::1]], ), | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
normalize=True, value_range=(-1,1),nrow=len(indices)*2) | |
# save depth and normal and alpha | |
torchvision.utils.save_image( | |
th.cat([surf_normal[::1], rend_normal[::1]], ), | |
f'{logger.get_dir()}/{self.step+self.resume_step}_normal_new.jpg', | |
normalize=True, value_range=(-1,1), nrow=len(indices)*2) | |
torchvision.utils.save_image( | |
th.cat([target['depth'][::1], pred_nv_cano[fine_scale_key]['image_depth'][::1]], ), | |
f'{logger.get_dir()}/{self.step+self.resume_step}_depth.jpg', | |
normalize=True, nrow=len(indices)*2) | |
# torchvision.utils.save_image( pred_nv_cano['image_depth'][::1], f'{logger.get_dir()}/{self.step+self.resume_step}_depth.jpg', normalize=True, nrow=len(indices)*2) | |
torchvision.utils.save_image( | |
th.cat([target['depth_mask'][::1], pred_nv_cano[fine_scale_key]['image_mask'][::1]], ), | |
f'{logger.get_dir()}/{self.step+self.resume_step}_alpha.jpg', | |
normalize=True, value_range=(0,1), nrow=len(indices)*2) | |
logger.log( | |
'log vis to: ', | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
except Exception as e: | |
logger.log('Exception when saving log: ', e) | |
# if self.step % 2500 == 0: | |
# th.cuda.empty_cache() # free vram | |
def export_mesh_from_2dgs(self, all_rgbs, all_depths, all_alphas, cam_pathes, latent_save_dir): | |
# https://github.com/autonomousvision/LaRa/blob/main/evaluation.py | |
n_thread = 1 # avoid TSDF cpu hanging bug. | |
os.environ["MKL_NUM_THREADS"] = f"{n_thread}" | |
os.environ["NUMEXPR_NUM_THREADS"] = f"{n_thread}" | |
os.environ["OMP_NUM_THREADS"] = f"4" | |
os.environ["VECLIB_MAXIMUM_THREADS"] = f"{n_thread}" | |
os.environ["OPENBLAS_NUM_THREADS"] = f"{n_thread}" | |
# copied from: https://github.com/hbb1/2d-gaussian-splatting/blob/19eb5f1e091a582e911b4282fe2832bac4c89f0f/render.py#L23 | |
logger.log("exporting mesh ...") | |
# for g-objv | |
aabb = [-0.45,-0.45,-0.45,0.45,0.45,0.45] | |
self.aabb = np.array(aabb).reshape(2,3)*1.1 | |
name = f'{latent_save_dir}/mesh_raw.obj' | |
mesh = self.extract_mesh_bounded(all_rgbs, all_depths, all_alphas, cam_pathes) | |
o3d.io.write_triangle_mesh(name, mesh) | |
logger.log("mesh saved at {}".format(name)) | |
mesh_post = smooth_mesh(mesh) | |
o3d.io.write_triangle_mesh(name.replace('_raw.obj', '.obj'), mesh_post) | |
logger.log("mesh post processed saved at {}".format(name.replace('.obj', '_post.obj'))) | |
def get_source_cw2wT(self, source_cameras_view_to_world): | |
return matrix_to_quaternion( | |
source_cameras_view_to_world[:3, :3].transpose(0, 1)) | |
def c_to_3dgs_format(self, pose): | |
# TODO, switch to torch version (batched later) | |
c2w = pose[:16].reshape(4, 4) # 3x4 | |
# ! load cam | |
w2c = np.linalg.inv(c2w) | |
R = np.transpose( | |
w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code | |
T = w2c[:3, 3] | |
fx = pose[16] | |
FovX = focal2fov(fx, 1) | |
FovY = focal2fov(fx, 1) | |
tanfovx = math.tan(FovX * 0.5) | |
tanfovy = math.tan(FovY * 0.5) | |
assert tanfovx == tanfovy | |
trans = np.array([0.0, 0.0, 0.0]) | |
scale = 1.0 | |
world_view_transform = torch.tensor(getWorld2View2(R, T, trans, | |
scale)).transpose( | |
0, 1) | |
projection_matrix = getProjectionMatrix(znear=self.znear, | |
zfar=self.zfar, | |
fovX=FovX, | |
fovY=FovY).transpose(0, 1) | |
full_proj_transform = (world_view_transform.unsqueeze(0).bmm( | |
projection_matrix.unsqueeze(0))).squeeze(0) | |
camera_center = world_view_transform.inverse()[3, :3] | |
view_world_transform = torch.tensor(getView2World(R, T, trans, | |
scale)).transpose( | |
0, 1) | |
# item.update(viewpoint_cam=[viewpoint_cam]) | |
c = {} | |
c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform) | |
c.update( | |
projection_matrix=projection_matrix, # K | |
cam_view=world_view_transform, # world_view_transform | |
cam_view_proj=full_proj_transform, # full_proj_transform | |
cam_pos=camera_center, | |
tanfov=tanfovx, # TODO, fix in the renderer | |
# orig_c2w=c2w, | |
# orig_w2c=w2c, | |
orig_pose=torch.from_numpy(pose), | |
orig_c2w=torch.from_numpy(c2w), | |
orig_w2c=torch.from_numpy(w2c), | |
# tanfovy=tanfovy, | |
) | |
return c # dict for gs rendering | |
def extract_mesh_bounded(self, rgbmaps, depthmaps, alpha_maps, cam_pathes, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, alpha_thres=0.08, mask_backgrond=False): | |
""" | |
Perform TSDF fusion given a fixed depth range, used in the paper. | |
voxel_size: the voxel size of the volume | |
sdf_trunc: truncation value | |
depth_trunc: maximum depth range, should depended on the scene's scales | |
mask_backgrond: whether to mask backgroud, only works when the dataset have masks | |
return o3d.mesh | |
""" | |
if self.aabb is not None: # as in lara. | |
center = self.aabb.mean(0) | |
# radius = np.linalg.norm(self.aabb[1] - self.aabb[0]) * 0.5 | |
radius = np.linalg.norm(self.aabb[1] - self.aabb[0]) * 0.5 | |
# voxel_size = radius / 256 | |
voxel_size = radius / 192 # less holes | |
# sdf_trunc = voxel_size * 16 # less holes, slower integration | |
sdf_trunc = voxel_size * 12 # | |
print("using aabb") | |
volume = o3d.pipelines.integration.ScalableTSDFVolume( | |
voxel_length= voxel_size, | |
sdf_trunc=sdf_trunc, | |
color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 | |
) | |
print("Running tsdf volume integration ...") | |
print(f'voxel_size: {voxel_size}') | |
print(f'sdf_trunc: {sdf_trunc}') | |
# print(f'depth_truc: {depth_trunc}') | |
# render_reference = th.load('eval_pose.pt', map_location='cpu').numpy() | |
# ! use uni_mesh_path, from Lara, Chen et al, ECCV 24' | |
# ''' | |
# for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack)), desc="TSDF integration progress"): | |
for i, cam in tqdm(enumerate(cam_pathes), desc="TSDF integration progress"): | |
# rgb = self.rgbmaps[i] | |
# depth = self.depthmaps[i] | |
cam = self.c_to_3dgs_format(cam) | |
cam_o3d = to_cam_open3d_compat(cam) | |
rgb = rgbmaps[i][0] | |
depth = depthmaps[i][0] | |
alpha = alpha_maps[i][0] | |
# if we have mask provided, use it | |
# if mask_backgrond and (self.viewpoint_stack[i].gt_alpha_mask is not None): | |
# depth[(self.viewpoint_stack[i].gt_alpha_mask < 0.5)] = 0 | |
depth[(alpha < alpha_thres)] = 0 | |
if self.aabb is not None: | |
campos = cam['cam_pos'].cpu().numpy() | |
depth_trunc = np.linalg.norm(campos - center, axis=-1) + radius | |
# make open3d rgbd | |
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( | |
o3d.geometry.Image(np.asarray(np.clip(rgb.permute(1,2,0).cpu().numpy(), 0.0, 1.0) * 255, order="C", dtype=np.uint8)), | |
o3d.geometry.Image(np.asarray(depth.permute(1,2,0).cpu().numpy(), order="C")), | |
depth_trunc = depth_trunc, | |
convert_rgb_to_intensity=False, | |
depth_scale = 1.0 | |
) | |
volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic) | |
mesh = volume.extract_triangle_mesh() | |
return mesh | |
def eval_novelview_loop(self, camera=None, save_latent=False): | |
# novel view synthesis given evaluation camera trajectory | |
if save_latent: # for diffusion learning | |
latent_dir = Path(f'{logger.get_dir()}/latent_dir') | |
latent_dir.mkdir(exist_ok=True, parents=True) | |
render_reference=uni_mesh_path(10) | |
for eval_idx, micro in enumerate(tqdm(self.eval_data)): | |
latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}' | |
all_latent_file = sorted(Path(latent_save_dir).glob('*.npz') ) | |
if len(all_latent_file) == 0: | |
save_prefix = 0 | |
else: | |
save_prefix = int(all_latent_file[-1].stem[-1] ) + 1 | |
Path(latent_save_dir).mkdir(parents=True, exist_ok=True) | |
with th.autocast(device_type='cuda', | |
dtype=self.dtype, | |
enabled=self.mp_trainer_rec.use_amp): | |
# st() # check whether more c info available | |
latent = self.rec_model( | |
img=micro['img_to_encoder'].to(self.dtype), | |
behaviour='enc_dec_wo_triplane', | |
c=micro['c'], | |
pcd=micro['fps_pcd'], # send in pcd for surface reference. | |
) # send in input-view C since pixel-aligned gaussians required | |
# fine_scale_key = list(pred.keys())[-1] | |
# fine_scale_key = 'gaussians_upsampled_2' | |
fine_scale_key = 'gaussians_upsampled_3' | |
export_mesh = True # for debug | |
if True: | |
# if eval_idx < 1500 and eval_idx % 3 == 0: | |
if eval_idx < 1500: | |
all_rgbs, all_depths, all_alphas=self.render_gs_video_given_latent( | |
latent, | |
self.rec_model, # compatible with join_model | |
name_prefix=f'{self.step + self.resume_step}_{micro["ins"][0].split("/")[0]}_{eval_idx}', | |
save_img=False, | |
render_reference=render_reference, | |
export_mesh=False) | |
if export_mesh: | |
self.export_mesh_from_2dgs(all_rgbs, all_depths, all_alphas, render_reference, latent_save_dir) | |
# ! B=2 here | |
np.savez_compressed(f'{latent_save_dir}/latent-{save_prefix}.npz', | |
latent_normalized=latent['latent_normalized'].cpu().numpy(), | |
query_pcd_xyz=latent['query_pcd_xyz'].cpu().numpy() | |
) | |
# st() | |
for scale in ['gaussians_upsampled', 'gaussians_base', 'gaussians_upsampled_2', 'gaussians_upsampled_3']: | |
np.save(f'{latent_save_dir}/{scale}.npy', latent[scale].cpu().numpy()) | |
def render_gs_video_given_latent(self, | |
ddpm_latent, | |
rec_model, | |
name_prefix='0', | |
save_img=False, | |
render_reference=None, | |
export_mesh=False): | |
all_rgbs, all_depths, all_alphas = [], [], [] | |
# batch_size, L, C = planes.shape | |
# ddpm_latent = { self.latent_name: planes[..., :-3] * self.triplane_scaling_divider, # kl-reg latent | |
# 'query_pcd_xyz': self.pcd_unnormalize_fn(planes[..., -3:]) } | |
# ddpm_latent.update(rec_model(latent=ddpm_latent, behaviour='decode_gs_after_vae_no_render')) | |
# assert render_reference is None | |
# render_reference = self.eval_data # compat | |
# else: # use train_traj | |
# for key in ['ins', 'bbox', 'caption']: | |
# if key in render_reference: | |
# render_reference.pop(key) | |
# render_reference = [ { k:v[idx:idx+1] for k, v in render_reference.items() } for idx in range(40) ] | |
video_out = imageio.get_writer( | |
f'{logger.get_dir()}/gs_{name_prefix}.mp4', | |
mode='I', | |
fps=15, | |
codec='libx264') | |
# for i, batch in enumerate(tqdm(self.eval_data)): | |
for i, micro_c in enumerate(tqdm(render_reference)): | |
# micro = { | |
# k: v.to(dist_util.dev()) if isinstance(v, th.Tensor) else v | |
# for k, v in batch.items() | |
# } | |
c = self.post_process.c_to_3dgs_format(micro_c) | |
for k in c.keys(): # to cuda | |
if isinstance(c[k], th.Tensor) and k != 'tanfov': | |
c[k] = c[k].unsqueeze(0).unsqueeze(0).to(dist_util.dev()) # actually, could render 40 views together. | |
c['tanfov'] = th.tensor(c['tanfov']).to(dist_util.dev()) | |
pred = rec_model( | |
img=None, | |
c=c, # TODO, to dict | |
latent=ddpm_latent, # render gs | |
behaviour='triplane_dec', | |
bg_color=self.gs_bg_color, | |
render_all_scale=True, | |
) | |
fine_scale_key = list(pred.keys())[-1] | |
all_rgbs.append(einops.rearrange(pred[fine_scale_key]['image'], 'B V ... -> (B V) ...')) | |
all_depths.append(einops.rearrange(pred[fine_scale_key]['depth'], 'B V ... -> (B V) ...')) | |
all_alphas.append(einops.rearrange(pred[fine_scale_key]['alpha'], 'B V ... -> (B V) ...')) | |
# st() | |
# fine_scale_key = list(pred.keys())[-1] | |
all_pred_vis = {} | |
for key in pred.keys(): | |
pred_scale = pred[key] # only show finest result here | |
for k in pred_scale.keys(): | |
pred_scale[k] = einops.rearrange(pred_scale[k], 'B V ... -> (B V) ...') # merge | |
pred_vis = self._make_vis_img(pred_scale) | |
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
vis = vis * 127.5 + 127.5 | |
vis = vis.clip(0, 255).astype(np.uint8) | |
all_pred_vis[key] = vis | |
# all_pred_vis_concat = np.concatenate([cv2.resize(all_pred_vis[k][0], (384*3, 384)) for k in ['gaussians_base', 'gaussians_upsampled', 'gaussians_upsampled_2']], axis=0) | |
# all_pred_vis_concat = np.concatenate([cv2.resize(all_pred_vis[k][0], (256*3, 256)) for k in ['gaussians_base', 'gaussians_upsampled',]], axis=0) | |
# all_pred_vis_concat = np.concatenate([cv2.resize(all_pred_vis[k][0], (384*3, 384)) for k in all_pred_vis.keys()], axis=0) | |
all_pred_vis_concat = np.concatenate([cv2.resize(all_pred_vis[k][0], (512*3, 512)) for k in all_pred_vis.keys()], axis=0) | |
# for j in range(vis.shape[0]): | |
video_out.append_data(all_pred_vis_concat) | |
video_out.close() | |
print('logged video to: ', | |
f'{logger.get_dir()}/triplane_{name_prefix}.mp4') | |
del video_out, pred, pred_vis, vis | |
return all_rgbs, all_depths, all_alphas | |
def _make_vis_img(self, pred): | |
# if True: | |
pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
pred_depth.min()) | |
pred_depth = pred_depth.cpu()[0].permute(1, 2, 0).numpy() | |
pred_depth = (plt.cm.viridis(pred_depth[..., 0])[..., :3]) * 2 - 1 | |
pred_depth = th.from_numpy(pred_depth).to( | |
pred['image_raw'].device).permute(2, 0, 1).unsqueeze(0) | |
gen_img = pred['image_raw'] | |
rend_normal = pred['rend_normal'] | |
pred_vis = th.cat( | |
[ | |
gen_img, | |
rend_normal, | |
pred_depth, | |
], | |
dim=-1) # B, 3, H, W | |
return pred_vis | |
class TrainLoop3DRecNVPatchSingleForwardMV_NoCrop_adv(TrainLoop3DRecNVPatchSingleForwardMV_NoCrop): | |
def __init__(self, *, rec_model, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, load_submodule_name='', ignore_resume_opt=False, model_name='rec', use_amp=False, num_frames=4, **kwargs): | |
super().__init__(rec_model=rec_model, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, load_submodule_name=load_submodule_name, ignore_resume_opt=ignore_resume_opt, model_name=model_name, use_amp=use_amp, num_frames=num_frames, **kwargs) | |
# create discriminator | |
# ! copied from ln3diff tri-plane version | |
disc_params = self.loss_class.get_trainable_parameters() | |
self.mp_trainer_disc = MixedPrecisionTrainer( | |
model=self.loss_class.discriminator, | |
use_fp16=self.use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
model_name='disc', | |
use_amp=use_amp, | |
model_params=disc_params) | |
# st() # check self.lr | |
self.opt_disc = AdamW( | |
self.mp_trainer_disc.master_params, | |
lr=self.lr, # follow sd code base | |
betas=(0, 0.999), | |
eps=1e-8) | |
# TODO, is loss cls already in the DDP? | |
if self.use_ddp: | |
self.ddp_disc = DDP( | |
self.loss_class.discriminator, | |
device_ids=[dist_util.dev()], | |
output_device=dist_util.dev(), | |
broadcast_buffers=False, | |
bucket_cap_mb=128, | |
find_unused_parameters=False, | |
) | |
else: | |
self.ddp_disc = self.loss_class.discriminator | |
def save(self, mp_trainer=None, model_name='rec'): | |
if mp_trainer is None: | |
mp_trainer = self.mp_trainer_rec | |
def save_checkpoint(rate, params): | |
state_dict = mp_trainer.master_params_to_state_dict(params) | |
if dist_util.get_rank() == 0: | |
logger.log(f"saving model {model_name} {rate}...") | |
if not rate: | |
filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt" | |
else: | |
filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt" | |
with bf.BlobFile(bf.join(get_blob_logdir(), filename), | |
"wb") as f: | |
th.save(state_dict, f) | |
save_checkpoint(0, mp_trainer.master_params) | |
dist.barrier() | |
def run_step(self, batch, step='g_step'): | |
# self.forward_backward(batch) | |
if step == 'g_step': | |
self.forward_backward(batch, behaviour='g_step') | |
took_step_g_rec = self.mp_trainer_rec.optimize(self.opt) | |
if took_step_g_rec: | |
self._update_ema() # g_ema | |
elif step == 'd_step': | |
self.forward_backward(batch, behaviour='d_step') | |
_ = self.mp_trainer_disc.optimize(self.opt_disc) | |
self._anneal_lr() | |
self.log_step() | |
def run_loop(self, batch=None): | |
while (not self.lr_anneal_steps | |
or self.step + self.resume_step < self.lr_anneal_steps): | |
batch = next(self.data) | |
self.run_step(batch, 'g_step') | |
batch = next(self.data) | |
self.run_step(batch, 'd_step') | |
if self.step % 1000 == 0: | |
dist_util.synchronize() | |
if self.step % 5000 == 0: | |
th.cuda.empty_cache() # avoid memory leak | |
if self.step % self.log_interval == 0 and dist_util.get_rank( | |
) == 0: | |
out = logger.dumpkvs() | |
# * log to tensorboard | |
for k, v in out.items(): | |
self.writer.add_scalar(f'Loss/{k}', v, | |
self.step + self.resume_step) | |
if self.step % self.eval_interval == 0 and self.step != 0: | |
if dist_util.get_rank() == 0: | |
try: | |
self.eval_loop() | |
except Exception as e: | |
logger.log(e) | |
dist_util.synchronize() | |
# if self.step % self.save_interval == 0 and self.step != 0: | |
if self.step % self.save_interval == 0: | |
self.save() | |
self.save(self.mp_trainer_disc, | |
self.mp_trainer_disc.model_name) | |
dist_util.synchronize() | |
# Run for a finite amount of time in integration tests. | |
if os.environ.get("DIFFUSION_TRAINING_TEST", | |
"") and self.step > 0: | |
return | |
self.step += 1 | |
if self.step > self.iterations: | |
logger.log('reached maximum iterations, exiting') | |
# Save the last checkpoint if it wasn't already saved. | |
if (self.step - | |
1) % self.save_interval != 0 and self.step != 1: | |
self.save() | |
exit() | |
# Save the last checkpoint if it wasn't already saved. | |
# if (self.step - 1) % self.save_interval != 0 and self.step != 1: | |
if (self.step - 1) % self.save_interval != 0: | |
try: | |
self.save() # save rec | |
self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name) | |
except Exception as e: | |
logger.log(e) | |
# ! load disc | |
def _load_and_sync_parameters(self, submodule_name=''): | |
super()._load_and_sync_parameters(submodule_name) | |
# load disc | |
resume_checkpoint = self.resume_checkpoint.replace( | |
'rec', 'disc') # * default behaviour | |
if os.path.exists(resume_checkpoint): | |
if dist_util.get_rank() == 0: | |
logger.log( | |
f"loading disc model from checkpoint: {resume_checkpoint}..." | |
) | |
map_location = { | |
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() | |
} # configure map_location properly | |
resume_state_dict = dist_util.load_state_dict( | |
resume_checkpoint, map_location=map_location) | |
model_state_dict = self.loss_class.discriminator.state_dict() | |
for k, v in resume_state_dict.items(): | |
if k in model_state_dict.keys(): | |
if v.size() == model_state_dict[k].size(): | |
model_state_dict[k] = v | |
# model_state_dict[k].copy_(v) | |
else: | |
logger.log('!!!! partially load: ', k, ": ", | |
v.size(), "state_dict: ", | |
model_state_dict[k].size()) | |
if dist_util.get_world_size() > 1: | |
# dist_util.sync_params(self.model.named_parameters()) | |
dist_util.sync_params( | |
self.loss_class.get_trainable_parameters()) | |
logger.log('synced disc params') | |