| import torch | |
| import torch.nn as nn | |
| from segmentation_models_pytorch import MAnet | |
| from segmentation_models_pytorch.base.modules import Activation | |
| class SegformerGH(MAnet): | |
| def __init__( | |
| self, | |
| encoder_name: str = "mit_b5", | |
| encoder_weights="imagenet", | |
| decoder_channels=(256, 128, 64, 32, 32), | |
| decoder_pab_channels=256, | |
| in_channels: int = 3, | |
| classes: int = 3, | |
| ): | |
| super(SegformerGH, self).__init__( | |
| encoder_name=encoder_name, | |
| encoder_weights=encoder_weights, | |
| decoder_channels=decoder_channels, | |
| decoder_pab_channels=decoder_pab_channels, | |
| in_channels=in_channels, | |
| classes=classes, | |
| ) | |
| convert_relu_to_mish(self.encoder) | |
| convert_relu_to_mish(self.decoder) | |
| self.cellprob_head = DeepSegmantationHead( | |
| in_channels=decoder_channels[-1], out_channels=1, kernel_size=3, | |
| ) | |
| self.gradflow_head = DeepSegmantationHead( | |
| in_channels=decoder_channels[-1], out_channels=2, kernel_size=3, | |
| ) | |
| def forward(self, x): | |
| """Sequentially pass `x` trough model`s encoder, decoder and heads""" | |
| self.check_input_shape(x) | |
| features = self.encoder(x) | |
| decoder_output = self.decoder(*features) | |
| gradflow_mask = self.gradflow_head(decoder_output) | |
| cellprob_mask = self.cellprob_head(decoder_output) | |
| masks = torch.cat([gradflow_mask, cellprob_mask], dim=1) | |
| return masks | |
| class DeepSegmantationHead(nn.Sequential): | |
| def __init__( | |
| self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1 | |
| ): | |
| conv2d_1 = nn.Conv2d( | |
| in_channels, | |
| in_channels // 2, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| ) | |
| bn = nn.BatchNorm2d(in_channels // 2) | |
| conv2d_2 = nn.Conv2d( | |
| in_channels // 2, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| padding=kernel_size // 2, | |
| ) | |
| mish = nn.Mish(inplace=True) | |
| upsampling = ( | |
| nn.UpsamplingBilinear2d(scale_factor=upsampling) | |
| if upsampling > 1 | |
| else nn.Identity() | |
| ) | |
| activation = Activation(activation) | |
| super().__init__(conv2d_1, mish, bn, conv2d_2, upsampling, activation) | |
| def convert_relu_to_mish(model): | |
| for child_name, child in model.named_children(): | |
| if isinstance(child, nn.ReLU): | |
| setattr(model, child_name, nn.Mish(inplace=True)) | |
| else: | |
| convert_relu_to_mish(child) | |
| if __name__ == "__main__": | |
| model = SegformerGH( | |
| encoder_name="mit_b5", | |
| encoder_weights=None, | |
| decoder_channels=(1024, 512, 256, 128, 64), | |
| decoder_pab_channels=256, | |
| in_channels=3, | |
| classes=3, | |
| ) | |
| model.load_state_dict(torch.load("./main_model.pth",map_location="cpu")) | |
| torch.save(model, "main_model.pt") | |