|
import torch |
|
import torch.nn as nn |
|
import segmentation_models_pytorch as smp |
|
from segmentation_models_pytorch import utils |
|
|
|
from SemanticModel.encoder_management import initialize_encoder |
|
from SemanticModel.custom_losses import FocalLossFunction, TverskyLossFunction, EnhancedCrossEntropy |
|
from SemanticModel.image_preprocessing import get_preprocessing_pipeline |
|
|
|
class SegmentationModel: |
|
def __init__(self, classes=['background', 'foreground'], architecture='unet', |
|
encoder='timm-regnety_120', weights='imagenet', loss=None): |
|
self._initialize_classes(classes) |
|
self.architecture = architecture |
|
self.encoder = encoder |
|
self.weights = weights |
|
self._setup_loss_function(loss) |
|
self._initialize_model() |
|
|
|
def _initialize_classes(self, classes): |
|
"""Sets up class configuration.""" |
|
if len(classes) <= 2: |
|
self.classes = [c for c in classes if c.lower() != 'background'] |
|
self.class_values = [i for i, c in enumerate(classes) if c.lower() != 'background'] |
|
self.background_flag = 'background' in classes |
|
else: |
|
self.classes = classes |
|
self.class_values = list(range(len(classes))) |
|
self.background_flag = False |
|
self.n_classes = len(self.classes) |
|
|
|
def _setup_loss_function(self, loss): |
|
"""Configures model's loss function.""" |
|
if not loss: |
|
loss = 'bce_with_logits' if self.n_classes > 1 else 'dice' |
|
|
|
if loss.lower() not in ['dice', 'bce_with_logits', 'focal', 'tversky']: |
|
print(f'Invalid loss: {loss}, defaulting to dice') |
|
loss = 'dice' |
|
|
|
loss_configs = { |
|
'bce_with_logits': { |
|
'activation': None, |
|
'loss': EnhancedCrossEntropy() if self.n_classes > 1 else utils.losses.BCEWithLogitsLoss() |
|
}, |
|
'dice': { |
|
'activation': 'softmax' if self.n_classes > 1 else 'sigmoid', |
|
'loss': utils.losses.DiceLoss() |
|
}, |
|
'focal': { |
|
'activation': None, |
|
'loss': FocalLossFunction() |
|
}, |
|
'tversky': { |
|
'activation': None, |
|
'loss': TverskyLossFunction() |
|
} |
|
} |
|
|
|
config = loss_configs[loss.lower()] |
|
self.activation = config['activation'] |
|
self.loss = config['loss'] |
|
self.loss_name = loss |
|
|
|
def _initialize_model(self): |
|
"""Initializes the segmentation model architecture.""" |
|
if self.weights.endswith('pth'): |
|
self._load_pretrained_model() |
|
else: |
|
self._create_new_model() |
|
|
|
def _load_pretrained_model(self): |
|
"""Loads model from pretrained weights.""" |
|
print('Loading pretrained model...') |
|
self.model = torch.load(self.weights) |
|
if isinstance(self.model, torch.nn.DataParallel): |
|
self.model = self.model.module |
|
|
|
try: |
|
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet') |
|
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn) |
|
except: |
|
print('Failed to configure preprocessing. Setting to None.') |
|
self.preprocessing = None |
|
|
|
def _create_new_model(self): |
|
"""Creates new model with specified architecture.""" |
|
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet') |
|
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn) |
|
initialize_encoder(name=self.encoder, weights=self.weights) |
|
|
|
architectures = { |
|
'unet': smp.Unet, |
|
'unet++': smp.UnetPlusPlus, |
|
'deeplabv3': smp.DeepLabV3, |
|
'deeplabv3+': smp.DeepLabV3Plus, |
|
'fpn': smp.FPN, |
|
'linknet': smp.Linknet, |
|
'manet': smp.MAnet, |
|
'pan': smp.PAN, |
|
'pspnet': smp.PSPNet |
|
} |
|
|
|
if self.architecture not in architectures: |
|
raise ValueError(f'Unsupported architecture: {self.architecture}') |
|
|
|
self.model = architectures[self.architecture]( |
|
encoder_name=self.encoder, |
|
encoder_weights=self.weights, |
|
classes=self.n_classes, |
|
activation=self.activation |
|
) |
|
|
|
@property |
|
def config_data(self): |
|
"""Returns model configuration data.""" |
|
return { |
|
'architecture': self.architecture, |
|
'encoder': self.encoder, |
|
'weights': self.weights, |
|
'activation': self.activation, |
|
'loss': self.loss_name, |
|
'classes': ['background'] + self.classes if self.background_flag else self.classes |
|
} |
|
|
|
def list_architectures(): |
|
"""Returns available architecture options.""" |
|
return ['unet', 'unet++', 'deeplabv3', 'deeplabv3+', 'fpn', |
|
'linknet', 'manet', 'pan', 'pspnet'] |