|
|
|
|
|
|
|
|
|
|
|
|
|
from enum import Enum |
|
|
|
from PIL.ImageOps import scale |
|
from matplotlib.scale import scale_factory |
|
from wandb.wandb_torch import torch |
|
|
|
from .optimizer import PointCloudOptimizer |
|
from .modular_optimizer import ModularPointCloudOptimizer |
|
from .pair_viewer import PairViewer |
|
from ..viz import pts3d_to_trimesh |
|
|
|
|
|
class GlobalAlignerMode(Enum): |
|
PointCloudOptimizer = "PointCloudOptimizer" |
|
ModularPointCloudOptimizer = "ModularPointCloudOptimizer" |
|
PairViewer = "PairViewer" |
|
|
|
import torch.nn.functional as F |
|
|
|
def global_aligner(dust3r_output, if_use_mono, mono_depths, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw): |
|
|
|
view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()] |
|
|
|
|
|
if mode == GlobalAlignerMode.PointCloudOptimizer: |
|
net = PointCloudOptimizer(view1, view2, pred1, pred2, if_use_mono, mono_depths, **optim_kw).to(device) |
|
elif mode == GlobalAlignerMode.ModularPointCloudOptimizer: |
|
net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) |
|
elif mode == GlobalAlignerMode.PairViewer: |
|
net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device) |
|
else: |
|
raise NotImplementedError(f'Unknown mode {mode}') |
|
|
|
return net |
|
|