Unet / src /models /unet /resunet.py
dawn17's picture
Upload 35 files
bcc0f94
import torch.nn as nn
from .encoder import ResnetEncoder as Encoder
from .decoder import CustomDecoder as Decoder
class UNet(nn.Module):
def __init__(self, decoder_config, nclasses, input_shape=(224, 224)):
super(UNet, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder(config=decoder_config)
self.output = nn.Sequential(
nn.Conv2d(
in_channels=decoder_config["block1"]["out_channels"],
out_channels=nclasses,
kernel_size=1,
),
nn.UpsamplingBilinear2d(size=input_shape),
)
def forward(self, x):
x, encoder_step_output = self.encoder(x)
x = self.decoder(x, encoder_step_output)
x = self.output(x)
return x
if __name__ == "__main__":
import torch
import yaml
from easydict import EasyDict
from torchinfo import summary
# load config
config_path = "./config/resnet_config.yml"
with open(config_path, "r") as file:
yaml_data = yaml.safe_load(file)
config = EasyDict(yaml_data)
# input shape
input_shape = (224, 224)
# device
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# model definition
model = UNet(
decoder_config=config["decoder_config"], nclasses=1, input_shape=input_shape
).to(device)
summary(
model,
input_data=torch.rand((1, 3, input_shape[0], input_shape[1])),
device=device,
)
# load weights (if any)
model_path = None
if model_path is not None:
checkpoint = torch.load(model_path, map_location=device)
model.decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=False)
model.output.load_state_dict(checkpoint["output_state_dict"], strict=False)