|
|
"""Contains the base class for encoders. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
import abc |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
class BaseEncoder(nn.Module, abc.ABC): |
|
|
"""Base encoder class.""" |
|
|
|
|
|
dim_in: int |
|
|
output_dims: list[int] |
|
|
|
|
|
@abc.abstractmethod |
|
|
def forward(self, image: torch.Tensor) -> list[torch.Tensor]: |
|
|
"""Encode input image into multi-resolution encodings.""" |
|
|
|
|
|
def internal_resolution(self) -> int: |
|
|
"""Internal resolution of the encoder.""" |
|
|
return 1536 |
|
|
|