DiffusionSfM / train.py
qitaoz's picture
Upload 57 files
4562a06 verified
"""
Configurations can be overwritten by adding: key=value
Use debug.wandb=False to disable logging to wandb.
"""
import datetime
from datetime import timedelta
import os
import random
import socket
import time
from glob import glob
import hydra
import ipdb # noqa: F401
import numpy as np
import omegaconf
import torch
import wandb
from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
from pytorch3d.renderer import PerspectiveCameras
from diffusionsfm.dataset.co3d_v2 import Co3dDataset, unnormalize_image_for_vis
# from diffusionsfm.dataset.multiloader import get_multiloader, MultiDataset
from diffusionsfm.eval.eval_category import evaluate
from diffusionsfm.model.diffuser import RayDiffuser
from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT
from diffusionsfm.model.scheduler import NoiseScheduler
from diffusionsfm.utils.rays import cameras_to_rays, normalize_cameras_batch, compute_ndc_coordinates
from diffusionsfm.utils.visualization import (
create_training_visualizations,
view_color_coded_images_from_tensor,
)
os.umask(000) # Default to 777 permissions
class Trainer(object):
def __init__(self, cfg):
seed = cfg.training.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
self.cfg = cfg
self.debug = cfg.debug
self.resume = cfg.training.resume
self.pretrain_path = cfg.training.pretrain_path
self.batch_size = cfg.training.batch_size
self.max_iterations = cfg.training.max_iterations
self.mixed_precision = cfg.training.mixed_precision
self.interval_visualize = cfg.training.interval_visualize
self.interval_save_checkpoint = cfg.training.interval_save_checkpoint
self.interval_delete_checkpoint = cfg.training.interval_delete_checkpoint
self.interval_evaluate = cfg.training.interval_evaluate
self.delete_all = cfg.training.delete_all_checkpoints_after_training
self.freeze_encoder = cfg.training.freeze_encoder
self.translation_scale = cfg.training.translation_scale
self.regression = cfg.training.regression
self.prob_unconditional = cfg.training.prob_unconditional
self.load_extra_cameras = cfg.training.load_extra_cameras
self.calculate_intrinsics = cfg.training.calculate_intrinsics
self.distort = cfg.training.distort
self.diffuse_origins_and_endpoints = cfg.training.diffuse_origins_and_endpoints
self.diffuse_depths = cfg.training.diffuse_depths
self.depth_resolution = cfg.training.depth_resolution
self.dpt_head = cfg.training.dpt_head
self.full_num_patches_x = cfg.training.full_num_patches_x
self.full_num_patches_y = cfg.training.full_num_patches_y
self.dpt_encoder_features = cfg.training.dpt_encoder_features
self.nearest_neighbor = cfg.training.nearest_neighbor
self.no_bg_targets = cfg.training.no_bg_targets
self.unit_normalize_scene = cfg.training.unit_normalize_scene
self.sd_scale = cfg.training.sd_scale
self.bfloat = cfg.training.bfloat
self.first_cam_mediod = cfg.training.first_cam_mediod
self.normalize_first_camera = cfg.training.normalize_first_camera
self.gradient_clipping = cfg.training.gradient_clipping
self.l1_loss = cfg.training.l1_loss
self.reinit = cfg.training.reinit
if self.first_cam_mediod:
assert self.normalize_first_camera
self.pred_x0 = cfg.model.pred_x0
self.num_patches_x = cfg.model.num_patches_x
self.num_patches_y = cfg.model.num_patches_y
self.depth = cfg.model.depth
self.num_images = cfg.model.num_images
self.num_visualize = min(self.batch_size, 2)
self.random_num_images = cfg.model.random_num_images
self.feature_extractor = cfg.model.feature_extractor
self.append_ndc = cfg.model.append_ndc
self.use_homogeneous = cfg.model.use_homogeneous
self.freeze_transformer = cfg.model.freeze_transformer
self.cond_depth_mask = cfg.model.cond_depth_mask
self.dataset_name = cfg.dataset.name
self.shape = cfg.dataset.shape
self.apply_augmentation = cfg.dataset.apply_augmentation
self.mask_holes = cfg.dataset.mask_holes
self.image_size = cfg.dataset.image_size
if not self.regression and (self.diffuse_origins_and_endpoints or self.diffuse_depths):
assert self.mask_holes or self.cond_depth_mask
if self.regression:
assert self.pred_x0
self.start_time = None
self.iteration = 0
self.epoch = 0
self.wandb_id = None
self.hostname = socket.gethostname()
if self.dpt_head:
find_unused_parameters = True
else:
find_unused_parameters = False
ddp_scaler = DistributedDataParallelKwargs(
find_unused_parameters=find_unused_parameters
)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400))
self.accelerator = Accelerator(
even_batches=False,
device_placement=False,
kwargs_handlers=[ddp_scaler, init_kwargs],
)
self.device = self.accelerator.device
scheduler = NoiseScheduler(
type=cfg.noise_scheduler.type,
max_timesteps=cfg.noise_scheduler.max_timesteps,
beta_start=cfg.noise_scheduler.beta_start,
beta_end=cfg.noise_scheduler.beta_end,
)
if self.dpt_head:
self.model = RayDiffuserDPT(
depth=self.depth,
width=self.num_patches_x,
P=1,
max_num_images=self.num_images,
noise_scheduler=scheduler,
freeze_encoder=self.freeze_encoder,
feature_extractor=self.feature_extractor,
append_ndc=self.append_ndc,
use_unconditional=self.prob_unconditional > 0,
diffuse_depths=self.diffuse_depths,
depth_resolution=self.depth_resolution,
encoder_features=self.dpt_encoder_features,
use_homogeneous=self.use_homogeneous,
freeze_transformer=self.freeze_transformer,
cond_depth_mask=self.cond_depth_mask,
).to(self.device)
else:
self.model = RayDiffuser(
depth=self.depth,
width=self.num_patches_x,
P=1,
max_num_images=self.num_images,
noise_scheduler=scheduler,
freeze_encoder=self.freeze_encoder,
feature_extractor=self.feature_extractor,
append_ndc=self.append_ndc,
use_unconditional=self.prob_unconditional > 0,
diffuse_depths=self.diffuse_depths,
depth_resolution=self.depth_resolution,
use_homogeneous=self.use_homogeneous,
cond_depth_mask=self.cond_depth_mask,
).to(self.device)
if self.dpt_head:
depth_size = self.full_num_patches_x
elif self.depth_resolution > 1:
depth_size = self.num_patches_x * self.depth_resolution
else:
depth_size = self.num_patches_x
self.depth_size = depth_size
if self.dataset_name == "multi":
self.dataset, self.train_dataloader, self.test_dataset = get_multiloader(
num_images=self.num_images,
apply_augmentation=self.apply_augmentation,
load_extra_cameras=self.load_extra_cameras,
distort_image=self.distort,
center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths,
crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths),
load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths,
depth_size=depth_size,
mask_holes=self.mask_holes,
img_size=self.image_size,
batch_size=self.batch_size,
num_workers=cfg.training.num_workers,
dust3r_pairs=True,
)
elif self.dataset_name == "co3d":
self.dataset = Co3dDataset(
category=self.shape,
split="train",
num_images=self.num_images,
apply_augmentation=self.apply_augmentation,
load_extra_cameras=self.load_extra_cameras,
distort_image=self.distort,
center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths,
crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths),
load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths,
depth_size=depth_size,
mask_holes=self.mask_holes,
img_size=self.image_size,
)
self.train_dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=cfg.training.num_workers,
pin_memory=True,
drop_last=True,
)
self.test_dataset = Co3dDataset(
category=self.shape,
split="test",
num_images=self.num_images,
apply_augmentation=False,
load_extra_cameras=self.load_extra_cameras,
distort_image=self.distort,
center_crop=self.diffuse_origins_and_endpoints or self.diffuse_depths,
crop_images=not (self.diffuse_origins_and_endpoints or self.diffuse_depths),
load_depths=self.diffuse_origins_and_endpoints or self.diffuse_depths,
depth_size=depth_size,
mask_holes=self.mask_holes,
img_size=self.image_size,
)
else:
raise NotImplementedError(f"Dataset '{self.dataset_name}' is not supported.")
self.lr = 1e-4
self.output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
self.checkpoint_dir = os.path.join(self.output_dir, "checkpoints")
if self.accelerator.is_main_process:
name = os.path.basename(self.output_dir)
name += f"_{self.debug.run_name}"
print("Output dir:", self.output_dir)
with open(os.path.join(self.output_dir, name), "w"):
# Create empty tag with name
pass
self.name = name
conf_dict = omegaconf.OmegaConf.to_container(
cfg, resolve=True, throw_on_missing=True
)
conf_dict["output_dir"] = self.output_dir
conf_dict["hostname"] = self.hostname
if self.dpt_head:
self.init_optimizer_with_separate_lrs()
else:
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.gradscaler = torch.cuda.amp.GradScaler(growth_interval=100000, enabled=self.mixed_precision)
self.model, self.optimizer, self.train_dataloader = self.accelerator.prepare(
self.model, self.optimizer, self.train_dataloader
)
if self.resume:
checkpoint_files = sorted(glob(os.path.join(self.checkpoint_dir, "*.pth")))
last_checkpoint = checkpoint_files[-1]
print("Resuming from checkpoint:", last_checkpoint)
self.load_model(last_checkpoint, load_metadata=True)
elif self.pretrain_path != "":
print("Loading pretrained model:", self.pretrain_path)
self.load_model(self.pretrain_path, load_metadata=False)
if self.accelerator.is_main_process:
mode = "online" if cfg.debug.wandb else "disabled"
if self.wandb_id is None:
self.wandb_id = wandb.util.generate_id()
self.wandb_run = wandb.init(
mode=mode,
name=name,
project=cfg.debug.project_name,
config=conf_dict,
resume=self.resume,
id=self.wandb_id,
)
wandb.define_metric("iteration")
noise_schedule = self.get_module().noise_scheduler.plot_schedule(
return_image=True
)
wandb.log(
{"Schedule": wandb.Image(noise_schedule, caption="Noise Schedule")}
)
def get_module(self):
if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
model = self.model.module
else:
model = self.model
return model
def init_optimizer_with_separate_lrs(self):
print("Use different LRs for the DINOv2 encoder and DiT!")
feature_extractor_params = [
p for n, p in self.model.feature_extractor.named_parameters()
]
feature_extractor_param_names = [
"feature_extractor." + n for n, _ in self.model.feature_extractor.named_parameters()
]
ray_predictor_params = [
p for n, p in self.model.ray_predictor.named_parameters()
]
ray_predictor_param_names = [
"ray_predictor." + n for n, p in self.model.ray_predictor.named_parameters()
]
other_params = [
p for n, p in self.model.named_parameters()
if n not in feature_extractor_param_names + ray_predictor_param_names
]
self.optimizer = torch.optim.Adam([
{'params': feature_extractor_params, 'lr': self.lr * 0.1}, # Lower LR for feature extractor
{'params': ray_predictor_params, 'lr': self.lr * 0.1}, # Lower LR for DIT (ray_predictor)
{'params': other_params, 'lr': self.lr} # Normal LR for other parts of the model
])
def train(self):
while self.iteration < self.max_iterations:
for batch in self.train_dataloader:
t0 = time.time()
self.optimizer.zero_grad()
float_type = torch.bfloat16 if self.bfloat else torch.float16
with torch.cuda.amp.autocast(
enabled=self.mixed_precision, dtype=float_type
):
images = batch["image"].to(self.device)
focal_lengths = batch["focal_length"].to(self.device)
crop_params = batch["crop_parameters"].to(self.device)
principal_points = batch["principal_point"].to(self.device)
R = batch["R"].to(self.device)
T = batch["T"].to(self.device)
if "distortion_coefficients" in batch:
distortion_coefficients = batch["distortion_coefficients"]
else:
distortion_coefficients = [None for _ in range(R.shape[0])]
depths = batch["depth"].to(self.device)
if self.no_bg_targets:
masks = batch["depth_masks"].to(self.device).bool()
cameras_og = [
PerspectiveCameras(
focal_length=focal_lengths[b],
principal_point=principal_points[b],
R=R[b],
T=T[b],
device=self.device,
)
for b in range(self.batch_size)
]
cameras, _ = normalize_cameras_batch(
cameras=cameras_og,
scale=self.translation_scale,
normalize_first_camera=self.normalize_first_camera,
depths=(
None
if not (self.diffuse_origins_and_endpoints or self.diffuse_depths)
else depths
),
first_cam_mediod=self.first_cam_mediod,
crop_parameters=crop_params,
num_patches_x=self.depth_size,
num_patches_y=self.depth_size,
distortion_coeffs=distortion_coefficients,
)
# Now that cameras are normalized, fix shapes of camera parameters
if self.load_extra_cameras or self.random_num_images:
if self.random_num_images:
num_images = torch.randint(2, self.num_images + 1, (1,))
else:
num_images = self.num_images
# The correct number of images is already loaded.
# Only need to modify these camera parameters shapes.
focal_lengths = focal_lengths[:, :num_images]
crop_params = crop_params[:, :num_images]
R = R[:, :num_images]
T = T[:, :num_images]
images = images[:, :num_images]
depths = depths[:, :num_images]
masks = masks[:, :num_images]
cameras = [
PerspectiveCameras(
focal_length=cameras[b].focal_length[:num_images],
principal_point=cameras[b].principal_point[:num_images],
R=cameras[b].R[:num_images],
T=cameras[b].T[:num_images],
device=self.device,
)
for b in range(self.batch_size)
]
if self.regression:
low = self.get_module().noise_scheduler.max_timesteps - 1
else:
low = 0
t = torch.randint(
low=low,
high=self.get_module().noise_scheduler.max_timesteps,
size=(self.batch_size,),
device=self.device,
)
if self.prob_unconditional > 0:
unconditional_mask = (
(torch.rand(self.batch_size) < self.prob_unconditional)
.float()
.to(self.device)
)
else:
unconditional_mask = None
if self.distort:
raise NotImplementedError()
else:
gt_rays = []
rays_dirs = []
rays = []
for i, (camera, crop_param, depth) in enumerate(
zip(cameras, crop_params, depths)
):
if self.diffuse_origins_and_endpoints:
mode = "segment"
else:
mode = "plucker"
r = cameras_to_rays(
cameras=camera,
num_patches_x=self.full_num_patches_x,
num_patches_y=self.full_num_patches_y,
crop_parameters=crop_param,
depths=depth,
mode=mode,
depth_resolution=self.depth_resolution,
nearest_neighbor=self.nearest_neighbor,
distortion_coefficients=distortion_coefficients[i],
)
rays_dirs.append(r.get_directions())
gt_rays.append(r)
if self.diffuse_origins_and_endpoints:
assert r.mode == "segment"
elif self.diffuse_depths:
assert r.mode == "plucker"
if self.unit_normalize_scene:
if self.diffuse_origins_and_endpoints:
assert r.mode == "segment"
# Let's say SD should be 0.5
scale = r.get_segments().std() * self.sd_scale
if scale.isnan().any():
assert False
camera.T /= scale
r.rays /= scale
depths[i] /= scale
else:
assert r.mode == "plucker"
scale = r.depths.std() * self.sd_scale
if scale.isnan().any():
assert False
camera.T /= scale
r.depths /= scale
depths[i] /= scale
rays.append(
r.to_spatial(
include_ndc_coordinates=self.append_ndc,
include_depths=self.diffuse_depths,
use_homogeneous=self.use_homogeneous,
)
)
rays_tensor = torch.stack(rays, dim=0)
if self.append_ndc:
ndc_coordinates = rays_tensor[..., -2:, :, :]
rays_tensor = rays_tensor[..., :-2, :, :]
if self.dpt_head:
xy_grid = compute_ndc_coordinates(
crop_params,
num_patches_x=self.depth_size // 16,
num_patches_y=self.depth_size // 16,
distortion_coeffs=distortion_coefficients,
)[..., :2]
ndc_coordinates = xy_grid.permute(0, 1, 4, 2, 3).contiguous()
else:
ndc_coordinates = None
if self.cond_depth_mask:
condition_mask = masks
else:
condition_mask = None
if rays_tensor.isnan().any():
import pickle
with open("bad.json", "wb") as f:
pickle.dump(batch, f)
ipdb.set_trace()
eps_pred, eps = self.model(
images=images,
rays=rays_tensor,
t=t,
ndc_coordinates=ndc_coordinates,
unconditional_mask=unconditional_mask,
depth_mask=condition_mask,
)
if self.pred_x0:
target = rays_tensor
else:
target = eps
if self.no_bg_targets:
C = eps_pred.shape[2]
loss_masks = masks.unsqueeze(2).repeat(1, 1, C, 1, 1)
eps_pred = loss_masks * eps_pred
target = loss_masks * target
loss = 0
if self.l1_loss:
loss_reconstruction = torch.mean(torch.abs(eps_pred - target))
else:
loss_reconstruction = torch.mean((eps_pred - target) ** 2)
loss += loss_reconstruction
if self.mixed_precision:
self.gradscaler.scale(loss).backward()
scaled_norm = 0
for p in self.model.parameters():
if p.requires_grad and p.grad is not None:
param_norm = p.grad.data.norm(2)
scaled_norm += param_norm.item() ** 2
scaled_norm = scaled_norm ** 0.5
if self.gradient_clipping and self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(
self.get_module().parameters(), 1
)
clipped_norm = 0
for p in self.model.parameters():
if p.requires_grad and p.grad is not None:
param_norm = p.grad.data.norm(2)
clipped_norm += param_norm.item() ** 2
clipped_norm = clipped_norm ** 0.5
self.gradscaler.unscale_(self.optimizer)
unscaled_norm = 0
for p in self.model.parameters():
if p.requires_grad and p.grad is not None:
param_norm = p.grad.data.norm(2)
unscaled_norm += param_norm.item() ** 2
unscaled_norm = unscaled_norm ** 0.5
self.gradscaler.step(self.optimizer)
self.gradscaler.update()
else:
self.accelerator.backward(loss)
if self.gradient_clipping and self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(
self.get_module().parameters(), 10
)
self.optimizer.step()
if self.accelerator.is_main_process:
if self.iteration % 10 == 0:
self.log_info(
loss_reconstruction,
t0,
self.lr,
scaled_norm,
unscaled_norm,
clipped_norm,
)
if self.iteration % self.interval_visualize == 0:
self.visualize(
images=unnormalize_image_for_vis(images.clone()),
cameras_gt=cameras,
depths=depths,
crop_parameters=crop_params,
distortion_coefficients=distortion_coefficients,
depth_mask=masks,
)
if self.iteration % self.interval_save_checkpoint == 0 and self.iteration != 0:
self.save_model()
if self.iteration % self.interval_delete_checkpoint == 0:
self.clear_old_checkpoints(self.checkpoint_dir)
if (
self.iteration % self.interval_evaluate == 0
and self.iteration > 0
):
self.evaluate_train_acc()
if self.iteration >= self.max_iterations + 1:
if self.delete_all:
self.clear_old_checkpoints(
self.checkpoint_dir, clear_all_old=True
)
return
self.iteration += 1
if self.reinit and self.iteration >= 50000:
state_dict = self.get_module().state_dict()
self.model = RayDiffuserDPT(
depth=self.depth,
width=self.num_patches_x,
P=1,
max_num_images=self.num_images,
noise_scheduler=self.get_module().noise_scheduler,
freeze_encoder=False,
feature_extractor=self.feature_extractor,
append_ndc=self.append_ndc,
use_unconditional=self.prob_unconditional > 0,
diffuse_depths=self.diffuse_depths,
depth_resolution=self.depth_resolution,
encoder_features=self.dpt_encoder_features,
use_homogeneous=self.use_homogeneous,
freeze_transformer=False,
cond_depth_mask=self.cond_depth_mask,
).to(self.device)
self.init_optimizer_with_separate_lrs()
self.gradscaler = torch.cuda.amp.GradScaler(growth_interval=100000, enabled=self.mixed_precision)
self.model, self.optimizer = self.accelerator.prepare(
self.model, self.optimizer
)
msg = self.get_module().load_state_dict(
state_dict,
strict=True,
)
print(msg)
self.reinit = False
self.epoch += 1
def load_model(self, path, load_metadata=True):
save_dict = torch.load(path, map_location=self.device)
del save_dict["state_dict"]["ray_predictor.x_pos_enc.image_pos_table"]
if not self.resume:
if len(save_dict["state_dict"]["scratch.input_conv.weight"].shape) == 2 and self.dpt_head:
print("Initialize conv layer weights from the linear layer!")
C = save_dict["state_dict"]["scratch.input_conv.weight"].shape[1]
input_conv_weight = save_dict["state_dict"]["scratch.input_conv.weight"].view(384, C, 1, 1).repeat(1, 1, 16, 16) / 256.
input_conv_bias = save_dict["state_dict"]["scratch.input_conv.bias"]
self.get_module().scratch.input_conv.weight.data = input_conv_weight
self.get_module().scratch.input_conv.bias.data = input_conv_bias
del save_dict["state_dict"]["scratch.input_conv.weight"]
del save_dict["state_dict"]["scratch.input_conv.bias"]
missing, unexpected = self.get_module().load_state_dict(
save_dict["state_dict"],
strict=False,
)
print(f"Missing keys: {missing}")
print(f"Unexpected keys: {unexpected}")
if load_metadata:
self.iteration = save_dict["iteration"]
self.epoch = save_dict["epoch"]
time_elapsed = save_dict["elapsed"]
self.start_time = time.time() - time_elapsed
if "wandb_id" in save_dict:
self.wandb_id = save_dict["wandb_id"]
self.optimizer.load_state_dict(save_dict["optimizer"])
self.gradscaler.load_state_dict(save_dict["gradscaler"])
def save_model(self):
path = os.path.join(self.checkpoint_dir, f"ckpt_{self.iteration:08d}.pth")
os.makedirs(os.path.dirname(path), exist_ok=True)
elapsed = time.time() - self.start_time if self.start_time is not None else 0
save_dict = {
"epoch": self.epoch,
"elapsed": elapsed,
"gradscaler": self.gradscaler.state_dict(),
"iteration": self.iteration,
"state_dict": self.get_module().state_dict(),
"optimizer": self.optimizer.state_dict(),
"wandb_id": self.wandb_id,
}
torch.save(save_dict, path)
def clear_old_checkpoints(self, checkpoint_dir, clear_all_old=False):
print("Clearing old checkpoints")
checkpoint_files = sorted(glob(os.path.join(checkpoint_dir, "ckpt_*.pth")))
if clear_all_old:
for checkpoint_file in checkpoint_files[:-1]:
os.remove(checkpoint_file)
else:
for checkpoint_file in checkpoint_files:
checkpoint = os.path.basename(checkpoint_file)
checkpoint_iteration = int("".join(filter(str.isdigit, checkpoint)))
if checkpoint_iteration % self.interval_delete_checkpoint != 0:
os.remove(checkpoint_file)
def log_info(
self,
loss,
t0,
lr,
scaled_norm,
unscaled_norm,
clipped_norm,
):
if self.start_time is None:
self.start_time = time.time()
time_elapsed = round(time.time() - self.start_time)
time_remaining = round(
(time.time() - self.start_time)
/ (self.iteration + 1)
* (self.max_iterations - self.iteration)
)
disp = [
f"Iter: {self.iteration}/{self.max_iterations}",
f"Epoch: {self.epoch}",
f"Loss: {loss.item():.4f}",
f"LR: {lr:.7f}",
f"Grad Norm: {scaled_norm:.4f}/{unscaled_norm:.4f}/{clipped_norm:.4f}",
f"Elap: {str(datetime.timedelta(seconds=time_elapsed))}",
f"Rem: {str(datetime.timedelta(seconds=time_remaining))}",
self.hostname,
self.name,
]
print(", ".join(disp), flush=True)
wandb_log = {
"loss": loss.item(),
"iter_time": time.time() - t0,
"lr": lr,
"iteration": self.iteration,
"hours_remaining": time_remaining / 3600,
"gradient norm": scaled_norm,
"unscaled norm": unscaled_norm,
"clipped norm": clipped_norm,
}
wandb.log(wandb_log)
def visualize(
self,
images,
cameras_gt,
crop_parameters=None,
depths=None,
distortion_coefficients=None,
depth_mask=None,
high_loss=False,
):
self.get_module().eval()
for camera in cameras_gt:
# AMP may not cast back to float
camera.R = camera.R.float()
camera.T = camera.T.float()
loss_tag = "" if not high_loss else " HIGH LOSS"
for i in range(self.num_visualize):
imgs = view_color_coded_images_from_tensor(images[i].cpu(), depth=False)
im = wandb.Image(imgs, caption=f"iteration {self.iteration} example {i}")
wandb.log({f"Vis images {i}{loss_tag}": im})
if self.cond_depth_mask:
imgs = view_color_coded_images_from_tensor(
depth_mask[i].cpu(), depth=True
)
im = wandb.Image(
imgs, caption=f"iteration {self.iteration} example {i}"
)
wandb.log({f"Vis masks {i}{loss_tag}": im})
vis_depths, _, _ = create_training_visualizations(
model=self.get_module(),
images=images[: self.num_visualize],
device=self.device,
cameras_gt=cameras_gt,
pred_x0=self.pred_x0,
num_images=images.shape[1],
crop_parameters=crop_parameters[: self.num_visualize],
visualize_pred=self.regression,
return_first=self.regression,
calculate_intrinsics=self.calculate_intrinsics,
mode="segment" if self.diffuse_origins_and_endpoints else "plucker",
depths=depths[: self.num_visualize],
diffuse_depths=self.diffuse_depths,
full_num_patches_x=self.full_num_patches_x,
full_num_patches_y=self.full_num_patches_y,
use_homogeneous=self.use_homogeneous,
distortion_coefficients=distortion_coefficients,
)
for i, vis_image in enumerate(vis_depths):
im = wandb.Image(
vis_image, caption=f"iteration {self.iteration} example {i}"
)
for i, vis_image in enumerate(vis_depths):
im = wandb.Image(
vis_image, caption=f"iteration {self.iteration} example {i}"
)
wandb.log({f"Vis origins and endpoints {i}{loss_tag}": im})
self.get_module().train()
def evaluate_train_acc(self, num_evaluate=10):
print("Evaluating train accuracy")
model = self.get_module()
model.eval()
additional_timesteps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
num_images = self.num_images
for split in ["train", "test"]:
if split == "train":
if self.dataset_name != "co3d":
to_evaluate = self.dataset.datasets
names = self.dataset.names
else:
to_evaluate = [self.dataset]
names = ["co3d"]
elif split == "test":
if self.dataset_name != "co3d":
to_evaluate = self.test_dataset.datasets
names = self.test_dataset.names
else:
to_evaluate = [self.test_dataset]
names = ["co3d"]
for name, dataset in zip(names, to_evaluate):
results = evaluate(
cfg=self.cfg,
model=model,
dataset=dataset,
num_images=num_images,
device=self.device,
additional_timesteps=additional_timesteps,
num_evaluate=num_evaluate,
use_pbar=True,
mode="segment" if self.diffuse_origins_and_endpoints else "plucker",
metrics=False,
)
R_err = []
CC_err = []
for key in results.keys():
R_err.append([v["R_error"] for v in results[key]])
CC_err.append([v["CC_error"] for v in results[key]])
R_err = np.array(R_err)
CC_err = np.array(CC_err)
R_acc_15 = np.mean(R_err < 15, (0, 2)).max()
CC_acc = np.mean(CC_err < 0.1, (0, 2)).max()
wandb.log(
{
f"R_acc_15_{name}_{split}": R_acc_15,
"iteration": self.iteration,
}
)
wandb.log(
{
f"CC_acc_0.1_{name}_{split}": CC_acc,
"iteration": self.iteration,
}
)
model.train()
@hydra.main(config_path="./conf", config_name="config", version_base="1.3")
def main(cfg):
print(cfg)
torch.autograd.set_detect_anomaly(cfg.debug.anomaly_detection)
torch.set_float32_matmul_precision(cfg.training.matmul_precision)
trainer = Trainer(cfg=cfg)
trainer.train()
if __name__ == "__main__":
main()