Mapper / mapper /models /schema.py
Cherie Ho
Initial upload
fd01725
raw
history blame
1.37 kB
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@dataclass
class LossConfiguration:
num_classes: int
xent_weight: float = 1.0
dice_weight: float = 1.0
focal_loss: bool = False
focal_loss_gamma: float = 2.0
requires_frustrum: bool = True
requires_flood_mask: bool = False
class_weights: Optional[Any] = None
label_smoothing: float = 0.1
@dataclass
class BackboneConfigurationBase:
pretrained: bool
frozen: bool
output_dim: bool
@dataclass
class DINOConfiguration(BackboneConfigurationBase):
pretrained: bool = True
frozen: bool = False
output_dim: int = 128
@dataclass
class ResNetConfiguration(BackboneConfigurationBase):
input_dim: int
encoder: str
remove_stride_from_first_conv: bool
num_downsample: Optional[int]
decoder_norm: str
do_average_pooling: bool
checkpointed: bool
@dataclass
class ImageEncoderConfiguration:
name: str
backbone: Any
@dataclass
class ModelConfiguration:
segmentation_head: Dict[str, Any]
image_encoder: ImageEncoderConfiguration
name: str
num_classes: int
latent_dim: int
z_max: int
x_max: int
pixel_per_meter: int
num_scale_bins: int
loss: LossConfiguration
scale_range: list[int] = field(default_factory=lambda: [0, 9])
z_min: Optional[int] = None