import tyro import time import random import torch from core.options import AllConfigs from core.models import LGM from accelerate import Accelerator, DistributedDataParallelKwargs from safetensors.torch import load_file import kiui def main(): opt = tyro.cli(AllConfigs) # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( mixed_precision=opt.mixed_precision, gradient_accumulation_steps=opt.gradient_accumulation_steps, # kwargs_handlers=[ddp_kwargs], ) # model model = LGM(opt) # resume if opt.resume is not None: if opt.resume.endswith('safetensors'): ckpt = load_file(opt.resume, device='cpu') else: ckpt = torch.load(opt.resume, map_location='cpu') # tolerant load (only load matching shapes) # model.load_state_dict(ckpt, strict=False) state_dict = model.state_dict() for k, v in ckpt.items(): if k in state_dict: if state_dict[k].shape == v.shape: state_dict[k].copy_(v) else: accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.') else: accelerator.print(f'[WARN] unexpected param {k}: {v.shape}') # data if opt.data_mode == 's3': from core.provider_objaverse import ObjaverseDataset as Dataset else: raise NotImplementedError train_dataset = Dataset(opt, training=True) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True, drop_last=True, ) test_dataset = Dataset(opt, training=False) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False, ) # optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95)) # scheduler (per-iteration) # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3000, eta_min=1e-6) total_steps = opt.num_epochs * len(train_dataloader) pct_start = 3000 / total_steps scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start) # accelerate model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare( model, optimizer, train_dataloader, test_dataloader, scheduler ) # loop for epoch in range(opt.num_epochs): # train model.train() total_loss = 0 total_psnr = 0 for i, data in enumerate(train_dataloader): with accelerator.accumulate(model): optimizer.zero_grad() step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs out = model(data, step_ratio) loss = out['loss'] psnr = out['psnr'] accelerator.backward(loss) # gradient clipping if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip) optimizer.step() scheduler.step() total_loss += loss.detach() total_psnr += psnr.detach() if accelerator.is_main_process: # logging if i % 100 == 0: mem_free, mem_total = torch.cuda.mem_get_info() print(f"[INFO] {i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} step_ratio: {step_ratio:.4f} loss: {loss.item():.6f}") # save log images if i % 500 == 0: gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images) # gt_alphas = data['masks_output'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] # gt_alphas = gt_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, gt_alphas.shape[1] * gt_alphas.shape[3], 1) # kiui.write_image(f'{opt.workspace}/train_gt_alphas_{epoch}_{i}.jpg', gt_alphas) pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images) # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1) # kiui.write_image(f'{opt.workspace}/train_pred_alphas_{epoch}_{i}.jpg', pred_alphas) total_loss = accelerator.gather_for_metrics(total_loss).mean() total_psnr = accelerator.gather_for_metrics(total_psnr).mean() if accelerator.is_main_process: total_loss /= len(train_dataloader) total_psnr /= len(train_dataloader) accelerator.print(f"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}") # checkpoint # if epoch % 10 == 0 or epoch == opt.num_epochs - 1: accelerator.wait_for_everyone() accelerator.save_model(model, opt.workspace) # eval with torch.no_grad(): model.eval() total_psnr = 0 for i, data in enumerate(test_dataloader): out = model(data) psnr = out['psnr'] total_psnr += psnr.detach() # save some images if accelerator.is_main_process: gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images) pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images) # pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size] # pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1) # kiui.write_image(f'{opt.workspace}/eval_pred_alphas_{epoch}_{i}.jpg', pred_alphas) torch.cuda.empty_cache() total_psnr = accelerator.gather_for_metrics(total_psnr).mean() if accelerator.is_main_process: total_psnr /= len(test_dataloader) accelerator.print(f"[eval] epoch: {epoch} psnr: {psnr:.4f}") if __name__ == "__main__": main()