joselobenitezg's picture
add files
5e4b3a1
raw
history blame contribute delete
712 Bytes
from torch import nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss
ENCODER = 'timm-efficientnet-b0'
WEIGHTS = 'imagenet'
class SegmentationModel(nn.Module):
def __init__(self):
super(SegmentationModel, self).__init__()
self.arc = smp.Unet(
encoder_name = ENCODER,
encoder_weights = WEIGHTS,
in_channels = 3,
classes = 1,
activation = None
)
def forward(self, images, masks = None):
logits = self.arc(images)
if masks != None:
loss1 = DiceLoss(mode='binary')(logits, masks)
loss2 = nn.BCEWithLogitsLoss()(logits, masks)
return logits, loss1 + loss2
return logits