Spaces:
Runtime error
Runtime error
from submodules.mast3r.dust3r.dust3r.losses import * | |
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, JaccardIndex, Accuracy | |
import lpips | |
from src.utils.gaussian_model import GaussianModel | |
from src.utils.cuda_splatting import render, DummyPipeline | |
from einops import rearrange | |
from src.utils.camera_utils import get_scaled_camera | |
from torchvision.utils import save_image | |
from dust3r.inference import make_batch_symmetric | |
class L2Loss (LLoss): | |
""" Euclidean distance between 3d points """ | |
def distance(self, a, b): | |
return torch.norm(a - b, dim=-1) # normalized L2 distance | |
class L1Loss (LLoss): | |
""" Manhattan distance between 3d points """ | |
def distance(self, a, b): | |
return torch.abs(a - b).mean() # L1 distance | |
L2 = L2Loss() | |
L1 = L1Loss() | |
def merge_and_split_predictions(pred1, pred2): | |
merged = {} | |
for key in pred1.keys(): | |
merged_pred = torch.stack([pred1[key], pred2[key]], dim=1) | |
merged_pred = rearrange(merged_pred, 'b v h w ... -> b (v h w) ...') | |
merged[key] = merged_pred | |
# Split along the batch dimension | |
batch_size = next(iter(merged.values())).shape[0] | |
split = [{key: value[i] for key, value in merged.items()} for i in range(batch_size)] | |
return split | |
class GaussianLoss(MultiLoss): | |
def __init__(self, ssim_weight=0.2): | |
super().__init__() | |
self.ssim_weight = ssim_weight | |
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).cuda() | |
self.psnr = PeakSignalNoiseRatio(data_range=1.0).cuda() | |
self.lpips_vgg = lpips.LPIPS(net='vgg').cuda() | |
self.pipeline = DummyPipeline() | |
# bg_color | |
self.register_buffer('bg_color', torch.tensor([0.0, 0.0, 0.0]).cuda()) | |
def get_name(self): | |
return f'GaussianLoss(ssim_weight={self.ssim_weight})' | |
# def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model): | |
# # render images | |
# # 1. merge predictions | |
# pred = merge_and_split_predictions(pred1, pred2) | |
# # 2. calculate optimal scaling | |
# pred_pts1 = pred1['means'] | |
# pred_pts2 = pred2['means'] | |
# # convert to camera1 coordinates | |
# # everything is normalized w.r.t. camera of view1 | |
# valid1 = gt1['valid_mask'].clone() | |
# valid2 = gt2['valid_mask'].clone() | |
# in_camera1 = inv(gt1['camera_pose']) | |
# gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device)) # B,H,W,3 | |
# gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device)) # B,H,W,3 | |
# scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2) | |
# # 3. render images(need gaussian model, camera, pipeline) | |
# rendered_images = [] | |
# rendered_feats = [] | |
# for i in range(len(pred)): | |
# # get gaussian model | |
# gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3) | |
# # get camera | |
# ref_camera_extrinsics = gt1['camera_pose'][i] | |
# target_extrinsics = target_view['camera_pose'][i] | |
# target_intrinsics = target_view['camera_intrinsics'][i] | |
# image_shape = target_view['true_shape'][i] | |
# scale = scaling[i] | |
# camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape) | |
# # render(image and features) | |
# rendered_output = render(camera, gaussians, self.pipeline, self.bg_color) | |
# rendered_images.append(rendered_output['render']) | |
# rendered_feats.append(rendered_output['feature_map']) | |
# rendered_images = torch.stack(rendered_images, dim=0) # B, 3, H, W | |
# rendered_feats = torch.stack(rendered_feats, dim=0) # B, d_feats, H, W | |
# rendered_feats = model.feature_expansion(rendered_feats) # B, 512, H//2, W//2 | |
# gt_images = target_view['img'] * 0.5 + 0.5 | |
# gt_feats = model.lseg_feature_extractor.extract_features(target_view['img']) # B, 512, H//2, W//2 | |
# image_loss = torch.abs(rendered_images - gt_images).mean() | |
# feature_loss = torch.abs(rendered_feats - gt_feats).mean() | |
# loss = image_loss + 100 * feature_loss | |
# # # temp | |
# # gt_logits = model.lseg_feature_extractor.decode_feature(gt_feats, ['wall', 'floor', 'others']) | |
# # gt_labels = torch.argmax(gt_logits, dim=1, keepdim=True) | |
# # rendered_logits = model.lseg_feature_extractor.decode_feature(rendered_feats, ['wall', 'floor', 'others']) | |
# # rendered_labels = torch.argmax(rendered_logits, dim=1, keepdim=True) | |
# # calculate metric | |
# with torch.no_grad(): | |
# ssim = self.ssim(rendered_images, gt_images) | |
# psnr = self.psnr(rendered_images, gt_images) | |
# lpips = self.lpips_vgg(rendered_images, gt_images).mean() | |
# return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss} | |
def compute_loss(self, gt1, gt2, target_view, pred1, pred2, model): | |
# render images | |
# 1. merge predictions | |
pred = merge_and_split_predictions(pred1, pred2) | |
# 2. calculate optimal scaling | |
pred_pts1 = pred1['means'] | |
pred_pts2 = pred2['means'] | |
# convert to camera1 coordinates | |
# everything is normalized w.r.t. camera of view1 | |
valid1 = gt1['valid_mask'].clone() | |
valid2 = gt2['valid_mask'].clone() | |
in_camera1 = inv(gt1['camera_pose']) | |
gt_pts1 = geotrf(in_camera1, gt1['pts3d'].to(in_camera1.device)) # B,H,W,3 | |
gt_pts2 = geotrf(in_camera1, gt2['pts3d'].to(in_camera1.device)) # B,H,W,3 | |
scaling = find_opt_scaling(gt_pts1, gt_pts2, pred_pts1, pred_pts2, valid1=valid1, valid2=valid2) | |
# 3. render images(need gaussian model, camera, pipeline) | |
rendered_images = [] | |
rendered_feats = [] | |
gt_images = [] | |
for i in range(len(pred)): | |
# get gaussian model | |
gaussians = GaussianModel.from_predictions(pred[i], sh_degree=3) | |
# get camera | |
ref_camera_extrinsics = gt1['camera_pose'][i] | |
target_view_list = [gt1, gt2, target_view] # use gt1, gt2, and target_view | |
for j in range(len(target_view_list)): | |
target_extrinsics = target_view_list[j]['camera_pose'][i] | |
target_intrinsics = target_view_list[j]['camera_intrinsics'][i] | |
image_shape = target_view_list[j]['true_shape'][i] | |
scale = scaling[i] | |
camera = get_scaled_camera(ref_camera_extrinsics, target_extrinsics, target_intrinsics, scale, image_shape) | |
# render(image and features) | |
rendered_output = render(camera, gaussians, self.pipeline, self.bg_color) | |
rendered_images.append(rendered_output['render']) | |
rendered_feats.append(rendered_output['feature_map']) | |
gt_images.append(target_view_list[j]['img'][i] * 0.5 + 0.5) | |
rendered_images = torch.stack(rendered_images, dim=0) # B, 3, H, W | |
gt_images = torch.stack(gt_images, dim=0) | |
rendered_feats = torch.stack(rendered_feats, dim=0) # B, d_feats, H, W | |
rendered_feats = model.feature_expansion(rendered_feats) # B, 512, H//2, W//2 | |
gt_feats = model.lseg_feature_extractor.extract_features(gt_images) # B, 512, H//2, W//2 | |
image_loss = torch.abs(rendered_images - gt_images).mean() | |
feature_loss = torch.abs(rendered_feats - gt_feats).mean() | |
loss = image_loss + feature_loss | |
# calculate metric | |
with torch.no_grad(): | |
ssim = self.ssim(rendered_images, gt_images) | |
psnr = self.psnr(rendered_images, gt_images) | |
lpips = self.lpips_vgg(rendered_images, gt_images).mean() | |
return loss, {'ssim': ssim, 'psnr': psnr, 'lpips': lpips, 'image_loss': image_loss, 'feature_loss': feature_loss} | |
# loss for one batch | |
def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None): | |
view1, view2, target_view = batch | |
ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng', 'pts3d']) | |
for view in batch: | |
for name in view.keys(): # pseudo_focal | |
if name in ignore_keys: | |
continue | |
view[name] = view[name].to(device, non_blocking=True) | |
if symmetrize_batch: | |
view1, view2 = make_batch_symmetric(batch) | |
# Get the actual model if it's distributed | |
actual_model = model.module if hasattr(model, 'module') else model | |
with torch.cuda.amp.autocast(enabled=bool(use_amp)): | |
pred1, pred2 = actual_model(view1, view2) | |
# loss is supposed to be symmetric | |
with torch.cuda.amp.autocast(enabled=False): | |
loss = criterion(view1, view2, target_view, pred1, pred2, actual_model) if criterion is not None else None | |
result = dict(view1=view1, view2=view2, target_view=target_view, pred1=pred1, pred2=pred2, loss=loss) | |
return result[ret] if ret else result |