harshinde commited on
Commit
ccf1001
·
verified ·
1 Parent(s): 148e0a0

Upload 14 files

Browse files
src/app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import h5py
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import yaml
7
+ import os
8
+
9
+ # Import models
10
+ from mobilenetv2_model import LandslideModel as MobileNetV2Model
11
+ from vgg16_model import LandslideModel as VGG16Model
12
+ from resnet34_model import LandslideModel as ResNet34Model
13
+ from efficientnetb0_model import LandslideModel as EfficientNetB0Model
14
+ from mitb1_model import LandslideModel as MiTB1Model
15
+ from inceptionv4_model import LandslideModel as InceptionV4Model
16
+ from densenet121_model import LandslideModel as DenseNet121Model
17
+ from deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel
18
+ from resnext50_32x4d_model import LandslideModel as ResNeXt50_32X4DModel
19
+ from se_resnet50_model import LandslideModel as SEResNet50Model
20
+ from se_resnext50_32x4d_model import LandslideModel as SEResNeXt50_32X4DModel
21
+ from segformer_model import LandslideModel as SegFormerB2Model
22
+ from inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
23
+
24
+ # Load the configuration file
25
+ config = """
26
+ model_config:
27
+ model_type: "mobilenet_v2"
28
+ in_channels: 14
29
+ num_classes: 1
30
+ encoder_weights: "imagenet"
31
+ wce_weight: 0.5
32
+
33
+ dataset_config:
34
+ num_classes: 1
35
+ num_channels: 14
36
+ channels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
37
+ normalize: False
38
+
39
+ train_config:
40
+ dataset_path: ""
41
+ checkpoint_path: "checkpoints"
42
+ seed: 42
43
+ train_val_split: 0.8
44
+ batch_size: 16
45
+ num_epochs: 100
46
+ lr: 0.001
47
+ device: "cuda:0"
48
+ save_config: True
49
+ experiment_name: "mobilenet_v2"
50
+
51
+ logging_config:
52
+ wandb_project: "l4s"
53
+ wandb_entity: "Silvamillion"
54
+ """
55
+
56
+ config = yaml.safe_load(config)
57
+
58
+ # Model descriptions
59
+ model_descriptions = {
60
+ "MobileNetV2": {"path": "mobilenetv2.pth", "type": "mobilenet_v2", "description": "MobileNetV2 is a lightweight deep learning model for image classification and segmentation."},
61
+ "VGG16": {"path": "vgg16.pth", "type": "vgg16", "description": "VGG16 is a popular deep learning model known for its simplicity and depth."},
62
+ "ResNet34": {"path": "resnet34.pth", "type": "resnet34", "description": "ResNet34 is a deep residual network that helps in training very deep networks."},
63
+ "EfficientNetB0": {"path": "effucientnetb0.pth", "type": "efficientnet_b0", "description": "EfficientNetB0 is part of the EfficientNet family, known for its efficiency and performance."},
64
+ "MiT-B1": {"path": "mitb1.pth", "type": "mit_b1", "description": "MiT-B1 is a transformer-based model designed for segmentation tasks."},
65
+ "InceptionV4": {"path": "inceptionv4.pth", "type": "inceptionv4", "description": "InceptionV4 is a convolutional neural network known for its inception modules."},
66
+ "DeepLabV3+": {"path": "deeplabv3.pth", "type": "deeplabv3+", "description": "DeepLabV3+ is an advanced model for semantic image segmentation."},
67
+ "DenseNet121": {"path": "densenet121.pth", "type": "densenet121", "description": "DenseNet121 is a densely connected convolutional network for image classification and segmentation."},
68
+ "ResNeXt50_32X4D": {"path": "resnext50-32x4d.pth", "type": "resnext50_32x4d", "description": "ResNeXt50_32X4D is a highly modularized network aimed at improving accuracy."},
69
+ "SEResNet50": {"path": "se_resnet50.pth", "type": "se_resnet50", "description": "SEResNet50 is a ResNet model with squeeze-and-excitation blocks for better feature recalibration."},
70
+ "SEResNeXt50_32X4D": {"path": "se_resnext50_32x4d.pth", "type": "se_resnext50_32x4d", "description": "SEResNeXt50_32X4D combines ResNeXt and SE blocks for improved performance."},
71
+ "SegFormerB2": {"path": "segformer.pth", "type": "segformer_b2", "description": "SegFormerB2 is a transformer-based model for semantic segmentation."},
72
+ "InceptionResNetV2": {"path": "inceptionresnetv2.pth", "type": "inceptionresnetv2", "description": "InceptionResNetV2 is a hybrid model combining Inception and ResNet architectures."},
73
+ }
74
+
75
+ # Streamlit app
76
+ st.set_page_config(page_title="Landslide Detection", layout="wide")
77
+
78
+ st.title("Landslide Detection")
79
+ st.markdown("""
80
+ ## Instructions
81
+ 1. Select a model from the sidebar or choose to run all models.
82
+ 2. Upload one or more `.h5` files.
83
+ 3. The app will process the files and display the input image, prediction, and overlay.
84
+ 4. You can download the prediction results.
85
+ """)
86
+
87
+ # Sidebar for model selection
88
+ st.sidebar.title("Model Selection")
89
+ model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])
90
+ if model_option == "Select a single model":
91
+ model_type = st.sidebar.selectbox("Select Model", list(model_descriptions.keys()))
92
+ config['model_config']['model_type'] = model_descriptions[model_type]['type']
93
+ if model_type == "DeepLabV3+":
94
+ model_class = DeepLabV3PlusModel
95
+ else:
96
+ model_class = locals()[model_type.replace("-", "") + "Model"]
97
+ model_path = model_descriptions[model_type]['path']
98
+
99
+ # Display model details in the sidebar
100
+ st.sidebar.markdown(f"**Model Type:** {model_descriptions[model_type]['type']}")
101
+ st.sidebar.markdown(f"**Model Path:** {model_descriptions[model_type]['path']}")
102
+ st.sidebar.markdown(f"**Description:** {model_descriptions[model_type]['description']}")
103
+
104
+ # Main content
105
+ st.header("Upload Data")
106
+ uploaded_files = st.file_uploader("Choose .h5 files...", type="h5", accept_multiple_files=True)
107
+ if uploaded_files:
108
+ for uploaded_file in uploaded_files:
109
+ st.write(f"Processing file: {uploaded_file.name}")
110
+ with st.spinner('Classifying...'):
111
+ with h5py.File(uploaded_file, 'r') as hdf:
112
+ data = np.array(hdf.get('img'))
113
+ data[np.isnan(data)] = 0.000001
114
+ channels = config["dataset_config"]["channels"]
115
+ image = np.zeros((128, 128, len(channels)))
116
+ for i, channel in enumerate(channels):
117
+ image[:, :, i] = data[:, :, channel-1]
118
+
119
+ # Transform the image to the required format
120
+ image = image.transpose((2, 0, 1)) # (H, W, C) to (C, H, W)
121
+ image = torch.from_numpy(image).float().unsqueeze(0) # Add batch dimension
122
+
123
+ if model_option == "Select a single model":
124
+ # Process the image with the selected model
125
+ st.write(f"Using model: {model_type}")
126
+
127
+ # Load the model
128
+ model = model_class(config)
129
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
130
+ model.eval()
131
+
132
+ # Make prediction
133
+ with torch.no_grad():
134
+ prediction = model(image)
135
+ prediction = torch.sigmoid(prediction).cpu().numpy()
136
+
137
+ # Display prediction
138
+ st.header(f"Prediction Results - {model_type}")
139
+ fig, ax = plt.subplots(1, 3, figsize=(15, 5))
140
+ img = image.squeeze().permute(1, 2, 0).numpy()
141
+ img = (img - img.min()) / (img.max() - img.min()) # Normalize the image to [0, 1] range for display
142
+ ax[0].imshow(img[:, :, 1:4]) # Display first three channels as RGB
143
+ ax[0].set_title("Input Image")
144
+ ax[1].imshow(prediction.squeeze() > 0.5, cmap='plasma') # Apply threshold
145
+ ax[1].set_title("Prediction")
146
+ ax[2].imshow(img[:, :, :3]) # Display first three channels as RGB
147
+ ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.3) # Overlay prediction
148
+ ax[2].set_title("Overlay")
149
+ st.pyplot(fig)
150
+
151
+ # Option to download the prediction
152
+ st.write(f"Download the prediction as a .npy file for {model_type}:")
153
+ npy_data = prediction.squeeze()
154
+ st.download_button(
155
+ label=f"Download Prediction - {model_type}",
156
+ data=npy_data.tobytes(),
157
+ file_name=f"{uploaded_file.name.split('.')[0]}_{model_type}_prediction.npy",
158
+ mime="application/octet-stream"
159
+ )
160
+
161
+ else:
162
+ # Process the image with each model
163
+ for model_name, model_info in model_descriptions.items():
164
+ st.write(f"Using model: {model_name}")
165
+ if model_name == "DeepLabV3+":
166
+ model_class = DeepLabV3PlusModel
167
+ else:
168
+ model_class = locals()[model_name.replace("-", "") + "Model"]
169
+ model_path = model_info['path']
170
+ config['model_config']['model_type'] = model_info['type']
171
+
172
+ # Load the model
173
+ model = model_class(config)
174
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
175
+ model.eval()
176
+
177
+ # Make prediction
178
+ with torch.no_grad():
179
+ prediction = model(image)
180
+ prediction = torch.sigmoid(prediction).cpu().numpy()
181
+
182
+ # Display prediction
183
+ st.header(f"Prediction Results - {model_name}")
184
+ fig, ax = plt.subplots(1, 3, figsize=(15, 5))
185
+ img = image.squeeze().permute(1, 2, 0).numpy()
186
+ img = (img - img.min()) / (img.max() - img.min()) # Normalize the image to [0, 1] range for display
187
+ ax[0].imshow(img[:, :, :3]) # Display first three channels as RGB
188
+ ax[0].set_title("Input Image")
189
+ ax[1].imshow(prediction.squeeze() > 0.5, cmap='plasma') # Apply threshold
190
+ ax[1].set_title("Prediction")
191
+ ax[2].imshow(img[:, :, :3]) # Display first three channels as RGB
192
+ ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.3) # Overlay prediction
193
+ ax[2].set_title("Overlay")
194
+ st.pyplot(fig)
195
+
196
+ # Option to download the prediction
197
+ st.write(f"Download the prediction as a .npy file for {model_name}:")
198
+ npy_data = prediction.squeeze()
199
+ st.download_button(
200
+ label=f"Download Prediction - {model_name}",
201
+ data=npy_data.tobytes(),
202
+ file_name=f"{uploaded_file.name.split('.')[0]}_{model_name}_prediction.npy",
203
+ mime="application/octet-stream"
204
+ )
205
+
206
+ st.success('Done!')
src/deeplabv3plus_model.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model, self).__init__()
13
+ if model_type == "deeplabv3+":
14
+ self.model = smp.DeepLabV3Plus(
15
+ encoder_name="resnet50", # Change this to a valid encoder
16
+ encoder_weights=encoder_weights,
17
+ in_channels=in_channels,
18
+ classes=num_classes
19
+ )
20
+ elif model_type == "unet":
21
+ self.model = smp.Unet(
22
+ encoder_name="resnet50",
23
+ encoder_weights=encoder_weights,
24
+ in_channels=in_channels,
25
+ classes=num_classes,
26
+ )
27
+ else:
28
+ raise ValueError(f"Model type {model_type} not supported!")
29
+
30
+ def forward(self, x):
31
+ return self.model(x)
32
+
33
+ class LandslideModel(pl.LightningModule):
34
+ def __init__(self, config, alpha=0.5):
35
+ super(LandslideModel, self).__init__()
36
+
37
+ model_type = config['model_config']['model_type']
38
+ in_channels = config['model_config']['in_channels']
39
+ num_classes = config['model_config']['num_classes']
40
+ self.alpha = alpha
41
+ self.lr = config['train_config']['lr']
42
+
43
+ if model_type == 'unet':
44
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
45
+ else:
46
+ encoder_weights = config['model_config']['encoder_weights']
47
+ self.model = smp_model(in_channels=in_channels,
48
+ out_channels=num_classes,
49
+ model_type=model_type,
50
+ num_classes=num_classes,
51
+ encoder_weights=encoder_weights)
52
+
53
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
54
+ self.wce = nn.BCELoss(weight=self.weights)
55
+
56
+ self.train_f1 = F1Score(task='binary')
57
+ self.val_f1 = F1Score(task='binary')
58
+
59
+ self.train_precision = Precision(task='binary')
60
+ self.val_precision = Precision(task='binary')
61
+
62
+ self.train_recall = Recall(task='binary')
63
+ self.val_recall = Recall(task='binary')
64
+
65
+ self.train_iou = JaccardIndex(task='binary')
66
+ self.val_iou = JaccardIndex(task='binary')
67
+
68
+ def forward(self, x):
69
+ return self.model(x)
70
+
71
+ def training_step(self, batch, batch_idx):
72
+ x, y = batch
73
+ y_hat = torch.sigmoid(self(x))
74
+
75
+ wce_loss = self.wce(y_hat, y)
76
+ dice = dice_loss(y_hat, y)
77
+
78
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
79
+
80
+ precision = self.train_precision(y_hat, y)
81
+ recall = self.train_recall(y_hat, y)
82
+ iou = self.train_iou(y_hat, y)
83
+ loss_f1 = self.train_f1(y_hat, y)
84
+
85
+ self.log('train_precision', precision)
86
+ self.log('train_recall', recall)
87
+ self.log('train_wce', wce_loss)
88
+ self.log('train_dice', dice)
89
+ self.log('train_iou', iou)
90
+ self.log('train_f1', loss_f1)
91
+ self.log('train_loss', combined_loss)
92
+ return {'loss': combined_loss}
93
+
94
+ def validation_step(self, batch, batch_idx):
95
+ x, y = batch
96
+ y_hat = torch.sigmoid(self(x))
97
+
98
+ wce_loss = self.wce(y_hat, y)
99
+ dice = dice_loss(y_hat, y)
100
+
101
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
102
+
103
+ precision = self.val_precision(y_hat, y)
104
+ recall = self.val_recall(y_hat, y)
105
+ iou = self.val_iou(y_hat, y)
106
+ loss_f1 = self.val_f1(y_hat, y)
107
+
108
+ self.log('val_precision', precision)
109
+ self.log('val_recall', recall)
110
+ self.log('val_wce', wce_loss)
111
+ self.log('val_dice', dice)
112
+ self.log('val_iou', iou)
113
+ self.log('val_f1', loss_f1)
114
+ self.log('val_loss', combined_loss)
115
+
116
+ if self.current_epoch % 10 == 0:
117
+ x = (x - x.min()) / (x.max() - x.min())
118
+ x = x[:, 0:3]
119
+ x = x.permute(0, 2, 3, 1)
120
+ y_hat = (y_hat > 0.5).float()
121
+
122
+ class_labels = {0: "no landslide", 1: "landslide"}
123
+
124
+ self.logger.experiment.log({
125
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
126
+ "predictions": {
127
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
128
+ "class_labels": class_labels
129
+ },
130
+ "ground_truth": {
131
+ "mask_data": y[0][0].cpu().detach().numpy(),
132
+ "class_labels": class_labels
133
+ }
134
+ })
135
+ })
136
+ return {'val_loss': combined_loss}
137
+
138
+ def configure_optimizers(self):
139
+ optimizer = Adam(self.parameters(), lr=self.lr)
140
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
141
+ return [optimizer], [scheduler]
142
+
143
+ class Block(nn.Module):
144
+ def __init__(self, inputs=3, middles=64, outs=64):
145
+ super().__init__()
146
+
147
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
148
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
149
+ self.relu = nn.ReLU()
150
+ self.bn = nn.BatchNorm2d(outs)
151
+ self.pool = nn.MaxPool2d(2, 2)
152
+
153
+ def forward(self, x):
154
+ x = self.relu(self.conv1(x))
155
+ x = self.relu(self.bn(self.conv2(x)))
156
+ return self.pool(x), x
157
+
158
+ class UNet(nn.Module):
159
+ def __init__(self, in_channels=3, out_channels=1):
160
+ super().__init__()
161
+
162
+ self.en1 = Block(in_channels, 64, 64)
163
+ self.en2 = Block(64, 128, 128)
164
+ self.en3 = Block(128, 256, 256)
165
+ self.en4 = Block(256, 512, 512)
166
+ self.en5 = Block(512, 1024, 512)
167
+
168
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
169
+ self.de4 = Block(1024, 512, 256)
170
+
171
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
172
+ self.de3 = Block(512, 256, 128)
173
+
174
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
175
+ self.de2 = Block(256, 128, 64)
176
+
177
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
178
+ self.de1 = Block(128, 64, 64)
179
+
180
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
181
+
182
+ def forward(self, x):
183
+ x, e1 = self.en1(x)
184
+ x, e2 = self.en2(x)
185
+ x, e3 = self.en3(x)
186
+ x, e4 = self.en4(x)
187
+ _, x = self.en5(x)
188
+
189
+ x = self.upsample4(x)
190
+ x = torch.cat([x, e4], dim=1)
191
+ _, x = self.de4(x)
192
+
193
+ x = self.upsample3(x)
194
+ x = torch.cat([x, e3], dim=1)
195
+ _, x = self.de3(x)
196
+
197
+ x = self.upsample2(x)
198
+ x = torch.cat([x, e2], dim=1)
199
+ _, x = self.de2(x)
200
+
201
+ x = self.upsample1(x)
202
+ x = torch.cat([x, e1], dim=1)
203
+ _, x = self.de1(x)
204
+
205
+ x = self.conv_last(x)
206
+
207
+ return x
208
+
209
+ def dice_loss(y_hat, y):
210
+ smooth = 1e-6
211
+ y_hat = y_hat.view(-1)
212
+ y = y.view(-1)
213
+ intersection = (y_hat * y).sum()
214
+ union = y_hat.sum() + y.sum()
215
+ dice = (2 * intersection + smooth) / (union + smooth)
216
+ return 1 - dice
src/densenet121_model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model, self).__init__()
13
+ self.model = smp.Unet(
14
+ encoder_name="densenet121", # This will be "densenet121"
15
+ encoder_weights=None, # Load weights manually
16
+ in_channels=in_channels,
17
+ classes=num_classes,
18
+ )
19
+
20
+ def load_pretrained_weights(self):
21
+ state_dict = torch.load('/home/hks/MOU/DenseNet121_14C_L4S/densenet121-fbdb23505-trainWeights.pth', map_location='cpu')
22
+ conv1_weight = state_dict['features.conv0.weight']
23
+ new_conv1_weight = torch.zeros((conv1_weight.shape[0], 14, *conv1_weight.shape[2:]))
24
+ new_conv1_weight[:, :3, :, :] = conv1_weight # Copy weights for the first 3 channels
25
+ state_dict['features.conv0.weight'] = new_conv1_weight
26
+ model_dict = self.model.encoder.state_dict()
27
+ model_dict.update(state_dict)
28
+ self.model.encoder.load_state_dict(model_dict)
29
+
30
+ def forward(self, x):
31
+ x = self.model(x)
32
+ return x
33
+
34
+ class LandslideModel(pl.LightningModule):
35
+ def __init__(self, config, alpha=0.5):
36
+ super(LandslideModel, self).__init__()
37
+
38
+ model_type = config['model_config']['model_type']
39
+ in_channels = config['model_config']['in_channels']
40
+ num_classes = config['model_config']['num_classes']
41
+ self.alpha = alpha
42
+ self.lr = config['train_config']['lr']
43
+
44
+ if model_type == 'unet':
45
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
46
+ else:
47
+ encoder_weights = config['model_config']['encoder_weights']
48
+ self.model = smp_model(in_channels=in_channels,
49
+ out_channels=num_classes,
50
+ model_type=model_type,
51
+ num_classes=num_classes,
52
+ encoder_weights=encoder_weights)
53
+ self.model.load_pretrained_weights()
54
+
55
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
56
+ self.wce = nn.BCELoss(weight=self.weights)
57
+
58
+ self.train_f1 = F1Score(task='binary')
59
+ self.val_f1 = F1Score(task='binary')
60
+
61
+ self.train_precision = Precision(task='binary')
62
+ self.val_precision = Precision(task='binary')
63
+
64
+ self.train_recall = Recall(task='binary')
65
+ self.val_recall = Recall(task='binary')
66
+
67
+ self.train_iou = JaccardIndex(task='binary')
68
+ self.val_iou = JaccardIndex(task='binary')
69
+
70
+ def forward(self, x):
71
+ return self.model(x)
72
+
73
+ def training_step(self, batch, batch_idx):
74
+ x, y = batch
75
+ y_hat = torch.sigmoid(self(x))
76
+
77
+ wce_loss = self.wce(y_hat, y)
78
+ dice = dice_loss(y_hat, y)
79
+
80
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
81
+
82
+ precision = self.train_precision(y_hat, y)
83
+ recall = self.train_recall(y_hat, y)
84
+ iou = self.train_iou(y_hat, y)
85
+ loss_f1 = self.train_f1(y_hat, y)
86
+
87
+ self.log('train_precision', precision)
88
+ self.log('train_recall', recall)
89
+ self.log('train_wce', wce_loss)
90
+ self.log('train_dice', dice)
91
+ self.log('train_iou', iou)
92
+ self.log('train_f1', loss_f1)
93
+ self.log('train_loss', combined_loss)
94
+ return {'loss': combined_loss}
95
+
96
+ def validation_step(self, batch, batch_idx):
97
+ x, y = batch
98
+ y_hat = torch.sigmoid(self(x))
99
+
100
+ wce_loss = self.wce(y_hat, y)
101
+ dice = dice_loss(y_hat, y)
102
+
103
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
104
+
105
+ precision = self.val_precision(y_hat, y)
106
+ recall = self.val_recall(y_hat, y)
107
+ iou = self.val_iou(y_hat, y)
108
+ loss_f1 = self.val_f1(y_hat, y)
109
+
110
+ self.log('val_precision', precision)
111
+ self.log('val_recall', recall)
112
+ self.log('val_wce', wce_loss)
113
+ self.log('val_dice', dice)
114
+ self.log('val_iou', iou)
115
+ self.log('val_f1', loss_f1)
116
+ self.log('val_loss', combined_loss)
117
+
118
+ if self.current_epoch % 10 == 0:
119
+ x = (x - x.min()) / (x.max() - x.min())
120
+ x = x[:, 0:3]
121
+ x = x.permute(0, 2, 3, 1)
122
+ y_hat = (y_hat > 0.5).float()
123
+
124
+ class_labels = {0: "no landslide", 1: "landslide"}
125
+
126
+ self.logger.experiment.log({
127
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
128
+ "predictions": {
129
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
130
+ "class_labels": class_labels
131
+ },
132
+ "ground_truth": {
133
+ "mask_data": y[0][0].cpu().detach().numpy(),
134
+ "class_labels": class_labels
135
+ }
136
+ })
137
+ })
138
+ return {'val_loss': combined_loss}
139
+
140
+ def configure_optimizers(self):
141
+ optimizer = Adam(self.parameters(), lr=self.lr)
142
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
143
+ return [optimizer], [scheduler]
144
+
145
+ class Block(nn.Module):
146
+ def __init__(self, inputs=3, middles=64, outs=64):
147
+ super().__init__()
148
+
149
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
150
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
151
+ self.relu = nn.ReLU()
152
+ self.bn = nn.BatchNorm2d(outs)
153
+ self.pool = nn.MaxPool2d(2, 2)
154
+
155
+ def forward(self, x):
156
+ x = self.relu(self.conv1(x))
157
+ x = self.relu(self.bn(self.conv2(x)))
158
+ return self.pool(x), x
159
+
160
+ class UNet(nn.Module):
161
+ def __init__(self, in_channels=3, out_channels=1):
162
+ super().__init__()
163
+
164
+ self.en1 = Block(in_channels, 64, 64)
165
+ self.en2 = Block(64, 128, 128)
166
+ self.en3 = Block(128, 256, 256)
167
+ self.en4 = Block(256, 512, 512)
168
+ self.en5 = Block(512, 1024, 512)
169
+
170
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
171
+ self.de4 = Block(1024, 512, 256)
172
+
173
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
174
+ self.de3 = Block(512, 256, 128)
175
+
176
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
177
+ self.de2 = Block(256, 128, 64)
178
+
179
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
180
+ self.de1 = Block(128, 64, 64)
181
+
182
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
183
+
184
+ def forward(self, x):
185
+ x, e1 = self.en1(x)
186
+ x, e2 = self.en2(x)
187
+ x, e3 = self.en3(x)
188
+ x, e4 = self.en4(x)
189
+ _, x = self.en5(x)
190
+
191
+ x = self.upsample4(x)
192
+ x = torch.cat([x, e4], dim=1)
193
+ _, x = self.de4(x)
194
+
195
+ x = self.upsample3(x)
196
+ x = torch.cat([x, e3], dim=1)
197
+ _, x = self.de3(x)
198
+
199
+ x = self.upsample2(x)
200
+ x = torch.cat([x, e2], dim=1)
201
+ _, x = self.de2(x)
202
+
203
+ x = self.upsample1(x)
204
+ x = torch.cat([x, e1], dim=1)
205
+ _, x = self.de1(x)
206
+
207
+ x = self.conv_last(x)
208
+
209
+ return x
210
+
211
+ def dice_loss(y_hat, y):
212
+ smooth = 1e-6
213
+ y_hat = y_hat.view(-1)
214
+ y = y.view(-1)
215
+ intersection = (y_hat * y).sum()
216
+ union = y_hat.sum() + y.sum()
217
+ dice = (2 * intersection + smooth) / (union + smooth)
218
+ return 1 - dice
src/efficientnetb0_model.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model_efficientnetb0(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model_efficientnetb0, self).__init__()
13
+
14
+ # Use EfficientNetB0 pre-trained model as encoder
15
+ self.model = smp.Unet(
16
+ encoder_name='efficientnet-b0',
17
+ encoder_weights=encoder_weights,
18
+ in_channels=in_channels, # The number of input channels, which is 14
19
+ classes=num_classes, # Output classes, which is 1
20
+ )
21
+
22
+ def forward(self, x):
23
+ return self.model(x)
24
+
25
+ class LandslideModel(pl.LightningModule):
26
+ def __init__(self, config, alpha=0.5):
27
+ super(LandslideModel, self).__init__()
28
+
29
+ model_type = config['model_config']['model_type']
30
+ in_channels = config['model_config']['in_channels']
31
+ num_classes = config['model_config']['num_classes']
32
+ self.alpha = alpha # Assign the alpha value to the class variable
33
+ self.lr = config['train_config']['lr']
34
+
35
+ if model_type == 'unet':
36
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
37
+ else:
38
+ encoder_weights = config['model_config']['encoder_weights']
39
+ # Use the custom smp_model_efficientnetb0 instead of smp_model
40
+ self.model = smp_model_efficientnetb0(in_channels=in_channels,
41
+ out_channels=num_classes,
42
+ model_type=model_type,
43
+ num_classes=num_classes,
44
+ encoder_weights=encoder_weights)
45
+
46
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
47
+ self.wce = nn.BCELoss(weight=self.weights)
48
+
49
+ self.train_f1 = F1Score(task='binary')
50
+ self.val_f1 = F1Score(task='binary')
51
+
52
+ self.train_precision = Precision(task='binary')
53
+ self.val_precision = Precision(task='binary')
54
+
55
+ self.train_recall = Recall(task='binary')
56
+ self.val_recall = Recall(task='binary')
57
+
58
+ self.train_iou = JaccardIndex(task='binary')
59
+ self.val_iou = JaccardIndex(task='binary')
60
+
61
+ def forward(self, x):
62
+ return self.model(x)
63
+
64
+ def training_step(self, batch, batch_idx):
65
+ x, y = batch
66
+ y_hat = torch.sigmoid(self(x))
67
+
68
+ wce_loss = self.wce(y_hat, y)
69
+ dice = dice_loss(y_hat, y)
70
+
71
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
72
+
73
+ precision = self.train_precision(y_hat, y)
74
+ recall = self.train_recall(y_hat, y)
75
+ iou = self.train_iou(y_hat, y)
76
+ loss_f1 = self.train_f1(y_hat, y)
77
+
78
+ self.log('train_precision', precision)
79
+ self.log('train_recall', recall)
80
+ self.log('train_wce', wce_loss)
81
+ self.log('train_dice', dice)
82
+ self.log('train_iou', iou)
83
+ self.log('train_f1', loss_f1)
84
+ self.log('train_loss', combined_loss)
85
+ return {'loss': combined_loss}
86
+
87
+ def validation_step(self, batch, batch_idx):
88
+ x, y = batch
89
+ y_hat = torch.sigmoid(self(x))
90
+
91
+ wce_loss = self.wce(y_hat, y)
92
+ dice = dice_loss(y_hat, y)
93
+
94
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
95
+
96
+ precision = self.val_precision(y_hat, y)
97
+ recall = self.val_recall(y_hat, y)
98
+ iou = self.val_iou(y_hat, y)
99
+ loss_f1 = self.val_f1(y_hat, y)
100
+
101
+ self.log('val_precision', precision)
102
+ self.log('val_recall', recall)
103
+ self.log('val_wce', wce_loss)
104
+ self.log('val_dice', dice)
105
+ self.log('val_iou', iou)
106
+ self.log('val_f1', loss_f1)
107
+ self.log('val_loss', combined_loss)
108
+
109
+ if self.current_epoch % 10 == 0:
110
+ x = (x - x.min()) / (x.max() - x.min())
111
+ x = x[:, 0:3]
112
+ x = x.permute(0, 2, 3, 1)
113
+ y_hat = (y_hat > 0.5).float()
114
+
115
+ class_labels = {0: "no landslide", 1: "landslide"} # Define class_labels here
116
+
117
+ self.logger.experiment.log({
118
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
119
+ "predictions": {
120
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
121
+ "class_labels": class_labels
122
+ },
123
+ "ground_truth": {
124
+ "mask_data": y[0][0].cpu().detach().numpy(),
125
+ "class_labels": class_labels
126
+ }
127
+ })
128
+ })
129
+ return {'val_loss': combined_loss}
130
+
131
+ def configure_optimizers(self):
132
+ optimizer = Adam(self.parameters(), lr=self.lr)
133
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
134
+ return [optimizer], [scheduler]
135
+
136
+ class Block(nn.Module):
137
+ def __init__(self, inputs=3, middles=64, outs=64):
138
+ super().__init__()
139
+
140
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
141
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
142
+ self.relu = nn.ReLU()
143
+ self.bn = nn.BatchNorm2d(outs)
144
+ self.pool = nn.MaxPool2d(2, 2)
145
+
146
+ def forward(self, x):
147
+ x = self.relu(self.conv1(x))
148
+ x = self.relu(self.bn(self.conv2(x)))
149
+ return self.pool(x), x
150
+
151
+ class UNet(nn.Module):
152
+ def __init__(self, in_channels=3, out_channels=1):
153
+ super().__init__()
154
+
155
+ self.en1 = Block(in_channels, 64, 64)
156
+ self.en2 = Block(64, 128, 128)
157
+ self.en3 = Block(128, 256, 256)
158
+ self.en4 = Block(256, 512, 512)
159
+ self.en5 = Block(512, 1024, 512)
160
+
161
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
162
+ self.de4 = Block(1024, 512, 256)
163
+
164
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
165
+ self.de3 = Block(512, 256, 128)
166
+
167
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
168
+ self.de2 = Block(256, 128, 64)
169
+
170
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
171
+ self.de1 = Block(128, 64, 64)
172
+
173
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
174
+
175
+ def forward(self, x):
176
+ x, e1 = self.en1(x)
177
+ x, e2 = self.en2(x)
178
+ x, e3 = self.en3(x)
179
+ x, e4 = self.en4(x)
180
+ _, x = self.en5(x)
181
+
182
+ x = self.upsample4(x)
183
+ x = torch.cat([x, e4], dim=1)
184
+ _, x = self.de4(x)
185
+
186
+ x = self.upsample3(x)
187
+ x = torch.cat([x, e3], dim=1)
188
+ _, x = self.de3(x)
189
+
190
+ x = self.upsample2(x)
191
+ x = torch.cat([x, e2], dim=1)
192
+ _, x = self.de2(x)
193
+
194
+ x = self.upsample1(x)
195
+ x = torch.cat([x, e1], dim=1)
196
+ _, x = self.de1(x)
197
+
198
+ x = self.conv_last(x)
199
+
200
+ return x
201
+
202
+ def dice_loss(y_hat, y):
203
+ smooth = 1e-6
204
+ y_hat = y_hat.view(-1)
205
+ y = y.view(-1)
206
+ intersection = (y_hat * y).sum()
207
+ union = y_hat.sum() + y.sum()
208
+ dice = (2 * intersection + smooth) / (union + smooth)
209
+ return 1 - dice
src/inceptionresnetv2_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model, self).__init__()
13
+ self.model = smp.Unet(
14
+ encoder_name=model_type,
15
+ encoder_weights=encoder_weights,
16
+ in_channels=in_channels,
17
+ classes=num_classes,
18
+ )
19
+
20
+ def forward(self, x):
21
+ x = self.model(x)
22
+ return x
23
+
24
+ class LandslideModel(pl.LightningModule):
25
+ def __init__(self, config, alpha=0.5):
26
+ super(LandslideModel, self).__init__()
27
+
28
+ model_type = config['model_config']['model_type']
29
+ in_channels = config['model_config']['in_channels']
30
+ num_classes = config['model_config']['num_classes']
31
+ self.alpha = alpha
32
+ self.lr = config['train_config']['lr']
33
+
34
+ if model_type == 'unet':
35
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
36
+ else:
37
+ encoder_weights = config['model_config']['encoder_weights']
38
+ self.model = smp_model(in_channels=in_channels,
39
+ out_channels=num_classes,
40
+ model_type=model_type,
41
+ num_classes=num_classes,
42
+ encoder_weights=encoder_weights)
43
+
44
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
45
+ self.wce = nn.BCELoss(weight=self.weights)
46
+
47
+ self.train_f1 = F1Score(task='binary')
48
+ self.val_f1 = F1Score(task='binary')
49
+
50
+ self.train_precision = Precision(task='binary')
51
+ self.val_precision = Precision(task='binary')
52
+
53
+ self.train_recall = Recall(task='binary')
54
+ self.val_recall = Recall(task='binary')
55
+
56
+ self.train_iou = JaccardIndex(task='binary')
57
+ self.val_iou = JaccardIndex(task='binary')
58
+
59
+ def forward(self, x):
60
+ return self.model(x)
61
+
62
+ def training_step(self, batch, batch_idx):
63
+ x, y = batch
64
+ y_hat = torch.sigmoid(self(x))
65
+
66
+ wce_loss = self.wce(y_hat, y)
67
+ dice = dice_loss(y_hat, y)
68
+
69
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
70
+
71
+ precision = self.train_precision(y_hat, y)
72
+ recall = self.train_recall(y_hat, y)
73
+ iou = self.train_iou(y_hat, y)
74
+ loss_f1 = self.train_f1(y_hat, y)
75
+
76
+ self.log('train_precision', precision)
77
+ self.log('train_recall', recall)
78
+ self.log('train_wce', wce_loss)
79
+ self.log('train_dice', dice)
80
+ self.log('train_iou', iou)
81
+ self.log('train_f1', loss_f1)
82
+ self.log('train_loss', combined_loss)
83
+ return {'loss': combined_loss}
84
+
85
+ def validation_step(self, batch, batch_idx):
86
+ x, y = batch
87
+ y_hat = torch.sigmoid(self(x))
88
+
89
+ wce_loss = self.wce(y_hat, y)
90
+ dice = dice_loss(y_hat, y)
91
+
92
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
93
+
94
+ precision = self.val_precision(y_hat, y)
95
+ recall = self.val_recall(y_hat, y)
96
+ iou = self.val_iou(y_hat, y)
97
+ loss_f1 = self.val_f1(y_hat, y)
98
+
99
+ self.log('val_precision', precision)
100
+ self.log('val_recall', recall)
101
+ self.log('val_wce', wce_loss)
102
+ self.log('val_dice', dice)
103
+ self.log('val_iou', iou)
104
+ self.log('val_f1', loss_f1)
105
+ self.log('val_loss', combined_loss)
106
+
107
+ if self.current_epoch % 10 == 0:
108
+ x = (x - x.min()) / (x.max() - x.min())
109
+ x = x[:, 0:3]
110
+ x = x.permute(0, 2, 3, 1)
111
+ y_hat = (y_hat > 0.5).float()
112
+
113
+ class_labels = {0: "no landslide", 1: "landslide"}
114
+
115
+ self.logger.experiment.log({
116
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
117
+ "predictions": {
118
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
119
+ "class_labels": class_labels
120
+ },
121
+ "ground_truth": {
122
+ "mask_data": y[0][0].cpu().detach().numpy(),
123
+ "class_labels": class_labels
124
+ }
125
+ })
126
+ })
127
+ return {'val_loss': combined_loss}
128
+
129
+ def configure_optimizers(self):
130
+ optimizer = Adam(self.parameters(), lr=self.lr)
131
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
132
+ return [optimizer], [scheduler]
133
+
134
+ class Block(nn.Module):
135
+ def __init__(self, inputs=3, middles=64, outs=64):
136
+ super().__init__()
137
+
138
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
139
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
140
+ self.relu = nn.ReLU()
141
+ self.bn = nn.BatchNorm2d(outs)
142
+ self.pool = nn.MaxPool2d(2, 2)
143
+
144
+ def forward(self, x):
145
+ x = self.relu(self.conv1(x))
146
+ x = self.relu(self.bn(self.conv2(x)))
147
+ return self.pool(x), x
148
+
149
+ class UNet(nn.Module):
150
+ def __init__(self, in_channels=3, out_channels=1):
151
+ super().__init__()
152
+
153
+ self.en1 = Block(in_channels, 64, 64)
154
+ self.en2 = Block(64, 128, 128)
155
+ self.en3 = Block(128, 256, 256)
156
+ self.en4 = Block(256, 512, 512)
157
+ self.en5 = Block(512, 1024, 512)
158
+
159
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
160
+ self.de4 = Block(1024, 512, 256)
161
+
162
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
163
+ self.de3 = Block(512, 256, 128)
164
+
165
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
166
+ self.de2 = Block(256, 128, 64)
167
+
168
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
169
+ self.de1 = Block(128, 64, 64)
170
+
171
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
172
+
173
+ def forward(self, x):
174
+ x, e1 = self.en1(x)
175
+ x, e2 = self.en2(x)
176
+ x, e3 = self.en3(x)
177
+ x, e4 = self.en4(x)
178
+ _, x = self.en5(x)
179
+
180
+ x = self.upsample4(x)
181
+ x = torch.cat([x, e4], dim=1)
182
+ _, x = self.de4(x)
183
+
184
+ x = self.upsample3(x)
185
+ x = torch.cat([x, e3], dim=1)
186
+ _, x = self.de3(x)
187
+
188
+ x = self.upsample2(x)
189
+ x = torch.cat([x, e2], dim=1)
190
+ _, x = self.de2(x)
191
+
192
+ x = self.upsample1(x)
193
+ x = torch.cat([x, e1], dim=1)
194
+ _, x = self.de1(x)
195
+
196
+ x = self.conv_last(x)
197
+
198
+ return x
199
+
200
+ def dice_loss(y_hat, y):
201
+ smooth = 1e-6
202
+ y_hat = y_hat.view(-1)
203
+ y = y.view(-1)
204
+ intersection = (y_hat * y).sum()
205
+ union = y_hat.sum() + y.sum()
206
+ dice = (2 * intersection + smooth) / (union + smooth)
207
+ return 1 - dice
src/inceptionv4_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model, self).__init__()
13
+ self.model = smp.Unet(
14
+ encoder_name="inceptionv4", # Corrected to string "inceptionv4"
15
+ encoder_weights="imagenet", # Corrected to string "imagenet"
16
+ in_channels=in_channels, # Use the original in_channels
17
+ classes=num_classes,
18
+ )
19
+
20
+ def forward(self, x):
21
+ x = self.model(x)
22
+ return x
23
+
24
+ class LandslideModel(pl.LightningModule):
25
+ def __init__(self, config, alpha=0.5):
26
+ super(LandslideModel, self).__init__()
27
+
28
+ model_type = config['model_config']['model_type']
29
+ in_channels = config['model_config']['in_channels']
30
+ num_classes = config['model_config']['num_classes']
31
+ self.alpha = alpha
32
+ self.lr = config['train_config']['lr']
33
+
34
+ if model_type == 'unet':
35
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
36
+ else:
37
+ encoder_weights = config['model_config']['encoder_weights']
38
+ self.model = smp_model(in_channels=in_channels,
39
+ out_channels=num_classes,
40
+ model_type=model_type,
41
+ num_classes=num_classes,
42
+ encoder_weights=encoder_weights)
43
+
44
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
45
+ self.wce = nn.BCELoss(weight=self.weights)
46
+
47
+ self.train_f1 = F1Score(task='binary')
48
+ self.val_f1 = F1Score(task='binary')
49
+
50
+ self.train_precision = Precision(task='binary')
51
+ self.val_precision = Precision(task='binary')
52
+
53
+ self.train_recall = Recall(task='binary')
54
+ self.val_recall = Recall(task='binary')
55
+
56
+ self.train_iou = JaccardIndex(task='binary')
57
+ self.val_iou = JaccardIndex(task='binary')
58
+
59
+ def forward(self, x):
60
+ return self.model(x)
61
+
62
+ def training_step(self, batch, batch_idx):
63
+ x, y = batch
64
+ y_hat = torch.sigmoid(self(x))
65
+
66
+ wce_loss = self.wce(y_hat, y)
67
+ dice = dice_loss(y_hat, y)
68
+
69
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
70
+
71
+ precision = self.train_precision(y_hat, y)
72
+ recall = self.train_recall(y_hat, y)
73
+ iou = self.train_iou(y_hat, y)
74
+ loss_f1 = self.train_f1(y_hat, y)
75
+
76
+ self.log('train_precision', precision)
77
+ self.log('train_recall', recall)
78
+ self.log('train_wce', wce_loss)
79
+ self.log('train_dice', dice)
80
+ self.log('train_iou', iou)
81
+ self.log('train_f1', loss_f1)
82
+ self.log('train_loss', combined_loss)
83
+ return {'loss': combined_loss}
84
+
85
+ def validation_step(self, batch, batch_idx):
86
+ x, y = batch
87
+ y_hat = torch.sigmoid(self(x))
88
+
89
+ wce_loss = self.wce(y_hat, y)
90
+ dice = dice_loss(y_hat, y)
91
+
92
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
93
+
94
+ precision = self.val_precision(y_hat, y)
95
+ recall = self.val_recall(y_hat, y)
96
+ iou = self.val_iou(y_hat, y)
97
+ loss_f1 = self.val_f1(y_hat, y)
98
+
99
+ self.log('val_precision', precision)
100
+ self.log('val_recall', recall)
101
+ self.log('val_wce', wce_loss)
102
+ self.log('val_dice', dice)
103
+ self.log('val_iou', iou)
104
+ self.log('val_f1', loss_f1)
105
+ self.log('val_loss', combined_loss)
106
+
107
+ if self.current_epoch % 10 == 0:
108
+ x = (x - x.min()) / (x.max() - x.min())
109
+ x = x[:, 0:3]
110
+ x = x.permute(0, 2, 3, 1)
111
+ y_hat = (y_hat > 0.5).float()
112
+
113
+ class_labels = {0: "no landslide", 1: "landslide"}
114
+
115
+ self.logger.experiment.log({
116
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
117
+ "predictions": {
118
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
119
+ "class_labels": class_labels
120
+ },
121
+ "ground_truth": {
122
+ "mask_data": y[0][0].cpu().detach().numpy(),
123
+ "class_labels": class_labels
124
+ }
125
+ })
126
+ })
127
+ return {'val_loss': combined_loss}
128
+
129
+ def configure_optimizers(self):
130
+ optimizer = Adam(self.parameters(), lr=self.lr)
131
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
132
+ return [optimizer], [scheduler]
133
+
134
+ class Block(nn.Module):
135
+ def __init__(self, inputs=3, middles=64, outs=64):
136
+ super().__init__()
137
+
138
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
139
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
140
+ self.relu = nn.ReLU()
141
+ self.bn = nn.BatchNorm2d(outs)
142
+ self.pool = nn.MaxPool2d(2, 2)
143
+
144
+ def forward(self, x):
145
+ x = self.relu(self.conv1(x))
146
+ x = self.relu(self.bn(self.conv2(x)))
147
+ return self.pool(x), x
148
+
149
+ class UNet(nn.Module):
150
+ def __init__(self, in_channels=3, out_channels=1):
151
+ super().__init__()
152
+
153
+ self.en1 = Block(in_channels, 64, 64)
154
+ self.en2 = Block(64, 128, 128)
155
+ self.en3 = Block(128, 256, 256)
156
+ self.en4 = Block(256, 512, 512)
157
+ self.en5 = Block(512, 1024, 512)
158
+
159
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
160
+ self.de4 = Block(1024, 512, 256)
161
+
162
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
163
+ self.de3 = Block(512, 256, 128)
164
+
165
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
166
+ self.de2 = Block(256, 128, 64)
167
+
168
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
169
+ self.de1 = Block(128, 64, 64)
170
+
171
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
172
+
173
+ def forward(self, x):
174
+ x, e1 = self.en1(x)
175
+ x, e2 = self.en2(x)
176
+ x, e3 = self.en3(x)
177
+ x, e4 = self.en4(x)
178
+ _, x = self.en5(x)
179
+
180
+ x = self.upsample4(x)
181
+ x = torch.cat([x, e4], dim=1)
182
+ _, x = self.de4(x)
183
+
184
+ x = self.upsample3(x)
185
+ x = torch.cat([x, e3], dim=1)
186
+ _, x = self.de3(x)
187
+
188
+ x = self.upsample2(x)
189
+ x = torch.cat([x, e2], dim=1)
190
+ _, x = self.de2(x)
191
+
192
+ x = self.upsample1(x)
193
+ x = torch.cat([x, e1], dim=1)
194
+ _, x = self.de1(x)
195
+
196
+ x = self.conv_last(x)
197
+
198
+ return x
199
+
200
+ def dice_loss(y_hat, y):
201
+ smooth = 1e-6
202
+ y_hat = y_hat.view(-1)
203
+ y = y.view(-1)
204
+ intersection = (y_hat * y).sum()
205
+ union = y_hat.sum() + y.sum()
206
+ dice = (2 * intersection + smooth) / (union + smooth)
207
+ return 1 - dice
src/mitb1_model.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model_mitb1(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model_mitb1, self).__init__()
13
+ self.conv = nn.Conv2d(in_channels, 3, kernel_size=1)
14
+ self.model = smp.Unet(
15
+ encoder_name="mit_b1", # Corrected to string "mit_b1"
16
+ encoder_weights="imagenet", # Corrected to string "imagenet"
17
+ in_channels=3, # Set in_channels to 3 for MiT-B1
18
+ classes=num_classes,
19
+ )
20
+
21
+ def forward(self, x):
22
+ x = self.conv(x)
23
+ x = self.model(x)
24
+ return x
25
+
26
+ class LandslideModel(pl.LightningModule):
27
+ def __init__(self, config, alpha=0.5):
28
+ super(LandslideModel, self).__init__()
29
+
30
+ model_type = config['model_config']['model_type']
31
+ in_channels = config['model_config']['in_channels']
32
+ num_classes = config['model_config']['num_classes']
33
+ self.alpha = alpha
34
+ self.lr = config['train_config']['lr']
35
+
36
+ if model_type == 'unet':
37
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
38
+ else:
39
+ encoder_weights = config['model_config']['encoder_weights']
40
+ self.model = smp_model_mitb1(in_channels=in_channels,
41
+ out_channels=num_classes,
42
+ model_type=model_type,
43
+ num_classes=num_classes,
44
+ encoder_weights=encoder_weights)
45
+
46
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
47
+ self.wce = nn.BCELoss(weight=self.weights)
48
+
49
+ self.train_f1 = F1Score(task='binary')
50
+ self.val_f1 = F1Score(task='binary')
51
+
52
+ self.train_precision = Precision(task='binary')
53
+ self.val_precision = Precision(task='binary')
54
+
55
+ self.train_recall = Recall(task='binary')
56
+ self.val_recall = Recall(task='binary')
57
+
58
+ self.train_iou = JaccardIndex(task='binary')
59
+ self.val_iou = JaccardIndex(task='binary')
60
+
61
+ def forward(self, x):
62
+ return self.model(x)
63
+
64
+ def training_step(self, batch, batch_idx):
65
+ x, y = batch
66
+ y_hat = torch.sigmoid(self(x))
67
+
68
+ wce_loss = self.wce(y_hat, y)
69
+ dice = dice_loss(y_hat, y)
70
+
71
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
72
+
73
+ precision = self.train_precision(y_hat, y)
74
+ recall = self.train_recall(y_hat, y)
75
+ iou = self.train_iou(y_hat, y)
76
+ loss_f1 = self.train_f1(y_hat, y)
77
+
78
+ self.log('train_precision', precision)
79
+ self.log('train_recall', recall)
80
+ self.log('train_wce', wce_loss)
81
+ self.log('train_dice', dice)
82
+ self.log('train_iou', iou)
83
+ self.log('train_f1', loss_f1)
84
+ self.log('train_loss', combined_loss)
85
+ return {'loss': combined_loss}
86
+
87
+ def validation_step(self, batch, batch_idx):
88
+ x, y = batch
89
+ y_hat = torch.sigmoid(self(x))
90
+
91
+ wce_loss = self.wce(y_hat, y)
92
+ dice = dice_loss(y_hat, y)
93
+
94
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
95
+
96
+ precision = self.val_precision(y_hat, y)
97
+ recall = self.val_recall(y_hat, y)
98
+ iou = self.val_iou(y_hat, y)
99
+ loss_f1 = self.val_f1(y_hat, y)
100
+
101
+ self.log('val_precision', precision)
102
+ self.log('val_recall', recall)
103
+ self.log('val_wce', wce_loss)
104
+ self.log('val_dice', dice)
105
+ self.log('val_iou', iou)
106
+ self.log('val_f1', loss_f1)
107
+ self.log('val_loss', combined_loss)
108
+
109
+ if self.current_epoch % 10 == 0:
110
+ x = (x - x.min()) / (x.max() - x.min())
111
+ x = x[:, 0:3]
112
+ x = x.permute(0, 2, 3, 1)
113
+ y_hat = (y_hat > 0.5).float()
114
+
115
+ class_labels = {0: "no landslide", 1: "landslide"}
116
+
117
+ self.logger.experiment.log({
118
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
119
+ "predictions": {
120
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
121
+ "class_labels": class_labels
122
+ },
123
+ "ground_truth": {
124
+ "mask_data": y[0][0].cpu().detach().numpy(),
125
+ "class_labels": class_labels
126
+ }
127
+ })
128
+ })
129
+ return {'val_loss': combined_loss}
130
+
131
+ def configure_optimizers(self):
132
+ optimizer = Adam(self.parameters(), lr=self.lr)
133
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
134
+ return [optimizer], [scheduler]
135
+
136
+ class Block(nn.Module):
137
+ def __init__(self, inputs=3, middles=64, outs=64):
138
+ super().__init__()
139
+
140
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
141
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
142
+ self.relu = nn.ReLU()
143
+ self.bn = nn.BatchNorm2d(outs)
144
+ self.pool = nn.MaxPool2d(2, 2)
145
+
146
+ def forward(self, x):
147
+ x = self.relu(self.conv1(x))
148
+ x = self.relu(self.bn(self.conv2(x)))
149
+ return self.pool(x), x
150
+
151
+ class UNet(nn.Module):
152
+ def __init__(self, in_channels=3, out_channels=1):
153
+ super().__init__()
154
+
155
+ self.en1 = Block(in_channels, 64, 64)
156
+ self.en2 = Block(64, 128, 128)
157
+ self.en3 = Block(128, 256, 256)
158
+ self.en4 = Block(256, 512, 512)
159
+ self.en5 = Block(512, 1024, 512)
160
+
161
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
162
+ self.de4 = Block(1024, 512, 256)
163
+
164
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
165
+ self.de3 = Block(512, 256, 128)
166
+
167
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
168
+ self.de2 = Block(256, 128, 64)
169
+
170
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
171
+ self.de1 = Block(128, 64, 64)
172
+
173
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
174
+
175
+ def forward(self, x):
176
+ x, e1 = self.en1(x)
177
+ x, e2 = self.en2(x)
178
+ x, e3 = self.en3(x)
179
+ x, e4 = self.en4(x)
180
+ _, x = self.en5(x)
181
+
182
+ x = self.upsample4(x)
183
+ x = torch.cat([x, e4], dim=1)
184
+ _, x = self.de4(x)
185
+
186
+ x = self.upsample3(x)
187
+ x = torch.cat([x, e3], dim=1)
188
+ _, x = self.de3(x)
189
+
190
+ x = self.upsample2(x)
191
+ x = torch.cat([x, e2], dim=1)
192
+ _, x = self.de2(x)
193
+
194
+ x = self.upsample1(x)
195
+ x = torch.cat([x, e1], dim=1)
196
+ _, x = self.de1(x)
197
+
198
+ x = self.conv_last(x)
199
+
200
+ return x
201
+
202
+ def dice_loss(y_hat, y):
203
+ smooth = 1e-6
204
+ y_hat = y_hat.view(-1)
205
+ y = y.view(-1)
206
+ intersection = (y_hat * y).sum()
207
+ union = y_hat.sum() + y.sum()
208
+ dice = (2 * intersection + smooth) / (union + smooth)
209
+ return 1 - dice
src/mobilenetv2_model.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model_v2(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model_v2, self).__init__()
13
+
14
+ # Use MobileNetV2 pre-trained model as encoder
15
+ self.model = smp.Unet(
16
+ encoder_name='mobilenet_v2',
17
+ encoder_weights=encoder_weights,
18
+ in_channels=in_channels, # The number of input channels, which is 14
19
+ classes=num_classes, # Output classes, which is 1
20
+ )
21
+
22
+ def forward(self, x):
23
+ return self.model(x)
24
+
25
+ class LandslideModel(pl.LightningModule):
26
+ def __init__(self, config, alpha=0.5):
27
+ super(LandslideModel, self).__init__()
28
+
29
+ model_type = config['model_config']['model_type']
30
+ in_channels = config['model_config']['in_channels']
31
+ num_classes = config['model_config']['num_classes']
32
+ self.alpha = alpha # Assign the alpha value to the class variable
33
+ self.lr = config['train_config']['lr']
34
+
35
+ if model_type == 'unet':
36
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
37
+ else:
38
+ encoder_weights = config['model_config']['encoder_weights']
39
+ # Use the custom smp_model_v2 instead of smp_model
40
+ self.model = smp_model_v2(in_channels=in_channels,
41
+ out_channels=num_classes,
42
+ model_type=model_type,
43
+ num_classes=num_classes,
44
+ encoder_weights=encoder_weights)
45
+
46
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
47
+ self.wce = nn.BCELoss(weight=self.weights)
48
+
49
+ self.train_f1 = F1Score(task='binary')
50
+ self.val_f1 = F1Score(task='binary')
51
+
52
+ self.train_precision = Precision(task='binary')
53
+ self.val_precision = Precision(task='binary')
54
+
55
+ self.train_recall = Recall(task='binary')
56
+ self.val_recall = Recall(task='binary')
57
+
58
+ self.train_iou = JaccardIndex(task='binary')
59
+ self.val_iou = JaccardIndex(task='binary')
60
+
61
+ def forward(self, x):
62
+ return self.model(x)
63
+
64
+ def training_step(self, batch, batch_idx):
65
+ x, y = batch
66
+ y_hat = torch.sigmoid(self(x))
67
+
68
+ wce_loss = self.wce(y_hat, y)
69
+ dice = dice_loss(y_hat, y)
70
+
71
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
72
+
73
+ precision = self.train_precision(y_hat, y)
74
+ recall = self.train_recall(y_hat, y)
75
+ iou = self.train_iou(y_hat, y)
76
+ loss_f1 = self.train_f1(y_hat, y)
77
+
78
+ self.log('train_precision', precision)
79
+ self.log('train_recall', recall)
80
+ self.log('train_wce', wce_loss)
81
+ self.log('train_dice', dice)
82
+ self.log('train_iou', iou)
83
+ self.log('train_f1', loss_f1)
84
+ self.log('train_loss', combined_loss)
85
+ return {'loss': combined_loss}
86
+
87
+ def validation_step(self, batch, batch_idx):
88
+ x, y = batch
89
+ y_hat = torch.sigmoid(self(x))
90
+
91
+ wce_loss = self.wce(y_hat, y)
92
+ dice = dice_loss(y_hat, y)
93
+
94
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
95
+
96
+ precision = self.val_precision(y_hat, y)
97
+ recall = self.val_recall(y_hat, y)
98
+ iou = self.val_iou(y_hat, y)
99
+ loss_f1 = self.val_f1(y_hat, y)
100
+
101
+ self.log('val_precision', precision)
102
+ self.log('val_recall', recall)
103
+ self.log('val_wce', wce_loss)
104
+ self.log('val_dice', dice)
105
+ self.log('val_iou', iou)
106
+ self.log('val_f1', loss_f1)
107
+ self.log('val_loss', combined_loss)
108
+
109
+ if self.current_epoch % 10 == 0:
110
+ x = (x - x.min()) / (x.max() - x.min())
111
+ x = x[:, 0:3]
112
+ x = x.permute(0, 2, 3, 1)
113
+ y_hat = (y_hat > 0.5).float()
114
+
115
+ class_labels = {0: "no landslide", 1: "landslide"} # Define class_labels here
116
+
117
+ self.logger.experiment.log({
118
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
119
+ "predictions": {
120
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
121
+ "class_labels": class_labels
122
+ },
123
+ "ground_truth": {
124
+ "mask_data": y[0][0].cpu().detach().numpy(),
125
+ "class_labels": class_labels
126
+ }
127
+ })
128
+ })
129
+ return {'val_loss': combined_loss}
130
+
131
+ def configure_optimizers(self):
132
+ optimizer = Adam(self.parameters(), lr=self.lr)
133
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
134
+ return [optimizer], [scheduler]
135
+
136
+ class Block(nn.Module):
137
+ def __init__(self, inputs=3, middles=64, outs=64):
138
+ super().__init__()
139
+
140
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
141
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
142
+ self.relu = nn.ReLU()
143
+ self.bn = nn.BatchNorm2d(outs)
144
+ self.pool = nn.MaxPool2d(2, 2)
145
+
146
+ def forward(self, x):
147
+ x = self.relu(self.conv1(x))
148
+ x = self.relu(self.bn(self.conv2(x)))
149
+ return self.pool(x), x
150
+
151
+ class UNet(nn.Module):
152
+ def __init__(self, in_channels=3, out_channels=1):
153
+ super().__init__()
154
+
155
+ self.en1 = Block(in_channels, 64, 64)
156
+ self.en2 = Block(64, 128, 128)
157
+ self.en3 = Block(128, 256, 256)
158
+ self.en4 = Block(256, 512, 512)
159
+ self.en5 = Block(512, 1024, 512)
160
+
161
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
162
+ self.de4 = Block(1024, 512, 256)
163
+
164
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
165
+ self.de3 = Block(512, 256, 128)
166
+
167
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
168
+ self.de2 = Block(256, 128, 64)
169
+
170
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
171
+ self.de1 = Block(128, 64, 64)
172
+
173
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
174
+
175
+ def forward(self, x):
176
+ x, e1 = self.en1(x)
177
+ x, e2 = self.en2(x)
178
+ x, e3 = self.en3(x)
179
+ x, e4 = self.en4(x)
180
+ _, x = self.en5(x)
181
+
182
+ x = self.upsample4(x)
183
+ x = torch.cat([x, e4], dim=1)
184
+ _, x = self.de4(x)
185
+
186
+ x = self.upsample3(x)
187
+ x = torch.cat([x, e3], dim=1)
188
+ _, x = self.de3(x)
189
+
190
+ x = self.upsample2(x)
191
+ x = torch.cat([x, e2], dim=1)
192
+ _, x = self.de2(x)
193
+
194
+ x = self.upsample1(x)
195
+ x = torch.cat([x, e1], dim=1)
196
+ _, x = self.de1(x)
197
+
198
+ x = self.conv_last(x)
199
+
200
+ return x
201
+
202
+ def dice_loss(y_hat, y):
203
+ smooth = 1e-6
204
+ y_hat = y_hat.view(-1)
205
+ y = y.view(-1)
206
+ intersection = (y_hat * y).sum()
207
+ union = y_hat.sum() + y.sum()
208
+ dice = (2 * intersection + smooth) / (union + smooth)
209
+ return 1 - dice
src/resnet34_model.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model_resnet34(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model_resnet34, self).__init__()
13
+
14
+ # Use ResNet34 pre-trained model as encoder
15
+ self.model = smp.Unet(
16
+ encoder_name='resnet34',
17
+ encoder_weights=encoder_weights,
18
+ in_channels=in_channels, # The number of input channels, which is 14
19
+ classes=num_classes, # Output classes, which is 1
20
+ )
21
+
22
+ def forward(self, x):
23
+ return self.model(x)
24
+
25
+ class LandslideModel(pl.LightningModule):
26
+ def __init__(self, config, alpha=0.5):
27
+ super(LandslideModel, self).__init__()
28
+
29
+ model_type = config['model_config']['model_type']
30
+ in_channels = config['model_config']['in_channels']
31
+ num_classes = config['model_config']['num_classes']
32
+ self.alpha = alpha # Assign the alpha value to the class variable
33
+ self.lr = config['train_config']['lr']
34
+
35
+ if model_type == 'unet':
36
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
37
+ else:
38
+ encoder_weights = config['model_config']['encoder_weights']
39
+ # Use the custom smp_model_resnet34 instead of smp_model
40
+ self.model = smp_model_resnet34(in_channels=in_channels,
41
+ out_channels=num_classes,
42
+ model_type=model_type,
43
+ num_classes=num_classes,
44
+ encoder_weights=encoder_weights)
45
+
46
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
47
+ self.wce = nn.BCELoss(weight=self.weights)
48
+
49
+ self.train_f1 = F1Score(task='binary')
50
+ self.val_f1 = F1Score(task='binary')
51
+
52
+ self.train_precision = Precision(task='binary')
53
+ self.val_precision = Precision(task='binary')
54
+
55
+ self.train_recall = Recall(task='binary')
56
+ self.val_recall = Recall(task='binary')
57
+
58
+ self.train_iou = JaccardIndex(task='binary')
59
+ self.val_iou = JaccardIndex(task='binary')
60
+
61
+ def forward(self, x):
62
+ return self.model(x)
63
+
64
+ def training_step(self, batch, batch_idx):
65
+ x, y = batch
66
+ y_hat = torch.sigmoid(self(x))
67
+
68
+ wce_loss = self.wce(y_hat, y)
69
+ dice = dice_loss(y_hat, y)
70
+
71
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
72
+
73
+ precision = self.train_precision(y_hat, y)
74
+ recall = self.train_recall(y_hat, y)
75
+ iou = self.train_iou(y_hat, y)
76
+ loss_f1 = self.train_f1(y_hat, y)
77
+
78
+ self.log('train_precision', precision)
79
+ self.log('train_recall', recall)
80
+ self.log('train_wce', wce_loss)
81
+ self.log('train_dice', dice)
82
+ self.log('train_iou', iou)
83
+ self.log('train_f1', loss_f1)
84
+ self.log('train_loss', combined_loss)
85
+ return {'loss': combined_loss}
86
+
87
+ def validation_step(self, batch, batch_idx):
88
+ x, y = batch
89
+ y_hat = torch.sigmoid(self(x))
90
+
91
+ wce_loss = self.wce(y_hat, y)
92
+ dice = dice_loss(y_hat, y)
93
+
94
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
95
+
96
+ precision = self.val_precision(y_hat, y)
97
+ recall = self.val_recall(y_hat, y)
98
+ iou = self.val_iou(y_hat, y)
99
+ loss_f1 = self.val_f1(y_hat, y)
100
+
101
+ self.log('val_precision', precision)
102
+ self.log('val_recall', recall)
103
+ self.log('val_wce', wce_loss)
104
+ self.log('val_dice', dice)
105
+ self.log('val_iou', iou)
106
+ self.log('val_f1', loss_f1)
107
+ self.log('val_loss', combined_loss)
108
+
109
+ if self.current_epoch % 10 == 0:
110
+ x = (x - x.min()) / (x.max() - x.min())
111
+ x = x[:, 0:3]
112
+ x = x.permute(0, 2, 3, 1)
113
+ y_hat = (y_hat > 0.5).float()
114
+
115
+ class_labels = {0: "no landslide", 1: "landslide"} # Define class_labels here
116
+
117
+ self.logger.experiment.log({
118
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
119
+ "predictions": {
120
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
121
+ "class_labels": class_labels
122
+ },
123
+ "ground_truth": {
124
+ "mask_data": y[0][0].cpu().detach().numpy(),
125
+ "class_labels": class_labels
126
+ }
127
+ })
128
+ })
129
+ return {'val_loss': combined_loss}
130
+
131
+ def configure_optimizers(self):
132
+ optimizer = Adam(self.parameters(), lr=self.lr)
133
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
134
+ return [optimizer], [scheduler]
135
+
136
+ class Block(nn.Module):
137
+ def __init__(self, inputs=3, middles=64, outs=64):
138
+ super().__init__()
139
+
140
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
141
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
142
+ self.relu = nn.ReLU()
143
+ self.bn = nn.BatchNorm2d(outs)
144
+ self.pool = nn.MaxPool2d(2, 2)
145
+
146
+ def forward(self, x):
147
+ x = self.relu(self.conv1(x))
148
+ x = self.relu(self.bn(self.conv2(x)))
149
+ return self.pool(x), x
150
+
151
+ class UNet(nn.Module):
152
+ def __init__(self, in_channels=3, out_channels=1):
153
+ super().__init__()
154
+
155
+ self.en1 = Block(in_channels, 64, 64)
156
+ self.en2 = Block(64, 128, 128)
157
+ self.en3 = Block(128, 256, 256)
158
+ self.en4 = Block(256, 512, 512)
159
+ self.en5 = Block(512, 1024, 512)
160
+
161
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
162
+ self.de4 = Block(1024, 512, 256)
163
+
164
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
165
+ self.de3 = Block(512, 256, 128)
166
+
167
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
168
+ self.de2 = Block(256, 128, 64)
169
+
170
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
171
+ self.de1 = Block(128, 64, 64)
172
+
173
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
174
+
175
+ def forward(self, x):
176
+ x, e1 = self.en1(x)
177
+ x, e2 = self.en2(x)
178
+ x, e3 = self.en3(x)
179
+ x, e4 = self.en4(x)
180
+ _, x = self.en5(x)
181
+
182
+ x = self.upsample4(x)
183
+ x = torch.cat([x, e4], dim=1)
184
+ _, x = self.de4(x)
185
+
186
+ x = self.upsample3(x)
187
+ x = torch.cat([x, e3], dim=1)
188
+ _, x = self.de3(x)
189
+
190
+ x = self.upsample2(x)
191
+ x = torch.cat([x, e2], dim=1)
192
+ _, x = self.de2(x)
193
+
194
+ x = self.upsample1(x)
195
+ x = torch.cat([x, e1], dim=1)
196
+ _, x = self.de1(x)
197
+
198
+ x = self.conv_last(x)
199
+
200
+ return x
201
+
202
+ def dice_loss(y_hat, y):
203
+ smooth = 1e-6
204
+ y_hat = y_hat.view(-1)
205
+ y = y.view(-1)
206
+ intersection = (y_hat * y).sum()
207
+ union = y_hat.sum() + y.sum()
208
+ dice = (2 * intersection + smooth) / (union + smooth)
209
+ return 1 - dice
src/resnext50_32x4d_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model, self).__init__()
13
+ self.model = smp.Unet(
14
+ encoder_name=model_type,
15
+ encoder_weights=encoder_weights,
16
+ in_channels=in_channels, # Use the original in_channels
17
+ classes=num_classes,
18
+ )
19
+
20
+ def forward(self, x):
21
+ x = self.model(x)
22
+ return x
23
+
24
+ class LandslideModel(pl.LightningModule):
25
+ def __init__(self, config, alpha=0.5):
26
+ super(LandslideModel, self).__init__()
27
+
28
+ model_type = config['model_config']['model_type']
29
+ in_channels = config['model_config']['in_channels']
30
+ num_classes = config['model_config']['num_classes']
31
+ self.alpha = alpha
32
+ self.lr = config['train_config']['lr']
33
+
34
+ if model_type == 'unet':
35
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
36
+ else:
37
+ encoder_weights = config['model_config']['encoder_weights']
38
+ self.model = smp_model(in_channels=in_channels,
39
+ out_channels=num_classes,
40
+ model_type=model_type,
41
+ num_classes=num_classes,
42
+ encoder_weights=encoder_weights)
43
+
44
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
45
+ self.wce = nn.BCELoss(weight=self.weights)
46
+
47
+ self.train_f1 = F1Score(task='binary')
48
+ self.val_f1 = F1Score(task='binary')
49
+
50
+ self.train_precision = Precision(task='binary')
51
+ self.val_precision = Precision(task='binary')
52
+
53
+ self.train_recall = Recall(task='binary')
54
+ self.val_recall = Recall(task='binary')
55
+
56
+ self.train_iou = JaccardIndex(task='binary')
57
+ self.val_iou = JaccardIndex(task='binary')
58
+
59
+ def forward(self, x):
60
+ return self.model(x)
61
+
62
+ def training_step(self, batch, batch_idx):
63
+ x, y = batch
64
+ y_hat = torch.sigmoid(self(x))
65
+
66
+ wce_loss = self.wce(y_hat, y)
67
+ dice = dice_loss(y_hat, y)
68
+
69
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
70
+
71
+ precision = self.train_precision(y_hat, y)
72
+ recall = self.train_recall(y_hat, y)
73
+ iou = self.train_iou(y_hat, y)
74
+ loss_f1 = self.train_f1(y_hat, y)
75
+
76
+ self.log('train_precision', precision)
77
+ self.log('train_recall', recall)
78
+ self.log('train_wce', wce_loss)
79
+ self.log('train_dice', dice)
80
+ self.log('train_iou', iou)
81
+ self.log('train_f1', loss_f1)
82
+ self.log('train_loss', combined_loss)
83
+ return {'loss': combined_loss}
84
+
85
+ def validation_step(self, batch, batch_idx):
86
+ x, y = batch
87
+ y_hat = torch.sigmoid(self(x))
88
+
89
+ wce_loss = self.wce(y_hat, y)
90
+ dice = dice_loss(y_hat, y)
91
+
92
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
93
+
94
+ precision = self.val_precision(y_hat, y)
95
+ recall = self.val_recall(y_hat, y)
96
+ iou = self.val_iou(y_hat, y)
97
+ loss_f1 = self.val_f1(y_hat, y)
98
+
99
+ self.log('val_precision', precision)
100
+ self.log('val_recall', recall)
101
+ self.log('val_wce', wce_loss)
102
+ self.log('val_dice', dice)
103
+ self.log('val_iou', iou)
104
+ self.log('val_f1', loss_f1)
105
+ self.log('val_loss', combined_loss)
106
+
107
+ if self.current_epoch % 10 == 0:
108
+ x = (x - x.min()) / (x.max() - x.min())
109
+ x = x[:, 0:3]
110
+ x = x.permute(0, 2, 3, 1)
111
+ y_hat = (y_hat > 0.5).float()
112
+
113
+ class_labels = {0: "no landslide", 1: "landslide"}
114
+
115
+ self.logger.experiment.log({
116
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
117
+ "predictions": {
118
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
119
+ "class_labels": class_labels
120
+ },
121
+ "ground_truth": {
122
+ "mask_data": y[0][0].cpu().detach().numpy(),
123
+ "class_labels": class_labels
124
+ }
125
+ })
126
+ })
127
+ return {'val_loss': combined_loss}
128
+
129
+ def configure_optimizers(self):
130
+ optimizer = Adam(self.parameters(), lr=self.lr)
131
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
132
+ return [optimizer], [scheduler]
133
+
134
+ class Block(nn.Module):
135
+ def __init__(self, inputs=3, middles=64, outs=64):
136
+ super().__init__()
137
+
138
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
139
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
140
+ self.relu = nn.ReLU()
141
+ self.bn = nn.BatchNorm2d(outs)
142
+ self.pool = nn.MaxPool2d(2, 2)
143
+
144
+ def forward(self, x):
145
+ x = self.relu(self.conv1(x))
146
+ x = self.relu(self.bn(self.conv2(x)))
147
+ return self.pool(x), x
148
+
149
+ class UNet(nn.Module):
150
+ def __init__(self, in_channels=3, out_channels=1):
151
+ super().__init__()
152
+
153
+ self.en1 = Block(in_channels, 64, 64)
154
+ self.en2 = Block(64, 128, 128)
155
+ self.en3 = Block(128, 256, 256)
156
+ self.en4 = Block(256, 512, 512)
157
+ self.en5 = Block(512, 1024, 512)
158
+
159
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
160
+ self.de4 = Block(1024, 512, 256)
161
+
162
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
163
+ self.de3 = Block(512, 256, 128)
164
+
165
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
166
+ self.de2 = Block(256, 128, 64)
167
+
168
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
169
+ self.de1 = Block(128, 64, 64)
170
+
171
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
172
+
173
+ def forward(self, x):
174
+ x, e1 = self.en1(x)
175
+ x, e2 = self.en2(x)
176
+ x, e3 = self.en3(x)
177
+ x, e4 = self.en4(x)
178
+ _, x = self.en5(x)
179
+
180
+ x = self.upsample4(x)
181
+ x = torch.cat([x, e4], dim=1)
182
+ _, x = self.de4(x)
183
+
184
+ x = self.upsample3(x)
185
+ x = torch.cat([x, e3], dim=1)
186
+ _, x = self.de3(x)
187
+
188
+ x = self.upsample2(x)
189
+ x = torch.cat([x, e2], dim=1)
190
+ _, x = self.de2(x)
191
+
192
+ x = self.upsample1(x)
193
+ x = torch.cat([x, e1], dim=1)
194
+ _, x = self.de1(x)
195
+
196
+ x = self.conv_last(x)
197
+
198
+ return x
199
+
200
+ def dice_loss(y_hat, y):
201
+ smooth = 1e-6
202
+ y_hat = y_hat.view(-1)
203
+ y = y.view(-1)
204
+ intersection = (y_hat * y).sum()
205
+ union = y_hat.sum() + y.sum()
206
+ dice = (2 * intersection + smooth) / (union + smooth)
207
+ return 1 - dice
src/se_resnet50_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model, self).__init__()
13
+ self.model = smp.Unet(
14
+ encoder_name=model_type,
15
+ encoder_weights=encoder_weights,
16
+ in_channels=in_channels, # Use the original in_channels
17
+ classes=num_classes,
18
+ )
19
+
20
+ def forward(self, x):
21
+ x = self.model(x)
22
+ return x
23
+
24
+ class LandslideModel(pl.LightningModule):
25
+ def __init__(self, config, alpha=0.5):
26
+ super(LandslideModel, self).__init__()
27
+
28
+ model_type = config['model_config']['model_type']
29
+ in_channels = config['model_config']['in_channels']
30
+ num_classes = config['model_config']['num_classes']
31
+ self.alpha = alpha
32
+ self.lr = config['train_config']['lr']
33
+
34
+ if model_type == 'unet':
35
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
36
+ else:
37
+ encoder_weights = config['model_config']['encoder_weights']
38
+ self.model = smp_model(in_channels=in_channels,
39
+ out_channels=num_classes,
40
+ model_type=model_type,
41
+ num_classes=num_classes,
42
+ encoder_weights=encoder_weights)
43
+
44
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
45
+ self.wce = nn.BCELoss(weight=self.weights)
46
+
47
+ self.train_f1 = F1Score(task='binary')
48
+ self.val_f1 = F1Score(task='binary')
49
+
50
+ self.train_precision = Precision(task='binary')
51
+ self.val_precision = Precision(task='binary')
52
+
53
+ self.train_recall = Recall(task='binary')
54
+ self.val_recall = Recall(task='binary')
55
+
56
+ self.train_iou = JaccardIndex(task='binary')
57
+ self.val_iou = JaccardIndex(task='binary')
58
+
59
+ def forward(self, x):
60
+ return self.model(x)
61
+
62
+ def training_step(self, batch, batch_idx):
63
+ x, y = batch
64
+ y_hat = torch.sigmoid(self(x))
65
+
66
+ wce_loss = self.wce(y_hat, y)
67
+ dice = dice_loss(y_hat, y)
68
+
69
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
70
+
71
+ precision = self.train_precision(y_hat, y)
72
+ recall = self.train_recall(y_hat, y)
73
+ iou = self.train_iou(y_hat, y)
74
+ loss_f1 = self.train_f1(y_hat, y)
75
+
76
+ self.log('train_precision', precision)
77
+ self.log('train_recall', recall)
78
+ self.log('train_wce', wce_loss)
79
+ self.log('train_dice', dice)
80
+ self.log('train_iou', iou)
81
+ self.log('train_f1', loss_f1)
82
+ self.log('train_loss', combined_loss)
83
+ return {'loss': combined_loss}
84
+
85
+ def validation_step(self, batch, batch_idx):
86
+ x, y = batch
87
+ y_hat = torch.sigmoid(self(x))
88
+
89
+ wce_loss = self.wce(y_hat, y)
90
+ dice = dice_loss(y_hat, y)
91
+
92
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
93
+
94
+ precision = self.val_precision(y_hat, y)
95
+ recall = self.val_recall(y_hat, y)
96
+ iou = self.val_iou(y_hat, y)
97
+ loss_f1 = self.val_f1(y_hat, y)
98
+
99
+ self.log('val_precision', precision)
100
+ self.log('val_recall', recall)
101
+ self.log('val_wce', wce_loss)
102
+ self.log('val_dice', dice)
103
+ self.log('val_iou', iou)
104
+ self.log('val_f1', loss_f1)
105
+ self.log('val_loss', combined_loss)
106
+
107
+ if self.current_epoch % 10 == 0:
108
+ x = (x - x.min()) / (x.max() - x.min())
109
+ x = x[:, 0:3]
110
+ x = x.permute(0, 2, 3, 1)
111
+ y_hat = (y_hat > 0.5).float()
112
+
113
+ class_labels = {0: "no landslide", 1: "landslide"}
114
+
115
+ self.logger.experiment.log({
116
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
117
+ "predictions": {
118
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
119
+ "class_labels": class_labels
120
+ },
121
+ "ground_truth": {
122
+ "mask_data": y[0][0].cpu().detach().numpy(),
123
+ "class_labels": class_labels
124
+ }
125
+ })
126
+ })
127
+ return {'val_loss': combined_loss}
128
+
129
+ def configure_optimizers(self):
130
+ optimizer = Adam(self.parameters(), lr=self.lr)
131
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
132
+ return [optimizer], [scheduler]
133
+
134
+ class Block(nn.Module):
135
+ def __init__(self, inputs=3, middles=64, outs=64):
136
+ super().__init__()
137
+
138
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
139
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
140
+ self.relu = nn.ReLU()
141
+ self.bn = nn.BatchNorm2d(outs)
142
+ self.pool = nn.MaxPool2d(2, 2)
143
+
144
+ def forward(self, x):
145
+ x = self.relu(self.conv1(x))
146
+ x = self.relu(self.bn(self.conv2(x)))
147
+ return self.pool(x), x
148
+
149
+ class UNet(nn.Module):
150
+ def __init__(self, in_channels=3, out_channels=1):
151
+ super().__init__()
152
+
153
+ self.en1 = Block(in_channels, 64, 64)
154
+ self.en2 = Block(64, 128, 128)
155
+ self.en3 = Block(128, 256, 256)
156
+ self.en4 = Block(256, 512, 512)
157
+ self.en5 = Block(512, 1024, 512)
158
+
159
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
160
+ self.de4 = Block(1024, 512, 256)
161
+
162
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
163
+ self.de3 = Block(512, 256, 128)
164
+
165
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
166
+ self.de2 = Block(256, 128, 64)
167
+
168
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
169
+ self.de1 = Block(128, 64, 64)
170
+
171
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
172
+
173
+ def forward(self, x):
174
+ x, e1 = self.en1(x)
175
+ x, e2 = self.en2(x)
176
+ x, e3 = self.en3(x)
177
+ x, e4 = self.en4(x)
178
+ _, x = self.en5(x)
179
+
180
+ x = self.upsample4(x)
181
+ x = torch.cat([x, e4], dim=1)
182
+ _, x = self.de4(x)
183
+
184
+ x = self.upsample3(x)
185
+ x = torch.cat([x, e3], dim=1)
186
+ _, x = self.de3(x)
187
+
188
+ x = self.upsample2(x)
189
+ x = torch.cat([x, e2], dim=1)
190
+ _, x = self.de2(x)
191
+
192
+ x = self.upsample1(x)
193
+ x = torch.cat([x, e1], dim=1)
194
+ _, x = self.de1(x)
195
+
196
+ x = self.conv_last(x)
197
+
198
+ return x
199
+
200
+ def dice_loss(y_hat, y):
201
+ smooth = 1e-6
202
+ y_hat = y_hat.view(-1)
203
+ y = y.view(-1)
204
+ intersection = (y_hat * y).sum()
205
+ union = y_hat.sum() + y.sum()
206
+ dice = (2 * intersection + smooth) / (union + smooth)
207
+ return 1 - dice
src/se_resnext50_32x4d_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model, self).__init__()
13
+ self.model = smp.Unet(
14
+ encoder_name=model_type,
15
+ encoder_weights=encoder_weights,
16
+ in_channels=in_channels, # Use the original in_channels
17
+ classes=num_classes,
18
+ )
19
+
20
+ def forward(self, x):
21
+ x = self.model(x)
22
+ return x
23
+
24
+ class LandslideModel(pl.LightningModule):
25
+ def __init__(self, config, alpha=0.5):
26
+ super(LandslideModel, self).__init__()
27
+
28
+ model_type = config['model_config']['model_type']
29
+ in_channels = config['model_config']['in_channels']
30
+ num_classes = config['model_config']['num_classes']
31
+ self.alpha = alpha
32
+ self.lr = config['train_config']['lr']
33
+
34
+ if model_type == 'unet':
35
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
36
+ else:
37
+ encoder_weights = config['model_config']['encoder_weights']
38
+ self.model = smp_model(in_channels=in_channels,
39
+ out_channels=num_classes,
40
+ model_type=model_type,
41
+ num_classes=num_classes,
42
+ encoder_weights=encoder_weights)
43
+
44
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
45
+ self.wce = nn.BCELoss(weight=self.weights)
46
+
47
+ self.train_f1 = F1Score(task='binary')
48
+ self.val_f1 = F1Score(task='binary')
49
+
50
+ self.train_precision = Precision(task='binary')
51
+ self.val_precision = Precision(task='binary')
52
+
53
+ self.train_recall = Recall(task='binary')
54
+ self.val_recall = Recall(task='binary')
55
+
56
+ self.train_iou = JaccardIndex(task='binary')
57
+ self.val_iou = JaccardIndex(task='binary')
58
+
59
+ def forward(self, x):
60
+ return self.model(x)
61
+
62
+ def training_step(self, batch, batch_idx):
63
+ x, y = batch
64
+ y_hat = torch.sigmoid(self(x))
65
+
66
+ wce_loss = self.wce(y_hat, y)
67
+ dice = dice_loss(y_hat, y)
68
+
69
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
70
+
71
+ precision = self.train_precision(y_hat, y)
72
+ recall = self.train_recall(y_hat, y)
73
+ iou = self.train_iou(y_hat, y)
74
+ loss_f1 = self.train_f1(y_hat, y)
75
+
76
+ self.log('train_precision', precision)
77
+ self.log('train_recall', recall)
78
+ self.log('train_wce', wce_loss)
79
+ self.log('train_dice', dice)
80
+ self.log('train_iou', iou)
81
+ self.log('train_f1', loss_f1)
82
+ self.log('train_loss', combined_loss)
83
+ return {'loss': combined_loss}
84
+
85
+ def validation_step(self, batch, batch_idx):
86
+ x, y = batch
87
+ y_hat = torch.sigmoid(self(x))
88
+
89
+ wce_loss = self.wce(y_hat, y)
90
+ dice = dice_loss(y_hat, y)
91
+
92
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
93
+
94
+ precision = self.val_precision(y_hat, y)
95
+ recall = self.val_recall(y_hat, y)
96
+ iou = self.val_iou(y_hat, y)
97
+ loss_f1 = self.val_f1(y_hat, y)
98
+
99
+ self.log('val_precision', precision)
100
+ self.log('val_recall', recall)
101
+ self.log('val_wce', wce_loss)
102
+ self.log('val_dice', dice)
103
+ self.log('val_iou', iou)
104
+ self.log('val_f1', loss_f1)
105
+ self.log('val_loss', combined_loss)
106
+
107
+ if self.current_epoch % 10 == 0:
108
+ x = (x - x.min()) / (x.max() - x.min())
109
+ x = x[:, 0:3]
110
+ x = x.permute(0, 2, 3, 1)
111
+ y_hat = (y_hat > 0.5).float()
112
+
113
+ class_labels = {0: "no landslide", 1: "landslide"}
114
+
115
+ self.logger.experiment.log({
116
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
117
+ "predictions": {
118
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
119
+ "class_labels": class_labels
120
+ },
121
+ "ground_truth": {
122
+ "mask_data": y[0][0].cpu().detach().numpy(),
123
+ "class_labels": class_labels
124
+ }
125
+ })
126
+ })
127
+ return {'val_loss': combined_loss}
128
+
129
+ def configure_optimizers(self):
130
+ optimizer = Adam(self.parameters(), lr=self.lr)
131
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
132
+ return [optimizer], [scheduler]
133
+
134
+ class Block(nn.Module):
135
+ def __init__(self, inputs=3, middles=64, outs=64):
136
+ super().__init__()
137
+
138
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
139
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
140
+ self.relu = nn.ReLU()
141
+ self.bn = nn.BatchNorm2d(outs)
142
+ self.pool = nn.MaxPool2d(2, 2)
143
+
144
+ def forward(self, x):
145
+ x = self.relu(self.conv1(x))
146
+ x = self.relu(self.bn(self.conv2(x)))
147
+ return self.pool(x), x
148
+
149
+ class UNet(nn.Module):
150
+ def __init__(self, in_channels=3, out_channels=1):
151
+ super().__init__()
152
+
153
+ self.en1 = Block(in_channels, 64, 64)
154
+ self.en2 = Block(64, 128, 128)
155
+ self.en3 = Block(128, 256, 256)
156
+ self.en4 = Block(256, 512, 512)
157
+ self.en5 = Block(512, 1024, 512)
158
+
159
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
160
+ self.de4 = Block(1024, 512, 256)
161
+
162
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
163
+ self.de3 = Block(512, 256, 128)
164
+
165
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
166
+ self.de2 = Block(256, 128, 64)
167
+
168
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
169
+ self.de1 = Block(128, 64, 64)
170
+
171
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
172
+
173
+ def forward(self, x):
174
+ x, e1 = self.en1(x)
175
+ x, e2 = self.en2(x)
176
+ x, e3 = self.en3(x)
177
+ x, e4 = self.en4(x)
178
+ _, x = self.en5(x)
179
+
180
+ x = self.upsample4(x)
181
+ x = torch.cat([x, e4], dim=1)
182
+ _, x = self.de4(x)
183
+
184
+ x = self.upsample3(x)
185
+ x = torch.cat([x, e3], dim=1)
186
+ _, x = self.de3(x)
187
+
188
+ x = self.upsample2(x)
189
+ x = torch.cat([x, e2], dim=1)
190
+ _, x = self.de2(x)
191
+
192
+ x = self.upsample1(x)
193
+ x = torch.cat([x, e1], dim=1)
194
+ _, x = self.de1(x)
195
+
196
+ x = self.conv_last(x)
197
+
198
+ return x
199
+
200
+ def dice_loss(y_hat, y):
201
+ smooth = 1e-6
202
+ y_hat = y_hat.view(-1)
203
+ y = y.view(-1)
204
+ intersection = (y_hat * y).sum()
205
+ union = y_hat.sum() + y.sum()
206
+ dice = (2 * intersection + smooth) / (union + smooth)
207
+ return 1 - dice
src/segformer_model.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchmetrics
4
+ import pytorch_lightning as pl
5
+ import wandb
6
+ from torch.optim import Adam
7
+ from torch.optim.lr_scheduler import StepLR
8
+ from transformers import SegformerForSemanticSegmentation
9
+
10
+ class LandslideModel(pl.LightningModule):
11
+ def __init__(self, config, alpha=0.5):
12
+ super(LandslideModel, self).__init__()
13
+
14
+ self.model_type = config['model_config']['model_type']
15
+ self.in_channels = config['model_config']['in_channels']
16
+ self.num_classes = config['model_config']['num_classes']
17
+ self.alpha = alpha
18
+ self.lr = config['train_config']['lr']
19
+
20
+ self.model = SegformerForSemanticSegmentation.from_pretrained(
21
+ "nvidia/segformer-b2-finetuned-ade-512-512",
22
+ ignore_mismatched_sizes=True,
23
+ num_labels=self.num_classes
24
+ )
25
+
26
+ # Modify the input layer for 14 channels
27
+ self.model.segformer.encoder.patch_embeddings[0].proj = nn.Conv2d(
28
+ in_channels=self.in_channels,
29
+ out_channels=self.model.segformer.encoder.patch_embeddings[0].proj.out_channels,
30
+ kernel_size=self.model.segformer.encoder.patch_embeddings[0].proj.kernel_size,
31
+ stride=self.model.segformer.encoder.patch_embeddings[0].proj.stride,
32
+ padding=self.model.segformer.encoder.patch_embeddings[0].proj.padding
33
+ )
34
+
35
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
36
+ self.wce = nn.BCELoss(weight=self.weights)
37
+
38
+ self.train_f1 = torchmetrics.F1Score(task='binary')
39
+ self.val_f1 = torchmetrics.F1Score(task='binary')
40
+
41
+ self.train_precision = torchmetrics.Precision(task='binary')
42
+ self.val_precision = torchmetrics.Precision(task='binary')
43
+
44
+ self.train_recall = torchmetrics.Recall(task='binary')
45
+ self.val_recall = torchmetrics.Recall(task='binary')
46
+
47
+ self.train_iou = torchmetrics.JaccardIndex(task='binary')
48
+ self.val_iou = torchmetrics.JaccardIndex(task='binary')
49
+
50
+ def forward(self, x):
51
+ return self.model(x).logits
52
+
53
+ def training_step(self, batch, batch_idx):
54
+ x, y = batch
55
+ y_hat = torch.sigmoid(self(x))
56
+
57
+ # Resize y_hat to match the size of y
58
+ y_hat = nn.functional.interpolate(y_hat, size=y.shape[2:], mode='bilinear', align_corners=False)
59
+
60
+ wce_loss = self.wce(y_hat, y)
61
+ dice = dice_loss(y_hat, y)
62
+
63
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
64
+
65
+ precision = self.train_precision(y_hat, y)
66
+ recall = self.train_recall(y_hat, y)
67
+ iou = self.train_iou(y_hat, y)
68
+ loss_f1 = self.train_f1(y_hat, y)
69
+
70
+ self.log('train_precision', precision)
71
+ self.log('train_recall', recall)
72
+ self.log('train_wce', wce_loss)
73
+ self.log('train_dice', dice)
74
+ self.log('train_iou', iou)
75
+ self.log('train_f1', loss_f1)
76
+ self.log('train_loss', combined_loss)
77
+ return {'loss': combined_loss}
78
+
79
+ def validation_step(self, batch, batch_idx):
80
+ x, y = batch
81
+ y_hat = torch.sigmoid(self(x))
82
+
83
+ # Resize y_hat to match the size of y
84
+ y_hat = nn.functional.interpolate(y_hat, size=y.shape[2:], mode='bilinear', align_corners=False)
85
+
86
+ wce_loss = self.wce(y_hat, y)
87
+ dice = dice_loss(y_hat, y)
88
+
89
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
90
+
91
+ precision = self.val_precision(y_hat, y)
92
+ recall = self.val_recall(y_hat, y)
93
+ iou = self.val_iou(y_hat, y)
94
+ loss_f1 = self.val_f1(y_hat, y)
95
+
96
+ self.log('val_precision', precision)
97
+ self.log('val_recall', recall)
98
+ self.log('val_wce', wce_loss)
99
+ self.log('val_dice', dice)
100
+ self.log('val_iou', iou)
101
+ self.log('val_f1', loss_f1)
102
+ self.log('val_loss', combined_loss)
103
+
104
+ if self.current_epoch % 10 == 0:
105
+ x = (x - x.min()) / (x.max() - x.min())
106
+ x = x[:, 0:3]
107
+ x = x.permute(0, 2, 3, 1)
108
+ y_hat = (y_hat > 0.5).float()
109
+
110
+ class_labels = {0: "no landslide", 1: "landslide"}
111
+
112
+ self.logger.experiment.log({
113
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
114
+ "predictions": {
115
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
116
+ "class_labels": class_labels
117
+ },
118
+ "ground_truth": {
119
+ "mask_data": y[0][0].cpu().detach().numpy(),
120
+ "class_labels": class_labels
121
+ }
122
+ })
123
+ })
124
+ return {'val_loss': combined_loss}
125
+
126
+ def configure_optimizers(self):
127
+ optimizer = Adam(self.parameters(), lr=self.lr)
128
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
129
+ return [optimizer], [scheduler]
130
+
131
+ def dice_loss(y_hat, y):
132
+ smooth = 1e-6
133
+ y_hat = y_hat.view(-1)
134
+ y = y.view(-1)
135
+ intersection = (y_hat * y).sum()
136
+ union = y_hat.sum() + y.sum()
137
+ dice = (2 * intersection + smooth) / (union + smooth)
138
+ return 1 - dice
src/vgg16_model.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from torchmetrics import F1Score, Precision, Recall, JaccardIndex
5
+ import pytorch_lightning as pl
6
+ import wandb
7
+ from torch.optim import Adam
8
+ from torch.optim.lr_scheduler import StepLR
9
+
10
+ class smp_model_vgg16(nn.Module):
11
+ def __init__(self, in_channels, out_channels, model_type, num_classes, encoder_weights):
12
+ super(smp_model_vgg16, self).__init__()
13
+
14
+ # Ensure that model is always initialized
15
+ self.model = smp.Unet(
16
+ encoder_name='vgg16',
17
+ encoder_weights=encoder_weights,
18
+ in_channels=in_channels, # The number of input channels, which is 14
19
+ classes=num_classes, # Output classes, which is 1
20
+ )
21
+
22
+ def forward(self, x):
23
+ return self.model(x)
24
+
25
+ class LandslideModel(pl.LightningModule):
26
+ def __init__(self, config, alpha=0.5):
27
+ super(LandslideModel, self).__init__()
28
+
29
+ model_type = config['model_config']['model_type']
30
+ in_channels = config['model_config']['in_channels']
31
+ num_classes = config['model_config']['num_classes']
32
+ self.alpha = alpha # Assign the alpha value to the class variable
33
+ self.lr = config['train_config']['lr']
34
+
35
+ if model_type == 'unet':
36
+ self.model = UNet(in_channels=in_channels, out_channels=num_classes)
37
+ else:
38
+ encoder_weights = config['model_config']['encoder_weights']
39
+ # Use the custom smp_model_vgg16 instead of smp_model
40
+ self.model = smp_model_vgg16(in_channels=in_channels,
41
+ out_channels=num_classes,
42
+ model_type=model_type,
43
+ num_classes=num_classes,
44
+ encoder_weights=encoder_weights)
45
+
46
+ self.weights = torch.tensor([5], dtype=torch.float32).to(self.device)
47
+ self.wce = nn.BCELoss(weight=self.weights)
48
+
49
+ self.train_f1 = F1Score(task='binary')
50
+ self.val_f1 = F1Score(task='binary')
51
+
52
+ self.train_precision = Precision(task='binary')
53
+ self.val_precision = Precision(task='binary')
54
+
55
+ self.train_recall = Recall(task='binary')
56
+ self.val_recall = Recall(task='binary')
57
+
58
+ self.train_iou = JaccardIndex(task='binary')
59
+ self.val_iou = JaccardIndex(task='binary')
60
+
61
+ def forward(self, x):
62
+ return self.model(x)
63
+
64
+ def training_step(self, batch, batch_idx):
65
+ x, y = batch
66
+ y_hat = torch.sigmoid(self(x))
67
+
68
+ wce_loss = self.wce(y_hat, y)
69
+ dice = dice_loss(y_hat, y)
70
+
71
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
72
+
73
+ precision = self.train_precision(y_hat, y)
74
+ recall = self.train_recall(y_hat, y)
75
+ iou = self.train_iou(y_hat, y)
76
+ loss_f1 = self.train_f1(y_hat, y)
77
+
78
+ self.log('train_precision', precision)
79
+ self.log('train_recall', recall)
80
+ self.log('train_wce', wce_loss)
81
+ self.log('train_dice', dice)
82
+ self.log('train_iou', iou)
83
+ self.log('train_f1', loss_f1)
84
+ self.log('train_loss', combined_loss)
85
+ return {'loss': combined_loss}
86
+
87
+ def validation_step(self, batch, batch_idx):
88
+ x, y = batch
89
+ y_hat = torch.sigmoid(self(x))
90
+
91
+ wce_loss = self.wce(y_hat, y)
92
+ dice = dice_loss(y_hat, y)
93
+
94
+ combined_loss = (1 - self.alpha) * wce_loss + self.alpha * dice
95
+
96
+ precision = self.val_precision(y_hat, y)
97
+ recall = self.val_recall(y_hat, y)
98
+ iou = self.val_iou(y_hat, y)
99
+ loss_f1 = self.val_f1(y_hat, y)
100
+
101
+ self.log('val_precision', precision)
102
+ self.log('val_recall', recall)
103
+ self.log('val_wce', wce_loss)
104
+ self.log('val_dice', dice)
105
+ self.log('val_iou', iou)
106
+ self.log('val_f1', loss_f1)
107
+ self.log('val_loss', combined_loss)
108
+
109
+ if self.current_epoch % 10 == 0:
110
+ x = (x - x.min()) / (x.max() - x.min())
111
+ x = x[:, 0:3]
112
+ x = x.permute(0, 2, 3, 1)
113
+ y_hat = (y_hat > 0.5).float()
114
+
115
+ class_labels = {0: "no landslide", 1: "landslide"} # Define class_labels here
116
+
117
+ self.logger.experiment.log({
118
+ "image": wandb.Image(x[0].cpu().detach().numpy(), masks={
119
+ "predictions": {
120
+ "mask_data": y_hat[0][0].cpu().detach().numpy(),
121
+ "class_labels": class_labels
122
+ },
123
+ "ground_truth": {
124
+ "mask_data": y[0][0].cpu().detach().numpy(),
125
+ "class_labels": class_labels
126
+ }
127
+ })
128
+ })
129
+ return {'val_loss': combined_loss}
130
+
131
+ def configure_optimizers(self):
132
+ optimizer = Adam(self.parameters(), lr=self.lr)
133
+ scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
134
+ return [optimizer], [scheduler]
135
+
136
+ class Block(nn.Module):
137
+ def __init__(self, inputs=3, middles=64, outs=64):
138
+ super().__init__()
139
+
140
+ self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
141
+ self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
142
+ self.relu = nn.ReLU()
143
+ self.bn = nn.BatchNorm2d(outs)
144
+ self.pool = nn.MaxPool2d(2, 2)
145
+
146
+ def forward(self, x):
147
+ x = self.relu(self.conv1(x))
148
+ x = self.relu(self.bn(self.conv2(x)))
149
+ return self.pool(x), x
150
+
151
+ class UNet(nn.Module):
152
+ def __init__(self, in_channels=3, out_channels=1):
153
+ super().__init__()
154
+
155
+ self.en1 = Block(in_channels, 64, 64)
156
+ self.en2 = Block(64, 128, 128)
157
+ self.en3 = Block(128, 256, 256)
158
+ self.en4 = Block(256, 512, 512)
159
+ self.en5 = Block(512, 1024, 512)
160
+
161
+ self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
162
+ self.de4 = Block(1024, 512, 256)
163
+
164
+ self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
165
+ self.de3 = Block(512, 256, 128)
166
+
167
+ self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
168
+ self.de2 = Block(256, 128, 64)
169
+
170
+ self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
171
+ self.de1 = Block(128, 64, 64)
172
+
173
+ self.conv_last = nn.Conv2d(64, out_channels, kernel_size=1, stride=1, padding=0)
174
+
175
+ def forward(self, x):
176
+ x, e1 = self.en1(x)
177
+ x, e2 = self.en2(x)
178
+ x, e3 = self.en3(x)
179
+ x, e4 = self.en4(x)
180
+ _, x = self.en5(x)
181
+
182
+ x = self.upsample4(x)
183
+ x = torch.cat([x, e4], dim=1)
184
+ _, x = self.de4(x)
185
+
186
+ x = self.upsample3(x)
187
+ x = torch.cat([x, e3], dim=1)
188
+ _, x = self.de3(x)
189
+
190
+ x = self.upsample2(x)
191
+ x = torch.cat([x, e2], dim=1)
192
+ _, x = self.de2(x)
193
+
194
+ x = self.upsample1(x)
195
+ x = torch.cat([x, e1], dim=1)
196
+ _, x = self.de1(x)
197
+
198
+ x = self.conv_last(x)
199
+
200
+ return x
201
+
202
+ def dice_loss(y_hat, y):
203
+ smooth = 1e-6
204
+ y_hat = y_hat.view(-1)
205
+ y = y.view(-1)
206
+ intersection = (y_hat * y).sum()
207
+ union = y_hat.sum() + y.sum()
208
+ dice = (2 * intersection + smooth) / (union + smooth)
209
+ return 1 - dice