GaussianAnything-AIGC3D / nsr /train_nv_util.py
yslan's picture
init
7f51798
import copy
import cv2
import einops
from collections import defaultdict
import matplotlib.pyplot as plt
import random
# import emd
import pytorch3d.loss
# import imageio.v3
import functools
import json
import os
from pathlib import Path
from pdb import set_trace as st
from einops import rearrange
import webdataset as wds
from nsr.camera_utils import generate_input_camera, uni_mesh_path
import point_cloud_utils as pcu
import traceback
import blobfile as bf
from datasets.g_buffer_objaverse import focal2fov, fov2focal
import math
import imageio
import numpy as np
# from sympy import O
import torch
from torch.autograd import Function
import torch.nn.functional as F
import torch as th
import open3d as o3d
import torch.distributed as dist
import torchvision
from PIL import Image
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import pytorch3d.ops
from torch.profiler import profile, record_function, ProfilerActivity
from guided_diffusion import dist_util, logger
from guided_diffusion.fp16_util import MixedPrecisionTrainer
from guided_diffusion.nn import update_ema
from guided_diffusion.resample import LossAwareSampler, UniformSampler
from guided_diffusion.train_util import (calc_average_loss,
find_ema_checkpoint,
find_resume_checkpoint,
get_blob_logdir, log_rec3d_loss_dict,
parse_resume_step_from_filename)
from datasets.g_buffer_objaverse import unity2blender, unity2blender_th, PostProcess
from nsr.volumetric_rendering.ray_sampler import RaySampler
from utils.mesh_util import post_process_mesh, to_cam_open3d_compat
from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics
from nsr.camera_utils import generate_input_camera, uni_mesh_path, sample_uniform_cameras_on_sphere
from utils.gs_utils.graphics_utils import getWorld2View2, getProjectionMatrix, getView2World
from utils.general_utils import matrix_to_quaternion
from utils.mesh_util import post_process_mesh, to_cam_open3d_compat, smooth_mesh
from datasets.g_buffer_objaverse import focal2fov, fov2focal
from .train_util import TrainLoop3DRec
import kornia
@th.autocast(device_type='cuda', dtype=th.float16, enabled=False)
def psnr(input, target, max_val):
return kornia.metrics.psnr(input, target, max_val)
def calc_emd(output, gt, eps=0.005, iterations=50):
import utils.emd.emd_module as emd
emd_loss = emd.emdModule()
dist, _ = emd_loss(output, gt, eps, iterations)
emd_out = torch.sqrt(dist).mean(1)
return emd_out
class TrainLoop3DRecNV(TrainLoop3DRec):
# supervise the training of novel view
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
self.rec_cano = True
def forward_backward(self, batch, *args, **kwargs):
# return super().forward_backward(batch, *args, **kwargs)
self.mp_trainer_rec.zero_grad()
batch_size = batch['img_to_encoder'].shape[0]
for i in range(0, batch_size, self.microbatch):
# st()
micro = {
k: v[i:i + self.microbatch].to(dist_util.dev())
for k, v in batch.items()
}
# ! concat novel-view? next version. also add self reconstruction, patch-based loss in the next version. verify novel-view prediction first.
# wrap forward within amp
with th.autocast(device_type='cuda',
dtype=th.float16,
enabled=self.mp_trainer_rec.use_amp):
target_nvs = {}
target_cano = {}
latent = self.rec_model(img=micro['img_to_encoder'],
behaviour='enc_dec_wo_triplane')
pred = self.rec_model(
latent=latent,
c=micro['nv_c'], # predict novel view here
behaviour='triplane_dec')
for k, v in micro.items():
if k[:2] == 'nv':
orig_key = k.replace('nv_', '')
target_nvs[orig_key] = v
target_cano[orig_key] = micro[orig_key]
with self.rec_model.no_sync(): # type: ignore
loss, loss_dict, fg_mask = self.loss_class(
pred,
target_nvs,
step=self.step + self.resume_step,
test_mode=False,
return_fg_mask=True,
conf_sigma_l1=None,
conf_sigma_percl=None)
log_rec3d_loss_dict(loss_dict)
if self.rec_cano:
pred_cano = self.rec_model(latent=latent,
c=micro['c'],
behaviour='triplane_dec')
with self.rec_model.no_sync(): # type: ignore
fg_mask = target_cano['depth_mask'].unsqueeze(
1).repeat_interleave(3, 1).float()
loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss(
pred_cano['image_raw'],
target_cano['img'],
fg_mask,
step=self.step + self.resume_step,
test_mode=False,
)
loss = loss + loss_cano
# remove redundant log
log_rec3d_loss_dict({
f'cano_{k}': v
for k, v in loss_cano_dict.items()
# if "loss" in k
})
self.mp_trainer_rec.backward(loss)
if dist_util.get_rank() == 0 and self.step % 500 == 0:
if self.rec_cano:
self.log_img(micro, pred, pred_cano)
else:
self.log_img(micro, pred, None)
@th.inference_mode()
def log_img(self, micro, pred, pred_cano):
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
def norm_depth(pred_depth): # to [-1,1]
# pred_depth = pred['image_depth']
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
pred_depth.min())
return -(pred_depth * 2 - 1)
pred_img = pred['image_raw']
gt_img = micro['img']
# infer novel view also
# if self.loss_class.opt.symmetry_loss:
# pred_nv_img = nvs_pred
# else:
# ! replace with novel view prediction
# ! log another novel-view prediction
# pred_nv_img = self.rec_model(
# img=micro['img_to_encoder'],
# c=self.novel_view_poses) # pred: (B, 3, 64, 64)
# if 'depth' in micro:
gt_depth = micro['depth']
if gt_depth.ndim == 3:
gt_depth = gt_depth.unsqueeze(1)
gt_depth = norm_depth(gt_depth)
# gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
# gt_depth.min())
# if True:
fg_mask = pred['image_mask'] * 2 - 1 # 0-1
input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1
if 'image_depth' in pred:
pred_depth = norm_depth(pred['image_depth'])
pred_nv_depth = norm_depth(pred_cano['image_depth'])
else:
pred_depth = th.zeros_like(gt_depth)
pred_nv_depth = th.zeros_like(gt_depth)
if 'image_sr' in pred:
if pred['image_sr'].shape[-1] == 512:
pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']],
dim=-1)
gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']],
dim=-1)
pred_depth = self.pool_512(pred_depth)
gt_depth = self.pool_512(gt_depth)
elif pred['image_sr'].shape[-1] == 256:
pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']],
dim=-1)
gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']],
dim=-1)
pred_depth = self.pool_256(pred_depth)
gt_depth = self.pool_256(gt_depth)
else:
pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']],
dim=-1)
gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']],
dim=-1)
gt_depth = self.pool_128(gt_depth)
pred_depth = self.pool_128(pred_depth)
else:
gt_img = self.pool_64(gt_img)
gt_depth = self.pool_64(gt_depth)
pred_vis = th.cat([
pred_img,
pred_depth.repeat_interleave(3, dim=1),
fg_mask.repeat_interleave(3, dim=1),
],
dim=-1) # B, 3, H, W
pred_vis_nv = th.cat([
pred_cano['image_raw'],
pred_nv_depth.repeat_interleave(3, dim=1),
input_fg_mask.repeat_interleave(3, dim=1),
],
dim=-1) # B, 3, H, W
pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim
gt_vis = th.cat([
gt_img,
gt_depth.repeat_interleave(3, dim=1),
th.zeros_like(gt_img)
],
dim=-1) # TODO, fail to load depth. range [0, 1]
if 'conf_sigma' in pred:
gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder
# vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
vis = th.cat([gt_vis, pred_vis], dim=-2)
# .permute(
# 0, 2, 3, 1).cpu()
vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] //
64) # HWC
torchvision.utils.save_image(
vis_tensor,
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
value_range=(-1, 1),
normalize=True)
# vis = vis.numpy() * 127.5 + 127.5
# vis = vis.clip(0, 255).astype(np.uint8)
# Image.fromarray(vis).save(
# f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
logger.log('log vis to: ',
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
# self.writer.add_image(f'images',
# vis,
# self.step + self.resume_step,
# dataformats='HWC')
# return pred
class TrainLoop3DRecNVPatch(TrainLoop3DRecNV):
# add patch rendering
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
# the rendrer
self.eg3d_model = self.rec_model.module.decoder.triplane_decoder # type: ignore
# self.rec_cano = False
self.rec_cano = True
def forward_backward(self, batch, *args, **kwargs):
# add patch sampling
self.mp_trainer_rec.zero_grad()
batch_size = batch['img_to_encoder'].shape[0]
for i in range(0, batch_size, self.microbatch):
micro = {
k: v[i:i + self.microbatch].to(dist_util.dev())
for k, v in batch.items()
}
# ! sample rendering patch
target = {
**self.eg3d_model(
c=micro['nv_c'], # type: ignore
ws=None,
planes=None,
sample_ray_only=True,
fg_bbox=micro['nv_bbox']), # rays o / dir
}
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[
'patch_rendering_resolution'] # type: ignore
cropped_target = {
k:
th.empty_like(v)
[..., :patch_rendering_resolution, :patch_rendering_resolution]
if k not in [
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder',
'nv_img_sr', 'c'
] else v
for k, v in micro.items()
}
# crop according to uv sampling
for j in range(micro['img'].shape[0]):
top, left, height, width = target['ray_bboxes'][
j] # list of tuple
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
for key in ('img', 'depth_mask', 'depth'): # type: ignore
# target[key][i:i+1] = torchvision.transforms.functional.crop(
# cropped_target[key][
# j:j + 1] = torchvision.transforms.functional.crop(
# micro[key][j:j + 1], top, left, height, width)
cropped_target[f'{key}'][ # ! no nv_ here
j:j + 1] = torchvision.transforms.functional.crop(
micro[f'nv_{key}'][j:j + 1], top, left, height,
width)
# target.update(cropped_target)
# wrap forward within amp
with th.autocast(device_type='cuda',
dtype=th.float16,
enabled=self.mp_trainer_rec.use_amp):
# target_nvs = {}
# target_cano = {}
latent = self.rec_model(img=micro['img_to_encoder'],
behaviour='enc_dec_wo_triplane')
pred_nv = self.rec_model(
latent=latent,
c=micro['nv_c'], # predict novel view here
behaviour='triplane_dec',
ray_origins=target['ray_origins'],
ray_directions=target['ray_directions'],
)
# ! directly retrieve from target
# for k, v in target.items():
# if k[:2] == 'nv':
# orig_key = k.replace('nv_', '')
# target_nvs[orig_key] = v
# target_cano[orig_key] = target[orig_key]
with self.rec_model.no_sync(): # type: ignore
loss, loss_dict, _ = self.loss_class(pred_nv,
cropped_target,
step=self.step +
self.resume_step,
test_mode=False,
return_fg_mask=True,
conf_sigma_l1=None,
conf_sigma_percl=None)
log_rec3d_loss_dict(loss_dict)
if self.rec_cano:
cano_target = {
**self.eg3d_model(
c=micro['c'], # type: ignore
ws=None,
planes=None,
sample_ray_only=True,
fg_bbox=micro['bbox']), # rays o / dir
}
cano_cropped_target = {
k: th.empty_like(v)
for k, v in cropped_target.items()
}
for j in range(micro['img'].shape[0]):
top, left, height, width = cano_target['ray_bboxes'][
j] # list of tuple
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
for key in ('img', 'depth_mask',
'depth'): # type: ignore
# target[key][i:i+1] = torchvision.transforms.functional.crop(
cano_cropped_target[key][
j:j +
1] = torchvision.transforms.functional.crop(
micro[key][j:j + 1], top, left, height,
width)
# cano_target.update(cano_cropped_target)
pred_cano = self.rec_model(
latent=latent,
c=micro['c'],
behaviour='triplane_dec',
ray_origins=cano_target['ray_origins'],
ray_directions=cano_target['ray_directions'],
)
with self.rec_model.no_sync(): # type: ignore
fg_mask = cano_cropped_target['depth_mask'].unsqueeze(
1).repeat_interleave(3, 1).float()
loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss(
pred_cano['image_raw'],
cano_cropped_target['img'],
fg_mask,
step=self.step + self.resume_step,
test_mode=False,
)
loss = loss + loss_cano
# remove redundant log
log_rec3d_loss_dict({
f'cano_{k}': v
for k, v in loss_cano_dict.items()
# if "loss" in k
})
self.mp_trainer_rec.backward(loss)
if dist_util.get_rank() == 0 and self.step % 500 == 0:
self.log_patch_img(cropped_target, pred_nv, pred_cano)
@th.inference_mode()
def log_patch_img(self, micro, pred, pred_cano):
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
def norm_depth(pred_depth): # to [-1,1]
# pred_depth = pred['image_depth']
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
pred_depth.min())
return -(pred_depth * 2 - 1)
pred_img = pred['image_raw']
gt_img = micro['img']
# infer novel view also
# if self.loss_class.opt.symmetry_loss:
# pred_nv_img = nvs_pred
# else:
# ! replace with novel view prediction
# ! log another novel-view prediction
# pred_nv_img = self.rec_model(
# img=micro['img_to_encoder'],
# c=self.novel_view_poses) # pred: (B, 3, 64, 64)
# if 'depth' in micro:
gt_depth = micro['depth']
if gt_depth.ndim == 3:
gt_depth = gt_depth.unsqueeze(1)
gt_depth = norm_depth(gt_depth)
# gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
# gt_depth.min())
# if True:
fg_mask = pred['image_mask'] * 2 - 1 # 0-1
input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1
if 'image_depth' in pred:
pred_depth = norm_depth(pred['image_depth'])
pred_cano_depth = norm_depth(pred_cano['image_depth'])
else:
pred_depth = th.zeros_like(gt_depth)
pred_cano_depth = th.zeros_like(gt_depth)
# if 'image_sr' in pred:
# if pred['image_sr'].shape[-1] == 512:
# pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']],
# dim=-1)
# gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']],
# dim=-1)
# pred_depth = self.pool_512(pred_depth)
# gt_depth = self.pool_512(gt_depth)
# elif pred['image_sr'].shape[-1] == 256:
# pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']],
# dim=-1)
# gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']],
# dim=-1)
# pred_depth = self.pool_256(pred_depth)
# gt_depth = self.pool_256(gt_depth)
# else:
# pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']],
# dim=-1)
# gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']],
# dim=-1)
# gt_depth = self.pool_128(gt_depth)
# pred_depth = self.pool_128(pred_depth)
# else:
# gt_img = self.pool_64(gt_img)
# gt_depth = self.pool_64(gt_depth)
pred_vis = th.cat([
pred_img,
pred_depth.repeat_interleave(3, dim=1),
fg_mask.repeat_interleave(3, dim=1),
],
dim=-1) # B, 3, H, W
pred_vis_nv = th.cat([
pred_cano['image_raw'],
pred_cano_depth.repeat_interleave(3, dim=1),
input_fg_mask.repeat_interleave(3, dim=1),
],
dim=-1) # B, 3, H, W
pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim
gt_vis = th.cat([
gt_img,
gt_depth.repeat_interleave(3, dim=1),
th.zeros_like(gt_img)
],
dim=-1) # TODO, fail to load depth. range [0, 1]
# if 'conf_sigma' in pred:
# gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder
# vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
# st()
vis = th.cat([gt_vis, pred_vis], dim=-2)
# .permute(
# 0, 2, 3, 1).cpu()
vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] //
64) # HWC
torchvision.utils.save_image(
vis_tensor,
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
value_range=(-1, 1),
normalize=True)
logger.log('log vis to: ',
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
# self.writer.add_image(f'images',
# vis,
# self.step + self.resume_step,
# dataformats='HWC')
class TrainLoop3DRecNVPatchSingleForward(TrainLoop3DRecNVPatch):
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
def forward_backward(self, batch, *args, **kwargs):
# add patch sampling
self.mp_trainer_rec.zero_grad()
batch_size = batch['img_to_encoder'].shape[0]
batch.pop('caption') # not required
batch.pop('ins') # not required
batch.pop('nv_caption') # not required
batch.pop('nv_ins') # not required
for i in range(0, batch_size, self.microbatch):
micro = {
k:
v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
v, th.Tensor) else v[i:i + self.microbatch]
for k, v in batch.items()
}
# ! sample rendering patch
target = {
**self.eg3d_model(
c=micro['nv_c'], # type: ignore
ws=None,
planes=None,
sample_ray_only=True,
fg_bbox=micro['nv_bbox']), # rays o / dir
}
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[
'patch_rendering_resolution'] # type: ignore
cropped_target = {
k:
th.empty_like(v)
[..., :patch_rendering_resolution, :patch_rendering_resolution]
if k not in [
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder',
'nv_img_sr', 'c', 'caption', 'nv_caption'
] else v
for k, v in micro.items()
}
# crop according to uv sampling
for j in range(micro['img'].shape[0]):
top, left, height, width = target['ray_bboxes'][
j] # list of tuple
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
for key in ('img', 'depth_mask', 'depth'): # type: ignore
# target[key][i:i+1] = torchvision.transforms.functional.crop(
# cropped_target[key][
# j:j + 1] = torchvision.transforms.functional.crop(
# micro[key][j:j + 1], top, left, height, width)
cropped_target[f'{key}'][ # ! no nv_ here
j:j + 1] = torchvision.transforms.functional.crop(
micro[f'nv_{key}'][j:j + 1], top, left, height,
width)
# ! cano view loss
# cano_target = {
# **self.eg3d_model(
# c=micro['c'], # type: ignore
# ws=None,
# planes=None,
# sample_ray_only=True,
# fg_bbox=micro['bbox']), # rays o / dir
# }
# cano_cropped_target = {
# k: th.empty_like(v)
# for k, v in cropped_target.items()
# }
# for j in range(micro['img'].shape[0]):
# top, left, height, width = cano_target['ray_bboxes'][
# j] # list of tuple
# # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
# for key in ('img', 'depth_mask', 'depth'): # type: ignore
# # target[key][i:i+1] = torchvision.transforms.functional.crop(
# cano_cropped_target[key][
# j:j + 1] = torchvision.transforms.functional.crop(
# micro[key][j:j + 1], top, left, height, width)
# ! vit no amp
latent = self.rec_model(img=micro['img_to_encoder'].to(self.dtype),
behaviour='enc_dec_wo_triplane')
# wrap forward within amp
with th.autocast(device_type='cuda',
# dtype=th.float16,
dtype=th.bfloat16, # avoid NAN
enabled=self.mp_trainer_rec.use_amp):
# c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here
# c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here
instance_mv_num = batch_size // 4 # 4 pairs by default
# instance_mv_num = 4
# ! roll views for multi-view supervision
c = th.cat([
micro['nv_c'].roll(instance_mv_num * i, dims=0)
for i in range(1, 4)
]
# + [micro['c']]
) # predict novel view here
ray_origins = th.cat(
[
target['ray_origins'].roll(instance_mv_num * i, dims=0)
for i in range(1, 4)
]
# + [cano_target['ray_origins'] ]
,
0)
ray_directions = th.cat([
target['ray_directions'].roll(instance_mv_num * i, dims=0)
for i in range(1, 4)
]
# + [cano_target['ray_directions'] ]
)
pred_nv_cano = self.rec_model(
# latent=latent.expand(2,),
latent={
'latent_after_vit': # ! triplane for rendering
# latent['latent_after_vit'].repeat(2, 1, 1, 1)
latent['latent_after_vit'].repeat(3, 1, 1, 1)
},
c=c,
behaviour='triplane_dec',
# ray_origins=target['ray_origins'],
# ray_directions=target['ray_directions'],
ray_origins=ray_origins,
ray_directions=ray_directions,
)
pred_nv_cano.update(
latent
) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True)
# gt = {
# k: th.cat([v, cano_cropped_target[k]], 0)
# for k, v in cropped_target.items()
# }
gt = {
k:
th.cat(
[
v.roll(instance_mv_num * i, dims=0)
for i in range(1, 4)
]
# + [cano_cropped_target[k] ]
,
0)
for k, v in cropped_target.items()
} # torchvision.utils.save_image(gt['img'], 'gt.png', normalize=True)
with self.rec_model.no_sync(): # type: ignore
loss, loss_dict, _ = self.loss_class(
pred_nv_cano,
gt, # prepare merged data
step=self.step + self.resume_step,
test_mode=False,
return_fg_mask=True,
conf_sigma_l1=None,
conf_sigma_percl=None)
log_rec3d_loss_dict(loss_dict)
self.mp_trainer_rec.backward(loss)
# for name, p in self.rec_model.named_parameters():
# if p.grad is None:
# logger.log(f"found rec unused param: {name}")
if dist_util.get_rank() == 0 and self.step % 500 == 0:
micro_bs = micro['img_to_encoder'].shape[0]
self.log_patch_img( # record one cano view and one novel view
cropped_target,
{
k: pred_nv_cano[k][-micro_bs:]
for k in ['image_raw', 'image_depth', 'image_mask']
},
{
k: pred_nv_cano[k][:micro_bs]
for k in ['image_raw', 'image_depth', 'image_mask']
},
)
def eval_loop(self):
return super().eval_loop()
@th.inference_mode()
# def eval_loop(self, c_list:list):
def eval_novelview_loop_old(self, camera=None):
# novel view synthesis given evaluation camera trajectory
all_loss_dict = []
novel_view_micro = {}
# ! randomly inference an instance
export_mesh = True
if export_mesh:
Path(f'{logger.get_dir()}/FID_Cals/').mkdir(parents=True,
exist_ok=True)
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
batch = {}
# if camera is not None:
# # batch['c'] = camera.to(batch['c'].device())
# batch['c'] = camera.clone()
# else:
# batch =
for eval_idx, render_reference in enumerate(tqdm(self.eval_data)):
if eval_idx > 500:
break
video_out = imageio.get_writer(
f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_{eval_idx}.mp4',
mode='I',
fps=25,
codec='libx264')
with open(
f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{eval_idx}_caption.txt',
'w') as f:
f.write(render_reference['caption'])
for key in ['ins', 'bbox', 'caption']:
if key in render_reference:
render_reference.pop(key)
real_flag = False
mv_flag = False # TODO, use full-instance for evaluation? Calculate the metrics.
if render_reference['c'].shape[:2] == (1, 40):
real_flag = True
# real img monocular reconstruction
# compat lst for enumerate
render_reference = [{
k: v[0][idx:idx + 1]
for k, v in render_reference.items()
} for idx in range(40)]
elif render_reference['c'].shape[0] == 8:
mv_flag = True
render_reference = {
k: v[:4]
for k, v in render_reference.items()
}
# save gt
torchvision.utils.save_image(
render_reference[0:4]['img'],
logger.get_dir() + '/FID_Cals/{}_inp.png'.format(eval_idx),
padding=0,
normalize=True,
value_range=(-1, 1),
)
# torchvision.utils.save_image(render_reference[4:8]['img'],
# logger.get_dir() + '/FID_Cals/{}_inp2.png'.format(eval_idx),
# padding=0,
# normalize=True,
# value_range=(-1,1),
# )
else:
# compat lst for enumerate
st()
render_reference = [{
k: v[idx:idx + 1]
for k, v in render_reference.items()
} for idx in range(40)]
# ! single-view version
render_reference[0]['img_to_encoder'] = render_reference[14][
'img_to_encoder'] # encode side view
render_reference[0]['img'] = render_reference[14][
'img'] # encode side view
# save gt
torchvision.utils.save_image(
render_reference[0]['img'],
logger.get_dir() + '/FID_Cals/{}_gt.png'.format(eval_idx),
padding=0,
normalize=True,
value_range=(-1, 1))
# ! TODO, merge with render_video_given_triplane later
for i, batch in enumerate(render_reference):
# for i in range(0, 8, self.microbatch):
# c = c_list[i].to(dist_util.dev()).reshape(1, -1)
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
st()
if i == 0:
if mv_flag:
novel_view_micro = None
else:
novel_view_micro = {
k:
v[0:1].to(dist_util.dev()).repeat_interleave(
# v[14:15].to(dist_util.dev()).repeat_interleave(
micro['img'].shape[0],
0) if isinstance(v, th.Tensor) else v[0:1]
for k, v in batch.items()
}
else:
if i == 1:
# ! output mesh
if export_mesh:
# ! get planes first
# self.latent_name = 'latent_normalized' # normalized triplane latent
# ddpm_latent = {
# self.latent_name: planes,
# }
# ddpm_latent.update(self.rec_model(latent=ddpm_latent, behaviour='decode_after_vae_no_render'))
# mesh_size = 512
# mesh_size = 256
mesh_size = 384
# mesh_size = 320
# mesh_thres = 3 # TODO, requires tuning
# mesh_thres = 5 # TODO, requires tuning
mesh_thres = 10 # TODO, requires tuning
import mcubes
import trimesh
dump_path = f'{logger.get_dir()}/mesh/'
os.makedirs(dump_path, exist_ok=True)
grid_out = self.rec_model(
latent=pred,
grid_size=mesh_size,
behaviour='triplane_decode_grid',
)
vtx, faces = mcubes.marching_cubes(
grid_out['sigma'].squeeze(0).squeeze(
-1).cpu().numpy(), mesh_thres)
vtx = vtx / (mesh_size - 1) * 2 - 1
# vtx_tensor = th.tensor(vtx, dtype=th.float32, device=dist_util.dev()).unsqueeze(0)
# vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1)
# vtx_colors = (vtx_colors * 255).astype(np.uint8)
# mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
mesh = trimesh.Trimesh(
vertices=vtx,
faces=faces,
)
mesh_dump_path = os.path.join(
dump_path, f'{eval_idx}.ply')
mesh.export(mesh_dump_path, 'ply')
print(f"Mesh dumped to {dump_path}")
del grid_out, mesh
th.cuda.empty_cache()
# return
# st()
# if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
novel_view_micro = {
k:
v[0:1].to(dist_util.dev()).repeat_interleave(
micro['img'].shape[0], 0)
for k, v in novel_view_micro.items()
}
pred = self.rec_model(img=novel_view_micro['img_to_encoder'],
c=micro['c']) # pred: (B, 3, 64, 64)
# target = {
# 'img': micro['img'],
# 'depth': micro['depth'],
# 'depth_mask': micro['depth_mask']
# }
# targe
# if not export_mesh:
if not real_flag:
_, loss_dict = self.loss_class(pred, micro, test_mode=True)
all_loss_dict.append(loss_dict)
# ! move to other places, add tensorboard
# pred_vis = th.cat([
# pred['image_raw'],
# -pred['image_depth'].repeat_interleave(3, dim=1)
# ],
# dim=-1)
# normalize depth
# if True:
pred_depth = pred['image_depth']
pred_depth = (pred_depth - pred_depth.min()) / (
pred_depth.max() - pred_depth.min())
if 'image_sr' in pred:
if pred['image_sr'].shape[-1] == 512:
pred_vis = th.cat([
micro['img_sr'],
self.pool_512(pred['image_raw']), pred['image_sr'],
self.pool_512(pred_depth).repeat_interleave(3,
dim=1)
],
dim=-1)
elif pred['image_sr'].shape[-1] == 256:
pred_vis = th.cat([
micro['img_sr'],
self.pool_256(pred['image_raw']), pred['image_sr'],
self.pool_256(pred_depth).repeat_interleave(3,
dim=1)
],
dim=-1)
else:
pred_vis = th.cat([
micro['img_sr'],
self.pool_128(pred['image_raw']),
self.pool_128(pred['image_sr']),
self.pool_128(pred_depth).repeat_interleave(3,
dim=1)
],
dim=-1)
else:
# pred_vis = th.cat([
# self.pool_64(micro['img']), pred['image_raw'],
# pred_depth.repeat_interleave(3, dim=1)
# ],
# dim=-1) # B, 3, H, W
pooled_depth = self.pool_128(pred_depth).repeat_interleave(
3, dim=1)
pred_vis = th.cat(
[
# self.pool_128(micro['img']),
self.pool_128(novel_view_micro['img']
), # use the input here
self.pool_128(pred['image_raw']),
pooled_depth,
],
dim=-1) # B, 3, H, W
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
vis = vis * 127.5 + 127.5
vis = vis.clip(0, 255).astype(np.uint8)
if export_mesh:
# save image
torchvision.utils.save_image(
pred['image_raw'],
logger.get_dir() +
'/FID_Cals/{}_{}.png'.format(eval_idx, i),
padding=0,
normalize=True,
value_range=(-1, 1))
torchvision.utils.save_image(
pooled_depth,
logger.get_dir() +
'/FID_Cals/{}_{}_dpeth.png'.format(eval_idx, i),
padding=0,
normalize=True,
value_range=(0, 1))
# st()
for j in range(vis.shape[0]):
video_out.append_data(vis[j])
video_out.close()
# if not export_mesh:
if not real_flag or mv_flag:
val_scores_for_logging = calc_average_loss(all_loss_dict)
with open(os.path.join(logger.get_dir(), 'scores_novelview.json'),
'a') as f:
json.dump({'step': self.step, **val_scores_for_logging}, f)
# * log to tensorboard
for k, v in val_scores_for_logging.items():
self.writer.add_scalar(f'Eval/NovelView/{k}', v,
self.step + self.resume_step)
del video_out
# del pred_vis
# del pred
th.cuda.empty_cache()
@th.inference_mode()
# def eval_loop(self, c_list:list):
def eval_novelview_loop(self, camera=None, save_latent=False):
# novel view synthesis given evaluation camera trajectory
if save_latent: # for diffusion learning
latent_dir = Path(f'{logger.get_dir()}/latent_dir')
latent_dir.mkdir(exist_ok=True, parents=True)
# wds_path = os.path.join(logger.get_dir(), 'latent_dir',
# f'wds-%06d.tar')
# sink = wds.ShardWriter(wds_path, start_shard=0)
# eval_batch_size = 20
# eval_batch_size = 1
eval_batch_size = 40 # ! for i23d
latent_rec_statistics = False
for eval_idx, micro in enumerate(tqdm(self.eval_data)):
# if eval_idx > 500:
# break
latent = self.rec_model(
img=micro['img_to_encoder'],
behaviour='encoder_vae') # pred: (B, 3, 64, 64)
# torchvision.utils.save_image(micro['img'], 'inp.jpg')
if micro['img'].shape[0] == 40:
assert eval_batch_size == 40
if save_latent:
# np.save(f'{logger.get_dir()}/latent_dir/{eval_idx}.npy', latent[self.latent_name].cpu().numpy())
latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}'
Path(latent_save_dir).mkdir(parents=True, exist_ok=True)
np.save(f'{latent_save_dir}/latent.npy',
latent[self.latent_name][0].cpu().numpy())
assert all([
micro['ins'][0] == micro['ins'][i]
for i in range(micro['c'].shape[0])
]) # ! assert same instance
# for i in range(micro['img'].shape[0]):
# compressed_sample = {
# 'latent':latent[self.latent_name][0].cpu().numpy(), # 12 32 32
# 'caption': micro['caption'][0].encode('utf-8'),
# 'ins': micro['ins'][0].encode('utf-8'),
# 'c': micro['c'][i].cpu().numpy(),
# 'img': micro['img'][i].cpu().numpy() # 128x128, for diffusion log
# }
# sink.write({
# "__key__": f"sample_{eval_idx*eval_batch_size+i:07d}",
# 'sample.pyd': compressed_sample
# })
if latent_rec_statistics:
gen_imgs = self.render_video_given_triplane(
latent[self.latent_name],
self.rec_model, # compatible with join_model
name_prefix=f'{self.step + self.resume_step}_{eval_idx}',
save_img=False,
render_reference={'c': micro['c']},
save_mesh=False,
render_reference_length=4,
return_gen_imgs=True)
rec_psnr = psnr((micro['img'] / 2 + 0.5),
(gen_imgs.cpu() / 2 + 0.5), 1.0)
with open(
os.path.join(logger.get_dir(),
'four_view_rec_psnr.json'), 'a') as f:
json.dump(
{
f'{eval_idx}': {
'ins': micro["ins"][0],
'psnr': rec_psnr.item(),
}
}, f)
# save to json
elif eval_idx < 30:
# if False:
self.render_video_given_triplane(
latent[self.latent_name],
self.rec_model, # compatible with join_model
name_prefix=f'{self.step + self.resume_step}_{micro["ins"][0].split("/")[0]}_{eval_idx}',
save_img=False,
render_reference={'c': camera},
save_mesh=True)
class TrainLoop3DRecNVPatchSingleForwardMV(TrainLoop3DRecNVPatchSingleForward):
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
def forward_backward(self, batch, behaviour='g_step', *args, **kwargs):
# add patch sampling
if behaviour == 'g_step':
self.mp_trainer_rec.zero_grad()
else:
self.mp_trainer_disc.zero_grad()
batch_size = batch['img_to_encoder'].shape[0]
batch.pop('caption') # not required
batch.pop('ins') # not required
batch.pop('nv_caption') # not required
batch.pop('nv_ins') # not required
if '__key__' in batch.keys():
batch.pop('__key__')
for i in range(0, batch_size, self.microbatch):
micro = {
k:
v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
v, th.Tensor) else v[i:i + self.microbatch]
for k, v in batch.items()
}
# ! sample rendering patch
# nv_c = th.cat([micro['nv_c'], micro['c']])
nv_c = th.cat([micro['nv_c'], micro['c']])
# nv_c = micro['nv_c']
target = {
**self.eg3d_model(
c=nv_c, # type: ignore
ws=None,
planes=None,
sample_ray_only=True,
fg_bbox=th.cat([micro['nv_bbox'], micro['bbox']])), # rays o / dir
}
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[
'patch_rendering_resolution'] # type: ignore
cropped_target = {
k:
th.empty_like(v).repeat_interleave(2, 0)
# th.empty_like(v).repeat_interleave(1, 0)
[..., :patch_rendering_resolution, :patch_rendering_resolution]
if k not in [
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder',
'nv_img_sr', 'c', 'caption', 'nv_caption'
] else v
for k, v in micro.items()
}
# crop according to uv sampling
for j in range(2 * self.microbatch):
top, left, height, width = target['ray_bboxes'][
j] # list of tuple
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
for key in ('img', 'depth_mask', 'depth'): # type: ignore
if j < self.microbatch:
cropped_target[f'{key}'][ # ! no nv_ here
j:j + 1] = torchvision.transforms.functional.crop(
micro[f'nv_{key}'][j:j + 1], top, left, height,
width)
else:
cropped_target[f'{key}'][ # ! no nv_ here
j:j + 1] = torchvision.transforms.functional.crop(
micro[f'{key}'][j - self.microbatch:j -
self.microbatch + 1], top,
left, height, width)
# for j in range(batch_size, 2*batch_size, 1):
# top, left, height, width = target['ray_bboxes'][
# j] # list of tuple
# # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
# for key in ('img', 'depth_mask', 'depth'): # type: ignore
# cropped_target[f'{key}'][ # ! no nv_ here
# j:j + 1] = torchvision.transforms.functional.crop(
# micro[f'{key}'][j-batch_size:j-batch_size + 1], top, left, height,
# width)
# wrap forward within amp
with th.autocast(device_type='cuda',
dtype=self.dtype,
enabled=self.mp_trainer_rec.use_amp):
# c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here
# c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here
# instance_mv_num = batch_size // 4 # 4 pairs by default
# instance_mv_num = 4
# ! roll views for multi-view supervision
# c = micro['nv_c']
# ! vit no amp
latent = self.rec_model(img=micro['img_to_encoder'].to(self.dtype),
behaviour='enc_dec_wo_triplane')
# # ! disable amp in rendering and loss
# with th.autocast(device_type='cuda',
# dtype=th.float16,
# enabled=False):
ray_origins = target['ray_origins']
ray_directions = target['ray_directions']
pred_nv_cano = self.rec_model(
# latent=latent.expand(2,),
latent={
'latent_after_vit': # ! triplane for rendering
# latent['latent_after_vit'].repeat_interleave(4, dim=0).repeat(2,1,1,1) # NV=4
latent['latent_after_vit'].repeat_interleave(6, dim=0).repeat(2,1,1,1) # NV=6
# latent['latent_after_vit'].repeat_interleave(10, dim=0).repeat(2,1,1,1) # NV=4
# latent['latent_after_vit'].repeat_interleave(8, dim=0) # NV=4
},
c=nv_c,
behaviour='triplane_dec',
ray_origins=ray_origins,
ray_directions=ray_directions,
)
pred_nv_cano.update(
latent
) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True)
gt = cropped_target
with self.rec_model.no_sync(): # type: ignore
loss, loss_dict, _ = self.loss_class(
pred_nv_cano,
gt, # prepare merged data
step=self.step + self.resume_step,
test_mode=False,
return_fg_mask=True,
behaviour=behaviour,
conf_sigma_l1=None,
conf_sigma_percl=None,
# dtype=self.dtype
)
log_rec3d_loss_dict(loss_dict)
if behaviour == 'g_step':
self.mp_trainer_rec.backward(loss)
else:
self.mp_trainer_disc.backward(loss)
# for name, p in self.rec_model.named_parameters():
# if p.grad is None:
# logger.log(f"found rec unused param: {name}")
# torchvision.utils.save_image(cropped_target['img'], 'gt.png', normalize=True)
# torchvision.utils.save_image( pred_nv_cano['image_raw'], 'pred.png', normalize=True)
if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0 and behaviour == 'g_step':
try:
torchvision.utils.save_image(
th.cat(
[cropped_target['img'], pred_nv_cano['image_raw']
], ),
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
normalize=True, nrow=6*2)
logger.log(
'log vis to: ',
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
except Exception as e:
logger.log(e)
# micro_bs = micro['img_to_encoder'].shape[0]
# self.log_patch_img( # record one cano view and one novel view
# cropped_target,
# {
# k: pred_nv_cano[k][0:1]
# for k in ['image_raw', 'image_depth', 'image_mask']
# },
# {
# k: pred_nv_cano[k][1:2]
# for k in ['image_raw', 'image_depth', 'image_mask']
# },
# )
# def save(self):
# return super().save()
class TrainLoop3DRecNVPatchSingleForwardMVAdvLoss(
TrainLoop3DRecNVPatchSingleForwardMV):
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
# create discriminator
disc_params = self.loss_class.get_trainable_parameters()
self.mp_trainer_disc = MixedPrecisionTrainer(
model=self.loss_class.discriminator,
use_fp16=self.use_fp16,
fp16_scale_growth=fp16_scale_growth,
model_name='disc',
use_amp=use_amp,
model_params=disc_params)
# st() # check self.lr
self.opt_disc = AdamW(
self.mp_trainer_disc.master_params,
lr=self.lr, # follow sd code base
betas=(0, 0.999),
eps=1e-8)
# TODO, is loss cls already in the DDP?
if self.use_ddp:
self.ddp_disc = DDP(
self.loss_class.discriminator,
device_ids=[dist_util.dev()],
output_device=dist_util.dev(),
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
else:
self.ddp_disc = self.loss_class.discriminator
# def run_st
# def run_step(self, batch, *args):
# self.forward_backward(batch)
# took_step = self.mp_trainer_rec.optimize(self.opt)
# if took_step:
# self._update_ema()
# self._anneal_lr()
# self.log_step()
def save(self, mp_trainer=None, model_name='rec'):
if mp_trainer is None:
mp_trainer = self.mp_trainer_rec
def save_checkpoint(rate, params):
state_dict = mp_trainer.master_params_to_state_dict(params)
if dist_util.get_rank() == 0:
logger.log(f"saving model {model_name} {rate}...")
if not rate:
filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt"
else:
filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename),
"wb") as f:
th.save(state_dict, f)
save_checkpoint(0, mp_trainer.master_params)
dist.barrier()
# ! load disc
def _load_and_sync_parameters(self, submodule_name=''):
super()._load_and_sync_parameters(submodule_name)
# load disc
resume_checkpoint = self.resume_checkpoint.replace(
'rec', 'disc') # * default behaviour
if os.path.exists(resume_checkpoint):
if dist_util.get_rank() == 0:
logger.log(
f"loading disc model from checkpoint: {resume_checkpoint}..."
)
map_location = {
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
} # configure map_location properly
resume_state_dict = dist_util.load_state_dict(
resume_checkpoint, map_location=map_location)
model_state_dict = self.loss_class.discriminator.state_dict()
for k, v in resume_state_dict.items():
if k in model_state_dict.keys():
if v.size() == model_state_dict[k].size():
model_state_dict[k] = v
# model_state_dict[k].copy_(v)
else:
logger.log('!!!! partially load: ', k, ": ",
v.size(), "state_dict: ",
model_state_dict[k].size())
if dist_util.get_world_size() > 1:
# dist_util.sync_params(self.model.named_parameters())
dist_util.sync_params(
self.loss_class.get_trainable_parameters())
logger.log('synced disc params')
def run_step(self, batch, step='g_step'):
# self.forward_backward(batch)
if step == 'g_step':
self.forward_backward(batch, behaviour='g_step')
took_step_g_rec = self.mp_trainer_rec.optimize(self.opt)
if took_step_g_rec:
self._update_ema() # g_ema
elif step == 'd_step':
self.forward_backward(batch, behaviour='d_step')
_ = self.mp_trainer_disc.optimize(self.opt_disc)
self._anneal_lr()
self.log_step()
def run_loop(self, batch=None):
while (not self.lr_anneal_steps
or self.step + self.resume_step < self.lr_anneal_steps):
batch = next(self.data)
self.run_step(batch, 'g_step')
batch = next(self.data)
self.run_step(batch, 'd_step')
if self.step % 1000 == 0:
dist_util.synchronize()
if self.step % 5000 == 0:
th.cuda.empty_cache() # avoid memory leak
if self.step % self.log_interval == 0 and dist_util.get_rank(
) == 0:
out = logger.dumpkvs()
# * log to tensorboard
for k, v in out.items():
self.writer.add_scalar(f'Loss/{k}', v,
self.step + self.resume_step)
if self.step % self.eval_interval == 0 and self.step != 0:
if dist_util.get_rank() == 0:
try:
self.eval_loop()
except Exception as e:
logger.log(e)
dist_util.synchronize()
# if self.step % self.save_interval == 0 and self.step != 0:
if self.step % self.save_interval == 0:
self.save()
self.save(self.mp_trainer_disc,
self.mp_trainer_disc.model_name)
dist_util.synchronize()
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST",
"") and self.step > 0:
return
self.step += 1
if self.step > self.iterations:
logger.log('reached maximum iterations, exiting')
# Save the last checkpoint if it wasn't already saved.
if (self.step -
1) % self.save_interval != 0 and self.step != 1:
self.save()
exit()
# Save the last checkpoint if it wasn't already saved.
# if (self.step - 1) % self.save_interval != 0 and self.step != 1:
if (self.step - 1) % self.save_interval != 0:
self.save() # save rec
self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name)
class TrainLoop3DRecNVPatchSingleForwardMV_NoCrop(
TrainLoop3DRecNVPatchSingleForwardMV):
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
num_frames=4,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
self.num_frames = num_frames
self.ray_sampler = RaySampler()
print(self.opt)
# ! requires tuning
N = 768 # hyp param, overfitting now
# self.scale_expected_threshold = (1 / (N/2)) ** 0.5 * 0.45
self.scale_expected_threshold = 0.0075
self.latent_name = 'latent_normalized' # normalized triplane latent
# to transform to 3dgs
self.gs_bg_color=th.tensor([1,1,1], dtype=th.float32, device=dist_util.dev())
self.post_process = PostProcess(
384,
384,
imgnet_normalize=True,
plucker_embedding=True,
decode_encode_img_only=False,
mv_input=True,
split_chunk_input=16,
duplicate_sample=True,
append_depth=False,
append_xyz=False,
gs_cam_format=True,
orthog_duplicate=False,
frame_0_as_canonical=False,
pcd_path='pcd_path',
load_pcd=True,
split_chunk_size=16,
)
self.zfar = 100.0
self.znear = 0.01
# def _init_optim_groups(self, kwargs):
# return super()._init_optim_groups({**kwargs, 'ignore_encoder': True}) # freeze MVEncoder to accelerate training.
def forward_backward(self, batch, behaviour='g_step', *args, **kwargs):
# add patch sampling
self.mp_trainer_rec.zero_grad()
batch_size = batch['img_to_encoder'].shape[0]
batch.pop('caption') # not required
ins = batch.pop('ins') # not required
if '__key__' in batch.keys():
batch.pop('__key__')
assert isinstance(batch['c'], dict)
for i in range(0, batch_size, self.microbatch):
micro = {}
for k, v in batch.items(): # grad acc
if isinstance(v, th.Tensor):
micro[k] = v[i:i + self.microbatch].to(dist_util.dev())
elif isinstance(v, list):
micro[k] = v[i:i + self.microbatch]
elif isinstance(v, dict): #
assert k in ['c', 'nv_c']
micro[k] = {
key:
value[i:i + self.microbatch].to(dist_util.dev()) if
isinstance(value, th.Tensor) else value # can be float
for key, value in v.items()
}
assert micro['img_to_encoder'].shape[1] == 15
micro['normal'] = micro['img_to_encoder'][:, 3:6]
micro['nv_normal'] = micro['nv_img_to_encoder'][:, 3:6]
# ! concat nv_c to render N+N views
indices = np.random.permutation(self.num_frames)
indices, indices_nv = indices[:4], indices[-4:] # make sure thorough pose converage.
# indices, indices_nv = indices[:2], indices[-2:] # ! 2+2 views for supervision, as in gs-lrm.
# indices_nv = np.random.permutation(self.num_frames)[:6] # randomly pick 4+4 views for supervision.
# indices = np.arange(self.num_frames)
# indices_nv = np.arange(self.num_frames)
nv_c = {}
for key in micro['c'].keys():
if isinstance(micro['c'][key], th.Tensor):
nv_c[key] = th.cat([micro['c'][key][:, indices], micro['nv_c'][key][:, indices_nv]],
1) # B 2V ...
else:
nv_c[key] = micro['c'][key] # float, will remove later
target = {}
for key in ('img', 'depth_mask', 'depth', 'normal',): # type: ignore
# st()
target[key] = th.cat([
rearrange(micro[key], '(B V) ... -> B V ...', V=self.num_frames)[:, indices],
rearrange(micro[f'nv_{key}'], '(B V) ... -> B V ...', V=self.num_frames)[:, indices_nv]
# rearrange(micro[key][:, indices], '(B V) ... -> B V ...', V=4),
# rearrange(micro[f'nv_{key}'][:, indices], '(B V) ... -> B V ...', V=4)
], 1) # B 2*V H W
target[key] = rearrange(target[key],
'B V ... -> (B V) ...') # concat
# st()
# wrap forward within amp
with th.autocast(device_type='cuda',
dtype=self.dtype,
enabled=self.mp_trainer_rec.use_amp):
# ! vit no amp
# with profile(activities=[
# ProfilerActivity.CUDA], record_shapes=True) as prof:
# with record_function("get_gs"):
latent = self.rec_model(
img=micro['img_to_encoder'].to(self.dtype),
behaviour='enc_dec_wo_triplane',
c=micro['c'],
pcd=micro['fps_pcd'], # send in pcd for surface reference.
) # send in input-view C since pixel-aligned gaussians required
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
gaussians, query_pcd_xyz = latent['gaussians'], latent['query_pcd_xyz']
# query_pcd_xyz = latent['query_pcd_xyz']
# if self.loss_class.opt.rand_aug_bg and random.random()>0.9:
if self.loss_class.opt.rand_aug_bg:
bg_color=torch.randint(0,255,(3,), device=dist_util.dev()) / 255.0
else:
bg_color=torch.tensor([1,1,1], dtype=torch.float32, device=dist_util.dev())
def visualize_latent_activations(latent, b_idx=0, write=False, ):
def normalize_latent_plane(latent_plane):
avg_p1 = latent_plane.detach().cpu().numpy().mean(0, keepdims=0)
avg_p1 = (avg_p1 - avg_p1.min()) / (avg_p1.max() - avg_p1.min())
# return avg_p1
return ((avg_p1).clip(0,1)*255.0).astype(np.uint8)
p1, p2, p3 = (normalize_latent_plane(latent_plane) for latent_plane in (latent[b_idx, 0:4], latent[b_idx,4:8], latent[b_idx,8:12]))
if write:
plt.imsave(os.path.join(logger.get_dir(), f'{self.step}_{b_idx}_1.jpg'), p1)
plt.imsave(os.path.join(logger.get_dir(), f'{self.step}_{b_idx}_2.jpg'), p2)
plt.imsave(os.path.join(logger.get_dir(), f'{self.step}_{b_idx}_3.jpg'), p3)
# imageio.imwrite(os.path.join(logger.get_dir(), f'{b_idx}_1.jpg'), p1)
# imageio.imwrite(os.path.join(logger.get_dir(), f'{b_idx}_2.jpg'), p2)
# imageio.imwrite(os.path.join(logger.get_dir(), f'{b_idx}_3.jpg'), p3)
return p1, p2, p3
# with profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU,], record_shapes=True) as prof:
# # ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
# with record_function("rendering"):
pred_nv_cano = self.rec_model(
latent=latent,
# latent={
# 'gaussians': latent['gaussians'].repeat_interleave(2,0)
# },
c=nv_c,
behaviour='triplane_dec',
bg_color=bg_color,
)
fine_scale_key = list(pred_nv_cano.keys())[-1]
# st()
fine_gaussians = latent[fine_scale_key]
fine_gaussians_opa = fine_gaussians[..., 3:4]
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
# st() # torchvision.utils.save_image(pred_nv_cano['image_raw'][0], 'pred.jpg', normalize=True, value_range=(-1,1))
if self.loss_class.opt.rand_aug_bg:
# bg_color
alpha_mask = target['depth_mask'].float().unsqueeze(1) # B 1 H W
target['img'] = target['img'] * alpha_mask + (bg_color.reshape(1,3,1,1) * 2 - 1) * (1-alpha_mask)
target['depth_mask'] = target['depth_mask'].unsqueeze(1)
target['depth'] = target['depth'].unsqueeze(1)
multiscale_target = defaultdict(dict)
multiscale_pred = defaultdict(dict)
for idx, (gaussian_wavelet_key, gaussian_wavelet) in enumerate(pred_nv_cano.items()):
gs_output_size = pred_nv_cano[gaussian_wavelet_key]['image_raw'].shape[-1]
for k in gaussian_wavelet.keys():
pred_nv_cano[gaussian_wavelet_key][k] = rearrange(
gaussian_wavelet[k], 'B V ... -> (B V) ...') # match GT shape order
# if idx == 0: # only KL calculation in scale 0
if gaussian_wavelet_key == fine_scale_key:
pred_nv_cano[gaussian_wavelet_key].update(
{
k: latent[k] for k in ['posterior']
}
) # ! for KL supervision
# ! prepare target according to the wavelet size
for k in target.keys():
if target[key].shape[-1] == gs_output_size:
multiscale_target[gaussian_wavelet_key][k] = target[k]
else:
if k in ('depth', 'normal'):
mode = 'nearest'
else:
mode='bilinear'
multiscale_target[gaussian_wavelet_key][k] = F.interpolate(target[k], size=(gs_output_size, gs_output_size), mode=mode)
# st()
# st()
# torchvision.utils.save_image(target['img'], 'gt.jpg', normalize=True, value_range=(-1,1))
# torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.jpg', normalize=True, value_range=(-1,1))
# torchvision.utils.save_image(micro['img'], 'inp_gt.jpg', normalize=True, value_range=(-1,1))
# torchvision.utils.save_image(micro['nv_img'], 'nv_gt.jpg', normalize=True, value_range=(-1,1))
# if self.loss_class.opt.rand_aug_bg:
# # bg_color
# alpha_mask = target['depth_mask'].float().unsqueeze(1) # B 1 H W
# target['img'] = target['img'] * alpha_mask + (bg_color.reshape(1,3,1,1) * 2 - 1) * (1-alpha_mask)
lod_num = len(pred_nv_cano.keys())
random_scale_for_lpips = random.choice(list(pred_nv_cano.keys()))
with self.rec_model.no_sync(): # type: ignore
# with profile(activities=[
# ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
# with record_function("loss"):
loss = th.tensor(0., device=dist_util.dev())
loss_dict = {}
if behaviour == 'd_step':
loss_scale, loss_dict_scale, _ = self.loss_class(
pred_nv_cano[random_scale_for_lpips],
multiscale_target[random_scale_for_lpips], # prepare merged data
step=self.step + self.resume_step,
test_mode=False,
return_fg_mask=True,
behaviour=behaviour,
conf_sigma_l1=None,
conf_sigma_percl=None,
ignore_kl=True, # only calculate once
ignore_lpips=True, # lpips on each lod
ignore_d_loss=False)
loss = loss + loss_scale
loss_dict.update(
{
f"{gaussian_wavelet_key.replace('gaussians_', '')}/{loss_key}": loss_v for loss_key, loss_v in loss_dict_scale.items()
}
)
else:
for scale_idx, gaussian_wavelet_key in enumerate(pred_nv_cano.keys()): # ! multi-scale gs rendering supervision
loss_scale, loss_dict_scale, _ = self.loss_class(
pred_nv_cano[gaussian_wavelet_key],
multiscale_target[gaussian_wavelet_key], # prepare merged data
step=self.step + self.resume_step,
test_mode=False,
return_fg_mask=True,
behaviour=behaviour,
conf_sigma_l1=None,
conf_sigma_percl=None,
ignore_kl=gaussian_wavelet_key!=fine_scale_key, # only calculate once
ignore_lpips=gaussian_wavelet_key!=random_scale_for_lpips, # lpips on each lod
ignore_d_loss=gaussian_wavelet_key!=fine_scale_key)
loss = loss + loss_scale
loss_dict.update(
{
f"{gaussian_wavelet_key.replace('gaussians_', '')}/{loss_key}": loss_v for loss_key, loss_v in loss_dict_scale.items()
}
)
pos = latent['pos']
opacity = gaussians[..., 3:4]
scaling = gaussians[..., 4:6] # 2dgs here
if self.step % self.log_interval == 0 and dist_util.get_rank(
) == 0:
with th.no_grad(): # save idx 0 here
try:
self.writer.add_histogram("scene/opacity_hist",
opacity[0][:],
self.step + self.resume_step)
self.writer.add_histogram("scene/scale_hist",
scaling[0][:],
self.step + self.resume_step)
except Exception as e:
logger.log(e)
if behaviour == 'g_step':
# ! 2dgs loss
# debugging now, open it from the beginning
# if (self.step + self.resume_step) >= 2000 and self.loss_class.opt.lambda_normal > 0:
surf_normal = multiscale_target[fine_scale_key]['normal'] * multiscale_target[fine_scale_key]['depth_mask'] # foreground supervision only.
# ! hard-coded
# rend_normal = pred_nv_cano['rend_normal'] # ! supervise disk normal with GT normal here instead;
# st()
rend_normal = pred_nv_cano[fine_scale_key]['rend_normal']
rend_dist = pred_nv_cano[fine_scale_key]['dist']
if self.loss_class.opt.lambda_scale_reg > 0:
scale_reg = (scaling-self.scale_expected_threshold).square().mean() * self.loss_class.opt.lambda_scale_reg
loss = loss + scale_reg
loss_dict.update({'loss_scale_reg': scale_reg})
if self.loss_class.opt.lambda_opa_reg > 0:
# small_base_opa = latent['gaussians_base_opa']
opa_reg = (-self.loss_class.beta_mvp_base_dist.log_prob(latent['gaussians_base_opa'].clamp(min=1/255, max=0.99)).mean()) * self.loss_class.opt.lambda_opa_reg
# ! also on the fine stage
opa_reg_fine = (-self.loss_class.beta_mvp_base_dist.log_prob(fine_gaussians_opa.clamp(min=1/255, max=0.99)).mean()) * self.loss_class.opt.lambda_opa_reg
# opa_reg = (1-latent['gaussians_base_opa'].mean() ) * self.loss_class.opt.lambda_opa_reg
loss = loss + opa_reg + opa_reg_fine
loss_dict.update({'loss_opa_reg': opa_reg, 'loss_opa_reg_fine': opa_reg_fine})
if (self.step + self.resume_step) >= 35000 and self.loss_class.opt.lambda_normal > 0:
# if (self.step + self.resume_step) >= 2000 and self.loss_class.opt.lambda_normal > 0:
# surf_normal = unity2blender_th(surf_normal) # ! g-buffer normal system is different
normal_error = (1 - (rend_normal * surf_normal).sum(dim=1)) # B H W
# normal_loss = self.loss_class.opt.lambda_normal * (normal_error.sum() / target['depth_mask'].sum()) # average with fg area ratio
normal_loss = self.loss_class.opt.lambda_normal * normal_error.mean()
loss = loss + normal_loss
loss_dict.update({'loss_normal': normal_loss})
# if (self.step + self.resume_step) >= 1500 and self.loss_class.opt.lambda_dist > 0:
if (self.step + self.resume_step) >= 15000 and self.loss_class.opt.lambda_dist > 0:
# if (self.step + self.resume_step) >= 300 and self.loss_class.opt.lambda_dist > 0:
dist_loss = self.loss_class.opt.lambda_dist * (rend_dist).mean()
loss = loss + dist_loss
loss_dict.update({'loss_dist': dist_loss})
if self.loss_class.opt.pruning_ot_lambda > 0:
# for now, save and analyze first
# selected_pts_mask_scaling = th.where(th.max(scaling, dim=-1).values < 0.01 * 0.9, True, False)
selected_pts_mask_scaling = th.where(
th.max(scaling, dim=-1).values > 0.05 * 0.9, True,
False)
# selected_pts_mask_opacity = th.where(opacity[..., 0] < 0.1, True, False) # B N
selected_pts_mask_opacity = th.where(
opacity[..., 0] < 0.01, True,
False) # 0.005 in the original 3dgs setting
selected_scaling_pts = pos[0][selected_pts_mask_scaling[0]]
selected_opacity_pts = pos[0][selected_pts_mask_opacity[0]]
pcu.save_mesh_v(
f'tmp/voxel/cd/10/scaling_masked_pts_0.05.ply',
selected_scaling_pts.detach().cpu().numpy(),
)
pcu.save_mesh_v(
f'tmp/voxel/cd/10/opacity_masked_pts_0.01.ply',
selected_opacity_pts.detach().cpu().numpy(),
)
# st()
# pass
if self.loss_class.opt.cd_lambda > 0:
# fuse depth to 3D point cloud to supervise the gaussians
B = latent['pos'].shape[0]
# c = micro['c']
# H = micro['depth'].shape[-1]
# V = 4
# # ! prepare 3D xyz ground truth
# cam2world_matrix = c['orig_c2w'][:, :, :16].reshape(
# B * V, 4, 4)
# intrinsics = c['orig_pose'][:, :,
# 16:25].reshape(B * V, 3, 3)
# # ! already in the world space after ray_sampler()
# ray_origins, ray_directions = self.ray_sampler( # shape:
# cam2world_matrix, intrinsics, H // 2)[:2]
# # depth = th.nn.functional.interpolate(micro['depth'].unsqueeze(1), (128,128), mode='nearest')[:, 0] # since each view has 128x128 Gaussians
# # depth = th.nn.functional.interpolate(micro['depth'].unsqueeze(1), (128,128), mode='nearest')[:, 0] # since each view has 128x128 Gaussians
# depth_128 = th.nn.functional.interpolate(
# micro['depth'].unsqueeze(1), (128, 128),
# mode='nearest'
# )[:, 0] # since each view has 128x128 Gaussians
# depth = depth_128.reshape(B * V, -1).unsqueeze(-1)
# # depth = micro['depth'].reshape(B*V, -1).unsqueeze(-1)
# gt_pos = ray_origins + depth * ray_directions # BV HW 3, already in the world space
# gt_pos = rearrange(gt_pos,
# '(B V) N C -> B (V N) C',
# B=B,
# V=V)
# gt_pos = gt_pos.clip(-0.45, 0.45)
# TODO
gt_pos = micro[
'fps_pcd'] # all the same, will update later.
# ! use online here
# gt_pos = query_pcd_xyz
cd_loss = pytorch3d.loss.chamfer_distance(
gt_pos, latent['pos']
)[0] * self.loss_class.opt.cd_lambda # V=4 GT for now. Test with V=8 GT later.
# st()
# for vis
if False:
torchvision.utils.save_image(micro['img'],
'gt.jpg',
value_range=(-1, 1),
normalize=True)
with th.no_grad():
for b in range(B):
pcu.save_mesh_v(
f'tmp/voxel/cd/10/again_pred-{b}.ply',
latent['pos'][b].detach().cpu().numpy(),
)
# pcu.save_mesh_v(
# f'tmp/voxel/cd/10/again-gt-{b}.ply',
# gt_pos[b].detach().cpu().numpy(),
# )
# st()
loss = loss + cd_loss
loss_dict.update({'loss_cd': cd_loss})
elif self.loss_class.opt.xyz_lambda > 0:
'''
B = latent['per_view_pos'].shape[0] // 4
V = 4
c = micro['c']
H = micro['depth'].shape[-1]
# ! prepare 3D xyz ground truth
cam2world_matrix = c['orig_c2w'][:, :, :16].reshape(
B * V, 4, 4)
intrinsics = c['orig_pose'][:, :,
16:25].reshape(B * V, 3, 3)
# ! already in the world space after ray_sampler()
ray_origins, ray_directions = self.ray_sampler( # shape:
cam2world_matrix, intrinsics, H // 2)[:2]
# self.gs.output_size,)[:2]
# depth = rearrange(micro['depth'], '(B V) H W -> ')
depth_128 = th.nn.functional.interpolate(
micro['depth'].unsqueeze(1), (128, 128),
mode='nearest'
)[:, 0] # since each view has 128x128 Gaussians
depth = depth_128.reshape(B * V, -1).unsqueeze(-1)
fg_mask = th.nn.functional.interpolate(
micro['depth_mask'].unsqueeze(1).to(th.uint8),
(128, 128),
mode='nearest').squeeze(1) # B*V H W
fg_mask = fg_mask.reshape(B * V, -1).unsqueeze(-1)
gt_pos = ray_origins + depth * ray_directions # BV HW 3, already in the world space
# st()
gt_pos = fg_mask * gt_pos.clip(
-0.45, 0.45) # g-buffer objaverse range
pred = fg_mask * latent['per_view_pos']
# for vis
if True:
torchvision.utils.save_image(micro['img'],
'gt.jpg',
value_range=(-1, 1),
normalize=True)
with th.no_grad():
gt_pos_vis = rearrange(gt_pos,
'(B V) N C -> B V N C',
B=B,
V=V)
pred_pos_vis = rearrange(pred,
'(B V) N C -> B V N C',
B=B,
V=V)
# save
for b in range(B):
for v in range(V):
# pcu.save_mesh_v(f'tmp/dust3r/add3dsupp-pred-{b}-{v}.ply',
# pred_pos_vis[b][v].detach().cpu().numpy(),)
# pcu.save_mesh_v(f'tmp/dust3r/add3dsupp-gt-{b}-{v}.ply',
# gt_pos_vis[b][v].detach().cpu().numpy(),)
pcu.save_mesh_v(
f'tmp/lambda50/no3dsupp-pred-{b}-{v}.ply',
pred_pos_vis[b]
[v].detach().cpu().numpy(),
)
pcu.save_mesh_v(
f'tmp/lambda50/no3dsupp-gt-{b}-{v}.ply',
gt_pos_vis[b]
[v].detach().cpu().numpy(),
)
st()
xyz_loss = th.nn.functional.mse_loss(
gt_pos, pred
) * self.loss_class.opt.xyz_lambda # ! 15% nonzero points
loss = loss + xyz_loss
'''
# ! directly gs center supervision with l1 loss, follow LION VAE
# xyz_loss = th.nn.functional.l1_loss(
# query_pcd_xyz, pred
# ) * self.loss_class.opt.xyz_lambda # ! 15% nonzero points
xyz_loss = self.loss_class.criterion_xyz(query_pcd_xyz, latent['pos']) * self.loss_class.opt.xyz_lambda
loss = loss + xyz_loss
# only calculate foreground gt_pos here?
loss_dict.update({'loss_xyz': xyz_loss})
elif self.loss_class.opt.emd_lambda > 0:
# rand_pt_size = 4096 # K value. Input Error! The size of the point clouds should be a multiple of 1024.
pred = latent['pos']
rand_pt_size = min(2048, max(pred.shape[1], 1024)) # K value. Input Error! The size of the point clouds should be a multiple of 1024.
if micro['fps_pcd'].shape[0] == pred.shape[0]:
gt_point = micro['fps_pcd']
else: # overfit memory dataset
gt_point = micro[
'fps_pcd'][::
4] # consecutive 4 views are from the same ID
B, gt_point_N = gt_point.shape[:2]
# random sample pred points
# sampled_pred =
# rand_pt_idx = torch.randint(high=pred.shape[1]-gt_point_N, size=(B,))
# pcu.save_mesh_v( f'tmp/voxel/emd/gt-half.ply', gt_point[0, ::4].detach().cpu().numpy(),)
# for b in range(gt_point.shape[0]):
# pcu.save_mesh_v( f'{logger.get_dir()}/gt-{b}.ply', gt_point[b].detach().cpu().numpy(),)
# pcu.save_mesh_v( f'{logger.get_dir()}/pred-{b}.ply', pred[b].detach().cpu().numpy(),)
# pcu.save_mesh_v( f'0.ply', latent['pos'][0].detach().cpu().numpy())
# st()
if self.loss_class.opt.fps_sampling: # O(N*K). reduce K later.
if self.loss_class.opt.subset_fps_sampling:
rand_pt_size = 1024 # for faster calculation
# ! uniform sampling with randomness
# sampled_gt_pts_for_emd_loss = gt_point[:, random.randint(0,9)::9][:, :1024] # direct uniform downsample to the K size
# sampled_gt_pts_for_emd_loss = gt_point[:, random.randint(0,9)::4][:, :1024] # direct uniform downsample to the K size
rand_perm = torch.randperm(
gt_point.shape[1]
)[:rand_pt_size] # shuffle the xyz before downsample - fps sampling
sampled_gt_pts_for_emd_loss = gt_point[:, rand_perm]
# sampled_gt_pts_for_emd_loss = gt_point[:, ::4]
# sampled_pred_pts_for_emd_loss = pred[:, ::32]
# sampled_gt_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points(
# gt_point[:, ::4], K=rand_pt_size)[0] # V4
if self.loss_class.opt.subset_half_fps_sampling:
# sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points(
# pred, K=rand_pt_size)[0] # V5
rand_perm = torch.randperm(
pred.shape[1]
)[:4096] # shuffle the xyz before downsample - fps sampling
sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points(
pred[:, rand_perm],
K=rand_pt_size)[0] # improve randomness
else:
sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points(
pred[:, ::4], K=rand_pt_size)[0] # V5
# rand_perm = torch.randperm(pred.shape[1]) # shuffle the xyz before downsample - fps sampling
# sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points(
# pred[:, rand_perm][:, ::4], K=rand_pt_size)[0] # rand perm before downsampling, V6
# sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points(
# pred[:, rand_perm][:, ::8], K=rand_pt_size)[0] # rand perm before downsampling, V7
# sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points(
# pred[:, self.step%2::4], K=rand_pt_size)[0] # rand perm before downsampling, V8, based on V50
else:
sampled_gt_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points(
gt_point, K=rand_pt_size)[0]
# if self.loss_class.opt.subset_half_fps_sampling:
# rand_pt_size = 4096 # K value. Input Error! The size of the point clouds should be a multiple of 1024.
sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points(
pred, K=rand_pt_size)[0]
# else:
# sampled_pred_pts_for_emd_loss = pytorch3d.ops.sample_farthest_points(
# pred, K=rand_pt_size)[0]
else: # random sampling
rand_pt_idx_pred = torch.randint(high=pred.shape[1] -
rand_pt_size,
size=(1, ))[0]
rand_pt_idx_gt = torch.randint(high=gt_point.shape[1] -
rand_pt_size,
size=(1, ))[0]
sampled_pred_pts_for_emd_loss = pred[:,
rand_pt_idx_pred:
rand_pt_idx_pred +
rand_pt_size, ...]
sampled_gt_pts_for_emd_loss = gt_point[:,
rand_pt_idx_gt:
rand_pt_idx_gt +
rand_pt_size,
...]
# only calculate foreground gt_pos here?
emd_loss = calc_emd(sampled_gt_pts_for_emd_loss,
sampled_pred_pts_for_emd_loss).mean(
) * self.loss_class.opt.emd_lambda
loss = loss + emd_loss
loss_dict.update({'loss_emd': emd_loss})
if self.loss_class.opt.commitment_loss_lambda > 0:
ellipsoid_vol = torch.prod(scaling, dim=-1, keepdim=True) / ((0.01 * 0.9)**3) # * (4/3*torch.pi). normalized vol
commitment = ellipsoid_vol * opacity
to_be_pruned_ellipsoid_idx = commitment < (3/4)**3 * 0.9 # those points shall have larger vol*opacity contribution
commitment_loss = -commitment[to_be_pruned_ellipsoid_idx].mean() * self.loss_class.opt.commitment_loss_lambda
loss = loss + commitment_loss
loss_dict.update({'loss_commitment': commitment_loss})
loss_dict.update({'loss_commitment_opacity': opacity.mean()})
loss_dict.update({'loss_commitment_vol': ellipsoid_vol.mean()})
log_rec3d_loss_dict(loss_dict)
# self.mp_trainer_rec.backward(loss)
if behaviour == 'g_step':
self.mp_trainer_rec.backward(loss)
else:
self.mp_trainer_disc.backward(loss)
# for name, p in self.rec_model.named_parameters():
# if p.grad is None:
# logger.log(f"found rec unused param: {name}")
# print(name, p.grad.mean(), p.grad.abs().max())
if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0 and behaviour=='g_step':
# if dist_util.get_rank() == 0 and self.step % 1 == 0 and i == 0:
try:
torchvision.utils.save_image(
th.cat([target['img'][::1], pred_nv_cano[fine_scale_key]['image_raw'][::1]], ),
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
normalize=True, value_range=(-1,1),nrow=len(indices)*2)
# save depth and normal and alpha
torchvision.utils.save_image(
th.cat([surf_normal[::1], rend_normal[::1]], ),
f'{logger.get_dir()}/{self.step+self.resume_step}_normal_new.jpg',
normalize=True, value_range=(-1,1), nrow=len(indices)*2)
torchvision.utils.save_image(
th.cat([target['depth'][::1], pred_nv_cano[fine_scale_key]['image_depth'][::1]], ),
f'{logger.get_dir()}/{self.step+self.resume_step}_depth.jpg',
normalize=True, nrow=len(indices)*2)
# torchvision.utils.save_image( pred_nv_cano['image_depth'][::1], f'{logger.get_dir()}/{self.step+self.resume_step}_depth.jpg', normalize=True, nrow=len(indices)*2)
torchvision.utils.save_image(
th.cat([target['depth_mask'][::1], pred_nv_cano[fine_scale_key]['image_mask'][::1]], ),
f'{logger.get_dir()}/{self.step+self.resume_step}_alpha.jpg',
normalize=True, value_range=(0,1), nrow=len(indices)*2)
logger.log(
'log vis to: ',
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
except Exception as e:
logger.log('Exception when saving log: ', e)
# if self.step % 2500 == 0:
# th.cuda.empty_cache() # free vram
@torch.no_grad()
def export_mesh_from_2dgs(self, all_rgbs, all_depths, all_alphas, cam_pathes, latent_save_dir):
# https://github.com/autonomousvision/LaRa/blob/main/evaluation.py
n_thread = 1 # avoid TSDF cpu hanging bug.
os.environ["MKL_NUM_THREADS"] = f"{n_thread}"
os.environ["NUMEXPR_NUM_THREADS"] = f"{n_thread}"
os.environ["OMP_NUM_THREADS"] = f"4"
os.environ["VECLIB_MAXIMUM_THREADS"] = f"{n_thread}"
os.environ["OPENBLAS_NUM_THREADS"] = f"{n_thread}"
# copied from: https://github.com/hbb1/2d-gaussian-splatting/blob/19eb5f1e091a582e911b4282fe2832bac4c89f0f/render.py#L23
logger.log("exporting mesh ...")
# for g-objv
aabb = [-0.45,-0.45,-0.45,0.45,0.45,0.45]
self.aabb = np.array(aabb).reshape(2,3)*1.1
name = f'{latent_save_dir}/mesh_raw.obj'
mesh = self.extract_mesh_bounded(all_rgbs, all_depths, all_alphas, cam_pathes)
o3d.io.write_triangle_mesh(name, mesh)
logger.log("mesh saved at {}".format(name))
mesh_post = smooth_mesh(mesh)
o3d.io.write_triangle_mesh(name.replace('_raw.obj', '.obj'), mesh_post)
logger.log("mesh post processed saved at {}".format(name.replace('.obj', '_post.obj')))
def get_source_cw2wT(self, source_cameras_view_to_world):
return matrix_to_quaternion(
source_cameras_view_to_world[:3, :3].transpose(0, 1))
def c_to_3dgs_format(self, pose):
# TODO, switch to torch version (batched later)
c2w = pose[:16].reshape(4, 4) # 3x4
# ! load cam
w2c = np.linalg.inv(c2w)
R = np.transpose(
w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
T = w2c[:3, 3]
fx = pose[16]
FovX = focal2fov(fx, 1)
FovY = focal2fov(fx, 1)
tanfovx = math.tan(FovX * 0.5)
tanfovy = math.tan(FovY * 0.5)
assert tanfovx == tanfovy
trans = np.array([0.0, 0.0, 0.0])
scale = 1.0
world_view_transform = torch.tensor(getWorld2View2(R, T, trans,
scale)).transpose(
0, 1)
projection_matrix = getProjectionMatrix(znear=self.znear,
zfar=self.zfar,
fovX=FovX,
fovY=FovY).transpose(0, 1)
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(
projection_matrix.unsqueeze(0))).squeeze(0)
camera_center = world_view_transform.inverse()[3, :3]
view_world_transform = torch.tensor(getView2World(R, T, trans,
scale)).transpose(
0, 1)
# item.update(viewpoint_cam=[viewpoint_cam])
c = {}
c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform)
c.update(
projection_matrix=projection_matrix, # K
cam_view=world_view_transform, # world_view_transform
cam_view_proj=full_proj_transform, # full_proj_transform
cam_pos=camera_center,
tanfov=tanfovx, # TODO, fix in the renderer
# orig_c2w=c2w,
# orig_w2c=w2c,
orig_pose=torch.from_numpy(pose),
orig_c2w=torch.from_numpy(c2w),
orig_w2c=torch.from_numpy(w2c),
# tanfovy=tanfovy,
)
return c # dict for gs rendering
@torch.no_grad()
def extract_mesh_bounded(self, rgbmaps, depthmaps, alpha_maps, cam_pathes, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, alpha_thres=0.08, mask_backgrond=False):
"""
Perform TSDF fusion given a fixed depth range, used in the paper.
voxel_size: the voxel size of the volume
sdf_trunc: truncation value
depth_trunc: maximum depth range, should depended on the scene's scales
mask_backgrond: whether to mask backgroud, only works when the dataset have masks
return o3d.mesh
"""
if self.aabb is not None: # as in lara.
center = self.aabb.mean(0)
# radius = np.linalg.norm(self.aabb[1] - self.aabb[0]) * 0.5
radius = np.linalg.norm(self.aabb[1] - self.aabb[0]) * 0.5
# voxel_size = radius / 256
voxel_size = radius / 192 # less holes
# sdf_trunc = voxel_size * 16 # less holes, slower integration
sdf_trunc = voxel_size * 12 #
print("using aabb")
volume = o3d.pipelines.integration.ScalableTSDFVolume(
voxel_length= voxel_size,
sdf_trunc=sdf_trunc,
color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
)
print("Running tsdf volume integration ...")
print(f'voxel_size: {voxel_size}')
print(f'sdf_trunc: {sdf_trunc}')
# print(f'depth_truc: {depth_trunc}')
# render_reference = th.load('eval_pose.pt', map_location='cpu').numpy()
# ! use uni_mesh_path, from Lara, Chen et al, ECCV 24'
# '''
# for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack)), desc="TSDF integration progress"):
for i, cam in tqdm(enumerate(cam_pathes), desc="TSDF integration progress"):
# rgb = self.rgbmaps[i]
# depth = self.depthmaps[i]
cam = self.c_to_3dgs_format(cam)
cam_o3d = to_cam_open3d_compat(cam)
rgb = rgbmaps[i][0]
depth = depthmaps[i][0]
alpha = alpha_maps[i][0]
# if we have mask provided, use it
# if mask_backgrond and (self.viewpoint_stack[i].gt_alpha_mask is not None):
# depth[(self.viewpoint_stack[i].gt_alpha_mask < 0.5)] = 0
depth[(alpha < alpha_thres)] = 0
if self.aabb is not None:
campos = cam['cam_pos'].cpu().numpy()
depth_trunc = np.linalg.norm(campos - center, axis=-1) + radius
# make open3d rgbd
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
o3d.geometry.Image(np.asarray(np.clip(rgb.permute(1,2,0).cpu().numpy(), 0.0, 1.0) * 255, order="C", dtype=np.uint8)),
o3d.geometry.Image(np.asarray(depth.permute(1,2,0).cpu().numpy(), order="C")),
depth_trunc = depth_trunc,
convert_rgb_to_intensity=False,
depth_scale = 1.0
)
volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic)
mesh = volume.extract_triangle_mesh()
return mesh
@th.inference_mode()
def eval_novelview_loop(self, camera=None, save_latent=False):
# novel view synthesis given evaluation camera trajectory
if save_latent: # for diffusion learning
latent_dir = Path(f'{logger.get_dir()}/latent_dir')
latent_dir.mkdir(exist_ok=True, parents=True)
render_reference=uni_mesh_path(10)
for eval_idx, micro in enumerate(tqdm(self.eval_data)):
latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}'
all_latent_file = sorted(Path(latent_save_dir).glob('*.npz') )
if len(all_latent_file) == 0:
save_prefix = 0
else:
save_prefix = int(all_latent_file[-1].stem[-1] ) + 1
Path(latent_save_dir).mkdir(parents=True, exist_ok=True)
with th.autocast(device_type='cuda',
dtype=self.dtype,
enabled=self.mp_trainer_rec.use_amp):
# st() # check whether more c info available
latent = self.rec_model(
img=micro['img_to_encoder'].to(self.dtype),
behaviour='enc_dec_wo_triplane',
c=micro['c'],
pcd=micro['fps_pcd'], # send in pcd for surface reference.
) # send in input-view C since pixel-aligned gaussians required
# fine_scale_key = list(pred.keys())[-1]
# fine_scale_key = 'gaussians_upsampled_2'
fine_scale_key = 'gaussians_upsampled_3'
export_mesh = True # for debug
if True:
# if eval_idx < 1500 and eval_idx % 3 == 0:
if eval_idx < 1500:
all_rgbs, all_depths, all_alphas=self.render_gs_video_given_latent(
latent,
self.rec_model, # compatible with join_model
name_prefix=f'{self.step + self.resume_step}_{micro["ins"][0].split("/")[0]}_{eval_idx}',
save_img=False,
render_reference=render_reference,
export_mesh=False)
if export_mesh:
self.export_mesh_from_2dgs(all_rgbs, all_depths, all_alphas, render_reference, latent_save_dir)
# ! B=2 here
np.savez_compressed(f'{latent_save_dir}/latent-{save_prefix}.npz',
latent_normalized=latent['latent_normalized'].cpu().numpy(),
query_pcd_xyz=latent['query_pcd_xyz'].cpu().numpy()
)
# st()
for scale in ['gaussians_upsampled', 'gaussians_base', 'gaussians_upsampled_2', 'gaussians_upsampled_3']:
np.save(f'{latent_save_dir}/{scale}.npy', latent[scale].cpu().numpy())
@th.inference_mode()
def render_gs_video_given_latent(self,
ddpm_latent,
rec_model,
name_prefix='0',
save_img=False,
render_reference=None,
export_mesh=False):
all_rgbs, all_depths, all_alphas = [], [], []
# batch_size, L, C = planes.shape
# ddpm_latent = { self.latent_name: planes[..., :-3] * self.triplane_scaling_divider, # kl-reg latent
# 'query_pcd_xyz': self.pcd_unnormalize_fn(planes[..., -3:]) }
# ddpm_latent.update(rec_model(latent=ddpm_latent, behaviour='decode_gs_after_vae_no_render'))
# assert render_reference is None
# render_reference = self.eval_data # compat
# else: # use train_traj
# for key in ['ins', 'bbox', 'caption']:
# if key in render_reference:
# render_reference.pop(key)
# render_reference = [ { k:v[idx:idx+1] for k, v in render_reference.items() } for idx in range(40) ]
video_out = imageio.get_writer(
f'{logger.get_dir()}/gs_{name_prefix}.mp4',
mode='I',
fps=15,
codec='libx264')
# for i, batch in enumerate(tqdm(self.eval_data)):
for i, micro_c in enumerate(tqdm(render_reference)):
# micro = {
# k: v.to(dist_util.dev()) if isinstance(v, th.Tensor) else v
# for k, v in batch.items()
# }
c = self.post_process.c_to_3dgs_format(micro_c)
for k in c.keys(): # to cuda
if isinstance(c[k], th.Tensor) and k != 'tanfov':
c[k] = c[k].unsqueeze(0).unsqueeze(0).to(dist_util.dev()) # actually, could render 40 views together.
c['tanfov'] = th.tensor(c['tanfov']).to(dist_util.dev())
pred = rec_model(
img=None,
c=c, # TODO, to dict
latent=ddpm_latent, # render gs
behaviour='triplane_dec',
bg_color=self.gs_bg_color,
render_all_scale=True,
)
fine_scale_key = list(pred.keys())[-1]
all_rgbs.append(einops.rearrange(pred[fine_scale_key]['image'], 'B V ... -> (B V) ...'))
all_depths.append(einops.rearrange(pred[fine_scale_key]['depth'], 'B V ... -> (B V) ...'))
all_alphas.append(einops.rearrange(pred[fine_scale_key]['alpha'], 'B V ... -> (B V) ...'))
# st()
# fine_scale_key = list(pred.keys())[-1]
all_pred_vis = {}
for key in pred.keys():
pred_scale = pred[key] # only show finest result here
for k in pred_scale.keys():
pred_scale[k] = einops.rearrange(pred_scale[k], 'B V ... -> (B V) ...') # merge
pred_vis = self._make_vis_img(pred_scale)
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
vis = vis * 127.5 + 127.5
vis = vis.clip(0, 255).astype(np.uint8)
all_pred_vis[key] = vis
# all_pred_vis_concat = np.concatenate([cv2.resize(all_pred_vis[k][0], (384*3, 384)) for k in ['gaussians_base', 'gaussians_upsampled', 'gaussians_upsampled_2']], axis=0)
# all_pred_vis_concat = np.concatenate([cv2.resize(all_pred_vis[k][0], (256*3, 256)) for k in ['gaussians_base', 'gaussians_upsampled',]], axis=0)
# all_pred_vis_concat = np.concatenate([cv2.resize(all_pred_vis[k][0], (384*3, 384)) for k in all_pred_vis.keys()], axis=0)
all_pred_vis_concat = np.concatenate([cv2.resize(all_pred_vis[k][0], (512*3, 512)) for k in all_pred_vis.keys()], axis=0)
# for j in range(vis.shape[0]):
video_out.append_data(all_pred_vis_concat)
video_out.close()
print('logged video to: ',
f'{logger.get_dir()}/triplane_{name_prefix}.mp4')
del video_out, pred, pred_vis, vis
return all_rgbs, all_depths, all_alphas
@th.no_grad()
def _make_vis_img(self, pred):
# if True:
pred_depth = pred['image_depth']
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
pred_depth.min())
pred_depth = pred_depth.cpu()[0].permute(1, 2, 0).numpy()
pred_depth = (plt.cm.viridis(pred_depth[..., 0])[..., :3]) * 2 - 1
pred_depth = th.from_numpy(pred_depth).to(
pred['image_raw'].device).permute(2, 0, 1).unsqueeze(0)
gen_img = pred['image_raw']
rend_normal = pred['rend_normal']
pred_vis = th.cat(
[
gen_img,
rend_normal,
pred_depth,
],
dim=-1) # B, 3, H, W
return pred_vis
class TrainLoop3DRecNVPatchSingleForwardMV_NoCrop_adv(TrainLoop3DRecNVPatchSingleForwardMV_NoCrop):
def __init__(self, *, rec_model, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, load_submodule_name='', ignore_resume_opt=False, model_name='rec', use_amp=False, num_frames=4, **kwargs):
super().__init__(rec_model=rec_model, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, load_submodule_name=load_submodule_name, ignore_resume_opt=ignore_resume_opt, model_name=model_name, use_amp=use_amp, num_frames=num_frames, **kwargs)
# create discriminator
# ! copied from ln3diff tri-plane version
disc_params = self.loss_class.get_trainable_parameters()
self.mp_trainer_disc = MixedPrecisionTrainer(
model=self.loss_class.discriminator,
use_fp16=self.use_fp16,
fp16_scale_growth=fp16_scale_growth,
model_name='disc',
use_amp=use_amp,
model_params=disc_params)
# st() # check self.lr
self.opt_disc = AdamW(
self.mp_trainer_disc.master_params,
lr=self.lr, # follow sd code base
betas=(0, 0.999),
eps=1e-8)
# TODO, is loss cls already in the DDP?
if self.use_ddp:
self.ddp_disc = DDP(
self.loss_class.discriminator,
device_ids=[dist_util.dev()],
output_device=dist_util.dev(),
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
else:
self.ddp_disc = self.loss_class.discriminator
def save(self, mp_trainer=None, model_name='rec'):
if mp_trainer is None:
mp_trainer = self.mp_trainer_rec
def save_checkpoint(rate, params):
state_dict = mp_trainer.master_params_to_state_dict(params)
if dist_util.get_rank() == 0:
logger.log(f"saving model {model_name} {rate}...")
if not rate:
filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt"
else:
filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename),
"wb") as f:
th.save(state_dict, f)
save_checkpoint(0, mp_trainer.master_params)
dist.barrier()
def run_step(self, batch, step='g_step'):
# self.forward_backward(batch)
if step == 'g_step':
self.forward_backward(batch, behaviour='g_step')
took_step_g_rec = self.mp_trainer_rec.optimize(self.opt)
if took_step_g_rec:
self._update_ema() # g_ema
elif step == 'd_step':
self.forward_backward(batch, behaviour='d_step')
_ = self.mp_trainer_disc.optimize(self.opt_disc)
self._anneal_lr()
self.log_step()
def run_loop(self, batch=None):
while (not self.lr_anneal_steps
or self.step + self.resume_step < self.lr_anneal_steps):
batch = next(self.data)
self.run_step(batch, 'g_step')
batch = next(self.data)
self.run_step(batch, 'd_step')
if self.step % 1000 == 0:
dist_util.synchronize()
if self.step % 5000 == 0:
th.cuda.empty_cache() # avoid memory leak
if self.step % self.log_interval == 0 and dist_util.get_rank(
) == 0:
out = logger.dumpkvs()
# * log to tensorboard
for k, v in out.items():
self.writer.add_scalar(f'Loss/{k}', v,
self.step + self.resume_step)
if self.step % self.eval_interval == 0 and self.step != 0:
if dist_util.get_rank() == 0:
try:
self.eval_loop()
except Exception as e:
logger.log(e)
dist_util.synchronize()
# if self.step % self.save_interval == 0 and self.step != 0:
if self.step % self.save_interval == 0:
self.save()
self.save(self.mp_trainer_disc,
self.mp_trainer_disc.model_name)
dist_util.synchronize()
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST",
"") and self.step > 0:
return
self.step += 1
if self.step > self.iterations:
logger.log('reached maximum iterations, exiting')
# Save the last checkpoint if it wasn't already saved.
if (self.step -
1) % self.save_interval != 0 and self.step != 1:
self.save()
exit()
# Save the last checkpoint if it wasn't already saved.
# if (self.step - 1) % self.save_interval != 0 and self.step != 1:
if (self.step - 1) % self.save_interval != 0:
try:
self.save() # save rec
self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name)
except Exception as e:
logger.log(e)
# ! load disc
def _load_and_sync_parameters(self, submodule_name=''):
super()._load_and_sync_parameters(submodule_name)
# load disc
resume_checkpoint = self.resume_checkpoint.replace(
'rec', 'disc') # * default behaviour
if os.path.exists(resume_checkpoint):
if dist_util.get_rank() == 0:
logger.log(
f"loading disc model from checkpoint: {resume_checkpoint}..."
)
map_location = {
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
} # configure map_location properly
resume_state_dict = dist_util.load_state_dict(
resume_checkpoint, map_location=map_location)
model_state_dict = self.loss_class.discriminator.state_dict()
for k, v in resume_state_dict.items():
if k in model_state_dict.keys():
if v.size() == model_state_dict[k].size():
model_state_dict[k] = v
# model_state_dict[k].copy_(v)
else:
logger.log('!!!! partially load: ', k, ": ",
v.size(), "state_dict: ",
model_state_dict[k].size())
if dist_util.get_world_size() > 1:
# dist_util.sync_params(self.model.named_parameters())
dist_util.sync_params(
self.loss_class.get_trainable_parameters())
logger.log('synced disc params')