File size: 432 Bytes
ab4b0b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from torch import nn
import segmentation_models_pytorch as smp


class Net(nn.Module):
    def __init__(self, class_num, in_channels=4, encoder_name="resnet34"):
        super(Net, self).__init__()
        self.net = smp.deeplabv3.DeepLabV3Plus(
            in_channels=in_channels,
            classes=class_num,
            encoder_name=encoder_name,
        )

    def forward(self, x):
        x = self.net(x)

        return x