|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import math |
|
from tqdm.auto import tqdm |
|
import torch |
|
import torch.nn as nn |
|
from torchvision.utils import make_grid |
|
from accelerate.logging import get_logger |
|
|
|
from .base_trainer import Trainer |
|
from openlrm.utils.profiler import DummyProfiler |
|
from openlrm.runners import REGISTRY_RUNNERS |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
@REGISTRY_RUNNERS.register('train.lrm') |
|
class LRMTrainer(Trainer): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.model = self._build_model(self.cfg) |
|
self.optimizer = self._build_optimizer(self.model, self.cfg) |
|
self.train_loader, self.val_loader = self._build_dataloader(self.cfg) |
|
self.scheduler = self._build_scheduler(self.optimizer, self.cfg) |
|
self.pixel_loss_fn, self.perceptual_loss_fn, self.tv_loss_fn = self._build_loss_fn(self.cfg) |
|
|
|
def _build_model(self, cfg): |
|
assert cfg.experiment.type == 'lrm', \ |
|
f"Config type {cfg.experiment.type} does not match with runner {self.__class__.__name__}" |
|
from openlrm.models import ModelLRM |
|
model = ModelLRM(**cfg.model) |
|
return model |
|
|
|
def _build_optimizer(self, model: nn.Module, cfg): |
|
decay_params, no_decay_params = [], [] |
|
|
|
|
|
for name, module in model.named_modules(): |
|
if isinstance(module, nn.LayerNorm): |
|
no_decay_params.extend([p for p in module.parameters()]) |
|
elif hasattr(module, 'bias') and module.bias is not None: |
|
no_decay_params.append(module.bias) |
|
|
|
|
|
_no_decay_ids = set(map(id, no_decay_params)) |
|
decay_params = [p for p in model.parameters() if id(p) not in _no_decay_ids] |
|
|
|
|
|
decay_params = list(filter(lambda p: p.requires_grad, decay_params)) |
|
no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params)) |
|
|
|
|
|
logger.info("======== Weight Decay Parameters ========") |
|
logger.info(f"Total: {len(decay_params)}") |
|
logger.info("======== No Weight Decay Parameters ========") |
|
logger.info(f"Total: {len(no_decay_params)}") |
|
|
|
|
|
opt_groups = [ |
|
{'params': decay_params, 'weight_decay': cfg.train.optim.weight_decay}, |
|
{'params': no_decay_params, 'weight_decay': 0.0}, |
|
] |
|
optimizer = torch.optim.AdamW( |
|
opt_groups, |
|
lr=cfg.train.optim.lr, |
|
betas=(cfg.train.optim.beta1, cfg.train.optim.beta2), |
|
) |
|
|
|
return optimizer |
|
|
|
def _build_scheduler(self, optimizer, cfg): |
|
local_batches_per_epoch = math.floor(len(self.train_loader) / self.accelerator.num_processes) |
|
total_global_batches = cfg.train.epochs * math.ceil(local_batches_per_epoch / self.cfg.train.accum_steps) |
|
effective_warmup_iters = cfg.train.scheduler.warmup_real_iters |
|
logger.debug(f"======== Scheduler effective max iters: {total_global_batches} ========") |
|
logger.debug(f"======== Scheduler effective warmup iters: {effective_warmup_iters} ========") |
|
if cfg.train.scheduler.type == 'cosine': |
|
from openlrm.utils.scheduler import CosineWarmupScheduler |
|
scheduler = CosineWarmupScheduler( |
|
optimizer=optimizer, |
|
warmup_iters=effective_warmup_iters, |
|
max_iters=total_global_batches, |
|
) |
|
else: |
|
raise NotImplementedError(f"Scheduler type {cfg.train.scheduler.type} not implemented") |
|
return scheduler |
|
|
|
def _build_dataloader(self, cfg): |
|
|
|
from openlrm.datasets import MixerDataset |
|
|
|
|
|
train_dataset = MixerDataset( |
|
split="train", |
|
subsets=cfg.dataset.subsets, |
|
sample_side_views=cfg.dataset.sample_side_views, |
|
render_image_res_low=cfg.dataset.render_image.low, |
|
render_image_res_high=cfg.dataset.render_image.high, |
|
render_region_size=cfg.dataset.render_image.region, |
|
source_image_res=cfg.dataset.source_image_res, |
|
normalize_camera=cfg.dataset.normalize_camera, |
|
normed_dist_to_center=cfg.dataset.normed_dist_to_center, |
|
) |
|
val_dataset = MixerDataset( |
|
split="val", |
|
subsets=cfg.dataset.subsets, |
|
sample_side_views=cfg.dataset.sample_side_views, |
|
render_image_res_low=cfg.dataset.render_image.low, |
|
render_image_res_high=cfg.dataset.render_image.high, |
|
render_region_size=cfg.dataset.render_image.region, |
|
source_image_res=cfg.dataset.source_image_res, |
|
normalize_camera=cfg.dataset.normalize_camera, |
|
normed_dist_to_center=cfg.dataset.normed_dist_to_center, |
|
) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_size=cfg.train.batch_size, |
|
shuffle=True, |
|
drop_last=True, |
|
num_workers=cfg.dataset.num_train_workers, |
|
pin_memory=cfg.dataset.pin_mem, |
|
persistent_workers=True, |
|
) |
|
val_loader = torch.utils.data.DataLoader( |
|
val_dataset, |
|
batch_size=cfg.val.batch_size, |
|
shuffle=False, |
|
drop_last=False, |
|
num_workers=cfg.dataset.num_val_workers, |
|
pin_memory=cfg.dataset.pin_mem, |
|
persistent_workers=False, |
|
) |
|
|
|
return train_loader, val_loader |
|
|
|
def _build_loss_fn(self, cfg): |
|
from openlrm.losses import PixelLoss, LPIPSLoss, TVLoss |
|
pixel_loss_fn = PixelLoss() |
|
with self.accelerator.main_process_first(): |
|
perceptual_loss_fn = LPIPSLoss(device=self.device, prefech=True) |
|
tv_loss_fn = TVLoss() |
|
return pixel_loss_fn, perceptual_loss_fn, tv_loss_fn |
|
|
|
def register_hooks(self): |
|
pass |
|
|
|
def forward_loss_local_step(self, data): |
|
|
|
source_camera = data['source_camera'] |
|
render_camera = data['render_camera'] |
|
source_image = data['source_image'] |
|
render_image = data['render_image'] |
|
if 'source_image_back' in data: |
|
source_image_back = data['source_image_back'] |
|
else: |
|
source_image_back = None |
|
render_anchors = data['render_anchors'] |
|
render_full_resolutions = data['render_full_resolutions'] |
|
render_bg_colors = data['render_bg_colors'] |
|
|
|
N, M, C, H, W = render_image.shape |
|
|
|
|
|
outputs = self.model( |
|
image=source_image, |
|
source_camera=source_camera, |
|
render_cameras=render_camera, |
|
render_anchors=render_anchors, |
|
render_resolutions=render_full_resolutions, |
|
render_bg_colors=render_bg_colors, |
|
render_region_size=self.cfg.dataset.render_image.region, |
|
image_back=source_image_back, |
|
) |
|
|
|
|
|
loss = 0. |
|
loss_pixel = None |
|
loss_perceptual = None |
|
loss_tv = None |
|
|
|
if self.cfg.train.loss.pixel_weight > 0.: |
|
loss_pixel = self.pixel_loss_fn(outputs['images_rgb'], render_image) |
|
loss += loss_pixel * self.cfg.train.loss.pixel_weight |
|
if self.cfg.train.loss.perceptual_weight > 0.: |
|
loss_perceptual = self.perceptual_loss_fn(outputs['images_rgb'], render_image) |
|
loss += loss_perceptual * self.cfg.train.loss.perceptual_weight |
|
if self.cfg.train.loss.tv_weight > 0.: |
|
loss_tv = self.tv_loss_fn(outputs['planes']) |
|
loss += loss_tv * self.cfg.train.loss.tv_weight |
|
|
|
return outputs, loss, loss_pixel, loss_perceptual, loss_tv |
|
|
|
def train_epoch(self, pbar: tqdm, loader: torch.utils.data.DataLoader, profiler: torch.profiler.profile): |
|
self.model.train() |
|
|
|
local_step_losses = [] |
|
global_step_losses = [] |
|
|
|
logger.debug(f"======== Starting epoch {self.current_epoch} ========") |
|
for data in loader: |
|
|
|
logger.debug(f"======== Starting global step {self.global_step} ========") |
|
with self.accelerator.accumulate(self.model): |
|
|
|
|
|
outs, loss, loss_pixel, loss_perceptual, loss_tv = self.forward_loss_local_step(data) |
|
|
|
|
|
self.accelerator.backward(loss) |
|
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.: |
|
self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm) |
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
|
|
|
|
local_step_losses.append(torch.stack([ |
|
_loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device) |
|
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv] |
|
])) |
|
|
|
|
|
if self.accelerator.sync_gradients: |
|
profiler.step() |
|
self.scheduler.step() |
|
logger.debug(f"======== Scheduler step ========") |
|
self.global_step += 1 |
|
global_step_loss = self.accelerator.gather(torch.stack(local_step_losses)).mean(dim=0).cpu() |
|
loss, loss_pixel, loss_perceptual, loss_tv = global_step_loss.unbind() |
|
loss_kwargs = { |
|
'loss': loss.item(), |
|
'loss_pixel': loss_pixel.item(), |
|
'loss_perceptual': loss_perceptual.item(), |
|
'loss_tv': loss_tv.item(), |
|
} |
|
self.log_scalar_kwargs( |
|
step=self.global_step, split='train', |
|
**loss_kwargs |
|
) |
|
self.log_optimizer(step=self.global_step, attrs=['lr'], group_ids=[0, 1]) |
|
local_step_losses = [] |
|
global_step_losses.append(global_step_loss) |
|
|
|
|
|
pbar.update(1) |
|
description = { |
|
**loss_kwargs, |
|
'lr': self.optimizer.param_groups[0]['lr'], |
|
} |
|
description = '[TRAIN STEP]' + \ |
|
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in description.items() if not math.isnan(v)) |
|
pbar.set_description(description) |
|
|
|
|
|
if self.global_step % self.cfg.saver.checkpoint_global_steps == 0: |
|
self.save_checkpoint() |
|
if self.global_step % self.cfg.val.global_step_period == 0: |
|
self.evaluate() |
|
self.model.train() |
|
if self.global_step % self.cfg.logger.image_monitor.train_global_steps == 0: |
|
self.log_image_monitor( |
|
step=self.global_step, split='train', |
|
renders=outs['images_rgb'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
gts=data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
) |
|
|
|
|
|
if self.global_step >= self.N_max_global_steps: |
|
self.accelerator.set_trigger() |
|
break |
|
|
|
|
|
self.current_epoch += 1 |
|
epoch_losses = torch.stack(global_step_losses).mean(dim=0) |
|
epoch_loss, epoch_loss_pixel, epoch_loss_perceptual, epoch_loss_tv = epoch_losses.unbind() |
|
epoch_loss_dict = { |
|
'loss': epoch_loss.item(), |
|
'loss_pixel': epoch_loss_pixel.item(), |
|
'loss_perceptual': epoch_loss_perceptual.item(), |
|
'loss_tv': epoch_loss_tv.item(), |
|
} |
|
self.log_scalar_kwargs( |
|
epoch=self.current_epoch, split='train', |
|
**epoch_loss_dict, |
|
) |
|
logger.info( |
|
f'[TRAIN EPOCH] {self.current_epoch}/{self.cfg.train.epochs}: ' + \ |
|
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in epoch_loss_dict.items() if not math.isnan(v)) |
|
) |
|
|
|
def train(self): |
|
|
|
starting_local_step_in_epoch = self.global_step_in_epoch * self.cfg.train.accum_steps |
|
skipped_loader = self.accelerator.skip_first_batches(self.train_loader, starting_local_step_in_epoch) |
|
logger.info(f"======== Skipped {starting_local_step_in_epoch} local batches ========") |
|
|
|
with tqdm( |
|
range(0, self.N_max_global_steps), |
|
initial=self.global_step, |
|
disable=(not self.accelerator.is_main_process), |
|
) as pbar: |
|
|
|
profiler = torch.profiler.profile( |
|
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], |
|
schedule=torch.profiler.schedule( |
|
wait=10, warmup=10, active=100, |
|
), |
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join( |
|
self.cfg.logger.tracker_root, |
|
self.cfg.experiment.parent, self.cfg.experiment.child, |
|
)), |
|
record_shapes=True, |
|
profile_memory=True, |
|
with_stack=True, |
|
) if self.cfg.logger.enable_profiler else DummyProfiler() |
|
|
|
with profiler: |
|
|
|
self.optimizer.zero_grad() |
|
for _ in range(self.current_epoch, self.cfg.train.epochs): |
|
|
|
loader = skipped_loader or self.train_loader |
|
skipped_loader = None |
|
self.train_epoch(pbar=pbar, loader=loader, profiler=profiler) |
|
if self.accelerator.check_trigger(): |
|
break |
|
|
|
logger.info(f"======== Training finished at global step {self.global_step} ========") |
|
|
|
|
|
self.save_checkpoint() |
|
self.evaluate() |
|
|
|
@torch.no_grad() |
|
@torch.compiler.disable |
|
def evaluate(self, epoch: int = None): |
|
self.model.eval() |
|
|
|
max_val_batches = self.cfg.val.debug_batches or len(self.val_loader) |
|
running_losses = [] |
|
sample_data, sample_outs = None, None |
|
|
|
for data in tqdm(self.val_loader, disable=(not self.accelerator.is_main_process), total=max_val_batches): |
|
|
|
if len(running_losses) >= max_val_batches: |
|
logger.info(f"======== Early stop validation at {len(running_losses)} batches ========") |
|
break |
|
|
|
outs, loss, loss_pixel, loss_perceptual, loss_tv = self.forward_loss_local_step(data) |
|
sample_data, sample_outs = data, outs |
|
|
|
running_losses.append(torch.stack([ |
|
_loss if _loss is not None else torch.tensor(float('nan'), device=self.device) |
|
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv] |
|
])) |
|
|
|
total_losses = self.accelerator.gather(torch.stack(running_losses)).mean(dim=0).cpu() |
|
total_loss, total_loss_pixel, total_loss_perceptual, total_loss_tv = total_losses.unbind() |
|
total_loss_dict = { |
|
'loss': total_loss.item(), |
|
'loss_pixel': total_loss_pixel.item(), |
|
'loss_perceptual': total_loss_perceptual.item(), |
|
'loss_tv': total_loss_tv.item(), |
|
} |
|
|
|
if epoch is not None: |
|
self.log_scalar_kwargs( |
|
epoch=epoch, split='val', |
|
**total_loss_dict, |
|
) |
|
logger.info( |
|
f'[VAL EPOCH] {epoch}/{self.cfg.train.epochs}: ' + \ |
|
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v)) |
|
) |
|
self.log_image_monitor( |
|
epoch=epoch, split='val', |
|
renders=sample_outs['images_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
) |
|
else: |
|
self.log_scalar_kwargs( |
|
step=self.global_step, split='val', |
|
**total_loss_dict, |
|
) |
|
logger.info( |
|
f'[VAL STEP] {self.global_step}/{self.N_max_global_steps}: ' + \ |
|
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v)) |
|
) |
|
self.log_image_monitor( |
|
step=self.global_step, split='val', |
|
renders=sample_outs['images_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(), |
|
) |
|
|
|
@Trainer.control('on_main_process') |
|
def log_image_monitor( |
|
self, epoch: int = None, step: int = None, split: str = None, |
|
renders: torch.Tensor = None, gts: torch.Tensor = None, |
|
): |
|
M = renders.shape[1] |
|
merged = torch.stack([renders, gts], dim=1)[0].view(-1, *renders.shape[2:]) |
|
renders, gts = renders.view(-1, *renders.shape[2:]), gts.view(-1, *gts.shape[2:]) |
|
renders, gts, merged = make_grid(renders, nrow=M), make_grid(gts, nrow=M), make_grid(merged, nrow=M) |
|
log_type, log_progress = self._get_str_progress(epoch, step) |
|
split = f'/{split}' if split else '' |
|
self.log_images({ |
|
f'Images_split{split}/rendered': renders.unsqueeze(0), |
|
f'Images_split{split}/gt': gts.unsqueeze(0), |
|
f'Images_merged{split}': merged.unsqueeze(0), |
|
}, log_progress) |
|
|