| from abc import ABC, abstractmethod |
| from typing import Generic, TypeVar |
|
|
| from torch import nn |
| from dataclasses import dataclass |
| from src.dataset.types import BatchedViews, DataShim |
| from ..types import Gaussians |
| from jaxtyping import Float |
| from torch import Tensor, nn |
|
|
| T = TypeVar("T") |
|
|
| @dataclass |
| class EncoderOutput: |
| gaussians: Gaussians |
| pred_pose_enc_list: list[Float[Tensor, "batch view 6"]] | None |
| pred_context_pose: dict | None |
| depth_dict: dict | None |
| infos: dict | None |
| distill_infos: dict | None |
|
|
| class Encoder(nn.Module, ABC, Generic[T]): |
| cfg: T |
|
|
| def __init__(self, cfg: T) -> None: |
| super().__init__() |
| self.cfg = cfg |
|
|
| @abstractmethod |
| def forward( |
| self, |
| context: BatchedViews, |
| ) -> Gaussians: |
| pass |
|
|
| def get_data_shim(self) -> DataShim: |
| """The default shim doesn't modify the batch.""" |
| return lambda x: x |
|
|