import copy # import imageio.v3 import functools import json import os from pathlib import Path from pdb import set_trace as st from einops import rearrange import webdataset as wds import traceback import blobfile as bf import imageio import numpy as np # from sympy import O import torch as th import torch.distributed as dist import torchvision from PIL import Image from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.optim import AdamW from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from guided_diffusion import dist_util, logger from guided_diffusion.fp16_util import MixedPrecisionTrainer from guided_diffusion.nn import update_ema from guided_diffusion.resample import LossAwareSampler, UniformSampler from guided_diffusion.train_util import (calc_average_loss, find_ema_checkpoint, find_resume_checkpoint, get_blob_logdir, log_rec3d_loss_dict, parse_resume_step_from_filename) from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics from .train_util import TrainLoop3DRec class TrainLoop3DRecNV(TrainLoop3DRec): # supervise the training of novel view def __init__(self, *, rec_model, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, load_submodule_name='', ignore_resume_opt=False, model_name='rec', use_amp=False, **kwargs): super().__init__(rec_model=rec_model, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, load_submodule_name=load_submodule_name, ignore_resume_opt=ignore_resume_opt, model_name=model_name, use_amp=use_amp, **kwargs) self.rec_cano = True def forward_backward(self, batch, *args, **kwargs): # return super().forward_backward(batch, *args, **kwargs) self.mp_trainer_rec.zero_grad() batch_size = batch['img_to_encoder'].shape[0] for i in range(0, batch_size, self.microbatch): # st() micro = { k: v[i:i + self.microbatch].to(dist_util.dev()) for k, v in batch.items() } # ! concat novel-view? next version. also add self reconstruction, patch-based loss in the next version. verify novel-view prediction first. # wrap forward within amp with th.autocast(device_type='cuda', dtype=th.float16, enabled=self.mp_trainer_rec.use_amp): target_nvs = {} target_cano = {} latent = self.rec_model(img=micro['img_to_encoder'], behaviour='enc_dec_wo_triplane') pred = self.rec_model( latent=latent, c=micro['nv_c'], # predict novel view here behaviour='triplane_dec') for k, v in micro.items(): if k[:2] == 'nv': orig_key = k.replace('nv_', '') target_nvs[orig_key] = v target_cano[orig_key] = micro[orig_key] with self.rec_model.no_sync(): # type: ignore loss, loss_dict, fg_mask = self.loss_class( pred, target_nvs, step=self.step + self.resume_step, test_mode=False, return_fg_mask=True, conf_sigma_l1=None, conf_sigma_percl=None) log_rec3d_loss_dict(loss_dict) if self.rec_cano: pred_cano = self.rec_model(latent=latent, c=micro['c'], behaviour='triplane_dec') with self.rec_model.no_sync(): # type: ignore fg_mask = target_cano['depth_mask'].unsqueeze( 1).repeat_interleave(3, 1).float() loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( pred_cano['image_raw'], target_cano['img'], fg_mask, step=self.step + self.resume_step, test_mode=False, ) loss = loss + loss_cano # remove redundant log log_rec3d_loss_dict({ f'cano_{k}': v for k, v in loss_cano_dict.items() # if "loss" in k }) self.mp_trainer_rec.backward(loss) if dist_util.get_rank() == 0 and self.step % 500 == 0: if self.rec_cano: self.log_img(micro, pred, pred_cano) else: self.log_img(micro, pred, None) @th.inference_mode() def log_img(self, micro, pred, pred_cano): # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) def norm_depth(pred_depth): # to [-1,1] # pred_depth = pred['image_depth'] pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min()) return -(pred_depth * 2 - 1) pred_img = pred['image_raw'] gt_img = micro['img'] # infer novel view also # if self.loss_class.opt.symmetry_loss: # pred_nv_img = nvs_pred # else: # ! replace with novel view prediction # ! log another novel-view prediction # pred_nv_img = self.rec_model( # img=micro['img_to_encoder'], # c=self.novel_view_poses) # pred: (B, 3, 64, 64) # if 'depth' in micro: gt_depth = micro['depth'] if gt_depth.ndim == 3: gt_depth = gt_depth.unsqueeze(1) gt_depth = norm_depth(gt_depth) # gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - # gt_depth.min()) # if True: fg_mask = pred['image_mask'] * 2 - 1 # 0-1 input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1 if 'image_depth' in pred: pred_depth = norm_depth(pred['image_depth']) pred_nv_depth = norm_depth(pred_cano['image_depth']) else: pred_depth = th.zeros_like(gt_depth) pred_nv_depth = th.zeros_like(gt_depth) if 'image_sr' in pred: if pred['image_sr'].shape[-1] == 512: pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], dim=-1) gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], dim=-1) pred_depth = self.pool_512(pred_depth) gt_depth = self.pool_512(gt_depth) elif pred['image_sr'].shape[-1] == 256: pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], dim=-1) gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], dim=-1) pred_depth = self.pool_256(pred_depth) gt_depth = self.pool_256(gt_depth) else: pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], dim=-1) gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], dim=-1) gt_depth = self.pool_128(gt_depth) pred_depth = self.pool_128(pred_depth) else: gt_img = self.pool_64(gt_img) gt_depth = self.pool_64(gt_depth) pred_vis = th.cat([ pred_img, pred_depth.repeat_interleave(3, dim=1), fg_mask.repeat_interleave(3, dim=1), ], dim=-1) # B, 3, H, W pred_vis_nv = th.cat([ pred_cano['image_raw'], pred_nv_depth.repeat_interleave(3, dim=1), input_fg_mask.repeat_interleave(3, dim=1), ], dim=-1) # B, 3, H, W pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim gt_vis = th.cat([ gt_img, gt_depth.repeat_interleave(3, dim=1), th.zeros_like(gt_img) ], dim=-1) # TODO, fail to load depth. range [0, 1] if 'conf_sigma' in pred: gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( vis = th.cat([gt_vis, pred_vis], dim=-2) # .permute( # 0, 2, 3, 1).cpu() vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // 64) # HWC torchvision.utils.save_image( vis_tensor, f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', value_range=(-1, 1), normalize=True) # vis = vis.numpy() * 127.5 + 127.5 # vis = vis.clip(0, 255).astype(np.uint8) # Image.fromarray(vis).save( # f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') logger.log('log vis to: ', f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') # self.writer.add_image(f'images', # vis, # self.step + self.resume_step, # dataformats='HWC') # return pred class TrainLoop3DRecNVPatch(TrainLoop3DRecNV): # add patch rendering def __init__(self, *, rec_model, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, load_submodule_name='', ignore_resume_opt=False, model_name='rec', use_amp=False, **kwargs): super().__init__(rec_model=rec_model, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, load_submodule_name=load_submodule_name, ignore_resume_opt=ignore_resume_opt, model_name=model_name, use_amp=use_amp, **kwargs) # the rendrer self.eg3d_model = self.rec_model.module.decoder.triplane_decoder # type: ignore # self.rec_cano = False self.rec_cano = True def forward_backward(self, batch, *args, **kwargs): # add patch sampling self.mp_trainer_rec.zero_grad() batch_size = batch['img_to_encoder'].shape[0] for i in range(0, batch_size, self.microbatch): micro = { k: v[i:i + self.microbatch].to(dist_util.dev()) for k, v in batch.items() } # ! sample rendering patch target = { **self.eg3d_model( c=micro['nv_c'], # type: ignore ws=None, planes=None, sample_ray_only=True, fg_bbox=micro['nv_bbox']), # rays o / dir } patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ 'patch_rendering_resolution'] # type: ignore cropped_target = { k: th.empty_like(v) [..., :patch_rendering_resolution, :patch_rendering_resolution] if k not in [ 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', 'nv_img_sr', 'c' ] else v for k, v in micro.items() } # crop according to uv sampling for j in range(micro['img'].shape[0]): top, left, height, width = target['ray_bboxes'][ j] # list of tuple # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore for key in ('img', 'depth_mask', 'depth'): # type: ignore # target[key][i:i+1] = torchvision.transforms.functional.crop( # cropped_target[key][ # j:j + 1] = torchvision.transforms.functional.crop( # micro[key][j:j + 1], top, left, height, width) cropped_target[f'{key}'][ # ! no nv_ here j:j + 1] = torchvision.transforms.functional.crop( micro[f'nv_{key}'][j:j + 1], top, left, height, width) # target.update(cropped_target) # wrap forward within amp with th.autocast(device_type='cuda', dtype=th.float16, enabled=self.mp_trainer_rec.use_amp): # target_nvs = {} # target_cano = {} latent = self.rec_model(img=micro['img_to_encoder'], behaviour='enc_dec_wo_triplane') pred_nv = self.rec_model( latent=latent, c=micro['nv_c'], # predict novel view here behaviour='triplane_dec', ray_origins=target['ray_origins'], ray_directions=target['ray_directions'], ) # ! directly retrieve from target # for k, v in target.items(): # if k[:2] == 'nv': # orig_key = k.replace('nv_', '') # target_nvs[orig_key] = v # target_cano[orig_key] = target[orig_key] with self.rec_model.no_sync(): # type: ignore loss, loss_dict, _ = self.loss_class(pred_nv, cropped_target, step=self.step + self.resume_step, test_mode=False, return_fg_mask=True, conf_sigma_l1=None, conf_sigma_percl=None) log_rec3d_loss_dict(loss_dict) if self.rec_cano: cano_target = { **self.eg3d_model( c=micro['c'], # type: ignore ws=None, planes=None, sample_ray_only=True, fg_bbox=micro['bbox']), # rays o / dir } cano_cropped_target = { k: th.empty_like(v) for k, v in cropped_target.items() } for j in range(micro['img'].shape[0]): top, left, height, width = cano_target['ray_bboxes'][ j] # list of tuple # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore for key in ('img', 'depth_mask', 'depth'): # type: ignore # target[key][i:i+1] = torchvision.transforms.functional.crop( cano_cropped_target[key][ j:j + 1] = torchvision.transforms.functional.crop( micro[key][j:j + 1], top, left, height, width) # cano_target.update(cano_cropped_target) pred_cano = self.rec_model( latent=latent, c=micro['c'], behaviour='triplane_dec', ray_origins=cano_target['ray_origins'], ray_directions=cano_target['ray_directions'], ) with self.rec_model.no_sync(): # type: ignore fg_mask = cano_cropped_target['depth_mask'].unsqueeze( 1).repeat_interleave(3, 1).float() loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( pred_cano['image_raw'], cano_cropped_target['img'], fg_mask, step=self.step + self.resume_step, test_mode=False, ) loss = loss + loss_cano # remove redundant log log_rec3d_loss_dict({ f'cano_{k}': v for k, v in loss_cano_dict.items() # if "loss" in k }) self.mp_trainer_rec.backward(loss) if dist_util.get_rank() == 0 and self.step % 500 == 0: self.log_patch_img(cropped_target, pred_nv, pred_cano) @th.inference_mode() def log_patch_img(self, micro, pred, pred_cano): # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) def norm_depth(pred_depth): # to [-1,1] # pred_depth = pred['image_depth'] pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min()) return -(pred_depth * 2 - 1) pred_img = pred['image_raw'] gt_img = micro['img'] # infer novel view also # if self.loss_class.opt.symmetry_loss: # pred_nv_img = nvs_pred # else: # ! replace with novel view prediction # ! log another novel-view prediction # pred_nv_img = self.rec_model( # img=micro['img_to_encoder'], # c=self.novel_view_poses) # pred: (B, 3, 64, 64) # if 'depth' in micro: gt_depth = micro['depth'] if gt_depth.ndim == 3: gt_depth = gt_depth.unsqueeze(1) gt_depth = norm_depth(gt_depth) # gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - # gt_depth.min()) # if True: fg_mask = pred['image_mask'] * 2 - 1 # 0-1 input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1 if 'image_depth' in pred: pred_depth = norm_depth(pred['image_depth']) pred_cano_depth = norm_depth(pred_cano['image_depth']) else: pred_depth = th.zeros_like(gt_depth) pred_cano_depth = th.zeros_like(gt_depth) # if 'image_sr' in pred: # if pred['image_sr'].shape[-1] == 512: # pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], # dim=-1) # gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], # dim=-1) # pred_depth = self.pool_512(pred_depth) # gt_depth = self.pool_512(gt_depth) # elif pred['image_sr'].shape[-1] == 256: # pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], # dim=-1) # gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], # dim=-1) # pred_depth = self.pool_256(pred_depth) # gt_depth = self.pool_256(gt_depth) # else: # pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], # dim=-1) # gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], # dim=-1) # gt_depth = self.pool_128(gt_depth) # pred_depth = self.pool_128(pred_depth) # else: # gt_img = self.pool_64(gt_img) # gt_depth = self.pool_64(gt_depth) pred_vis = th.cat([ pred_img, pred_depth.repeat_interleave(3, dim=1), fg_mask.repeat_interleave(3, dim=1), ], dim=-1) # B, 3, H, W pred_vis_nv = th.cat([ pred_cano['image_raw'], pred_cano_depth.repeat_interleave(3, dim=1), input_fg_mask.repeat_interleave(3, dim=1), ], dim=-1) # B, 3, H, W pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim gt_vis = th.cat([ gt_img, gt_depth.repeat_interleave(3, dim=1), th.zeros_like(gt_img) ], dim=-1) # TODO, fail to load depth. range [0, 1] # if 'conf_sigma' in pred: # gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( # st() vis = th.cat([gt_vis, pred_vis], dim=-2) # .permute( # 0, 2, 3, 1).cpu() vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // 64) # HWC torchvision.utils.save_image( vis_tensor, f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', value_range=(-1, 1), normalize=True) logger.log('log vis to: ', f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') # self.writer.add_image(f'images', # vis, # self.step + self.resume_step, # dataformats='HWC') class TrainLoop3DRecNVPatchSingleForward(TrainLoop3DRecNVPatch): def __init__(self, *, rec_model, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, load_submodule_name='', ignore_resume_opt=False, model_name='rec', use_amp=False, **kwargs): super().__init__(rec_model=rec_model, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, load_submodule_name=load_submodule_name, ignore_resume_opt=ignore_resume_opt, model_name=model_name, use_amp=use_amp, **kwargs) def forward_backward(self, batch, *args, **kwargs): # add patch sampling self.mp_trainer_rec.zero_grad() batch_size = batch['img_to_encoder'].shape[0] batch.pop('caption') # not required batch.pop('ins') # not required # batch.pop('nv_caption') # not required for i in range(0, batch_size, self.microbatch): micro = { k: v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( v, th.Tensor) else v[i:i + self.microbatch] for k, v in batch.items() } # ! sample rendering patch target = { **self.eg3d_model( c=micro['nv_c'], # type: ignore ws=None, planes=None, sample_ray_only=True, fg_bbox=micro['nv_bbox']), # rays o / dir } patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ 'patch_rendering_resolution'] # type: ignore cropped_target = { k: th.empty_like(v) [..., :patch_rendering_resolution, :patch_rendering_resolution] if k not in [ 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', 'nv_img_sr', 'c', 'caption', 'nv_caption' ] else v for k, v in micro.items() } # crop according to uv sampling for j in range(micro['img'].shape[0]): top, left, height, width = target['ray_bboxes'][ j] # list of tuple # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore for key in ('img', 'depth_mask', 'depth'): # type: ignore # target[key][i:i+1] = torchvision.transforms.functional.crop( # cropped_target[key][ # j:j + 1] = torchvision.transforms.functional.crop( # micro[key][j:j + 1], top, left, height, width) cropped_target[f'{key}'][ # ! no nv_ here j:j + 1] = torchvision.transforms.functional.crop( micro[f'nv_{key}'][j:j + 1], top, left, height, width) # ! cano view loss cano_target = { **self.eg3d_model( c=micro['c'], # type: ignore ws=None, planes=None, sample_ray_only=True, fg_bbox=micro['bbox']), # rays o / dir } # cano_cropped_target = { # k: th.empty_like(v) # for k, v in cropped_target.items() # } # for j in range(micro['img'].shape[0]): # top, left, height, width = cano_target['ray_bboxes'][ # j] # list of tuple # # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore # for key in ('img', 'depth_mask', 'depth'): # type: ignore # # target[key][i:i+1] = torchvision.transforms.functional.crop( # cano_cropped_target[key][ # j:j + 1] = torchvision.transforms.functional.crop( # micro[key][j:j + 1], top, left, height, width) # ! vit no amp latent = self.rec_model(img=micro['img_to_encoder'], behaviour='enc_dec_wo_triplane') # wrap forward within amp with th.autocast(device_type='cuda', dtype=th.float16, enabled=self.mp_trainer_rec.use_amp): # c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here # c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here instance_mv_num = batch_size // 4 # 4 pairs by default # instance_mv_num = 4 # ! roll views for multi-view supervision c = th.cat([ micro['nv_c'].roll(instance_mv_num * i, dims=0) for i in range(1, 4) ] # + [micro['c']] ) # predict novel view here ray_origins = th.cat( [ target['ray_origins'].roll(instance_mv_num * i, dims=0) for i in range(1, 4) ] # + [cano_target['ray_origins'] ] , 0) ray_directions = th.cat([ target['ray_directions'].roll(instance_mv_num * i, dims=0) for i in range(1, 4) ] # + [cano_target['ray_directions'] ] ) pred_nv_cano = self.rec_model( # latent=latent.expand(2,), latent={ 'latent_after_vit': # ! triplane for rendering # latent['latent_after_vit'].repeat(2, 1, 1, 1) latent['latent_after_vit'].repeat(3, 1, 1, 1) }, c=c, behaviour='triplane_dec', # ray_origins=target['ray_origins'], # ray_directions=target['ray_directions'], ray_origins=ray_origins, ray_directions=ray_directions, ) pred_nv_cano.update( latent ) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True) # gt = { # k: th.cat([v, cano_cropped_target[k]], 0) # for k, v in cropped_target.items() # } gt = { k: th.cat( [ v.roll(instance_mv_num * i, dims=0) for i in range(1, 4) ] # + [cano_cropped_target[k] ] , 0) for k, v in cropped_target.items() } # torchvision.utils.save_image(gt['img'], 'gt.png', normalize=True) with self.rec_model.no_sync(): # type: ignore loss, loss_dict, _ = self.loss_class( pred_nv_cano, gt, # prepare merged data step=self.step + self.resume_step, test_mode=False, return_fg_mask=True, conf_sigma_l1=None, conf_sigma_percl=None) log_rec3d_loss_dict(loss_dict) self.mp_trainer_rec.backward(loss) # for name, p in self.rec_model.named_parameters(): # if p.grad is None: # logger.log(f"found rec unused param: {name}") if dist_util.get_rank() == 0 and self.step % 500 == 0: micro_bs = micro['img_to_encoder'].shape[0] self.log_patch_img( # record one cano view and one novel view cropped_target, { k: pred_nv_cano[k][-micro_bs:] for k in ['image_raw', 'image_depth', 'image_mask'] }, { k: pred_nv_cano[k][:micro_bs] for k in ['image_raw', 'image_depth', 'image_mask'] }, ) def eval_loop(self): return super().eval_loop() @th.inference_mode() # def eval_loop(self, c_list:list): def eval_novelview_loop_old(self, camera=None): # novel view synthesis given evaluation camera trajectory all_loss_dict = [] novel_view_micro = {} # ! randomly inference an instance export_mesh = True if export_mesh: Path(f'{logger.get_dir()}/FID_Cals/').mkdir(parents=True, exist_ok=True) # for i in range(0, len(c_list), 1): # TODO, larger batch size for eval batch = {} # if camera is not None: # # batch['c'] = camera.to(batch['c'].device()) # batch['c'] = camera.clone() # else: # batch = for eval_idx, render_reference in enumerate(tqdm(self.eval_data)): if eval_idx > 500: break video_out = imageio.get_writer( f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_{eval_idx}.mp4', mode='I', fps=25, codec='libx264') with open( f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{eval_idx}_caption.txt', 'w') as f: f.write(render_reference['caption']) for key in ['ins', 'bbox', 'caption']: if key in render_reference: render_reference.pop(key) real_flag = False mv_flag = False # TODO, use full-instance for evaluation? Calculate the metrics. if render_reference['c'].shape[:2] == (1, 40): real_flag = True # real img monocular reconstruction # compat lst for enumerate render_reference = [{ k: v[0][idx:idx + 1] for k, v in render_reference.items() } for idx in range(40)] elif render_reference['c'].shape[0] == 8: mv_flag = True render_reference = { k: v[:4] for k, v in render_reference.items() } # save gt torchvision.utils.save_image( render_reference[0:4]['img'], logger.get_dir() + '/FID_Cals/{}_inp.png'.format(eval_idx), padding=0, normalize=True, value_range=(-1, 1), ) # torchvision.utils.save_image(render_reference[4:8]['img'], # logger.get_dir() + '/FID_Cals/{}_inp2.png'.format(eval_idx), # padding=0, # normalize=True, # value_range=(-1,1), # ) else: # compat lst for enumerate st() render_reference = [{ k: v[idx:idx + 1] for k, v in render_reference.items() } for idx in range(40)] # ! single-view version render_reference[0]['img_to_encoder'] = render_reference[14][ 'img_to_encoder'] # encode side view render_reference[0]['img'] = render_reference[14][ 'img'] # encode side view # save gt torchvision.utils.save_image( render_reference[0]['img'], logger.get_dir() + '/FID_Cals/{}_gt.png'.format(eval_idx), padding=0, normalize=True, value_range=(-1, 1)) # ! TODO, merge with render_video_given_triplane later for i, batch in enumerate(render_reference): # for i in range(0, 8, self.microbatch): # c = c_list[i].to(dist_util.dev()).reshape(1, -1) micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} st() if i == 0: if mv_flag: novel_view_micro = None else: novel_view_micro = { k: v[0:1].to(dist_util.dev()).repeat_interleave( # v[14:15].to(dist_util.dev()).repeat_interleave( micro['img'].shape[0], 0) if isinstance(v, th.Tensor) else v[0:1] for k, v in batch.items() } else: if i == 1: # ! output mesh if export_mesh: # ! get planes first # self.latent_name = 'latent_normalized' # normalized triplane latent # ddpm_latent = { # self.latent_name: planes, # } # ddpm_latent.update(self.rec_model(latent=ddpm_latent, behaviour='decode_after_vae_no_render')) # mesh_size = 512 # mesh_size = 256 mesh_size = 384 # mesh_size = 320 # mesh_thres = 3 # TODO, requires tuning # mesh_thres = 5 # TODO, requires tuning mesh_thres = 10 # TODO, requires tuning import mcubes import trimesh dump_path = f'{logger.get_dir()}/mesh/' os.makedirs(dump_path, exist_ok=True) grid_out = self.rec_model( latent=pred, grid_size=mesh_size, behaviour='triplane_decode_grid', ) vtx, faces = mcubes.marching_cubes( grid_out['sigma'].squeeze(0).squeeze( -1).cpu().numpy(), mesh_thres) vtx = vtx / (mesh_size - 1) * 2 - 1 # vtx_tensor = th.tensor(vtx, dtype=th.float32, device=dist_util.dev()).unsqueeze(0) # vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1) # vtx_colors = (vtx_colors * 255).astype(np.uint8) # mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) mesh = trimesh.Trimesh( vertices=vtx, faces=faces, ) mesh_dump_path = os.path.join( dump_path, f'{eval_idx}.ply') mesh.export(mesh_dump_path, 'ply') print(f"Mesh dumped to {dump_path}") del grid_out, mesh th.cuda.empty_cache() # return # st() # if novel_view_micro['c'].shape[0] < micro['img'].shape[0]: novel_view_micro = { k: v[0:1].to(dist_util.dev()).repeat_interleave( micro['img'].shape[0], 0) for k, v in novel_view_micro.items() } pred = self.rec_model(img=novel_view_micro['img_to_encoder'], c=micro['c']) # pred: (B, 3, 64, 64) # target = { # 'img': micro['img'], # 'depth': micro['depth'], # 'depth_mask': micro['depth_mask'] # } # targe # if not export_mesh: if not real_flag: _, loss_dict = self.loss_class(pred, micro, test_mode=True) all_loss_dict.append(loss_dict) # ! move to other places, add tensorboard # pred_vis = th.cat([ # pred['image_raw'], # -pred['image_depth'].repeat_interleave(3, dim=1) # ], # dim=-1) # normalize depth # if True: pred_depth = pred['image_depth'] pred_depth = (pred_depth - pred_depth.min()) / ( pred_depth.max() - pred_depth.min()) if 'image_sr' in pred: if pred['image_sr'].shape[-1] == 512: pred_vis = th.cat([ micro['img_sr'], self.pool_512(pred['image_raw']), pred['image_sr'], self.pool_512(pred_depth).repeat_interleave(3, dim=1) ], dim=-1) elif pred['image_sr'].shape[-1] == 256: pred_vis = th.cat([ micro['img_sr'], self.pool_256(pred['image_raw']), pred['image_sr'], self.pool_256(pred_depth).repeat_interleave(3, dim=1) ], dim=-1) else: pred_vis = th.cat([ micro['img_sr'], self.pool_128(pred['image_raw']), self.pool_128(pred['image_sr']), self.pool_128(pred_depth).repeat_interleave(3, dim=1) ], dim=-1) else: # pred_vis = th.cat([ # self.pool_64(micro['img']), pred['image_raw'], # pred_depth.repeat_interleave(3, dim=1) # ], # dim=-1) # B, 3, H, W pooled_depth = self.pool_128(pred_depth).repeat_interleave( 3, dim=1) pred_vis = th.cat( [ # self.pool_128(micro['img']), self.pool_128(novel_view_micro['img'] ), # use the input here self.pool_128(pred['image_raw']), pooled_depth, ], dim=-1) # B, 3, H, W vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() vis = vis * 127.5 + 127.5 vis = vis.clip(0, 255).astype(np.uint8) if export_mesh: # save image torchvision.utils.save_image( pred['image_raw'], logger.get_dir() + '/FID_Cals/{}_{}.png'.format(eval_idx, i), padding=0, normalize=True, value_range=(-1, 1)) torchvision.utils.save_image( pooled_depth, logger.get_dir() + '/FID_Cals/{}_{}_dpeth.png'.format(eval_idx, i), padding=0, normalize=True, value_range=(0, 1)) # st() for j in range(vis.shape[0]): video_out.append_data(vis[j]) video_out.close() # if not export_mesh: if not real_flag or mv_flag: val_scores_for_logging = calc_average_loss(all_loss_dict) with open(os.path.join(logger.get_dir(), 'scores_novelview.json'), 'a') as f: json.dump({'step': self.step, **val_scores_for_logging}, f) # * log to tensorboard for k, v in val_scores_for_logging.items(): self.writer.add_scalar(f'Eval/NovelView/{k}', v, self.step + self.resume_step) del video_out # del pred_vis # del pred th.cuda.empty_cache() @th.inference_mode() # def eval_loop(self, c_list:list): def eval_novelview_loop(self, camera=None, save_latent=False): # novel view synthesis given evaluation camera trajectory if save_latent: # for diffusion learning latent_dir = Path(f'{logger.get_dir()}/latent_dir') latent_dir.mkdir(exist_ok=True, parents=True) # wds_path = os.path.join(logger.get_dir(), 'latent_dir', # f'wds-%06d.tar') # sink = wds.ShardWriter(wds_path, start_shard=0) # eval_batch_size = 20 # eval_batch_size = 1 eval_batch_size = 40 # ! for i23d for eval_idx, micro in enumerate(tqdm(self.eval_data)): latent = self.rec_model( img=micro['img_to_encoder'], behaviour='encoder_vae') # pred: (B, 3, 64, 64) # torchvision.utils.save_image(micro['img'], 'inp.jpg') if save_latent: # np.save(f'{logger.get_dir()}/latent_dir/{eval_idx}.npy', latent[self.latent_name].cpu().numpy()) latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}' Path(latent_save_dir).mkdir(parents=True, exist_ok=True) np.save(f'{latent_save_dir}/latent.npy', latent[self.latent_name][0].cpu().numpy()) assert all([ micro['ins'][0] == micro['ins'][i] for i in range(micro['c'].shape[0]) ]) # ! assert same instance if eval_idx < 50: # if False: self.render_video_given_triplane( latent[self.latent_name], # B 12 32 32 self.rec_model, # compatible with join_model name_prefix=f'{self.step + self.resume_step}_{eval_idx}', save_img=False, render_reference={'c': camera}, save_mesh=True) class TrainLoop3DRecNVPatchSingleForwardMV(TrainLoop3DRecNVPatchSingleForward): def __init__(self, *, rec_model, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, load_submodule_name='', ignore_resume_opt=False, model_name='rec', use_amp=False, **kwargs): super().__init__(rec_model=rec_model, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, load_submodule_name=load_submodule_name, ignore_resume_opt=ignore_resume_opt, model_name=model_name, use_amp=use_amp, **kwargs) def forward_backward(self, batch, behaviour='g_step', *args, **kwargs): # add patch sampling self.mp_trainer_rec.zero_grad() batch_size = batch['img_to_encoder'].shape[0] batch.pop('caption') # not required batch.pop('nv_caption') # not required batch.pop('ins') # not required batch.pop('nv_ins') # not required if '__key__' in batch.keys(): batch.pop('__key__') for i in range(0, batch_size, self.microbatch): micro = { k: v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( v, th.Tensor) else v[i:i + self.microbatch] for k, v in batch.items() } # ! sample rendering patch # nv_c = th.cat([micro['nv_c'], micro['c']]) nv_c = th.cat([micro['nv_c'], micro['c']]) # nv_c = micro['nv_c'] target = { **self.eg3d_model( c=nv_c, # type: ignore ws=None, planes=None, sample_ray_only=True, fg_bbox=th.cat([micro['nv_bbox'], micro['bbox']])), # rays o / dir } patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ 'patch_rendering_resolution'] # type: ignore cropped_target = { k: th.empty_like(v).repeat_interleave(2, 0) [..., :patch_rendering_resolution, :patch_rendering_resolution] if k not in [ 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', 'nv_img_sr', 'c', 'caption', 'nv_caption' ] else v for k, v in micro.items() } # crop according to uv sampling for j in range(2 * self.microbatch): top, left, height, width = target['ray_bboxes'][ j] # list of tuple # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore for key in ('img', 'depth_mask', 'depth'): # type: ignore if j < self.microbatch: cropped_target[f'{key}'][ # ! no nv_ here j:j + 1] = torchvision.transforms.functional.crop( micro[f'nv_{key}'][j:j + 1], top, left, height, width) else: cropped_target[f'{key}'][ # ! no nv_ here j:j + 1] = torchvision.transforms.functional.crop( micro[f'{key}'][j - self.microbatch:j - self.microbatch + 1], top, left, height, width) # ! vit no amp latent = self.rec_model(img=micro['img_to_encoder'], behaviour='enc_dec_wo_triplane') # wrap forward within amp with th.autocast(device_type='cuda', dtype=th.float16, enabled=self.mp_trainer_rec.use_amp): # c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here # c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here # instance_mv_num = batch_size // 4 # 4 pairs by default # instance_mv_num = 4 # ! roll views for multi-view supervision # c = micro['nv_c'] ray_origins = target['ray_origins'] ray_directions = target['ray_directions'] pred_nv_cano = self.rec_model( # latent=latent.expand(2,), latent={ 'latent_after_vit': # ! triplane for rendering latent['latent_after_vit'].repeat_interleave(4, dim=0).repeat(2,1,1,1) # NV=4 # latent['latent_after_vit'].repeat_interleave(8, dim=0) # NV=4 }, c=nv_c, behaviour='triplane_dec', ray_origins=ray_origins, ray_directions=ray_directions, ) pred_nv_cano.update( latent ) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True) gt = cropped_target with self.rec_model.no_sync(): # type: ignore loss, loss_dict, _ = self.loss_class( pred_nv_cano, gt, # prepare merged data step=self.step + self.resume_step, test_mode=False, return_fg_mask=True, behaviour=behaviour, conf_sigma_l1=None, conf_sigma_percl=None) log_rec3d_loss_dict(loss_dict) self.mp_trainer_rec.backward(loss) # for name, p in self.rec_model.named_parameters(): # if p.grad is None: # logger.log(f"found rec unused param: {name}") # torchvision.utils.save_image(cropped_target['img'], 'gt.png', normalize=True) # torchvision.utils.save_image( pred_nv_cano['image_raw'], 'pred.png', normalize=True) if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0: try: torchvision.utils.save_image( th.cat( [cropped_target['img'], pred_nv_cano['image_raw'] ], ), f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', normalize=True) logger.log( 'log vis to: ', f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') except Exception as e: logger.log(e) # micro_bs = micro['img_to_encoder'].shape[0] # self.log_patch_img( # record one cano view and one novel view # cropped_target, # { # k: pred_nv_cano[k][0:1] # for k in ['image_raw', 'image_depth', 'image_mask'] # }, # { # k: pred_nv_cano[k][1:2] # for k in ['image_raw', 'image_depth', 'image_mask'] # }, # ) # def save(self): # return super().save() class TrainLoop3DRecNVPatchSingleForwardMVAdvLoss( TrainLoop3DRecNVPatchSingleForwardMV): def __init__(self, *, rec_model, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, load_submodule_name='', ignore_resume_opt=False, model_name='rec', use_amp=False, **kwargs): super().__init__(rec_model=rec_model, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, load_submodule_name=load_submodule_name, ignore_resume_opt=ignore_resume_opt, model_name=model_name, use_amp=use_amp, **kwargs) # create discriminator disc_params = self.loss_class.get_trainable_parameters() self.mp_trainer_disc = MixedPrecisionTrainer( model=self.loss_class.discriminator, use_fp16=self.use_fp16, fp16_scale_growth=fp16_scale_growth, model_name='disc', use_amp=use_amp, model_params=disc_params) # st() # check self.lr self.opt_disc = AdamW( self.mp_trainer_disc.master_params, lr=self.lr, # follow sd code base betas=(0, 0.999), eps=1e-8) # TODO, is loss cls already in the DDP? if self.use_ddp: self.ddp_disc = DDP( self.loss_class.discriminator, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=False, ) else: self.ddp_disc = self.loss_class.discriminator # def run_st # def run_step(self, batch, *args): # self.forward_backward(batch) # took_step = self.mp_trainer_rec.optimize(self.opt) # if took_step: # self._update_ema() # self._anneal_lr() # self.log_step() def save(self, mp_trainer=None, model_name='rec'): if mp_trainer is None: mp_trainer = self.mp_trainer_rec def save_checkpoint(rate, params): state_dict = mp_trainer.master_params_to_state_dict(params) if dist_util.get_rank() == 0: logger.log(f"saving model {model_name} {rate}...") if not rate: filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt" else: filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt" with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: th.save(state_dict, f) save_checkpoint(0, mp_trainer.master_params) dist.barrier() def run_step(self, batch, step='g_step'): # self.forward_backward(batch) if step == 'g_step': self.forward_backward(batch, behaviour='g_step') took_step_g_rec = self.mp_trainer_rec.optimize(self.opt) if took_step_g_rec: self._update_ema() # g_ema elif step == 'd_step': self.forward_backward(batch, behaviour='d_step') _ = self.mp_trainer_disc.optimize(self.opt_disc) self._anneal_lr() self.log_step() def run_loop(self, batch=None): while (not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps): batch = next(self.data) self.run_step(batch, 'g_step') batch = next(self.data) self.run_step(batch, 'd_step') if self.step % 1000 == 0: dist_util.synchronize() if self.step % 10000 == 0: th.cuda.empty_cache() # avoid memory leak if self.step % self.log_interval == 0 and dist_util.get_rank( ) == 0: out = logger.dumpkvs() # * log to tensorboard for k, v in out.items(): self.writer.add_scalar(f'Loss/{k}', v, self.step + self.resume_step) if self.step % self.eval_interval == 0 and self.step != 0: if dist_util.get_rank() == 0: try: self.eval_loop() except Exception as e: logger.log(e) dist_util.synchronize() # if self.step % self.save_interval == 0 and self.step != 0: if self.step % self.save_interval == 0: self.save() self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name) dist_util.synchronize() # Run for a finite amount of time in integration tests. if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: return self.step += 1 if self.step > self.iterations: logger.log('reached maximum iterations, exiting') # Save the last checkpoint if it wasn't already saved. if (self.step - 1) % self.save_interval != 0 and self.step != 1: self.save() exit() # Save the last checkpoint if it wasn't already saved. # if (self.step - 1) % self.save_interval != 0 and self.step != 1: if (self.step - 1) % self.save_interval != 0: self.save() # save rec self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name)