AnySplat / src /model /encoder /visualization /encoder_visualizer.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame
561 Bytes
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from jaxtyping import Float
from torch import Tensor
T_cfg = TypeVar("T_cfg")
T_encoder = TypeVar("T_encoder")
class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]):
cfg: T_cfg
encoder: T_encoder
def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None:
self.cfg = cfg
self.encoder = encoder
@abstractmethod
def visualize(
self,
context: dict,
global_step: int,
) -> dict[str, Float[Tensor, "3 _ _"]]:
pass