import torch import torch.nn as nn from segmentation_models_pytorch import MAnet from segmentation_models_pytorch.base.modules import Activation class SegformerGH(MAnet): def __init__( self, encoder_name: str = "mit_b5", encoder_weights="imagenet", decoder_channels=(256, 128, 64, 32, 32), decoder_pab_channels=256, in_channels: int = 3, classes: int = 3, ): super(SegformerGH, self).__init__( encoder_name=encoder_name, encoder_weights=encoder_weights, decoder_channels=decoder_channels, decoder_pab_channels=decoder_pab_channels, in_channels=in_channels, classes=classes, ) convert_relu_to_mish(self.encoder) convert_relu_to_mish(self.decoder) self.cellprob_head = DeepSegmantationHead( in_channels=decoder_channels[-1], out_channels=1, kernel_size=3, ) self.gradflow_head = DeepSegmantationHead( in_channels=decoder_channels[-1], out_channels=2, kernel_size=3, ) def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" self.check_input_shape(x) features = self.encoder(x) decoder_output = self.decoder(*features) gradflow_mask = self.gradflow_head(decoder_output) cellprob_mask = self.cellprob_head(decoder_output) masks = torch.cat([gradflow_mask, cellprob_mask], dim=1) return masks class DeepSegmantationHead(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1 ): conv2d_1 = nn.Conv2d( in_channels, in_channels // 2, kernel_size=kernel_size, padding=kernel_size // 2, ) bn = nn.BatchNorm2d(in_channels // 2) conv2d_2 = nn.Conv2d( in_channels // 2, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, ) mish = nn.Mish(inplace=True) upsampling = ( nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() ) activation = Activation(activation) super().__init__(conv2d_1, mish, bn, conv2d_2, upsampling, activation) def convert_relu_to_mish(model): for child_name, child in model.named_children(): if isinstance(child, nn.ReLU): setattr(model, child_name, nn.Mish(inplace=True)) else: convert_relu_to_mish(child) if __name__ == "__main__": model = SegformerGH( encoder_name="mit_b5", encoder_weights=None, decoder_channels=(1024, 512, 256, 128, 64), decoder_pab_channels=256, in_channels=3, classes=3, ) model.load_state_dict(torch.load("./main_model.pth",map_location="cpu")) torch.save(model, "main_model.pt")