|
import pathlib |
|
import safetensors.torch |
|
import segmentation_models_pytorch as smp |
|
import matplotlib.pyplot as plt |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class SegformerBranch(nn.Module): |
|
def __init__(self, in_channels=4, classes=4): |
|
super(SegformerBranch, self).__init__() |
|
self.segformer = smp.Segformer( |
|
encoder_name="mobilenet_v2", |
|
encoder_weights=None, |
|
in_channels=in_channels, |
|
classes=classes, |
|
) |
|
|
|
def forward(self, x): |
|
return self.segformer(x) |
|
|
|
|
|
class PixelWiseNet(nn.Module): |
|
def __init__(self, in_channels=4, out_channels=4, base_channels=32): |
|
super(PixelWiseNet, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=1, bias=False) |
|
self.bn1 = nn.BatchNorm2d(base_channels) |
|
self.conv2 = nn.Conv2d(base_channels, base_channels, kernel_size=1, bias=False) |
|
self.bn2 = nn.BatchNorm2d(base_channels) |
|
self.conv3 = nn.Conv2d(base_channels, out_channels, kernel_size=1, bias=False) |
|
|
|
def forward(self, x): |
|
x = F.relu(self.bn1(self.conv1(x))) |
|
x = F.relu(self.bn2(self.conv2(x))) |
|
x = self.conv3(x) |
|
return x |
|
|
|
class CombinedNet(nn.Module): |
|
def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): |
|
super(CombinedNet, self).__init__() |
|
self.seg_branch = SegformerBranch(in_channels=in_channels, classes=classes) |
|
self.pixel_branch = PixelWiseNet(in_channels=in_channels, out_channels=classes, base_channels=base_channels) |
|
self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) |
|
self.benchmark = benchmark |
|
|
|
def forward(self, x): |
|
seg_out = self.seg_branch(x) |
|
pixel_out = self.pixel_branch(x) |
|
fused = seg_out + pixel_out |
|
out = self.fusion_conv(fused) |
|
if self.benchmark: |
|
out = torch.sigmoid(out) |
|
return out |
|
|
|
|
|
|
|
|
|
def example_data(path: pathlib.Path, device = "cpu", *args, **kwargs): |
|
data_f = path / "example_data.safetensor" |
|
sample = safetensors.torch.load_file(data_f) |
|
return sample["image"].float().unsqueeze(0).to(device) |
|
|
|
def trainable_model(path, device: str = "cpu", *args, **kwargs): |
|
trainable_f = path / "model.safetensor" |
|
|
|
|
|
cloud_model_weights = safetensors.torch.load_file(trainable_f) |
|
cloud_model = CombinedNet(classes=1) |
|
cloud_model.load_state_dict(cloud_model_weights) |
|
cloud_model = cloud_model.eval() |
|
|
|
return cloud_model |
|
|
|
|
|
def compiled_model(path, device: str = "cpu", *args, **kwargs): |
|
trainable_f = path / "model.safetensor" |
|
|
|
|
|
cloud_model_weights = safetensors.torch.load_file(trainable_f) |
|
cloud_model = CombinedNet(classes=1, benchmark=True) |
|
cloud_model.load_state_dict(cloud_model_weights) |
|
cloud_model = cloud_model.eval() |
|
|
|
|
|
cloud_model = cloud_model.to(device) |
|
|
|
|
|
for param in cloud_model.parameters(): |
|
param.requires_grad = False |
|
|
|
return cloud_model |
|
|
|
|
|
def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs): |
|
|
|
model = compiled_model(path, device, benchmark=True) |
|
|
|
|
|
probav = example_data(path) |
|
|
|
|
|
cloudprobs = model(probav).squeeze().cpu() |
|
|
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(8, 4)) |
|
ax[0].imshow(probav[0, [2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0)) |
|
ax[0].set_title("Input") |
|
ax[1].imshow(cloudprobs.cpu().detach().numpy(), cmap="gray") |
|
ax[1].set_title("Output") |
|
for a in ax: |
|
a.axis("off") |
|
fig.tight_layout() |
|
return fig |