| | from abc import ABC, abstractmethod |
| | from dataclasses import dataclass |
| | from typing import Generic, Literal, TypeVar |
| |
|
| | from jaxtyping import Float |
| | from torch import Tensor, nn |
| |
|
| | from ..types import Gaussians |
| |
|
| | DepthRenderingMode = Literal[ |
| | "depth", |
| | "log", |
| | "disparity", |
| | "relative_disparity", |
| | ] |
| |
|
| |
|
| | @dataclass |
| | class DecoderOutput: |
| | color: Float[Tensor, "batch view 3 height width"] |
| | depth: Float[Tensor, "batch view height width"] | None |
| | alpha: Float[Tensor, "batch view height width"] | None |
| | lod_rendering: dict | None |
| |
|
| | T = TypeVar("T") |
| |
|
| |
|
| | class Decoder(nn.Module, ABC, Generic[T]): |
| | cfg: T |
| |
|
| | def __init__(self, cfg: T) -> None: |
| | super().__init__() |
| | self.cfg = cfg |
| | |
| | @abstractmethod |
| | def forward( |
| | self, |
| | gaussians: Gaussians, |
| | extrinsics: Float[Tensor, "batch view 4 4"], |
| | intrinsics: Float[Tensor, "batch view 3 3"], |
| | near: Float[Tensor, "batch view"], |
| | far: Float[Tensor, "batch view"], |
| | image_shape: tuple[int, int], |
| | depth_mode: DepthRenderingMode | None = None, |
| | ) -> DecoderOutput: |
| | pass |
| |
|