| from abc import ABC, abstractmethod |
| from dataclasses import dataclass |
| from typing import Generic, Literal, TypeVar |
|
|
| from jaxtyping import Float |
| from torch import Tensor, nn |
|
|
| from ..types import Gaussians |
|
|
| DepthRenderingMode = Literal[ |
| "depth", |
| "log", |
| "disparity", |
| "relative_disparity", |
| ] |
|
|
|
|
| @dataclass |
| class DecoderOutput: |
| color: Float[Tensor, "batch view 3 height width"] |
| depth: Float[Tensor, "batch view height width"] | None |
| alpha: Float[Tensor, "batch view height width"] | None |
| lod_rendering: dict | None |
|
|
| T = TypeVar("T") |
|
|
|
|
| class Decoder(nn.Module, ABC, Generic[T]): |
| cfg: T |
|
|
| def __init__(self, cfg: T) -> None: |
| super().__init__() |
| self.cfg = cfg |
| |
| @abstractmethod |
| def forward( |
| self, |
| gaussians: Gaussians, |
| extrinsics: Float[Tensor, "batch view 4 4"], |
| intrinsics: Float[Tensor, "batch view 3 3"], |
| near: Float[Tensor, "batch view"], |
| far: Float[Tensor, "batch view"], |
| image_shape: tuple[int, int], |
| depth_mode: DepthRenderingMode | None = None, |
| ) -> DecoderOutput: |
| pass |
|
|