import torch import torch.nn as nn import segmentation_models_pytorch as smp from segmentation_models_pytorch import utils from SemanticModel.encoder_management import initialize_encoder from SemanticModel.custom_losses import FocalLossFunction, TverskyLossFunction, EnhancedCrossEntropy from SemanticModel.image_preprocessing import get_preprocessing_pipeline class SegmentationModel: def __init__(self, classes=['background', 'foreground'], architecture='unet', encoder='timm-regnety_120', weights='imagenet', loss=None): self._initialize_classes(classes) self.architecture = architecture self.encoder = encoder self.weights = weights self._setup_loss_function(loss) self._initialize_model() def _initialize_classes(self, classes): """Sets up class configuration.""" if len(classes) <= 2: self.classes = [c for c in classes if c.lower() != 'background'] self.class_values = [i for i, c in enumerate(classes) if c.lower() != 'background'] self.background_flag = 'background' in classes else: self.classes = classes self.class_values = list(range(len(classes))) self.background_flag = False self.n_classes = len(self.classes) def _setup_loss_function(self, loss): """Configures model's loss function.""" if not loss: loss = 'bce_with_logits' if self.n_classes > 1 else 'dice' if loss.lower() not in ['dice', 'bce_with_logits', 'focal', 'tversky']: print(f'Invalid loss: {loss}, defaulting to dice') loss = 'dice' loss_configs = { 'bce_with_logits': { 'activation': None, 'loss': EnhancedCrossEntropy() if self.n_classes > 1 else utils.losses.BCEWithLogitsLoss() }, 'dice': { 'activation': 'softmax' if self.n_classes > 1 else 'sigmoid', 'loss': utils.losses.DiceLoss() }, 'focal': { 'activation': None, 'loss': FocalLossFunction() }, 'tversky': { 'activation': None, 'loss': TverskyLossFunction() } } config = loss_configs[loss.lower()] self.activation = config['activation'] self.loss = config['loss'] self.loss_name = loss def _initialize_model(self): """Initializes the segmentation model architecture.""" if self.weights.endswith('pth'): self._load_pretrained_model() else: self._create_new_model() def _load_pretrained_model(self): """Loads model from pretrained weights.""" print('Loading pretrained model...') self.model = torch.load(self.weights) if isinstance(self.model, torch.nn.DataParallel): self.model = self.model.module try: preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet') self.preprocessing = get_preprocessing_pipeline(preprocessing_fn) except: print('Failed to configure preprocessing. Setting to None.') self.preprocessing = None def _create_new_model(self): """Creates new model with specified architecture.""" preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet') self.preprocessing = get_preprocessing_pipeline(preprocessing_fn) initialize_encoder(name=self.encoder, weights=self.weights) architectures = { 'unet': smp.Unet, 'unet++': smp.UnetPlusPlus, 'deeplabv3': smp.DeepLabV3, 'deeplabv3+': smp.DeepLabV3Plus, 'fpn': smp.FPN, 'linknet': smp.Linknet, 'manet': smp.MAnet, 'pan': smp.PAN, 'pspnet': smp.PSPNet } if self.architecture not in architectures: raise ValueError(f'Unsupported architecture: {self.architecture}') self.model = architectures[self.architecture]( encoder_name=self.encoder, encoder_weights=self.weights, classes=self.n_classes, activation=self.activation ) @property def config_data(self): """Returns model configuration data.""" return { 'architecture': self.architecture, 'encoder': self.encoder, 'weights': self.weights, 'activation': self.activation, 'loss': self.loss_name, 'classes': ['background'] + self.classes if self.background_flag else self.classes } def list_architectures(): """Returns available architecture options.""" return ['unet', 'unet++', 'deeplabv3', 'deeplabv3+', 'fpn', 'linknet', 'manet', 'pan', 'pspnet']