from dataclasses import dataclass, field from typing import Any, Dict, Optional @dataclass class LossConfiguration: num_classes: int xent_weight: float = 1.0 dice_weight: float = 1.0 focal_loss: bool = False focal_loss_gamma: float = 2.0 requires_frustrum: bool = True requires_flood_mask: bool = False class_weights: Optional[Any] = None label_smoothing: float = 0.1 @dataclass class BackboneConfigurationBase: pretrained: bool frozen: bool output_dim: bool @dataclass class DINOConfiguration(BackboneConfigurationBase): pretrained: bool = True frozen: bool = False output_dim: int = 128 @dataclass class ResNetConfiguration(BackboneConfigurationBase): input_dim: int encoder: str remove_stride_from_first_conv: bool num_downsample: Optional[int] decoder_norm: str do_average_pooling: bool checkpointed: bool @dataclass class ImageEncoderConfiguration: name: str backbone: Any @dataclass class ModelConfiguration: segmentation_head: Dict[str, Any] image_encoder: ImageEncoderConfiguration name: str num_classes: int latent_dim: int z_max: int x_max: int pixel_per_meter: int num_scale_bins: int loss: LossConfiguration scale_range: list[int] = field(default_factory=lambda: [0, 9]) z_min: Optional[int] = None