|
from dataclasses import dataclass, field |
|
from typing import Any, Dict, Optional |
|
|
|
@dataclass |
|
class LossConfiguration: |
|
num_classes: int |
|
|
|
xent_weight: float = 1.0 |
|
dice_weight: float = 1.0 |
|
focal_loss: bool = False |
|
focal_loss_gamma: float = 2.0 |
|
requires_frustrum: bool = True |
|
requires_flood_mask: bool = False |
|
class_weights: Optional[Any] = None |
|
label_smoothing: float = 0.1 |
|
|
|
@dataclass |
|
class BackboneConfigurationBase: |
|
pretrained: bool |
|
frozen: bool |
|
output_dim: bool |
|
|
|
@dataclass |
|
class DINOConfiguration(BackboneConfigurationBase): |
|
pretrained: bool = True |
|
frozen: bool = False |
|
output_dim: int = 128 |
|
|
|
@dataclass |
|
class ResNetConfiguration(BackboneConfigurationBase): |
|
input_dim: int |
|
encoder: str |
|
remove_stride_from_first_conv: bool |
|
num_downsample: Optional[int] |
|
decoder_norm: str |
|
do_average_pooling: bool |
|
checkpointed: bool |
|
|
|
@dataclass |
|
class ImageEncoderConfiguration: |
|
name: str |
|
backbone: Any |
|
|
|
@dataclass |
|
class ModelConfiguration: |
|
segmentation_head: Dict[str, Any] |
|
image_encoder: ImageEncoderConfiguration |
|
|
|
name: str |
|
num_classes: int |
|
latent_dim: int |
|
z_max: int |
|
x_max: int |
|
|
|
pixel_per_meter: int |
|
num_scale_bins: int |
|
|
|
loss: LossConfiguration |
|
|
|
scale_range: list[int] = field(default_factory=lambda: [0, 9]) |
|
z_min: Optional[int] = None |