|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import tyro |
|
import time |
|
import random |
|
|
|
import torch |
|
from core.options import AllConfigs |
|
from core.models import LGM |
|
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs |
|
from safetensors.torch import load_file |
|
|
|
import kiui |
|
from PIL import Image |
|
|
|
import json |
|
import os |
|
import numpy as np |
|
import imageio |
|
|
|
def main(): |
|
opt = tyro.cli(AllConfigs) |
|
|
|
accelerator = Accelerator( |
|
mixed_precision=opt.mixed_precision, |
|
gradient_accumulation_steps=opt.gradient_accumulation_steps, |
|
|
|
) |
|
if accelerator.is_main_process: |
|
print(opt) |
|
|
|
|
|
model = LGM(opt) |
|
|
|
epoch_start = 0 |
|
if os.path.exists(f'{opt.workspace}/model.safetensors') and os.path.exists(f'{opt.workspace}/metadata.json'): |
|
opt.resume = f'{opt.workspace}/model.safetensors' |
|
with open(f'{opt.workspace}/metadata.json', 'r') as f: |
|
dc = json.load(f) |
|
epoch_start = dc['epoch'] + 1 |
|
|
|
|
|
|
|
if opt.resume is not None and opt.resume != 'None': |
|
if opt.resume.endswith('safetensors'): |
|
ckpt = load_file(opt.resume, device='cpu') |
|
else: |
|
ckpt = torch.load(opt.resume, map_location='cpu') |
|
|
|
|
|
|
|
state_dict = model.state_dict() |
|
for k, v in ckpt.items(): |
|
if k in state_dict: |
|
if state_dict[k].shape == v.shape: |
|
state_dict[k].copy_(v) |
|
else: |
|
accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.') |
|
else: |
|
accelerator.print(f'[WARN] unexpected param {k}: {v.shape}') |
|
|
|
|
|
if opt.data_mode == '4d': |
|
from core.provider_objaverse_4d import ObjaverseDataset as Dataset |
|
elif opt.data_mode == '4d_interp': |
|
from core.provider_objaverse_4d_interp import ObjaverseDataset as Dataset |
|
else: |
|
raise NotImplementedError |
|
|
|
train_dataset = Dataset(opt, training=True) |
|
train_dataloader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=True, |
|
num_workers=opt.num_workers, |
|
pin_memory=True, |
|
drop_last=True, |
|
) |
|
|
|
test_dataset = Dataset(opt, training=False) |
|
test_dataloader = torch.utils.data.DataLoader( |
|
test_dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=False, |
|
num_workers=0, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95)) |
|
|
|
|
|
total_steps = opt.num_epochs * len(train_dataloader) |
|
pct_start = 3000 / total_steps |
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start) |
|
|
|
if epoch_start > 0: |
|
optimizer.load_state_dict(torch.load(os.path.join(opt.workspace, 'optimizer.pth'), map_location='cpu')) |
|
scheduler.load_state_dict(torch.load(os.path.join(opt.workspace, 'scheduler.pth'))) |
|
|
|
|
|
model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare( |
|
model, optimizer, train_dataloader, test_dataloader, scheduler |
|
) |
|
|
|
|
|
|
|
|
|
os.makedirs(opt.workspace, exist_ok=True) |
|
end_time = time.time() |
|
for epoch in range(epoch_start, opt.num_epochs): |
|
|
|
model.train() |
|
total_loss = 0 |
|
total_psnr = 0 |
|
for i, data in enumerate(train_dataloader): |
|
with accelerator.accumulate(model): |
|
|
|
optimizer.zero_grad() |
|
|
|
step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs |
|
|
|
out = model(data, step_ratio) |
|
loss = out['loss'] |
|
psnr = out['psnr'] |
|
accelerator.backward(loss) |
|
|
|
|
|
if accelerator.sync_gradients: |
|
accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip) |
|
|
|
optimizer.step() |
|
scheduler.step() |
|
|
|
total_loss += loss.detach() |
|
total_psnr += psnr.detach() |
|
|
|
if accelerator.is_main_process: |
|
|
|
if i % 10 == 0: |
|
mem_free, mem_total = torch.cuda.mem_get_info() |
|
print(f"[INFO] {i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} step_ratio: {step_ratio:.4f} loss: {loss.item():.6f} time: {time.time() - end_time:.6f}") |
|
end_time = time.time() |
|
|
|
|
|
if i % 500 == 0: |
|
if '4d' in opt.data_mode: |
|
B, T, V = opt.batch_size, opt.num_frames, opt.num_views |
|
|
|
gt_images = data['images_output'].reshape(B, T, V, *data['images_output'].shape[2:]).detach() |
|
pred_images = out['images_pred'].reshape(B, T, V, *out['images_pred'].shape[2:]).detach() |
|
|
|
train_gt_images = [] |
|
train_pred_images = [] |
|
for t in range(T): |
|
train_gt_images_V = [] |
|
train_pred_images_V = [] |
|
for v in range(V): |
|
train_gt_images_V.append((gt_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) |
|
train_pred_images_V.append((pred_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) |
|
train_gt_images.append(np.concatenate(train_gt_images_V, axis=2)) |
|
train_pred_images.append(np.concatenate(train_pred_images_V, axis=2)) |
|
train_gt_images = np.concatenate(train_gt_images, axis=0) |
|
train_pred_images = np.concatenate(train_pred_images, axis=0) |
|
imageio.mimwrite(f'{opt.workspace}/train_gt_images_{epoch}_{i}.mp4', train_gt_images, fps=8) |
|
imageio.mimwrite(f'{opt.workspace}/train_pred_images_{epoch}_{i}.mp4', train_pred_images, fps=8) |
|
|
|
|
|
elif '3d' in opt.data_mode: |
|
gt_images = data['images_output'].detach().cpu().numpy() |
|
gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) |
|
kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images) |
|
|
|
pred_images = out['images_pred'].detach().cpu().numpy() |
|
pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) |
|
kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
total_loss = accelerator.gather_for_metrics(total_loss).mean() |
|
total_psnr = accelerator.gather_for_metrics(total_psnr).mean() |
|
if accelerator.is_main_process: |
|
total_loss /= len(train_dataloader) |
|
total_psnr /= len(train_dataloader) |
|
accelerator.print(f"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}") |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
accelerator.save_model(model, opt.workspace) |
|
accelerator.save_model(model, os.path.join(opt.workspace, 'backup')) |
|
if accelerator.is_main_process: |
|
torch.save(optimizer.state_dict(), os.path.join(opt.workspace, 'optimizer.pth')) |
|
torch.save(scheduler.state_dict(), os.path.join(opt.workspace, 'scheduler.pth')) |
|
with open(f'{opt.workspace}/metadata.json', 'w') as f: |
|
json.dump({'epoch': epoch}, f) |
|
|
|
torch.save(optimizer.state_dict(), os.path.join(opt.workspace, 'backup', 'optimizer.pth')) |
|
torch.save(scheduler.state_dict(), os.path.join(opt.workspace, 'backup', 'scheduler.pth')) |
|
with open(f'{opt.workspace}/backup/metadata.json', 'w') as f: |
|
json.dump({'epoch': epoch}, f) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
model.eval() |
|
total_psnr = 0 |
|
for i, data in enumerate(test_dataloader): |
|
|
|
out = model(data) |
|
|
|
psnr = out['psnr'] |
|
total_psnr += psnr.detach() |
|
|
|
|
|
if accelerator.is_main_process: |
|
if '4d' in opt.data_mode: |
|
B, T, V = opt.batch_size, opt.num_frames, opt.num_views |
|
|
|
gt_images = data['images_output'].reshape(-1, T, V, *data['images_output'].shape[2:]).detach() |
|
pred_images = out['images_pred'].reshape(-1, T, V, *out['images_pred'].shape[2:]).detach() |
|
|
|
eval_gt_images = [] |
|
eval_pred_images = [] |
|
for t in range(T): |
|
eval_gt_images_V = [] |
|
eval_pred_images_V = [] |
|
for v in range(V): |
|
eval_gt_images_V.append((gt_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) |
|
eval_pred_images_V.append((pred_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8)) |
|
eval_gt_images.append(np.concatenate(eval_gt_images_V, axis=2)) |
|
eval_pred_images.append(np.concatenate(eval_pred_images_V, axis=2)) |
|
eval_gt_images = np.concatenate(eval_gt_images, axis=0) |
|
eval_pred_images = np.concatenate(eval_pred_images, axis=0) |
|
imageio.mimwrite(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.mp4', eval_gt_images, fps=8) |
|
imageio.mimwrite(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.mp4', eval_pred_images, fps=8) |
|
|
|
elif '3d' in opt.data_mode: |
|
gt_images = data['images_output'].detach().cpu().numpy() |
|
gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) |
|
kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images) |
|
|
|
pred_images = out['images_pred'].detach().cpu().numpy() |
|
pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) |
|
kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images) |
|
else: |
|
raise NotImplementedError |
|
|
|
torch.cuda.empty_cache() |
|
|
|
total_psnr = accelerator.gather_for_metrics(total_psnr).mean() |
|
if accelerator.is_main_process: |
|
total_psnr /= len(test_dataloader) |
|
accelerator.print(f"[eval] epoch: {epoch} psnr: {psnr:.4f}") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|