|
from .unet import Unet |
|
from .unetplusplus import UnetPlusPlus |
|
from .manet import MAnet |
|
from .linknet import Linknet |
|
from .fpn import FPN |
|
from .pspnet import PSPNet |
|
from .deeplabv3 import DeepLabV3, DeepLabV3Plus |
|
from .pan import PAN |
|
|
|
from . import encoders |
|
from . import utils |
|
from . import losses |
|
|
|
from .__version__ import __version__ |
|
|
|
from typing import Optional |
|
import torch |
|
|
|
|
|
def create_model( |
|
arch: str, |
|
encoder_name: str = "resnet34", |
|
encoder_weights: Optional[str] = "imagenet", |
|
in_channels: int = 3, |
|
classes: int = 1, |
|
**kwargs, |
|
) -> torch.nn.Module: |
|
"""Models wrapper. Allows to create any model just with parametes |
|
|
|
""" |
|
|
|
archs = [Unet, UnetPlusPlus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN] |
|
archs_dict = {a.__name__.lower(): a for a in archs} |
|
try: |
|
model_class = archs_dict[arch.lower()] |
|
except KeyError: |
|
raise KeyError("Wrong architecture type `{}`. Available options are: {}".format( |
|
arch, list(archs_dict.keys()), |
|
)) |
|
return model_class( |
|
encoder_name=encoder_name, |
|
encoder_weights=encoder_weights, |
|
in_channels=in_channels, |
|
classes=classes, |
|
**kwargs, |
|
) |