|
import torch.nn as nn |
|
from .encoder import ResnetEncoder as Encoder |
|
from .decoder import CustomDecoder as Decoder |
|
|
|
|
|
class UNet(nn.Module): |
|
def __init__(self, decoder_config, nclasses, input_shape=(224, 224)): |
|
super(UNet, self).__init__() |
|
self.encoder = Encoder() |
|
self.decoder = Decoder(config=decoder_config) |
|
|
|
self.output = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=decoder_config["block1"]["out_channels"], |
|
out_channels=nclasses, |
|
kernel_size=1, |
|
), |
|
nn.UpsamplingBilinear2d(size=input_shape), |
|
) |
|
|
|
def forward(self, x): |
|
x, encoder_step_output = self.encoder(x) |
|
x = self.decoder(x, encoder_step_output) |
|
x = self.output(x) |
|
return x |
|
|
|
|
|
if __name__ == "__main__": |
|
import torch |
|
import yaml |
|
from easydict import EasyDict |
|
from torchinfo import summary |
|
|
|
|
|
config_path = "./config/resnet_config.yml" |
|
|
|
with open(config_path, "r") as file: |
|
yaml_data = yaml.safe_load(file) |
|
|
|
config = EasyDict(yaml_data) |
|
|
|
|
|
input_shape = (224, 224) |
|
|
|
|
|
use_cuda = torch.cuda.is_available() |
|
device = torch.device("cuda" if use_cuda else "cpu") |
|
|
|
|
|
model = UNet( |
|
decoder_config=config["decoder_config"], nclasses=1, input_shape=input_shape |
|
).to(device) |
|
|
|
summary( |
|
model, |
|
input_data=torch.rand((1, 3, input_shape[0], input_shape[1])), |
|
device=device, |
|
) |
|
|
|
|
|
model_path = None |
|
|
|
if model_path is not None: |
|
checkpoint = torch.load(model_path, map_location=device) |
|
model.decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=False) |
|
model.output.load_state_dict(checkpoint["output_state_dict"], strict=False) |
|
|