lgm / main.py
brarnovidra's picture
lgm
e206fc8
import tyro
import time
import random
import torch
from core.options import AllConfigs
from core.models import LGM
from accelerate import Accelerator, DistributedDataParallelKwargs
from safetensors.torch import load_file
import kiui
def main():
opt = tyro.cli(AllConfigs)
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
mixed_precision=opt.mixed_precision,
gradient_accumulation_steps=opt.gradient_accumulation_steps,
# kwargs_handlers=[ddp_kwargs],
)
# model
model = LGM(opt)
# resume
if opt.resume is not None:
if opt.resume.endswith('safetensors'):
ckpt = load_file(opt.resume, device='cpu')
else:
ckpt = torch.load(opt.resume, map_location='cpu')
# tolerant load (only load matching shapes)
# model.load_state_dict(ckpt, strict=False)
state_dict = model.state_dict()
for k, v in ckpt.items():
if k in state_dict:
if state_dict[k].shape == v.shape:
state_dict[k].copy_(v)
else:
accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
else:
accelerator.print(f'[WARN] unexpected param {k}: {v.shape}')
# data
if opt.data_mode == 's3':
from core.provider_objaverse import ObjaverseDataset as Dataset
else:
raise NotImplementedError
train_dataset = Dataset(opt, training=True)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.num_workers,
pin_memory=True,
drop_last=True,
)
test_dataset = Dataset(opt, training=False)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=opt.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=False,
)
# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95))
# scheduler (per-iteration)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3000, eta_min=1e-6)
total_steps = opt.num_epochs * len(train_dataloader)
pct_start = 3000 / total_steps
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start)
# accelerate
model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare(
model, optimizer, train_dataloader, test_dataloader, scheduler
)
# loop
for epoch in range(opt.num_epochs):
# train
model.train()
total_loss = 0
total_psnr = 0
for i, data in enumerate(train_dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs
out = model(data, step_ratio)
loss = out['loss']
psnr = out['psnr']
accelerator.backward(loss)
# gradient clipping
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip)
optimizer.step()
scheduler.step()
total_loss += loss.detach()
total_psnr += psnr.detach()
if accelerator.is_main_process:
# logging
if i % 100 == 0:
mem_free, mem_total = torch.cuda.mem_get_info()
print(f"[INFO] {i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} step_ratio: {step_ratio:.4f} loss: {loss.item():.6f}")
# save log images
if i % 500 == 0:
gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]
kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images)
# gt_alphas = data['masks_output'].detach().cpu().numpy() # [B, V, 1, output_size, output_size]
# gt_alphas = gt_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, gt_alphas.shape[1] * gt_alphas.shape[3], 1)
# kiui.write_image(f'{opt.workspace}/train_gt_alphas_{epoch}_{i}.jpg', gt_alphas)
pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3)
kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images)
# pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size]
# pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1)
# kiui.write_image(f'{opt.workspace}/train_pred_alphas_{epoch}_{i}.jpg', pred_alphas)
total_loss = accelerator.gather_for_metrics(total_loss).mean()
total_psnr = accelerator.gather_for_metrics(total_psnr).mean()
if accelerator.is_main_process:
total_loss /= len(train_dataloader)
total_psnr /= len(train_dataloader)
accelerator.print(f"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}")
# checkpoint
# if epoch % 10 == 0 or epoch == opt.num_epochs - 1:
accelerator.wait_for_everyone()
accelerator.save_model(model, opt.workspace)
# eval
with torch.no_grad():
model.eval()
total_psnr = 0
for i, data in enumerate(test_dataloader):
out = model(data)
psnr = out['psnr']
total_psnr += psnr.detach()
# save some images
if accelerator.is_main_process:
gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]
kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images)
pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3)
kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images)
# pred_alphas = out['alphas_pred'].detach().cpu().numpy() # [B, V, 1, output_size, output_size]
# pred_alphas = pred_alphas.transpose(0, 3, 1, 4, 2).reshape(-1, pred_alphas.shape[1] * pred_alphas.shape[3], 1)
# kiui.write_image(f'{opt.workspace}/eval_pred_alphas_{epoch}_{i}.jpg', pred_alphas)
torch.cuda.empty_cache()
total_psnr = accelerator.gather_for_metrics(total_psnr).mean()
if accelerator.is_main_process:
total_psnr /= len(test_dataloader)
accelerator.print(f"[eval] epoch: {epoch} psnr: {psnr:.4f}")
if __name__ == "__main__":
main()