|
import copy |
|
import functools |
|
import json |
|
import os |
|
from pathlib import Path |
|
from pdb import set_trace as st |
|
|
|
import blobfile as bf |
|
import imageio |
|
import numpy as np |
|
import torch as th |
|
import torch.distributed as dist |
|
import torchvision |
|
from PIL import Image |
|
from torch.nn.parallel.distributed import DistributedDataParallel as DDP |
|
from torch.optim import AdamW |
|
from torch.utils.tensorboard import SummaryWriter |
|
from tqdm import tqdm |
|
|
|
from guided_diffusion import dist_util, logger |
|
from guided_diffusion.fp16_util import MixedPrecisionTrainer |
|
from guided_diffusion.nn import update_ema |
|
from guided_diffusion.resample import LossAwareSampler, UniformSampler |
|
from guided_diffusion.train_util import (calc_average_loss, |
|
find_ema_checkpoint, |
|
find_resume_checkpoint, |
|
get_blob_logdir, log_rec3d_loss_dict, |
|
parse_resume_step_from_filename) |
|
|
|
from .train_util import TrainLoop3DRec |
|
|
|
|
|
class TrainLoop3DRecEG3D(TrainLoop3DRec): |
|
|
|
def __init__(self, |
|
*, |
|
G, |
|
rec_model, |
|
loss_class, |
|
data, |
|
eval_data, |
|
batch_size, |
|
microbatch, |
|
lr, |
|
ema_rate, |
|
log_interval, |
|
eval_interval, |
|
save_interval, |
|
resume_checkpoint, |
|
use_fp16=False, |
|
fp16_scale_growth=0.001, |
|
weight_decay=0, |
|
lr_anneal_steps=0, |
|
iterations=10001, |
|
load_submodule_name='', |
|
ignore_resume_opt=False, |
|
model_name='rec', |
|
use_amp=False, |
|
|
|
**kwargs): |
|
super().__init__(rec_model=rec_model, |
|
loss_class=loss_class, |
|
data=data, |
|
eval_data=eval_data, |
|
batch_size=batch_size, |
|
microbatch=microbatch, |
|
lr=lr, |
|
ema_rate=ema_rate, |
|
log_interval=log_interval, |
|
eval_interval=eval_interval, |
|
save_interval=save_interval, |
|
resume_checkpoint=resume_checkpoint, |
|
use_fp16=use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
weight_decay=weight_decay, |
|
lr_anneal_steps=lr_anneal_steps, |
|
iterations=iterations, |
|
load_submodule_name=load_submodule_name, |
|
ignore_resume_opt=ignore_resume_opt, |
|
model_name=model_name, |
|
use_amp=use_amp, |
|
**kwargs) |
|
self.G = G |
|
|
|
|
|
self.pool_224 = th.nn.AdaptiveAvgPool2d((224, 224)) |
|
|
|
@th.no_grad() |
|
def run_G( |
|
self, |
|
z, |
|
c, |
|
swapping_prob, |
|
neural_rendering_resolution, |
|
update_emas=False, |
|
return_raw_only=False, |
|
): |
|
"""add truncation psi |
|
|
|
Args: |
|
z (_type_): _description_ |
|
c (_type_): _description_ |
|
swapping_prob (_type_): _description_ |
|
neural_rendering_resolution (_type_): _description_ |
|
update_emas (bool, optional): _description_. Defaults to False. |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
c_gen_conditioning = th.zeros_like(c) |
|
|
|
|
|
|
|
ws = self.G.mapping( |
|
z, |
|
c_gen_conditioning, |
|
truncation_psi=0.7, |
|
truncation_cutoff=None, |
|
update_emas=update_emas, |
|
) |
|
|
|
gen_output = self.G.synthesis( |
|
ws, |
|
c, |
|
neural_rendering_resolution=neural_rendering_resolution, |
|
update_emas=update_emas, |
|
noise_mode='const', |
|
return_raw_only=return_raw_only |
|
|
|
) |
|
|
|
return gen_output, ws |
|
|
|
def run_loop(self, batch=None): |
|
while (not self.lr_anneal_steps |
|
or self.step + self.resume_step < self.lr_anneal_steps): |
|
|
|
|
|
dist_util.synchronize() |
|
|
|
|
|
|
|
batch = next(self.data) |
|
|
|
|
|
self.run_step(batch) |
|
if self.step % self.log_interval == 0 and dist_util.get_rank( |
|
) == 0: |
|
out = logger.dumpkvs() |
|
|
|
for k, v in out.items(): |
|
self.writer.add_scalar(f'Loss/{k}', v, |
|
self.step + self.resume_step) |
|
|
|
if self.step % self.eval_interval == 0 and self.step != 0: |
|
|
|
|
|
|
|
|
|
dist_util.synchronize() |
|
|
|
if self.step % self.save_interval == 0: |
|
self.save() |
|
dist_util.synchronize() |
|
|
|
if os.environ.get("DIFFUSION_TRAINING_TEST", |
|
"") and self.step > 0: |
|
return |
|
|
|
self.step += 1 |
|
|
|
if self.step > self.iterations: |
|
print('reached maximum iterations, exiting') |
|
|
|
|
|
if (self.step - 1) % self.save_interval != 0: |
|
self.save() |
|
|
|
exit() |
|
|
|
|
|
if (self.step - 1) % self.save_interval != 0: |
|
self.save() |
|
|
|
def run_step(self, batch, *args): |
|
self.forward_backward(batch) |
|
took_step = self.mp_trainer_rec.optimize(self.opt) |
|
if took_step: |
|
self._update_ema() |
|
self._anneal_lr() |
|
self.log_step() |
|
|
|
def forward_backward(self, batch, *args, **kwargs): |
|
|
|
self.mp_trainer_rec.zero_grad() |
|
|
|
batch_size = batch['c'].shape[0] |
|
|
|
for i in range(0, batch_size, self.microbatch): |
|
|
|
micro = {'c': batch['c'].to(dist_util.dev())} |
|
|
|
with th.no_grad(): |
|
eg3d_batch, ws = self.run_G( |
|
z=th.randn(micro['c'].shape[0], |
|
512).to(dist_util.dev()), |
|
c=micro['c'].to(dist_util.dev( |
|
)), |
|
swapping_prob=0, |
|
neural_rendering_resolution=128) |
|
|
|
micro.update({ |
|
'img': |
|
eg3d_batch['image_raw'], |
|
'img_to_encoder': |
|
self.pool_224(eg3d_batch['image']), |
|
'depth': |
|
eg3d_batch['image_depth'], |
|
'img_sr': eg3d_batch['image'], |
|
}) |
|
|
|
last_batch = (i + self.microbatch) >= batch_size |
|
|
|
|
|
with th.autocast(device_type='cuda', |
|
dtype=th.float16, |
|
enabled=self.mp_trainer_rec.use_amp): |
|
|
|
pred_gen_output = self.rec_model( |
|
img=micro['img_to_encoder'], |
|
c=micro['c']) |
|
|
|
|
|
target = dict( |
|
img=eg3d_batch['image_raw'], |
|
shape_synthesized=eg3d_batch['shape_synthesized'], |
|
img_sr=eg3d_batch['image'], |
|
) |
|
|
|
pred_gen_output['shape_synthesized_query'] = { |
|
'coarse_densities': |
|
pred_gen_output['shape_synthesized']['coarse_densities'], |
|
'image_depth': pred_gen_output['image_depth'], |
|
} |
|
|
|
eg3d_batch['shape_synthesized']['image_depth'] = eg3d_batch['image_depth'] |
|
|
|
batch_size, num_rays, _, _ = pred_gen_output[ |
|
'shape_synthesized']['coarse_densities'].shape |
|
|
|
|
|
for coord_key in ['fine_coords']: |
|
|
|
sigma = self.rec_model( |
|
latent=pred_gen_output['latent_denormalized'], |
|
coordinates=eg3d_batch['shape_synthesized'][coord_key], |
|
directions=th.randn_like( |
|
eg3d_batch['shape_synthesized'][coord_key]), |
|
behaviour='triplane_renderer', |
|
)['sigma'] |
|
|
|
rendering_kwargs = self.rec_model( |
|
behaviour='get_rendering_kwargs') |
|
|
|
sigma = sigma.reshape( |
|
batch_size, num_rays, |
|
rendering_kwargs['depth_resolution_importance'], 1) |
|
|
|
pred_gen_output['shape_synthesized_query'][ |
|
f"{coord_key.split('_')[0]}_densities"] = sigma |
|
|
|
|
|
if last_batch or not self.use_ddp: |
|
loss, loss_dict = self.loss_class(pred_gen_output, |
|
target, |
|
test_mode=False) |
|
else: |
|
with self.rec_model.no_sync(): |
|
loss, loss_dict = self.loss_class(pred_gen_output, |
|
target, |
|
test_mode=False) |
|
|
|
|
|
|
|
loss_shape = self.calc_shape_rec_loss( |
|
pred_gen_output['shape_synthesized_query'], |
|
eg3d_batch['shape_synthesized']) |
|
|
|
loss += loss_shape.mean() |
|
|
|
|
|
loss_feature_volume = th.nn.functional.mse_loss( |
|
eg3d_batch['feature_volume'], |
|
pred_gen_output['feature_volume']) |
|
loss += loss_feature_volume * 0.1 |
|
|
|
loss_ws = th.nn.functional.mse_loss( |
|
ws[:, -1:, :], |
|
pred_gen_output['sr_w_code']) |
|
loss += loss_ws * 0.1 |
|
|
|
loss_dict.update( |
|
dict(loss_feature_volume=loss_feature_volume, |
|
loss=loss, |
|
loss_shape=loss_shape, |
|
loss_ws=loss_ws)) |
|
|
|
loss_dict.update(dict(loss_feature_volume=loss_feature_volume, loss=loss, loss_shape=loss_shape)) |
|
|
|
log_rec3d_loss_dict(loss_dict) |
|
|
|
|
|
self.mp_trainer_rec.backward(loss) |
|
|
|
|
|
|
|
|
|
|
|
if dist_util.get_rank() == 0 and self.step % 500 == 0: |
|
with th.no_grad(): |
|
|
|
|
|
pred_img = pred_gen_output['image_raw'] |
|
gt_img = micro['img'] |
|
|
|
if 'depth' in micro: |
|
gt_depth = micro['depth'] |
|
if gt_depth.ndim == 3: |
|
gt_depth = gt_depth.unsqueeze(1) |
|
gt_depth = (gt_depth - gt_depth.min()) / ( |
|
gt_depth.max() - gt_depth.min()) |
|
|
|
pred_depth = pred_gen_output['image_depth'] |
|
pred_depth = (pred_depth - pred_depth.min()) / ( |
|
pred_depth.max() - pred_depth.min()) |
|
|
|
gt_vis = th.cat( |
|
[gt_img, |
|
gt_depth.repeat_interleave(3, dim=1)], |
|
dim=-1) |
|
else: |
|
|
|
gt_vis = th.cat( |
|
[gt_img], |
|
dim=-1) |
|
|
|
if 'image_sr' in pred_gen_output: |
|
pred_img = th.cat([ |
|
self.pool_512(pred_img), |
|
pred_gen_output['image_sr'] |
|
], |
|
dim=-1) |
|
pred_depth = self.pool_512(pred_depth) |
|
gt_depth = self.pool_512(gt_depth) |
|
|
|
gt_vis = th.cat( |
|
[self.pool_512(micro['img']), micro['img_sr'], gt_depth.repeat_interleave(3, dim=1)], |
|
dim=-1) |
|
|
|
pred_vis = th.cat( |
|
[pred_img, |
|
pred_depth.repeat_interleave(3, dim=1)], |
|
dim=-1) |
|
|
|
vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( |
|
1, 2, 0).cpu() |
|
|
|
vis = vis.numpy() * 127.5 + 127.5 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
Image.fromarray(vis).save( |
|
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') |
|
print( |
|
'log vis to: ', |
|
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') |
|
|
|
|
|
|
|
|
|
|
|
return pred_gen_output |
|
|
|
def calc_shape_rec_loss( |
|
self, |
|
pred_shape: dict, |
|
gt_shape: dict, |
|
): |
|
|
|
loss_shape, loss_shape_dict = self.loss_class.calc_shape_rec_loss( |
|
pred_shape, |
|
gt_shape, |
|
dist_util.dev(), |
|
) |
|
|
|
for loss_k, loss_v in loss_shape_dict.items(): |
|
|
|
log_rec3d_loss_dict({'Loss/3D/{}'.format(loss_k): loss_v}) |
|
|
|
return loss_shape |
|
|
|
|
|
def eval_novelview_loop(self): |
|
|
|
video_out = imageio.get_writer( |
|
f'{logger.get_dir()}/video_novelview_real_{self.step+self.resume_step}.mp4', |
|
mode='I', |
|
fps=60, |
|
codec='libx264') |
|
|
|
all_loss_dict = [] |
|
novel_view_micro = {} |
|
|
|
|
|
for i, batch in enumerate(tqdm(self.eval_data)): |
|
|
|
|
|
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} |
|
|
|
if i == 0: |
|
novel_view_micro = { |
|
k: v[0:1].to(dist_util.dev()).repeat_interleave( |
|
micro['img'].shape[0], 0) |
|
for k, v in batch.items() |
|
} |
|
|
|
else: |
|
|
|
novel_view_micro = { |
|
k: v[0:1].to(dist_util.dev()).repeat_interleave( |
|
micro['img'].shape[0], 0) |
|
for k, v in novel_view_micro.items() |
|
} |
|
|
|
|
|
|
|
pred = self.rec_model(img=novel_view_micro['img_to_encoder'], |
|
c=micro['c']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_depth = pred['image_depth'] |
|
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - |
|
pred_depth.min()) |
|
if 'image_sr' in pred: |
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_512(pred['image_raw']), pred['image_sr'], |
|
self.pool_512(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
else: |
|
pred_vis = th.cat([ |
|
self.pool_128(micro['img']), pred['image_raw'], |
|
pred_depth.repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() |
|
vis = vis * 127.5 + 127.5 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
|
|
for j in range(vis.shape[0]): |
|
video_out.append_data(vis[j]) |
|
|
|
video_out.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del video_out |
|
|
|
|
|
|
|
th.cuda.empty_cache() |
|
|
|
|
|
|
|
@th.inference_mode() |
|
def eval_novelview_loop_eg3d(self): |
|
|
|
video_out = imageio.get_writer( |
|
f'{logger.get_dir()}/video_novelview_synthetic_{self.step+self.resume_step}.mp4', |
|
mode='I', |
|
fps=60, |
|
codec='libx264') |
|
|
|
all_loss_dict = [] |
|
novel_view_micro = {} |
|
|
|
|
|
for i, batch in enumerate(tqdm(self.eval_data)): |
|
|
|
|
|
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} |
|
|
|
if i == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
with th.no_grad(): |
|
eg3d_batch, _ = self.run_G( |
|
z=th.randn(micro['c'].shape[0], |
|
512).to(dist_util.dev()), |
|
c=micro['c'].to(dist_util.dev( |
|
)), |
|
swapping_prob=0, |
|
neural_rendering_resolution=128) |
|
|
|
novel_view_micro.update({ |
|
'img': |
|
eg3d_batch['image_raw'], |
|
'img_to_encoder': |
|
self.pool_224(eg3d_batch['image']), |
|
'depth': |
|
eg3d_batch['image_depth'], |
|
}) |
|
|
|
else: |
|
|
|
novel_view_micro = { |
|
k: v[0:1].to(dist_util.dev()).repeat_interleave( |
|
micro['img'].shape[0], 0) |
|
for k, v in novel_view_micro.items() |
|
} |
|
|
|
|
|
|
|
pred = self.rec_model(img=novel_view_micro['img_to_encoder'], |
|
c=micro['c']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_depth = pred['image_depth'] |
|
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - |
|
pred_depth.min()) |
|
if 'image_sr' in pred: |
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_512(pred['image_raw']), pred['image_sr'], |
|
self.pool_512(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
else: |
|
pred_vis = th.cat([ |
|
self.pool_128(micro['img']), pred['image_raw'], |
|
pred_depth.repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() |
|
vis = vis * 127.5 + 127.5 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
|
|
for j in range(vis.shape[0]): |
|
video_out.append_data(vis[j]) |
|
|
|
video_out.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
del video_out |
|
|
|
|
|
|
|
th.cuda.empty_cache() |