alexzyqi's picture
20240706
52d68d4
# Copyright (c) 2023-2024, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
from tqdm.auto import tqdm
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from accelerate.logging import get_logger
from .base_trainer import Trainer
from openlrm.utils.profiler import DummyProfiler
from openlrm.runners import REGISTRY_RUNNERS
logger = get_logger(__name__)
@REGISTRY_RUNNERS.register('train.lrm')
class LRMTrainer(Trainer):
def __init__(self):
super().__init__()
self.model = self._build_model(self.cfg)
self.optimizer = self._build_optimizer(self.model, self.cfg)
self.train_loader, self.val_loader = self._build_dataloader(self.cfg)
self.scheduler = self._build_scheduler(self.optimizer, self.cfg)
self.pixel_loss_fn, self.perceptual_loss_fn, self.tv_loss_fn = self._build_loss_fn(self.cfg)
def _build_model(self, cfg):
assert cfg.experiment.type == 'lrm', \
f"Config type {cfg.experiment.type} does not match with runner {self.__class__.__name__}"
from openlrm.models import ModelLRM
model = ModelLRM(**cfg.model)
return model
def _build_optimizer(self, model: nn.Module, cfg):
decay_params, no_decay_params = [], []
# add all bias and LayerNorm params to no_decay_params
for name, module in model.named_modules():
if isinstance(module, nn.LayerNorm):
no_decay_params.extend([p for p in module.parameters()])
elif hasattr(module, 'bias') and module.bias is not None:
no_decay_params.append(module.bias)
# add remaining parameters to decay_params
_no_decay_ids = set(map(id, no_decay_params))
decay_params = [p for p in model.parameters() if id(p) not in _no_decay_ids]
# filter out parameters with no grad
decay_params = list(filter(lambda p: p.requires_grad, decay_params))
no_decay_params = list(filter(lambda p: p.requires_grad, no_decay_params))
# monitor this to make sure we don't miss any parameters
logger.info("======== Weight Decay Parameters ========")
logger.info(f"Total: {len(decay_params)}")
logger.info("======== No Weight Decay Parameters ========")
logger.info(f"Total: {len(no_decay_params)}")
# Optimizer
opt_groups = [
{'params': decay_params, 'weight_decay': cfg.train.optim.weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0},
]
optimizer = torch.optim.AdamW(
opt_groups,
lr=cfg.train.optim.lr,
betas=(cfg.train.optim.beta1, cfg.train.optim.beta2),
)
return optimizer
def _build_scheduler(self, optimizer, cfg):
local_batches_per_epoch = math.floor(len(self.train_loader) / self.accelerator.num_processes)
total_global_batches = cfg.train.epochs * math.ceil(local_batches_per_epoch / self.cfg.train.accum_steps)
effective_warmup_iters = cfg.train.scheduler.warmup_real_iters
logger.debug(f"======== Scheduler effective max iters: {total_global_batches} ========")
logger.debug(f"======== Scheduler effective warmup iters: {effective_warmup_iters} ========")
if cfg.train.scheduler.type == 'cosine':
from openlrm.utils.scheduler import CosineWarmupScheduler
scheduler = CosineWarmupScheduler(
optimizer=optimizer,
warmup_iters=effective_warmup_iters,
max_iters=total_global_batches,
)
else:
raise NotImplementedError(f"Scheduler type {cfg.train.scheduler.type} not implemented")
return scheduler
def _build_dataloader(self, cfg):
# dataset class
from openlrm.datasets import MixerDataset
# build dataset
train_dataset = MixerDataset(
split="train",
subsets=cfg.dataset.subsets,
sample_side_views=cfg.dataset.sample_side_views,
render_image_res_low=cfg.dataset.render_image.low,
render_image_res_high=cfg.dataset.render_image.high,
render_region_size=cfg.dataset.render_image.region,
source_image_res=cfg.dataset.source_image_res,
normalize_camera=cfg.dataset.normalize_camera,
normed_dist_to_center=cfg.dataset.normed_dist_to_center,
)
val_dataset = MixerDataset(
split="val",
subsets=cfg.dataset.subsets,
sample_side_views=cfg.dataset.sample_side_views,
render_image_res_low=cfg.dataset.render_image.low,
render_image_res_high=cfg.dataset.render_image.high,
render_region_size=cfg.dataset.render_image.region,
source_image_res=cfg.dataset.source_image_res,
normalize_camera=cfg.dataset.normalize_camera,
normed_dist_to_center=cfg.dataset.normed_dist_to_center,
)
# build data loader
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=cfg.train.batch_size,
shuffle=True,
drop_last=True,
num_workers=cfg.dataset.num_train_workers,
pin_memory=cfg.dataset.pin_mem,
persistent_workers=True,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=cfg.val.batch_size,
shuffle=False,
drop_last=False,
num_workers=cfg.dataset.num_val_workers,
pin_memory=cfg.dataset.pin_mem,
persistent_workers=False,
)
return train_loader, val_loader
def _build_loss_fn(self, cfg):
from openlrm.losses import PixelLoss, LPIPSLoss, TVLoss
pixel_loss_fn = PixelLoss()
with self.accelerator.main_process_first():
perceptual_loss_fn = LPIPSLoss(device=self.device, prefech=True)
tv_loss_fn = TVLoss()
return pixel_loss_fn, perceptual_loss_fn, tv_loss_fn
def register_hooks(self):
pass
def forward_loss_local_step(self, data):
source_camera = data['source_camera']
render_camera = data['render_camera']
source_image = data['source_image']
render_image = data['render_image']
if 'source_image_back' in data:
source_image_back = data['source_image_back'] #!!!
else:
source_image_back = None
render_anchors = data['render_anchors']
render_full_resolutions = data['render_full_resolutions']
render_bg_colors = data['render_bg_colors']
N, M, C, H, W = render_image.shape
# forward
outputs = self.model(
image=source_image,
source_camera=source_camera,
render_cameras=render_camera,
render_anchors=render_anchors,
render_resolutions=render_full_resolutions,
render_bg_colors=render_bg_colors,
render_region_size=self.cfg.dataset.render_image.region,
image_back=source_image_back, #!!!
)
# loss calculation
loss = 0.
loss_pixel = None
loss_perceptual = None
loss_tv = None
if self.cfg.train.loss.pixel_weight > 0.:
loss_pixel = self.pixel_loss_fn(outputs['images_rgb'], render_image)
loss += loss_pixel * self.cfg.train.loss.pixel_weight
if self.cfg.train.loss.perceptual_weight > 0.:
loss_perceptual = self.perceptual_loss_fn(outputs['images_rgb'], render_image)
loss += loss_perceptual * self.cfg.train.loss.perceptual_weight
if self.cfg.train.loss.tv_weight > 0.:
loss_tv = self.tv_loss_fn(outputs['planes'])
loss += loss_tv * self.cfg.train.loss.tv_weight
return outputs, loss, loss_pixel, loss_perceptual, loss_tv
def train_epoch(self, pbar: tqdm, loader: torch.utils.data.DataLoader, profiler: torch.profiler.profile):
self.model.train()
local_step_losses = []
global_step_losses = []
logger.debug(f"======== Starting epoch {self.current_epoch} ========")
for data in loader:
logger.debug(f"======== Starting global step {self.global_step} ========")
with self.accelerator.accumulate(self.model):
# forward to loss
outs, loss, loss_pixel, loss_perceptual, loss_tv = self.forward_loss_local_step(data)
# backward
self.accelerator.backward(loss)
if self.accelerator.sync_gradients and self.cfg.train.optim.clip_grad_norm > 0.:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.cfg.train.optim.clip_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
# track local losses
local_step_losses.append(torch.stack([
_loss.detach() if _loss is not None else torch.tensor(float('nan'), device=self.device)
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv]
]))
# track global step
if self.accelerator.sync_gradients:
profiler.step()
self.scheduler.step()
logger.debug(f"======== Scheduler step ========")
self.global_step += 1
global_step_loss = self.accelerator.gather(torch.stack(local_step_losses)).mean(dim=0).cpu()
loss, loss_pixel, loss_perceptual, loss_tv = global_step_loss.unbind()
loss_kwargs = {
'loss': loss.item(),
'loss_pixel': loss_pixel.item(),
'loss_perceptual': loss_perceptual.item(),
'loss_tv': loss_tv.item(),
}
self.log_scalar_kwargs(
step=self.global_step, split='train',
**loss_kwargs
)
self.log_optimizer(step=self.global_step, attrs=['lr'], group_ids=[0, 1])
local_step_losses = []
global_step_losses.append(global_step_loss)
# manage display
pbar.update(1)
description = {
**loss_kwargs,
'lr': self.optimizer.param_groups[0]['lr'],
}
description = '[TRAIN STEP]' + \
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in description.items() if not math.isnan(v))
pbar.set_description(description)
# periodic actions
if self.global_step % self.cfg.saver.checkpoint_global_steps == 0:
self.save_checkpoint()
if self.global_step % self.cfg.val.global_step_period == 0:
self.evaluate()
self.model.train()
if self.global_step % self.cfg.logger.image_monitor.train_global_steps == 0:
self.log_image_monitor(
step=self.global_step, split='train',
renders=outs['images_rgb'].detach()[:self.cfg.logger.image_monitor.samples_per_log].cpu(),
gts=data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
)
# progress control
if self.global_step >= self.N_max_global_steps:
self.accelerator.set_trigger()
break
# track epoch
self.current_epoch += 1
epoch_losses = torch.stack(global_step_losses).mean(dim=0)
epoch_loss, epoch_loss_pixel, epoch_loss_perceptual, epoch_loss_tv = epoch_losses.unbind()
epoch_loss_dict = {
'loss': epoch_loss.item(),
'loss_pixel': epoch_loss_pixel.item(),
'loss_perceptual': epoch_loss_perceptual.item(),
'loss_tv': epoch_loss_tv.item(),
}
self.log_scalar_kwargs(
epoch=self.current_epoch, split='train',
**epoch_loss_dict,
)
logger.info(
f'[TRAIN EPOCH] {self.current_epoch}/{self.cfg.train.epochs}: ' + \
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in epoch_loss_dict.items() if not math.isnan(v))
)
def train(self):
starting_local_step_in_epoch = self.global_step_in_epoch * self.cfg.train.accum_steps
skipped_loader = self.accelerator.skip_first_batches(self.train_loader, starting_local_step_in_epoch)
logger.info(f"======== Skipped {starting_local_step_in_epoch} local batches ========")
with tqdm(
range(0, self.N_max_global_steps),
initial=self.global_step,
disable=(not self.accelerator.is_main_process),
) as pbar:
profiler = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=10, warmup=10, active=100,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(
self.cfg.logger.tracker_root,
self.cfg.experiment.parent, self.cfg.experiment.child,
)),
record_shapes=True,
profile_memory=True,
with_stack=True,
) if self.cfg.logger.enable_profiler else DummyProfiler()
with profiler:
self.optimizer.zero_grad()
for _ in range(self.current_epoch, self.cfg.train.epochs):
loader = skipped_loader or self.train_loader
skipped_loader = None
self.train_epoch(pbar=pbar, loader=loader, profiler=profiler)
if self.accelerator.check_trigger():
break
logger.info(f"======== Training finished at global step {self.global_step} ========")
# final checkpoint and evaluation
self.save_checkpoint()
self.evaluate()
@torch.no_grad()
@torch.compiler.disable
def evaluate(self, epoch: int = None):
self.model.eval()
max_val_batches = self.cfg.val.debug_batches or len(self.val_loader)
running_losses = []
sample_data, sample_outs = None, None
for data in tqdm(self.val_loader, disable=(not self.accelerator.is_main_process), total=max_val_batches):
if len(running_losses) >= max_val_batches:
logger.info(f"======== Early stop validation at {len(running_losses)} batches ========")
break
outs, loss, loss_pixel, loss_perceptual, loss_tv = self.forward_loss_local_step(data)
sample_data, sample_outs = data, outs
running_losses.append(torch.stack([
_loss if _loss is not None else torch.tensor(float('nan'), device=self.device)
for _loss in [loss, loss_pixel, loss_perceptual, loss_tv]
]))
total_losses = self.accelerator.gather(torch.stack(running_losses)).mean(dim=0).cpu()
total_loss, total_loss_pixel, total_loss_perceptual, total_loss_tv = total_losses.unbind()
total_loss_dict = {
'loss': total_loss.item(),
'loss_pixel': total_loss_pixel.item(),
'loss_perceptual': total_loss_perceptual.item(),
'loss_tv': total_loss_tv.item(),
}
if epoch is not None:
self.log_scalar_kwargs(
epoch=epoch, split='val',
**total_loss_dict,
)
logger.info(
f'[VAL EPOCH] {epoch}/{self.cfg.train.epochs}: ' + \
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v))
)
self.log_image_monitor(
epoch=epoch, split='val',
renders=sample_outs['images_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
)
else:
self.log_scalar_kwargs(
step=self.global_step, split='val',
**total_loss_dict,
)
logger.info(
f'[VAL STEP] {self.global_step}/{self.N_max_global_steps}: ' + \
', '.join(f'{k}={tqdm.format_num(v)}' for k, v in total_loss_dict.items() if not math.isnan(v))
)
self.log_image_monitor(
step=self.global_step, split='val',
renders=sample_outs['images_rgb'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
gts=sample_data['render_image'][:self.cfg.logger.image_monitor.samples_per_log].cpu(),
)
@Trainer.control('on_main_process')
def log_image_monitor(
self, epoch: int = None, step: int = None, split: str = None,
renders: torch.Tensor = None, gts: torch.Tensor = None,
):
M = renders.shape[1]
merged = torch.stack([renders, gts], dim=1)[0].view(-1, *renders.shape[2:])
renders, gts = renders.view(-1, *renders.shape[2:]), gts.view(-1, *gts.shape[2:])
renders, gts, merged = make_grid(renders, nrow=M), make_grid(gts, nrow=M), make_grid(merged, nrow=M)
log_type, log_progress = self._get_str_progress(epoch, step)
split = f'/{split}' if split else ''
self.log_images({
f'Images_split{split}/rendered': renders.unsqueeze(0),
f'Images_split{split}/gt': gts.unsqueeze(0),
f'Images_merged{split}': merged.unsqueeze(0),
}, log_progress)