chesapeakersc / README.md
isaaccorley's picture
Update README.md
c53d07d verified
metadata
license: mit

The model can be loaded like below:

import torch
import torchvision.transforms.v2 as T
import segmentation_models_pytorch as smp

checkpoint = "unet-resnet18.pt"
model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights=None,
    in_channels=4,
    classes=2,
)
model.load_state_dict(checkpoint, map_location="cpu")
transforms = torch.nn.Sequential(T.Normalize(mean=[0.0], std=[255.0]))