obichimav's picture
Upload 42 files
8e5d8c7 verified
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils
from SemanticModel.encoder_management import initialize_encoder
from SemanticModel.custom_losses import FocalLossFunction, TverskyLossFunction, EnhancedCrossEntropy
from SemanticModel.image_preprocessing import get_preprocessing_pipeline
class SegmentationModel:
def __init__(self, classes=['background', 'foreground'], architecture='unet',
encoder='timm-regnety_120', weights='imagenet', loss=None):
self._initialize_classes(classes)
self.architecture = architecture
self.encoder = encoder
self.weights = weights
self._setup_loss_function(loss)
self._initialize_model()
def _initialize_classes(self, classes):
"""Sets up class configuration."""
if len(classes) <= 2:
self.classes = [c for c in classes if c.lower() != 'background']
self.class_values = [i for i, c in enumerate(classes) if c.lower() != 'background']
self.background_flag = 'background' in classes
else:
self.classes = classes
self.class_values = list(range(len(classes)))
self.background_flag = False
self.n_classes = len(self.classes)
def _setup_loss_function(self, loss):
"""Configures model's loss function."""
if not loss:
loss = 'bce_with_logits' if self.n_classes > 1 else 'dice'
if loss.lower() not in ['dice', 'bce_with_logits', 'focal', 'tversky']:
print(f'Invalid loss: {loss}, defaulting to dice')
loss = 'dice'
loss_configs = {
'bce_with_logits': {
'activation': None,
'loss': EnhancedCrossEntropy() if self.n_classes > 1 else utils.losses.BCEWithLogitsLoss()
},
'dice': {
'activation': 'softmax' if self.n_classes > 1 else 'sigmoid',
'loss': utils.losses.DiceLoss()
},
'focal': {
'activation': None,
'loss': FocalLossFunction()
},
'tversky': {
'activation': None,
'loss': TverskyLossFunction()
}
}
config = loss_configs[loss.lower()]
self.activation = config['activation']
self.loss = config['loss']
self.loss_name = loss
def _initialize_model(self):
"""Initializes the segmentation model architecture."""
if self.weights.endswith('pth'):
self._load_pretrained_model()
else:
self._create_new_model()
def _load_pretrained_model(self):
"""Loads model from pretrained weights."""
print('Loading pretrained model...')
self.model = torch.load(self.weights)
if isinstance(self.model, torch.nn.DataParallel):
self.model = self.model.module
try:
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
except:
print('Failed to configure preprocessing. Setting to None.')
self.preprocessing = None
def _create_new_model(self):
"""Creates new model with specified architecture."""
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
initialize_encoder(name=self.encoder, weights=self.weights)
architectures = {
'unet': smp.Unet,
'unet++': smp.UnetPlusPlus,
'deeplabv3': smp.DeepLabV3,
'deeplabv3+': smp.DeepLabV3Plus,
'fpn': smp.FPN,
'linknet': smp.Linknet,
'manet': smp.MAnet,
'pan': smp.PAN,
'pspnet': smp.PSPNet
}
if self.architecture not in architectures:
raise ValueError(f'Unsupported architecture: {self.architecture}')
self.model = architectures[self.architecture](
encoder_name=self.encoder,
encoder_weights=self.weights,
classes=self.n_classes,
activation=self.activation
)
@property
def config_data(self):
"""Returns model configuration data."""
return {
'architecture': self.architecture,
'encoder': self.encoder,
'weights': self.weights,
'activation': self.activation,
'loss': self.loss_name,
'classes': ['background'] + self.classes if self.background_flag else self.classes
}
def list_architectures():
"""Returns available architecture options."""
return ['unet', 'unet++', 'deeplabv3', 'deeplabv3+', 'fpn',
'linknet', 'manet', 'pan', 'pspnet']