File size: 4,904 Bytes
8e5d8c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils
from SemanticModel.encoder_management import initialize_encoder
from SemanticModel.custom_losses import FocalLossFunction, TverskyLossFunction, EnhancedCrossEntropy
from SemanticModel.image_preprocessing import get_preprocessing_pipeline
class SegmentationModel:
def __init__(self, classes=['background', 'foreground'], architecture='unet',
encoder='timm-regnety_120', weights='imagenet', loss=None):
self._initialize_classes(classes)
self.architecture = architecture
self.encoder = encoder
self.weights = weights
self._setup_loss_function(loss)
self._initialize_model()
def _initialize_classes(self, classes):
"""Sets up class configuration."""
if len(classes) <= 2:
self.classes = [c for c in classes if c.lower() != 'background']
self.class_values = [i for i, c in enumerate(classes) if c.lower() != 'background']
self.background_flag = 'background' in classes
else:
self.classes = classes
self.class_values = list(range(len(classes)))
self.background_flag = False
self.n_classes = len(self.classes)
def _setup_loss_function(self, loss):
"""Configures model's loss function."""
if not loss:
loss = 'bce_with_logits' if self.n_classes > 1 else 'dice'
if loss.lower() not in ['dice', 'bce_with_logits', 'focal', 'tversky']:
print(f'Invalid loss: {loss}, defaulting to dice')
loss = 'dice'
loss_configs = {
'bce_with_logits': {
'activation': None,
'loss': EnhancedCrossEntropy() if self.n_classes > 1 else utils.losses.BCEWithLogitsLoss()
},
'dice': {
'activation': 'softmax' if self.n_classes > 1 else 'sigmoid',
'loss': utils.losses.DiceLoss()
},
'focal': {
'activation': None,
'loss': FocalLossFunction()
},
'tversky': {
'activation': None,
'loss': TverskyLossFunction()
}
}
config = loss_configs[loss.lower()]
self.activation = config['activation']
self.loss = config['loss']
self.loss_name = loss
def _initialize_model(self):
"""Initializes the segmentation model architecture."""
if self.weights.endswith('pth'):
self._load_pretrained_model()
else:
self._create_new_model()
def _load_pretrained_model(self):
"""Loads model from pretrained weights."""
print('Loading pretrained model...')
self.model = torch.load(self.weights)
if isinstance(self.model, torch.nn.DataParallel):
self.model = self.model.module
try:
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
except:
print('Failed to configure preprocessing. Setting to None.')
self.preprocessing = None
def _create_new_model(self):
"""Creates new model with specified architecture."""
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
initialize_encoder(name=self.encoder, weights=self.weights)
architectures = {
'unet': smp.Unet,
'unet++': smp.UnetPlusPlus,
'deeplabv3': smp.DeepLabV3,
'deeplabv3+': smp.DeepLabV3Plus,
'fpn': smp.FPN,
'linknet': smp.Linknet,
'manet': smp.MAnet,
'pan': smp.PAN,
'pspnet': smp.PSPNet
}
if self.architecture not in architectures:
raise ValueError(f'Unsupported architecture: {self.architecture}')
self.model = architectures[self.architecture](
encoder_name=self.encoder,
encoder_weights=self.weights,
classes=self.n_classes,
activation=self.activation
)
@property
def config_data(self):
"""Returns model configuration data."""
return {
'architecture': self.architecture,
'encoder': self.encoder,
'weights': self.weights,
'activation': self.activation,
'loss': self.loss_name,
'classes': ['background'] + self.classes if self.background_flag else self.classes
}
def list_architectures():
"""Returns available architecture options."""
return ['unet', 'unet++', 'deeplabv3', 'deeplabv3+', 'fpn',
'linknet', 'manet', 'pan', 'pspnet'] |