csaybar's picture
Update single/load.py
102bc5b verified
import pathlib
import safetensors.torch
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
class SegformerBranch(nn.Module):
def __init__(self, in_channels=4, classes=4):
super(SegformerBranch, self).__init__()
self.segformer = smp.Segformer(
encoder_name="mobilenet_v2",
encoder_weights=None,
in_channels=in_channels,
classes=classes,
)
def forward(self, x):
return self.segformer(x)
class PixelWiseNet(nn.Module):
def __init__(self, in_channels=4, out_channels=4, base_channels=32):
super(PixelWiseNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(base_channels)
self.conv2 = nn.Conv2d(base_channels, base_channels, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(base_channels)
self.conv3 = nn.Conv2d(base_channels, out_channels, kernel_size=1, bias=False)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.conv3(x)
return x
class CombinedNet(nn.Module):
def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False):
super(CombinedNet, self).__init__()
self.seg_branch = SegformerBranch(in_channels=in_channels, classes=classes)
self.pixel_branch = PixelWiseNet(in_channels=in_channels, out_channels=classes, base_channels=base_channels)
self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False)
self.benchmark = benchmark
def forward(self, x):
seg_out = self.seg_branch(x)
pixel_out = self.pixel_branch(x)
fused = seg_out + pixel_out
out = self.fusion_conv(fused)
if self.benchmark:
out = torch.sigmoid(out)
return out
# MLSTAC API -----------------------------------------------------------------------
def example_data(path: pathlib.Path, device = "cpu", *args, **kwargs):
data_f = path / "example_data.safetensor"
sample = safetensors.torch.load_file(data_f)
return sample["image"].float().unsqueeze(0).to(device)
def trainable_model(path, device: str = "cpu", *args, **kwargs):
trainable_f = path / "model.safetensor"
# Load model parameters
cloud_model_weights = safetensors.torch.load_file(trainable_f)
cloud_model = CombinedNet(classes=1)
cloud_model.load_state_dict(cloud_model_weights)
cloud_model = cloud_model.eval()
return cloud_model
def compiled_model(path, device: str = "cpu", *args, **kwargs):
trainable_f = path / "model.safetensor"
# Load model parameters
cloud_model_weights = safetensors.torch.load_file(trainable_f)
cloud_model = CombinedNet(classes=1, benchmark=True)
cloud_model.load_state_dict(cloud_model_weights)
cloud_model = cloud_model.eval()
# Move model to device
cloud_model = cloud_model.to(device)
# Desativate gradients
for param in cloud_model.parameters():
param.requires_grad = False
return cloud_model
def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs):
# Load model
model = compiled_model(path, device, benchmark=True)
# Load data
probav = example_data(path)
# Run model
cloudprobs = model(probav).squeeze().cpu()
#Display results
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(probav[0, [2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0))
ax[0].set_title("Input")
ax[1].imshow(cloudprobs.cpu().detach().numpy(), cmap="gray")
ax[1].set_title("Output")
for a in ax:
a.axis("off")
fig.tight_layout()
return fig