| from typing import Any |
| import torch.nn as nn |
|
|
| from .backbone import Backbone |
| from .backbone_croco_multiview import AsymmetricCroCoMulti |
| from .backbone_dino import BackboneDino, BackboneDinoCfg |
| from .backbone_resnet import BackboneResnet, BackboneResnetCfg |
| from .backbone_croco import AsymmetricCroCo, BackboneCrocoCfg |
|
|
| BACKBONES: dict[str, Backbone[Any]] = { |
| "resnet": BackboneResnet, |
| "dino": BackboneDino, |
| "croco": AsymmetricCroCo, |
| "croco_multi": AsymmetricCroCoMulti, |
| } |
|
|
| BackboneCfg = BackboneResnetCfg | BackboneDinoCfg | BackboneCrocoCfg |
|
|
|
|
| def get_backbone(cfg: BackboneCfg, d_in: int = 3) -> nn.Module: |
| return BACKBONES[cfg.name](cfg, d_in) |
|
|