unet_plus_plus / hf_model.py
voitl's picture
Upload HFUnetPlusPlus
1150876
raw
history blame contribute delete
630 Bytes
import segmentation_models_pytorch as smp
from .hf_config import UnetConfig
from transformers import PreTrainedModel
class HFUnetPlusPlus(PreTrainedModel):
config_class = UnetConfig
def __init__(self, config):
super().__init__(config)
self.model = smp.UnetPlusPlus(
encoder_name=config.encoder_name,
encoder_weights="imagenet",
decoder_channels=config.decoder_channels,
in_channels=config.input_channels,
classes=config.num_classes,
decoder_attention_type="scse")
def forward(self, tensor):
return self.model(tensor)