import os import glob import tqdm import math import imageio import random import warnings import tensorboardX import numpy as np import pandas as pd import time from datetime import datetime import cv2 import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.distributed as dist from torch.utils.data import Dataset, DataLoader import trimesh from rich.console import Console from torch_ema import ExponentialMovingAverage from packaging import version as pver def custom_meshgrid(*args): # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid if pver.parse(torch.__version__) < pver.parse('1.10'): return torch.meshgrid(*args) else: return torch.meshgrid(*args, indexing='ij') def safe_normalize(x, eps=1e-20): return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) @torch.cuda.amp.autocast(enabled=False) def get_rays(poses, intrinsics, H, W, N=-1, error_map=None): ''' get rays Args: poses: [B, 4, 4], cam2world intrinsics: [4] H, W, N: int error_map: [B, 128 * 128], sample probability based on training error Returns: rays_o, rays_d: [B, N, 3] inds: [B, N] ''' device = poses.device B = poses.shape[0] fx, fy, cx, cy = intrinsics i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 results = {} if N > 0: N = min(N, H*W) if error_map is None: inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate inds = inds.expand([B, N]) else: # weighted sample on a low-reso grid inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128) # map to the original resolution with random perturb. inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway. sx, sy = H / 128, W / 128 inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1) inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1) inds = inds_x * W + inds_y results['inds_coarse'] = inds_coarse # need this when updating error_map i = torch.gather(i, -1, inds) j = torch.gather(j, -1, inds) results['inds'] = inds else: inds = torch.arange(H*W, device=device).expand([B, H*W]) zs = torch.ones_like(i) xs = (i - cx) / fx * zs ys = (j - cy) / fy * zs directions = torch.stack((xs, ys, zs), dim=-1) directions = safe_normalize(directions) rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) rays_o = poses[..., :3, 3] # [B, 3] rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] results['rays_o'] = rays_o results['rays_d'] = rays_d return results def seed_everything(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) #torch.backends.cudnn.deterministic = True #torch.backends.cudnn.benchmark = True def torch_vis_2d(x, renormalize=False): # x: [3, H, W] or [1, H, W] or [H, W] import matplotlib.pyplot as plt import numpy as np import torch if isinstance(x, torch.Tensor): if len(x.shape) == 3: x = x.permute(1,2,0).squeeze() x = x.detach().cpu().numpy() print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}') x = x.astype(np.float32) # renormalize if renormalize: x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8) plt.imshow(x) plt.show() @torch.jit.script def linear_to_srgb(x): return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055) @torch.jit.script def srgb_to_linear(x): return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) class Trainer(object): def __init__(self, name, # name of this experiment opt, # extra conf model, # network guidance, # guidance network criterion=None, # loss function, if None, assume inline implementation in train_step optimizer=None, # optimizer ema_decay=None, # if use EMA, set the decay lr_scheduler=None, # scheduler metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. local_rank=0, # which GPU am I world_size=1, # total num of GPUs device=None, # device to use, usually setting to None is OK. (auto choose device) mute=False, # whether to mute all print fp16=False, # amp optimize level eval_interval=1, # eval once every $ epoch max_keep_ckpt=2, # max num of saved ckpts in disk workspace='workspace', # workspace to save logs & ckpts best_mode='min', # the smaller/larger result, the better use_loss_as_metric=True, # use loss as the first metric report_metric_at_train=False, # also report metrics at training use_checkpoint="latest", # which ckpt to use at init time use_tensorboardX=True, # whether to use tensorboard for logging scheduler_update_every_step=False, # whether to call scheduler.step() after every train step ): self.name = name self.opt = opt self.mute = mute self.metrics = metrics self.local_rank = local_rank self.world_size = world_size self.workspace = workspace self.ema_decay = ema_decay self.fp16 = fp16 self.best_mode = best_mode self.use_loss_as_metric = use_loss_as_metric self.report_metric_at_train = report_metric_at_train self.max_keep_ckpt = max_keep_ckpt self.eval_interval = eval_interval self.use_checkpoint = use_checkpoint self.use_tensorboardX = use_tensorboardX self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") self.scheduler_update_every_step = scheduler_update_every_step self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') self.console = Console() model.to(self.device) if self.world_size > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) self.model = model # guide model self.guidance = guidance # text prompt if self.guidance is not None: for p in self.guidance.parameters(): p.requires_grad = False self.prepare_text_embeddings() else: self.text_z = None if isinstance(criterion, nn.Module): criterion.to(self.device) self.criterion = criterion if optimizer is None: self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam else: self.optimizer = optimizer(self.model) if lr_scheduler is None: self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler else: self.lr_scheduler = lr_scheduler(self.optimizer) if ema_decay is not None: self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay) else: self.ema = None self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) # variable init self.epoch = 0 self.global_step = 0 self.local_step = 0 self.stats = { "loss": [], "valid_loss": [], "results": [], # metrics[0], or valid_loss "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt "best_result": None, } # auto fix if len(metrics) == 0 or self.use_loss_as_metric: self.best_mode = 'min' # workspace prepare self.log_ptr = None if self.workspace is not None: os.makedirs(self.workspace, exist_ok=True) self.log_path = os.path.join(workspace, f"log_{self.name}.txt") self.log_ptr = open(self.log_path, "a+") self.ckpt_path = os.path.join(self.workspace, 'checkpoints') self.best_path = f"{self.ckpt_path}/{self.name}.pth" os.makedirs(self.ckpt_path, exist_ok=True) self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}') self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') if self.workspace is not None: if self.use_checkpoint == "scratch": self.log("[INFO] Training from scratch ...") elif self.use_checkpoint == "latest": self.log("[INFO] Loading latest checkpoint ...") self.load_checkpoint() elif self.use_checkpoint == "latest_model": self.log("[INFO] Loading latest checkpoint (model only)...") self.load_checkpoint(model_only=True) elif self.use_checkpoint == "best": if os.path.exists(self.best_path): self.log("[INFO] Loading best checkpoint ...") self.load_checkpoint(self.best_path) else: self.log(f"[INFO] {self.best_path} not found, loading latest ...") self.load_checkpoint() else: # path to ckpt self.log(f"[INFO] Loading {self.use_checkpoint} ...") self.load_checkpoint(self.use_checkpoint) # calculate the text embs. def prepare_text_embeddings(self): if self.opt.text is None: self.log(f"[WARN] text prompt is not provided.") self.text_z = None return if not self.opt.dir_text: self.text_z = self.guidance.get_text_embeds([self.opt.text]) else: self.text_z = [] for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']: text = f"{self.opt.text}, {d} view" text_z = self.guidance.get_text_embeds([text]) self.text_z.append(text_z) def __del__(self): if self.log_ptr: self.log_ptr.close() def log(self, *args, **kwargs): if self.local_rank == 0: if not self.mute: #print(*args) self.console.print(*args, **kwargs) if self.log_ptr: print(*args, file=self.log_ptr) self.log_ptr.flush() # write immediately to file ### ------------------------------ def train_step(self, data): rays_o = data['rays_o'] # [B, N, 3] rays_d = data['rays_d'] # [B, N, 3] B, N = rays_o.shape[:2] H, W = data['H'], data['W'] # TODO: shading is not working right now... if self.global_step < self.opt.albedo_iters: shading = 'albedo' ambient_ratio = 1.0 else: rand = random.random() if rand > 0.8: shading = 'albedo' ambient_ratio = 1.0 # elif rand > 0.4: # shading = 'textureless' # ambient_ratio = 0.1 else: shading = 'lambertian' ambient_ratio = 0.1 # _t = time.time() bg_color = torch.rand((B * N, 3), device=rays_o.device) # pixel-wise random outputs = self.model.render(rays_o, rays_d, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt)) pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [1, 3, H, W] # torch.cuda.synchronize(); print(f'[TIME] nerf render {time.time() - _t:.4f}s') # print(shading) # torch_vis_2d(pred_rgb[0]) # text embeddings if self.opt.dir_text: dirs = data['dir'] # [B,] text_z = self.text_z[dirs] else: text_z = self.text_z # encode pred_rgb to latents # _t = time.time() loss = self.guidance.train_step(text_z, pred_rgb) # torch.cuda.synchronize(); print(f'[TIME] total guiding {time.time() - _t:.4f}s') # occupancy loss pred_ws = outputs['weights_sum'].reshape(B, 1, H, W) if self.opt.lambda_opacity > 0: loss_opacity = (pred_ws ** 2).mean() loss = loss + self.opt.lambda_opacity * loss_opacity if self.opt.lambda_entropy > 0: alphas = (pred_ws).clamp(1e-5, 1 - 1e-5) # alphas = alphas ** 2 # skewed entropy, favors 0 over 1 loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean() loss = loss + self.opt.lambda_entropy * loss_entropy if self.opt.lambda_orient > 0 and 'loss_orient' in outputs: loss_orient = outputs['loss_orient'] loss = loss + self.opt.lambda_orient * loss_orient return pred_rgb, pred_ws, loss def eval_step(self, data): rays_o = data['rays_o'] # [B, N, 3] rays_d = data['rays_d'] # [B, N, 3] B, N = rays_o.shape[:2] H, W = data['H'], data['W'] shading = data['shading'] if 'shading' in data else 'albedo' ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0 light_d = data['light_d'] if 'light_d' in data else None outputs = self.model.render(rays_o, rays_d, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt)) pred_rgb = outputs['image'].reshape(B, H, W, 3) pred_depth = outputs['depth'].reshape(B, H, W) pred_ws = outputs['weights_sum'].reshape(B, H, W) # mask_ws = outputs['mask'].reshape(B, H, W) # near < far # loss_ws = pred_ws.sum() / mask_ws.sum() # loss_ws = pred_ws.mean() alphas = (pred_ws).clamp(1e-5, 1 - 1e-5) # alphas = alphas ** 2 # skewed entropy, favors 0 over 1 loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean() loss = self.opt.lambda_entropy * loss_entropy return pred_rgb, pred_depth, loss def test_step(self, data, bg_color=None, perturb=False): rays_o = data['rays_o'] # [B, N, 3] rays_d = data['rays_d'] # [B, N, 3] B, N = rays_o.shape[:2] H, W = data['H'], data['W'] if bg_color is not None: bg_color = bg_color.to(rays_o.device) else: bg_color = torch.ones(3, device=rays_o.device) # [3] shading = data['shading'] if 'shading' in data else 'albedo' ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0 light_d = data['light_d'] if 'light_d' in data else None outputs = self.model.render(rays_o, rays_d, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, bg_color=bg_color, **vars(self.opt)) pred_rgb = outputs['image'].reshape(B, H, W, 3) pred_depth = outputs['depth'].reshape(B, H, W) return pred_rgb, pred_depth def save_mesh(self, save_path=None, resolution=128): if save_path is None: save_path = os.path.join(self.workspace, 'mesh') self.log(f"==> Saving mesh to {save_path}") os.makedirs(save_path, exist_ok=True) self.model.export_mesh(save_path, resolution=resolution) self.log(f"==> Finished saving mesh.") ### ------------------------------ def train(self, train_loader, valid_loader, max_epochs): assert self.text_z is not None, 'Training must provide a text prompt!' if self.use_tensorboardX and self.local_rank == 0: self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name)) start_t = time.time() for epoch in range(self.epoch + 1, max_epochs + 1): self.epoch = epoch self.train_one_epoch(train_loader) if self.workspace is not None and self.local_rank == 0: self.save_checkpoint(full=True, best=False) if self.epoch % self.eval_interval == 0: self.evaluate_one_epoch(valid_loader) self.save_checkpoint(full=False, best=True) end_t = time.time() self.log(f"[INFO] training takes {(end_t - start_t)/ 60:.4f} minutes.") if self.use_tensorboardX and self.local_rank == 0: self.writer.close() def evaluate(self, loader, name=None): self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX self.evaluate_one_epoch(loader, name) self.use_tensorboardX = use_tensorboardX def test(self, loader, save_path=None, name=None, write_video=True): if save_path is None: save_path = os.path.join(self.workspace, 'results') if name is None: name = f'{self.name}_ep{self.epoch:04d}' os.makedirs(save_path, exist_ok=True) self.log(f"==> Start Test, save results to {save_path}") pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') self.model.eval() if write_video: all_preds = [] all_preds_depth = [] with torch.no_grad(): for i, data in enumerate(loader): with torch.cuda.amp.autocast(enabled=self.fp16): preds, preds_depth = self.test_step(data) pred = preds[0].detach().cpu().numpy() pred = (pred * 255).astype(np.uint8) pred_depth = preds_depth[0].detach().cpu().numpy() pred_depth = (pred_depth * 255).astype(np.uint8) if write_video: all_preds.append(pred) all_preds_depth.append(pred_depth) else: cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)) cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth) pbar.update(loader.batch_size) if write_video: all_preds = np.stack(all_preds, axis=0) all_preds_depth = np.stack(all_preds_depth, axis=0) imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, macro_block_size=1) imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1) self.log(f"==> Finished Test.") # [GUI] train text step. def train_gui(self, train_loader, step=16): self.model.train() total_loss = torch.tensor([0], dtype=torch.float32, device=self.device) loader = iter(train_loader) for _ in range(step): # mimic an infinite loop dataloader (in case the total dataset is smaller than step) try: data = next(loader) except StopIteration: loader = iter(train_loader) data = next(loader) # update grid every 16 steps if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: with torch.cuda.amp.autocast(enabled=self.fp16): self.model.update_extra_state() self.global_step += 1 self.optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=self.fp16): pred_rgbs, pred_ws, loss = self.train_step(data) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() if self.scheduler_update_every_step: self.lr_scheduler.step() total_loss += loss.detach() if self.ema is not None: self.ema.update() average_loss = total_loss.item() / step if not self.scheduler_update_every_step: if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step(average_loss) else: self.lr_scheduler.step() outputs = { 'loss': average_loss, 'lr': self.optimizer.param_groups[0]['lr'], } return outputs # [GUI] test on a single image def test_gui(self, pose, intrinsics, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'): # render resolution (may need downscale to for better frame rate) rH = int(H * downscale) rW = int(W * downscale) intrinsics = intrinsics * downscale pose = torch.from_numpy(pose).unsqueeze(0).to(self.device) rays = get_rays(pose, intrinsics, rH, rW, -1) # from degree theta/phi to 3D normalized vec light_d = np.deg2rad(light_d) light_d = np.array([ np.sin(light_d[0]) * np.sin(light_d[1]), np.cos(light_d[0]), np.sin(light_d[0]) * np.cos(light_d[1]), ], dtype=np.float32) light_d = torch.from_numpy(light_d).to(self.device) data = { 'rays_o': rays['rays_o'], 'rays_d': rays['rays_d'], 'H': rH, 'W': rW, 'light_d': light_d, 'ambient_ratio': ambient_ratio, 'shading': shading, } self.model.eval() if self.ema is not None: self.ema.store() self.ema.copy_to() with torch.no_grad(): with torch.cuda.amp.autocast(enabled=self.fp16): # here spp is used as perturb random seed! preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=spp) if self.ema is not None: self.ema.restore() # interpolation to the original resolution if downscale != 1: # have to permute twice with torch... preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous() preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) outputs = { 'image': preds[0].detach().cpu().numpy(), 'depth': preds_depth[0].detach().cpu().numpy(), } return outputs def train_one_epoch(self, loader): self.log(f"==> Start Training {self.workspace} Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...") total_loss = 0 if self.local_rank == 0 and self.report_metric_at_train: for metric in self.metrics: metric.clear() self.model.train() # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs # ref: https://pytorch.org/docs/stable/data.html if self.world_size > 1: loader.sampler.set_epoch(self.epoch) if self.local_rank == 0: pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') self.local_step = 0 for data in loader: # update grid every 16 steps if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: with torch.cuda.amp.autocast(enabled=self.fp16): self.model.update_extra_state() self.local_step += 1 self.global_step += 1 self.optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=self.fp16): pred_rgbs, pred_ws, loss = self.train_step(data) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() if self.scheduler_update_every_step: self.lr_scheduler.step() loss_val = loss.item() total_loss += loss_val if self.local_rank == 0: # if self.report_metric_at_train: # for metric in self.metrics: # metric.update(preds, truths) if self.use_tensorboardX: self.writer.add_scalar("train/loss", loss_val, self.global_step) self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step) if self.scheduler_update_every_step: pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}") else: pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") pbar.update(loader.batch_size) if self.ema is not None: self.ema.update() average_loss = total_loss / self.local_step self.stats["loss"].append(average_loss) if self.local_rank == 0: pbar.close() if self.report_metric_at_train: for metric in self.metrics: self.log(metric.report(), style="red") if self.use_tensorboardX: metric.write(self.writer, self.epoch, prefix="train") metric.clear() if not self.scheduler_update_every_step: if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step(average_loss) else: self.lr_scheduler.step() self.log(f"==> Finished Epoch {self.epoch}.") def evaluate_one_epoch(self, loader, name=None): self.log(f"++> Evaluate {self.workspace} at epoch {self.epoch} ...") if name is None: name = f'{self.name}_ep{self.epoch:04d}' total_loss = 0 if self.local_rank == 0: for metric in self.metrics: metric.clear() self.model.eval() if self.ema is not None: self.ema.store() self.ema.copy_to() if self.local_rank == 0: pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') with torch.no_grad(): self.local_step = 0 for data in loader: self.local_step += 1 with torch.cuda.amp.autocast(enabled=self.fp16): preds, preds_depth, loss = self.eval_step(data) # all_gather/reduce the statistics (NCCL only support all_*) if self.world_size > 1: dist.all_reduce(loss, op=dist.ReduceOp.SUM) loss = loss / self.world_size preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] dist.all_gather(preds_list, preds) preds = torch.cat(preds_list, dim=0) preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] dist.all_gather(preds_depth_list, preds_depth) preds_depth = torch.cat(preds_depth_list, dim=0) loss_val = loss.item() total_loss += loss_val # only rank = 0 will perform evaluation. if self.local_rank == 0: # save image save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png') save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png') #self.log(f"==> Saving validation image to {save_path}") os.makedirs(os.path.dirname(save_path), exist_ok=True) pred = preds[0].detach().cpu().numpy() pred = (pred * 255).astype(np.uint8) pred_depth = preds_depth[0].detach().cpu().numpy() pred_depth = (pred_depth * 255).astype(np.uint8) cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)) cv2.imwrite(save_path_depth, pred_depth) pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") pbar.update(loader.batch_size) average_loss = total_loss / self.local_step self.stats["valid_loss"].append(average_loss) if self.local_rank == 0: pbar.close() if not self.use_loss_as_metric and len(self.metrics) > 0: result = self.metrics[0].measure() self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result else: self.stats["results"].append(average_loss) # if no metric, choose best by min loss for metric in self.metrics: self.log(metric.report(), style="blue") if self.use_tensorboardX: metric.write(self.writer, self.epoch, prefix="evaluate") metric.clear() if self.ema is not None: self.ema.restore() self.log(f"++> Evaluate epoch {self.epoch} Finished.") def save_checkpoint(self, name=None, full=False, best=False): if name is None: name = f'{self.name}_ep{self.epoch:04d}' state = { 'epoch': self.epoch, 'global_step': self.global_step, 'stats': self.stats, } if self.model.cuda_ray: state['mean_count'] = self.model.mean_count state['mean_density'] = self.model.mean_density if full: state['optimizer'] = self.optimizer.state_dict() state['lr_scheduler'] = self.lr_scheduler.state_dict() state['scaler'] = self.scaler.state_dict() if self.ema is not None: state['ema'] = self.ema.state_dict() if not best: state['model'] = self.model.state_dict() file_path = f"{name}.pth" self.stats["checkpoints"].append(file_path) if len(self.stats["checkpoints"]) > self.max_keep_ckpt: old_ckpt = os.path.join(self.ckpt_path, self.stats["checkpoints"].pop(0)) if os.path.exists(old_ckpt): os.remove(old_ckpt) torch.save(state, os.path.join(self.ckpt_path, file_path)) else: if len(self.stats["results"]) > 0: if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]: self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}") self.stats["best_result"] = self.stats["results"][-1] # save ema results if self.ema is not None: self.ema.store() self.ema.copy_to() state['model'] = self.model.state_dict() if self.ema is not None: self.ema.restore() torch.save(state, self.best_path) else: self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.") def load_checkpoint(self, checkpoint=None, model_only=False): if checkpoint is None: checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth')) if checkpoint_list: checkpoint = checkpoint_list[-1] self.log(f"[INFO] Latest checkpoint is {checkpoint}") else: self.log("[WARN] No checkpoint found, model randomly initialized.") return checkpoint_dict = torch.load(checkpoint, map_location=self.device) if 'model' not in checkpoint_dict: self.model.load_state_dict(checkpoint_dict) self.log("[INFO] loaded model.") return missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False) self.log("[INFO] loaded model.") if len(missing_keys) > 0: self.log(f"[WARN] missing keys: {missing_keys}") if len(unexpected_keys) > 0: self.log(f"[WARN] unexpected keys: {unexpected_keys}") if self.ema is not None and 'ema' in checkpoint_dict: try: self.ema.load_state_dict(checkpoint_dict['ema']) self.log("[INFO] loaded EMA.") except: self.log("[WARN] failed to loaded EMA.") if self.model.cuda_ray: if 'mean_count' in checkpoint_dict: self.model.mean_count = checkpoint_dict['mean_count'] if 'mean_density' in checkpoint_dict: self.model.mean_density = checkpoint_dict['mean_density'] if model_only: return self.stats = checkpoint_dict['stats'] self.epoch = checkpoint_dict['epoch'] self.global_step = checkpoint_dict['global_step'] self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") if self.optimizer and 'optimizer' in checkpoint_dict: try: self.optimizer.load_state_dict(checkpoint_dict['optimizer']) self.log("[INFO] loaded optimizer.") except: self.log("[WARN] Failed to load optimizer.") if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: try: self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) self.log("[INFO] loaded scheduler.") except: self.log("[WARN] Failed to load scheduler.") if self.scaler and 'scaler' in checkpoint_dict: try: self.scaler.load_state_dict(checkpoint_dict['scaler']) self.log("[INFO] loaded scaler.") except: self.log("[WARN] Failed to load scaler.")