Spaces:
Runtime error
Runtime error
| from torch import nn | |
| import segmentation_models_pytorch as smp | |
| from segmentation_models_pytorch.losses import DiceLoss | |
| ENCODER = 'timm-efficientnet-b0' | |
| WEIGHTS = 'imagenet' | |
| class SegmentationModel(nn.Module): | |
| def __init__(self): | |
| super(SegmentationModel, self).__init__() | |
| self.arc = smp.Unet( | |
| encoder_name = ENCODER, | |
| encoder_weights = WEIGHTS, | |
| in_channels = 3, | |
| classes = 1, | |
| activation = None | |
| ) | |
| def forward(self, images, masks = None): | |
| logits = self.arc(images) | |
| if masks != None: | |
| loss1 = DiceLoss(mode='binary')(logits, masks) | |
| loss2 = nn.BCEWithLogitsLoss()(logits, masks) | |
| return logits, loss1 + loss2 | |
| return logits | |