File size: 4,904 Bytes
8e5d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils

from SemanticModel.encoder_management import initialize_encoder
from SemanticModel.custom_losses import FocalLossFunction, TverskyLossFunction, EnhancedCrossEntropy
from SemanticModel.image_preprocessing import get_preprocessing_pipeline

class SegmentationModel:
    def __init__(self, classes=['background', 'foreground'], architecture='unet',
                 encoder='timm-regnety_120', weights='imagenet', loss=None):
        self._initialize_classes(classes)
        self.architecture = architecture
        self.encoder = encoder
        self.weights = weights
        self._setup_loss_function(loss)
        self._initialize_model()

    def _initialize_classes(self, classes):
        """Sets up class configuration."""
        if len(classes) <= 2:
            self.classes = [c for c in classes if c.lower() != 'background']
            self.class_values = [i for i, c in enumerate(classes) if c.lower() != 'background']
            self.background_flag = 'background' in classes
        else:
            self.classes = classes
            self.class_values = list(range(len(classes)))
            self.background_flag = False
        self.n_classes = len(self.classes)

    def _setup_loss_function(self, loss):
        """Configures model's loss function."""
        if not loss:
            loss = 'bce_with_logits' if self.n_classes > 1 else 'dice'
        
        if loss.lower() not in ['dice', 'bce_with_logits', 'focal', 'tversky']:
            print(f'Invalid loss: {loss}, defaulting to dice')
            loss = 'dice'

        loss_configs = {
            'bce_with_logits': {
                'activation': None,
                'loss': EnhancedCrossEntropy() if self.n_classes > 1 else utils.losses.BCEWithLogitsLoss()
            },
            'dice': {
                'activation': 'softmax' if self.n_classes > 1 else 'sigmoid',
                'loss': utils.losses.DiceLoss()
            },
            'focal': {
                'activation': None,
                'loss': FocalLossFunction()
            },
            'tversky': {
                'activation': None,
                'loss': TverskyLossFunction()
            }
        }

        config = loss_configs[loss.lower()]
        self.activation = config['activation']
        self.loss = config['loss']
        self.loss_name = loss

    def _initialize_model(self):
        """Initializes the segmentation model architecture."""
        if self.weights.endswith('pth'):
            self._load_pretrained_model()
        else:
            self._create_new_model()

    def _load_pretrained_model(self):
        """Loads model from pretrained weights."""
        print('Loading pretrained model...')
        self.model = torch.load(self.weights)
        if isinstance(self.model, torch.nn.DataParallel):
            self.model = self.model.module

        try:
            preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
            self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
        except:
            print('Failed to configure preprocessing. Setting to None.')
            self.preprocessing = None

    def _create_new_model(self):
        """Creates new model with specified architecture."""
        preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
        self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
        initialize_encoder(name=self.encoder, weights=self.weights)

        architectures = {
            'unet': smp.Unet,
            'unet++': smp.UnetPlusPlus,
            'deeplabv3': smp.DeepLabV3,
            'deeplabv3+': smp.DeepLabV3Plus,
            'fpn': smp.FPN,
            'linknet': smp.Linknet,
            'manet': smp.MAnet,
            'pan': smp.PAN,
            'pspnet': smp.PSPNet
        }

        if self.architecture not in architectures:
            raise ValueError(f'Unsupported architecture: {self.architecture}')

        self.model = architectures[self.architecture](
            encoder_name=self.encoder,
            encoder_weights=self.weights,
            classes=self.n_classes,
            activation=self.activation
        )

    @property
    def config_data(self):
        """Returns model configuration data."""
        return {
            'architecture': self.architecture,
            'encoder': self.encoder,
            'weights': self.weights,
            'activation': self.activation,
            'loss': self.loss_name,
            'classes': ['background'] + self.classes if self.background_flag else self.classes
        }

def list_architectures():
    """Returns available architecture options."""
    return ['unet', 'unet++', 'deeplabv3', 'deeplabv3+', 'fpn', 
            'linknet', 'manet', 'pan', 'pspnet']