File size: 6,234 Bytes
039daa1 567a505 039daa1 a09c701 039daa1 a09c701 039daa1 34c4cb0 039daa1 a09c701 039daa1 567a505 039daa1 a09c701 039daa1 a09c701 039daa1 a09c701 7cdd27c a09c701 039daa1 a09c701 039daa1 a09c701 039daa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import torch
import pathlib
import importlib.util
import safetensors.torch
import matplotlib.pyplot as plt
import math
from typing import Literal
def load_model_module(model_path: pathlib.Path):
model_path = model_path.resolve()
spec = importlib.util.spec_from_file_location("model", model_path)
model = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model)
return model
class EnsembleModel(torch.nn.Module):
def __init__(self, model1, model2, model3, model4, model5, model6, mode="max"):
super(EnsembleModel, self).__init__()
self.model1 = model1
self.model2 = model2
self.model3 = model3
self.model4 = model4
self.model5 = model5
self.model6 = model6
self.models = [model1, model2, model3, model4, model5, model6]
self.mode = mode
if mode not in ["min", "mean", "max", "none"]:
raise ValueError("Mode must be 'none', 'min', 'mean', or 'max'.")
def forward(self, x):
outputs = []
for model in self.models:
output = model(x)
outputs.append(output)
# Average the outputs
if self.mode == "max":
output_probs = torch.max(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
elif self.mode == "mean":
output_probs = torch.mean(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
elif self.mode == "min":
output_probs = torch.min(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
elif self.mode == "none":
return torch.cat(outputs, dim=1)
else:
raise ValueError("Mode must be 'min', 'mean', or 'max'.")
# Kind of uncertainty
std_output = torch.std(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
# Normalize the standard deviation [0 - 1]
N = len(outputs)
std_max = math.sqrt(0.25 * N / (N - 1))
std_output = std_output / std_max
return output_probs, std_output
# MLSTAC API -----------------------------------------------------------------------
def example_data(path: pathlib.Path, device = "cpu", *args, **kwargs):
data_f = path / "example_data.safetensor"
sample = safetensors.torch.load_file(data_f)
return sample["image"].float().unsqueeze(0).to(device)
def trainable_model(*args, **kwargs):
print("The model is not available in training mode.")
return None
def compiled_model(path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max",*args, **kwargs):
model1_f = path / "1dpwdeeplabv3.safetensor"
model2_f = path / "1dpwseg.safetensor"
model3_f = path / "1dpwunetpp.safetensor"
model4_f = path / "unet.safetensor"
model5_f = path / "unetpp.safetensor"
model6_f = path / "c2r1km.safetensor"
# Load model parameters
model1_weights = safetensors.torch.load_file(model1_f)
model2_weights = safetensors.torch.load_file(model2_f)
model3_weights = safetensors.torch.load_file(model3_f)
model4_weights = safetensors.torch.load_file(model4_f)
model5_weights = safetensors.torch.load_file(model5_f)
model6_weights = safetensors.torch.load_file(model6_f)
# Load all models
# Model 1 (DeepLabV3Branch + PixelWise)
model1 = load_model_module(path / "model.py").CombinedNet4(
classes=1, benchmark=True, in_channels=4
)
model1.load_state_dict(model1_weights)
model1 = model1.to(device)
for param in model1.parameters():
param.requires_grad = False
model1 = model1.eval()
# Model 2 (SegformerBranch + PixelWise)
model2 = load_model_module(path / "model.py").CombinedNet(
classes=1, benchmark=True
)
model2.load_state_dict(model2_weights)
model2 = model2.to(device)
for param in model2.parameters():
param.requires_grad = False
model2 = model2.eval()
# Model 3 (UNetPlusPlusBranch + PixelWise)
model3 = load_model_module(path / "model.py").CombinedNet3(
classes=1, benchmark=True
)
model3.load_state_dict(model3_weights)
model3 = model3.to(device)
for param in model3.parameters():
param.requires_grad = False
model3 = model3.eval()
# Model 4 (UNetBranch)
model4 = load_model_module(path / "model.py").UNetBranch(
classes=1, benchmark=True
)
model4.load_state_dict(model4_weights)
model4 = model4.to(device)
for param in model4.parameters():
param.requires_grad = False
model4 = model4.eval()
# Model 5 (UNetPlusPlusBranch)
model5 = load_model_module(path / "model.py").UNetPlusPlusBranch(
classes=1, benchmark=True
)
model5.load_state_dict(model5_weights)
model5 = model5.to(device)
for param in model5.parameters():
param.requires_grad = False
model5 = model5.eval()
# Model 6 (C2R1KM)
model6 = load_model_module(path / "c2r1km.py").CloudMaskOne(
hidden_layer_sizes=(128, 112),
activation='relu',
last_activation='sigmoid',
dropout_rate=0.1,
input_dim=40,
batch_norm=False
)
model6.load_state_dict(model6_weights)
model6 = model6.to(device)
for param in model6.parameters():
param.requires_grad = False
model6 = model6.eval()
# Create ensemble model
cloud_model = EnsembleModel(model1, model2, model3, model4, model5, model6, mode=mode)
return cloud_model
def display_results(path: pathlib.Path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max", *args, **kwargs):
# Load model
model = compiled_model(path, device, mode=mode)
# Load data
probav = example_data(path)
# Run model
cloudprob, uncertainty = model(probav.float().to(device))
#Display results
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(probav[0, [2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0))
ax[0].set_title("Input")
ax[1].imshow(cloudprob.cpu().detach().numpy(), cmap="gray")
ax[1].set_title("Cloud Probability")
ax[2].imshow(uncertainty.cpu().detach().numpy(), cmap="gray")
ax[2].set_title("Uncertainty")
for a in ax:
a.axis("off")
fig.tight_layout()
return fig |