File size: 6,234 Bytes
039daa1
 
 
 
 
567a505
039daa1
 
 
 
 
 
 
 
 
 
 
 
 
 
a09c701
039daa1
 
 
 
 
 
a09c701
 
039daa1
34c4cb0
 
039daa1
 
 
 
 
 
 
 
 
 
 
 
 
 
a09c701
 
039daa1
 
 
 
 
 
567a505
 
 
 
 
039daa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a09c701
039daa1
 
 
 
 
 
 
a09c701
039daa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a09c701
 
 
7cdd27c
a09c701
 
 
 
 
 
 
 
 
 
 
039daa1
 
a09c701
039daa1
 
 
a09c701
 
039daa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
import pathlib
import importlib.util
import safetensors.torch
import matplotlib.pyplot as plt
import math

from typing import Literal


def load_model_module(model_path: pathlib.Path):
    model_path = model_path.resolve()

    spec = importlib.util.spec_from_file_location("model", model_path)
    model = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(model)

    return model

class EnsembleModel(torch.nn.Module):
    def __init__(self, model1, model2, model3, model4, model5, model6, mode="max"):
        super(EnsembleModel, self).__init__()
        self.model1 = model1
        self.model2 = model2
        self.model3 = model3
        self.model4 = model4
        self.model5 = model5
        self.model6 = model6
        self.models = [model1, model2, model3, model4, model5, model6]
        self.mode = mode
        if mode not in ["min", "mean", "max", "none"]:
            raise ValueError("Mode must be 'none', 'min', 'mean', or 'max'.")
        
    def forward(self, x):
        outputs = []
        for model in self.models:
            output = model(x)
            outputs.append(output)
        
        # Average the outputs
        if self.mode == "max":
            output_probs = torch.max(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
        elif self.mode == "mean":
            output_probs = torch.mean(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
        elif self.mode == "min":
            output_probs = torch.min(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
        elif self.mode == "none":
            return torch.cat(outputs, dim=1)
        else:
            raise ValueError("Mode must be 'min', 'mean', or 'max'.")

        # Kind of uncertainty
        std_output = torch.std(torch.cat(outputs, dim=1), dim=1)[0].squeeze()

        # Normalize the standard deviation [0 - 1]
        N = len(outputs)
        std_max = math.sqrt(0.25 * N / (N - 1))
        std_output = std_output / std_max

        return output_probs, std_output


# MLSTAC API -----------------------------------------------------------------------
def example_data(path: pathlib.Path, device = "cpu", *args, **kwargs):
    data_f = path / "example_data.safetensor"
    sample = safetensors.torch.load_file(data_f)
    return  sample["image"].float().unsqueeze(0).to(device)

def trainable_model(*args, **kwargs):
    print("The model is not available in training mode.")
    return None

def compiled_model(path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max",*args, **kwargs):
    
    model1_f = path / "1dpwdeeplabv3.safetensor"
    model2_f = path / "1dpwseg.safetensor"
    model3_f = path / "1dpwunetpp.safetensor"
    model4_f = path / "unet.safetensor"
    model5_f = path / "unetpp.safetensor"
    model6_f = path / "c2r1km.safetensor"

    # Load model parameters
    model1_weights = safetensors.torch.load_file(model1_f)
    model2_weights = safetensors.torch.load_file(model2_f)
    model3_weights = safetensors.torch.load_file(model3_f)
    model4_weights = safetensors.torch.load_file(model4_f)
    model5_weights = safetensors.torch.load_file(model5_f)
    model6_weights = safetensors.torch.load_file(model6_f)

    # Load all models

    # Model 1 (DeepLabV3Branch + PixelWise)
    model1 = load_model_module(path / "model.py").CombinedNet4(
        classes=1, benchmark=True, in_channels=4
    )
    model1.load_state_dict(model1_weights)
    model1 = model1.to(device)
    for param in model1.parameters():
        param.requires_grad = False
    model1 = model1.eval()
    
    
    # Model 2 (SegformerBranch + PixelWise)
    model2 = load_model_module(path / "model.py").CombinedNet(
        classes=1, benchmark=True
    )
    model2.load_state_dict(model2_weights)
    model2 = model2.to(device)
    for param in model2.parameters():
        param.requires_grad = False
    model2 = model2.eval()
    
    # Model 3 (UNetPlusPlusBranch + PixelWise)
    model3 = load_model_module(path / "model.py").CombinedNet3(
        classes=1, benchmark=True
    )
    model3.load_state_dict(model3_weights)
    model3 = model3.to(device)
    for param in model3.parameters():
        param.requires_grad = False
    model3 = model3.eval()

    # Model 4 (UNetBranch)
    model4 = load_model_module(path / "model.py").UNetBranch(
        classes=1, benchmark=True
    )    
    model4.load_state_dict(model4_weights)
    model4 = model4.to(device)
    for param in model4.parameters():
        param.requires_grad = False
    model4 = model4.eval()

    # Model 5 (UNetPlusPlusBranch)
    model5 = load_model_module(path / "model.py").UNetPlusPlusBranch(
        classes=1, benchmark=True
    )
    model5.load_state_dict(model5_weights)
    model5 = model5.to(device)
    for param in model5.parameters():
        param.requires_grad = False
    model5 = model5.eval()

    # Model 6 (C2R1KM)
    model6 = load_model_module(path / "c2r1km.py").CloudMaskOne(
        hidden_layer_sizes=(128, 112),
        activation='relu',
        last_activation='sigmoid',
        dropout_rate=0.1,
        input_dim=40,
        batch_norm=False
    )
    model6.load_state_dict(model6_weights)
    model6 = model6.to(device)
    for param in model6.parameters():
        param.requires_grad = False
    model6 = model6.eval()
    
    # Create ensemble model
    cloud_model = EnsembleModel(model1, model2, model3, model4, model5, model6, mode=mode)

    return cloud_model



def display_results(path: pathlib.Path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max", *args, **kwargs):
    # Load model
    model = compiled_model(path, device, mode=mode)

    # Load data
    probav = example_data(path)

    # Run model
    cloudprob, uncertainty = model(probav.float().to(device))

    #Display results
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))
    ax[0].imshow(probav[0, [2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0))
    ax[0].set_title("Input")
    ax[1].imshow(cloudprob.cpu().detach().numpy(), cmap="gray")
    ax[1].set_title("Cloud Probability")
    ax[2].imshow(uncertainty.cpu().detach().numpy(), cmap="gray")
    ax[2].set_title("Uncertainty")
    for a in ax:
        a.axis("off")    
    fig.tight_layout()    
    return fig