|
import segmentation_models_pytorch as smp |
|
from .hf_config import UnetConfig |
|
from transformers import PreTrainedModel |
|
|
|
|
|
class HFUnetPlusPlus(PreTrainedModel): |
|
config_class = UnetConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.model = smp.UnetPlusPlus( |
|
encoder_name=config.encoder_name, |
|
encoder_weights="imagenet", |
|
decoder_channels=config.decoder_channels, |
|
in_channels=config.input_channels, |
|
classes=config.num_classes, |
|
decoder_attention_type="scse") |
|
|
|
def forward(self, tensor): |
|
return self.model(tensor) |
|
|