yvokeller commited on
Commit
5b24075
·
1 Parent(s): 1eea5f1

first messis demo app version

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ hf_cache
2
+ __pycache__
inference.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import yaml
4
+ import json
5
+ import rasterio
6
+ from rasterio.windows import Window
7
+ from rasterio.transform import rowcol
8
+ from pyproj import Transformer
9
+ from torchvision import transforms
10
+ import numpy as np
11
+ from rasterio.features import shapes
12
+ from shapely.geometry import shape
13
+ import geopandas as gpd
14
+
15
+ from messis.messis import LogConfusionMatrix
16
+
17
+ class InferenceDataLoader:
18
+ def __init__(self, features_path, labels_path, field_ids_path, stats_path, window_size=224, n_timesteps=3, fold_indices=None, debug=False):
19
+ self.features_path = features_path
20
+ self.labels_path = labels_path
21
+ self.field_ids_path = field_ids_path
22
+ self.stats_path = stats_path
23
+ self.window_size = window_size
24
+ self.n_timesteps = n_timesteps
25
+ self.fold_indices = fold_indices if fold_indices is not None else []
26
+ self.debug = debug
27
+
28
+ # Load normalization stats
29
+ self.means, self.stds = self.load_stats()
30
+
31
+ # Set up the transformer for coordinate conversion
32
+ self.transformer = Transformer.from_crs("EPSG:4326", "EPSG:32632", always_xy=True)
33
+
34
+ def load_stats(self):
35
+ """Load normalization statistics for dataset from YAML file."""
36
+ if self.debug:
37
+ print(f"Loading mean/std stats from {self.stats_path}")
38
+ assert os.path.exists(self.stats_path), f"Mean/std stats file not found at {self.stats_path}"
39
+
40
+ with open(self.stats_path, 'r') as file:
41
+ stats = yaml.safe_load(file)
42
+
43
+ mean_list, std_list, n_list = [], [], []
44
+ for fold in self.fold_indices:
45
+ key = f'fold_{fold}'
46
+ if key not in stats:
47
+ raise ValueError(f"Mean/std stats for fold {fold} not found in {self.stats_path}")
48
+ if self.debug:
49
+ print(f"Stats with selected test fold {fold}: {stats[key]} over {self.n_timesteps} timesteps.")
50
+ mean_list.append(torch.tensor(stats[key]['mean'])) # list of 6 means
51
+ std_list.append(torch.tensor(stats[key]['std'])) # list of 6 stds
52
+ n_list.append(stats[key]['n_chips']) # list of 6 ns
53
+
54
+ means, stds = [], []
55
+ for channel in range(mean_list[0].shape[0]):
56
+ means.append(torch.stack([mean_list[i][channel] for i in range(len(mean_list))]).mean())
57
+ variances = torch.stack([std_list[i][channel] ** 2 for i in range(len(std_list))])
58
+ n = torch.tensor([n_list[i] for i in range(len(n_list))], dtype=torch.float32)
59
+ combined_variance = torch.sum(variances * (n - 1)) / (torch.sum(n) - len(n_list))
60
+ stds.append(torch.sqrt(combined_variance))
61
+
62
+ return means * self.n_timesteps, stds * self.n_timesteps
63
+
64
+ def identify_window(self, path, lon, lat):
65
+ """Identify the 224x224 window centered on the clicked coordinates (lon, lat) from the specified GeoTIFF."""
66
+ with rasterio.open(path) as src:
67
+ # Transform the coordinates from WGS84 to UTM (EPSG:32632)
68
+ utm_x, utm_y = self.transformer.transform(lon, lat)
69
+
70
+ try:
71
+ px, py = rowcol(src.transform, utm_x, utm_y)
72
+ except ValueError:
73
+ raise ValueError("Coordinates out of bounds for this raster.")
74
+
75
+ if self.debug:
76
+ print(f"Row: {py}, Column: {px}")
77
+
78
+ half_window_size = self.window_size // 2
79
+
80
+ col_off = px - half_window_size
81
+ row_off = py - half_window_size
82
+
83
+ if col_off < 0:
84
+ col_off = 0
85
+ if row_off < 0:
86
+ row_off = 0
87
+ if col_off + self.window_size > src.width:
88
+ col_off = src.width - self.window_size
89
+ if row_off + self.window_size > src.height:
90
+ row_off = src.height - self.window_size
91
+
92
+ window = Window(col_off, row_off, self.window_size, self.window_size)
93
+ window_transform = src.window_transform(window)
94
+ crs = src.crs
95
+
96
+ return window, window_transform, crs
97
+
98
+ def extract_window(self, path, window):
99
+ """Extract data from the specified window from the GeoTIFF."""
100
+ with rasterio.open(path) as src:
101
+ window_data = src.read(window=window)
102
+
103
+ if self.debug:
104
+ print(f"Extracted window data from {path}")
105
+ print(f"Min: {window_data.min()}, Max: {window_data.max()}")
106
+
107
+ return window_data
108
+
109
+ def prepare_data_for_model(self, features_data):
110
+ """Prepare the window data for model inference."""
111
+ # Convert to tensor
112
+ features_data = torch.tensor(features_data, dtype=torch.float32)
113
+
114
+ # Normalize
115
+ normalize = transforms.Normalize(mean=self.means, std=self.stds)
116
+ features_data = normalize(features_data)
117
+
118
+ # Permute the dimensions if needed
119
+ height, width = features_data.shape[-2:]
120
+ features_data = features_data.view(self.n_timesteps, 6, height, width).permute(1, 0, 2, 3)
121
+
122
+ # Add batch dimension
123
+ features_data = features_data.unsqueeze(0)
124
+
125
+ return features_data
126
+
127
+ def get_data(self, lon, lat):
128
+ """Extract, normalize, and prepare data for inference, including labels and field IDs."""
129
+ # Identify the window and get the georeferencing information
130
+ window, features_transform, features_crs = self.identify_window(self.features_path, lon, lat)
131
+
132
+ # Extract data from the GeoTIFF, labels, and field IDs
133
+ features_data = self.extract_window(self.features_path, window)
134
+ label_data = self.extract_window(self.labels_path, window)
135
+ field_ids_data = self.extract_window(self.field_ids_path, window)
136
+
137
+ # Prepare the window data for the model
138
+ prepared_features_data = self.prepare_data_for_model(features_data)
139
+
140
+ # Convert labels and field_ids to tensors (without normalization)
141
+ label_data = torch.tensor(label_data, dtype=torch.long)
142
+ field_ids_data = torch.tensor(field_ids_data, dtype=torch.long)
143
+
144
+ # Return the prepared data along with transform and CRS
145
+ return prepared_features_data, label_data, field_ids_data, features_transform, features_crs
146
+
147
+ def crop_predictions_to_gdf(field_ids, targets, predictions, transform, crs, class_names):
148
+ """
149
+ Convert field_ids, targets, and predictions tensors to field polygons with corresponding class reference.
150
+
151
+ :param field_ids: PyTorch tensor of shape (1, 224, 224) representing individual fields
152
+ :param targets: PyTorch tensor of shape (1, 224, 224) for targets
153
+ :param predictions: PyTorch tensor of shape (1, 224, 224) for predictions
154
+ :param transform: Affine transform for georeferencing
155
+ :param crs: Coordinate reference system (CRS) of the data
156
+ :param class_names: Dictionary mapping class indices to class names
157
+ :return: GeoPandas DataFrame with polygons, prediction class labels, and target class labels
158
+ """
159
+ field_array = field_ids.squeeze().cpu().numpy().astype(np.int32)
160
+ target_array = targets.squeeze().cpu().numpy().astype(np.int8)
161
+ pred_array = predictions.squeeze().cpu().numpy().astype(np.int8)
162
+
163
+ polygons = []
164
+ field_values = []
165
+ target_values = []
166
+ pred_values = []
167
+
168
+ for geom, field_value in shapes(field_array, transform=transform):
169
+ polygons.append(shape(geom))
170
+ field_values.append(field_value)
171
+
172
+ # Get a single value from the field area for targets and predictions
173
+ target_value = target_array[field_array == field_value][0]
174
+ pred_value = pred_array[field_array == field_value][0]
175
+
176
+ target_values.append(target_value)
177
+ pred_values.append(pred_value)
178
+
179
+ gdf = gpd.GeoDataFrame({
180
+ 'geometry': polygons,
181
+ 'field_id': field_values,
182
+ 'target': target_values,
183
+ 'prediction': pred_values
184
+ }, crs=crs)
185
+
186
+ gdf['prediction_class'] = gdf['prediction'].apply(lambda x: class_names[x])
187
+ gdf['target_class'] = gdf['target'].apply(lambda x: class_names[x])
188
+
189
+ gdf['correct'] = gdf['target'] == gdf['prediction']
190
+
191
+ gdf = gdf[gdf.geometry.area > 250] # Threshold for small polygons
192
+
193
+ return gdf
194
+
195
+ def perform_inference(lon, lat, model, config, debug=False):
196
+ features_path = "./data/stacked_features.tif"
197
+ labels_path = "./data/labels.tif"
198
+ field_ids_path = "./data/field_ids.tif"
199
+ stats_path = "./data/chips_stats.yaml"
200
+
201
+ loader = InferenceDataLoader(features_path, labels_path, field_ids_path, stats_path, n_timesteps=9, fold_indices=[0], debug=True)
202
+
203
+ # Coordinates must be in EPSG:4326 and lon lat order - are converted to the CRS of the raster
204
+ satellite_data, label_data, field_ids_data, features_transform, features_crs = loader.get_data(lon, lat)
205
+
206
+ if debug:
207
+ # Print the shape of the extracted data
208
+ print(satellite_data.shape)
209
+ print(label_data.shape)
210
+ print(field_ids_data.shape)
211
+
212
+ with open('./data/dataset_info.json', 'r') as file:
213
+ dataset_info = json.load(file)
214
+ class_names = dataset_info['tier3']
215
+
216
+ tiers_dict = {k: v for k, v in config.hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)}
217
+ tiers = list(tiers_dict.keys())
218
+
219
+ # Perform inference
220
+ model.eval()
221
+ with torch.no_grad():
222
+ output = model(satellite_data)['tier3_refinement_head']
223
+
224
+ pixelwise_outputs_stacked, majority_outputs_stacked = LogConfusionMatrix.get_pixelwise_and_majority_outputs(output, tiers, field_ids=field_ids_data, dataset_info=dataset_info)
225
+ majority_tier3_predictions = majority_outputs_stacked[2] # Tier 3 predictions
226
+
227
+ # Convert the predictions to a GeoDataFrame
228
+ gdf = crop_predictions_to_gdf(field_ids_data, label_data, majority_tier3_predictions, features_transform, features_crs, class_names)
229
+
230
+ # Simple GeoDataFrame with only the necessary columns
231
+ gdf = gdf[['prediction_class', 'target_class', 'correct', 'geometry']]
232
+ gdf.columns = ['Prediction', 'Target', 'Correct', 'geometry']
233
+ # gdf = gdf[gdf['Target'] != 'Background']
234
+
235
+ return gdf
main.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def main():
4
+ st.set_page_config(layout="wide", page_title="Messis 🌾 - Crop Classification 🌎")
5
+
6
+ st.title("Messis 🌾 - Crop Classification 🌎")
7
+
8
+ st.write("Welcome to the Messis Crop Classification app. Use the sidebar to navigate between selecting coordinates and performing inference.")
9
+
10
+ st.page_link("main.py", label="Home", icon="🏠")
11
+ st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍")
12
+ st.page_link("pages/2_Perform_Crop_Classification.py", label="Perform Crop Classification", icon="🔍")
13
+
14
+ if __name__ == "__main__":
15
+ main()
messis/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # About
2
+
3
+ This package contains the code for the crop classification model Messis.
4
+
5
+ It can be found on Hugging Face at [this link](https://huggingface.co/crop-classification/messis).
6
+
7
+ TODO: Add more information about the model.
messis/__init__.py ADDED
File without changes
messis/dataloader.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision import transforms
5
+ from pytorch_lightning import LightningDataModule
6
+ import os
7
+ import re
8
+ import yaml
9
+ import rasterio
10
+ import dvc.api
11
+
12
+
13
+ params = dvc.api.params_show()
14
+ N_TIMESTEPS = params['number_of_timesteps']
15
+
16
+ class ToTensorTransform(object):
17
+ def __init__(self, dtype):
18
+ self.dtype = dtype
19
+
20
+ def __call__(self, data):
21
+ return torch.tensor(data, dtype=self.dtype)
22
+
23
+ class NormalizeTransform(object):
24
+ def __init__(self, means, stds):
25
+ self.mean = means
26
+ self.std = stds
27
+
28
+ def __call__(self, data):
29
+ return transforms.Normalize(self.mean, self.std)(data)
30
+
31
+ class PermuteTransform:
32
+ def __call__(self, data):
33
+ height, width = data.shape[-2:]
34
+
35
+ # Ensure the channel dimension is as expected
36
+ if data.shape[0] != N_TIMESTEPS * 6:
37
+ raise ValueError(f"Expected {N_TIMESTEPS*6} channels, got {data.shape[1]}")
38
+
39
+ # Step 1: Reshape the data to group the N_TIMESTEPS*6 bands into N_TIMESTEPS groups of 6 bands
40
+ data = data.view(N_TIMESTEPS, 6, height, width)
41
+
42
+ # Step 2: Permute to bring the bands to the front
43
+ data = data.permute(1, 0, 2, 3) # NOTE: Prithvi wants it bands first # after this, shape is (6, N_TIMESTEPS, height, width)
44
+ return data
45
+
46
+ class RandomFlipAndJitterTransform:
47
+ """
48
+ Apply random horizontal and vertical flips, and channel jitter to the input image and corresponding mask.
49
+
50
+ Parameters:
51
+ -----------
52
+ flip_prob : float, optional (default=0.5)
53
+ Probability of applying horizontal and vertical flips to the image and mask.
54
+ Each flip (horizontal and vertical) is applied independently based on this probability.
55
+
56
+ jitter_std : float, optional (default=0.02)
57
+ Standard deviation of the Gaussian noise added to the image channels for jitter.
58
+ This value controls the intensity of the random noise applied to the image channels.
59
+
60
+ Effects of Parameters:
61
+ ----------------------
62
+ flip_prob:
63
+ - Higher flip_prob increases the likelihood of the image and mask being flipped.
64
+ - A value of 0 means no flipping, while a value of 1 means always flip.
65
+
66
+ jitter_std:
67
+ - Higher jitter_std increases the intensity of the noise added to the image channels.
68
+ - A value of 0 means no noise, while larger values add more significant noise.
69
+ """
70
+ def __init__(self, flip_prob=0.5, jitter_std=0.02):
71
+ self.flip_prob = flip_prob
72
+ self.jitter_std = jitter_std
73
+
74
+ def __call__(self, img, mask, field_ids):
75
+ # Shapes (..., H, W)| img: torch.Size([6, N_TIMESTEPS, 224, 224]), mask: torch.Size([N_TIMESTEPS, 224, 224]), field_ids: torch.Size([1, 224, 224])
76
+
77
+ # Temporarily convert field_ids to int32 for flipping (flip not implemented for uint16)
78
+ field_ids = field_ids.to(torch.int32)
79
+
80
+ # Random horizontal flip
81
+ if random.random() < self.flip_prob:
82
+ img = torch.flip(img, [2])
83
+ mask = torch.flip(mask, [1])
84
+ field_ids = torch.flip(field_ids, [1])
85
+
86
+ # Random vertical flip
87
+ if random.random() < self.flip_prob:
88
+ img = torch.flip(img, [3])
89
+ mask = torch.flip(mask, [2])
90
+ field_ids = torch.flip(field_ids, [2])
91
+
92
+ # Convert field_ids back to uint16
93
+ field_ids = field_ids.to(torch.uint16)
94
+
95
+ # Channel jitter
96
+ noise = torch.randn(img.size()) * self.jitter_std
97
+ img += noise
98
+
99
+ return img, mask, field_ids
100
+
101
+ def get_img_transforms():
102
+ return transforms.Compose([])
103
+
104
+ def get_mask_transforms():
105
+ return transforms.Compose([])
106
+
107
+ class GeospatialDataset(Dataset):
108
+ def __init__(self, data_dir, fold_indicies, transform_img=None, transform_mask=None, transform_field_ids=None, debug=False, subset_size=None, data_augmentation=None):
109
+ self.data_dir = data_dir
110
+ self.chips_dir = os.path.join(data_dir, 'chips')
111
+ self.transform_img = transform_img
112
+ self.transform_mask = transform_mask
113
+ self.transform_field_ids = transform_field_ids
114
+ self.debug = debug
115
+ self.images = []
116
+ self.masks = []
117
+ self.field_ids = []
118
+ self.data_augmentation = data_augmentation
119
+
120
+ self.means, self.stds = self.load_stats(fold_indicies, N_TIMESTEPS)
121
+ self.transform_img_load = self.get_img_load_transforms(self.means, self.stds)
122
+ self.transform_mask_load = self.get_mask_load_transforms()
123
+ self.transform_field_ids_load = self.get_field_ids_load_transforms()
124
+
125
+ # Adjust file selection based on fold
126
+ for file in os.listdir(self.chips_dir):
127
+ if re.match(f".*_fold_[{''.join([str(f) for f in fold_indicies])}]_merged.tif", file):
128
+ self.images.append(file)
129
+ mask_file = file.replace("_merged.tif", "_mask.tif")
130
+ self.masks.append(mask_file)
131
+ field_ids_file = file.replace("_merged.tif", "_field_ids.tif")
132
+ self.field_ids.append(field_ids_file)
133
+
134
+ assert len(self.images) == len(self.masks), "Number of images and masks do not match"
135
+
136
+ # If subset_size is specified, randomly select a subset of the data
137
+ if subset_size is not None and len(self.images) > subset_size:
138
+ print(f"Randomly selecting {subset_size} samples from {len(self.images)} samples.")
139
+ selected_indices = random.sample(range(len(self.images)), subset_size)
140
+ self.images = [self.images[i] for i in selected_indices]
141
+ self.masks = [self.masks[i] for i in selected_indices]
142
+ self.field_ids = [self.field_ids[i] for i in selected_indices]
143
+
144
+ def load_stats(self, fold_indicies, n_timesteps=3):
145
+ """Load normalization statistics for dataset from YAML file."""
146
+ stats_path = os.path.join(self.data_dir, 'chips_stats.yaml')
147
+ if self.debug:
148
+ print(f"Loading mean/std stats from {stats_path}")
149
+ assert os.path.exists(stats_path), f"mean/std stats file for dataset not found at {stats_path}"
150
+ with open(stats_path, 'r') as file:
151
+ stats = yaml.safe_load(file)
152
+ mean_list, std_list, n_list = [], [], []
153
+ for fold in fold_indicies:
154
+ key = f'fold_{fold}'
155
+ if key not in stats:
156
+ raise ValueError(f"mean/std stats for fold {fold} not found in {stats_path}")
157
+ if self.debug:
158
+ print(f"Stats with selected test fold {fold}: {stats[key]} over {n_timesteps} timesteps.")
159
+ mean_list.append(torch.Tensor(stats[key]['mean'])) # list of 6 means
160
+ std_list.append(torch.Tensor(stats[key]['std'])) # list of 6 stds
161
+ n_list.append(stats[key]['n_chips']) # list of 6 ns
162
+ # aggregate means and stds over all folds
163
+ means, stds = [], []
164
+ for channel in range(mean_list[0].shape[0]):
165
+ means.append(torch.stack([mean_list[i][channel] for i in range(len(mean_list))]).mean())
166
+ # stds are waaaay more complex to aggregate
167
+ # \sqrt{\frac{\sum_{i=1}^{n} (\sigma_i * (n_i - 1))}{\sum_{i=1}^{n} (n_i) - n}}
168
+ variances = torch.stack([std_list[i][channel] ** 2 for i in range(len(std_list))])
169
+ n = torch.tensor([n_list[i] for i in range(len(n_list))], dtype=torch.float32)
170
+ combined_variance = torch.sum(variances * (n - 1)) / (torch.sum(n) - len(n_list))
171
+ stds.append(torch.sqrt(combined_variance))
172
+
173
+ # make means and stds into 2d arrays, as torchvision would otherwise convert it into a 3d tensor which is incompatible with our 4d temporal images
174
+ # https://github.com/pytorch/vision/blob/6e18cea3485066b7277785415bf2e0422dbdb9da/torchvision/transforms/_functional_tensor.py#L923
175
+ return means * n_timesteps, stds * n_timesteps
176
+
177
+ def get_img_load_transforms(self, means, stds):
178
+ return transforms.Compose([
179
+ ToTensorTransform(torch.float32),
180
+ NormalizeTransform(means, stds),
181
+ PermuteTransform()
182
+ ])
183
+
184
+ def get_mask_load_transforms(self):
185
+ return transforms.Compose([
186
+ ToTensorTransform(torch.uint8)
187
+ ])
188
+
189
+ def get_field_ids_load_transforms(self):
190
+ return transforms.Compose([
191
+ ToTensorTransform(torch.uint16)
192
+ ])
193
+
194
+ def __len__(self):
195
+ return len(self.images)
196
+
197
+ def __getitem__(self, idx):
198
+ img_path = os.path.join(self.chips_dir, self.images[idx])
199
+ mask_path = os.path.join(self.chips_dir, self.masks[idx])
200
+ field_ids_path = os.path.join(self.chips_dir, self.field_ids[idx])
201
+
202
+ img = rasterio.open(img_path).read().astype('uint16')
203
+ mask = rasterio.open(mask_path).read().astype('uint8')
204
+ field_ids = rasterio.open(field_ids_path).read().astype('uint16')
205
+
206
+ # Apply our base transforms
207
+ img = self.transform_img_load(img)
208
+ mask = self.transform_mask_load(mask)
209
+ field_ids = self.transform_field_ids_load(field_ids)
210
+
211
+ # Apply additional transforms passed from GeospatialDataModule if applicable
212
+ if self.transform_img is not None:
213
+ img = self.transform_img(img)
214
+ if self.transform_mask is not None:
215
+ mask = self.transform_mask(mask)
216
+ if self.transform_field_ids is not None:
217
+ field_ids = self.transform_field_ids(field_ids)
218
+
219
+ # Apply data augmentation if enabled
220
+ if self.data_augmentation is not None and self.data_augmentation.get('enabled', True):
221
+ img, mask, field_ids = RandomFlipAndJitterTransform(
222
+ flip_prob=self.data_augmentation.get('flip_prob', 0.5),
223
+ jitter_std=self.data_augmentation.get('jitter_std', 0.02)
224
+ )(img, mask, field_ids)
225
+
226
+ # Load targets for given tiers
227
+ num_tiers = mask.shape[0]
228
+ targets = ()
229
+ for i in range(num_tiers):
230
+ targets += (mask[i, :, :].type(torch.long),)
231
+
232
+ return img, (targets, field_ids)
233
+
234
+ class GeospatialDataModule(LightningDataModule):
235
+ def __init__(self, data_dir, train_folds, val_folds, test_folds, batch_size=8, num_workers=4, debug=False, subsets=None, data_augmentation=None):
236
+ super().__init__()
237
+ self.data_dir = data_dir
238
+ self.batch_size = batch_size
239
+ self.num_workers = num_workers
240
+ self.debug = debug
241
+ self.subsets = subsets if subsets is not None else {}
242
+ self.data_augmentation = data_augmentation if data_augmentation is not None else {}
243
+
244
+ GeospatialDataModule.validate_folds(train_folds, val_folds, test_folds)
245
+ self.train_folds = train_folds
246
+ self.val_folds = val_folds
247
+ self.test_folds = test_folds
248
+
249
+ # NOTE: Transforms on this level not used for now
250
+ self.transform_img = get_img_transforms()
251
+ self.transform_mask = get_mask_transforms()
252
+
253
+ @staticmethod
254
+ def validate_folds(train, val, test):
255
+ if train is None or val is None or test is None:
256
+ raise ValueError("All fold sets must be specified")
257
+
258
+ if len(set(train) & set(val)) > 0 or len(set(train) & set(test)) > 0 or len(set(val) & set(test)) > 0:
259
+ raise ValueError("Folds must be mutually exclusive")
260
+
261
+ def setup(self, stage=None):
262
+ print(f"Setting up GeospatialDataModule for stage: {stage}. Data augmentation config: {self.data_augmentation}")
263
+ common_params = {
264
+ 'data_dir': self.data_dir,
265
+ 'debug': self.debug,
266
+ 'data_augmentation': self.data_augmentation
267
+ }
268
+ common_params_val_test = { # Never augment validation or test data
269
+ **common_params,
270
+ 'data_augmentation': {
271
+ 'enabled': False
272
+ }
273
+ }
274
+ if stage in ('fit', None):
275
+ self.train_dataset = GeospatialDataset(fold_indicies=self.train_folds, subset_size=self.subsets.get('train', None), **common_params)
276
+ self.val_dataset = GeospatialDataset(fold_indicies=self.val_folds, subset_size=self.subsets.get('val', None), **common_params_val_test)
277
+ if stage in ('test', None):
278
+ self.test_dataset = GeospatialDataset(fold_indicies=self.test_folds, subset_size=self.subsets.get('test', None), **common_params_val_test)
279
+
280
+ def train_dataloader(self):
281
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, shuffle=True)
282
+
283
+ def val_dataloader(self):
284
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True)
285
+
286
+ def test_dataloader(self):
287
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True)
messis/messis.py ADDED
@@ -0,0 +1,919 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pytorch_lightning as pl
4
+ from torchmetrics import classification
5
+ import wandb
6
+ from matplotlib import pyplot as plt
7
+ import numpy as np
8
+ import matplotlib.ticker as ticker
9
+ from matplotlib.colors import ListedColormap
10
+ from huggingface_hub import PyTorchModelHubMixin
11
+ from lion_pytorch import Lion
12
+
13
+ import json
14
+
15
+ from messis.prithvi import TemporalViTEncoder, ConvTransformerTokensToEmbeddingNeck, ConvTransformerTokensToEmbeddingBottleneckNeck
16
+
17
+
18
+ def safe_shape(x):
19
+ if isinstance(x, tuple):
20
+ # loop through tuple
21
+ shape_info = '(tuple) : '
22
+ for i in x:
23
+ shape_info += str(i.shape) + ', '
24
+ return shape_info
25
+ if isinstance(x, list):
26
+ # loop through list
27
+ shape_info = '(list) : '
28
+ for i in x:
29
+ shape_info += str(i.shape) + ', '
30
+ return shape_info
31
+ return x.shape
32
+
33
+ class ConvModule(nn.Module):
34
+ """
35
+ A simple convolutional module including Conv, BatchNorm, and ReLU layers.
36
+ """
37
+ def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
38
+ super(ConvModule, self).__init__()
39
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False)
40
+ self.bn = nn.BatchNorm2d(out_channels)
41
+ self.relu = nn.ReLU(inplace=True)
42
+
43
+ def forward(self, x):
44
+ x = self.conv(x)
45
+ x = self.bn(x)
46
+ return self.relu(x)
47
+
48
+ class HierarchicalFCNHead(nn.Module):
49
+ """
50
+ Hierarchical FCN Head for semantic segmentation.
51
+ """
52
+ def __init__(self, in_channels, out_channels, num_classes, num_convs=2, kernel_size=3, dilation=1, dropout_p=0.1, debug=False):
53
+ super(HierarchicalFCNHead, self).__init__()
54
+
55
+ self.debug = debug
56
+
57
+ self.convs = nn.Sequential(*[
58
+ ConvModule(
59
+ in_channels if i == 0 else out_channels,
60
+ out_channels,
61
+ kernel_size,
62
+ padding=dilation * (kernel_size // 2),
63
+ dilation=dilation
64
+ ) for i in range(num_convs)
65
+ ])
66
+
67
+ self.conv_seg = nn.Conv2d(out_channels, num_classes, kernel_size=1)
68
+ self.dropout = nn.Dropout2d(p=dropout_p)
69
+
70
+ def forward(self, x):
71
+ if self.debug:
72
+ print('HierarchicalFCNHead forward INP: ', safe_shape(x))
73
+ x = self.convs(x)
74
+ features = self.dropout(x)
75
+ output = self.conv_seg(features)
76
+ if self.debug:
77
+ print('HierarchicalFCNHead forward features OUT: ', safe_shape(features))
78
+ print('HierarchicalFCNHead forward output OUT: ', safe_shape(output))
79
+ return output, features
80
+
81
+ class LabelRefinementHead(nn.Module):
82
+ """
83
+ Similar to the label refinement module introduced in the ZueriCrop paper, this module refines the predictions for tier 3.
84
+ It takes the raw predictions from head 1, head 2 and head 3 and refines them to produce the final prediction for tier 3.
85
+ According to ZueriCrop, this helps with making the predictions more consistent across the different tiers.
86
+ """
87
+ def __init__(self, input_channels, num_classes):
88
+ super(LabelRefinementHead, self).__init__()
89
+
90
+ self.cnn_layers = nn.Sequential(
91
+ # 1x1 Convolutional layer
92
+ nn.Conv2d(in_channels=input_channels, out_channels=128, kernel_size=1, stride=1, padding=0),
93
+ nn.BatchNorm2d(128),
94
+ nn.ReLU(inplace=True),
95
+
96
+ # 3x3 Convolutional layer
97
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
98
+ nn.BatchNorm2d(128),
99
+ nn.ReLU(inplace=True),
100
+ nn.Dropout(p=0.5),
101
+
102
+ # Skip connection (implemented in forward method)
103
+
104
+ # Another 3x3 Convolutional layer
105
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
106
+ nn.BatchNorm2d(128),
107
+ nn.ReLU(inplace=True),
108
+
109
+ # 1x1 Convolutional layer to adjust the number of output channels to num_classes
110
+ nn.Conv2d(in_channels=128, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
111
+ nn.Dropout(p=0.5)
112
+ )
113
+
114
+ def forward(self, x):
115
+ # Apply initial conv layer
116
+ y = self.cnn_layers[0:3](x)
117
+
118
+ # Save for skip connection
119
+ y_skip = y
120
+
121
+ # Apply the next two conv layers
122
+ y = self.cnn_layers[3:9](y)
123
+
124
+ # Skip connection (element-wise addition)
125
+ y = y + y_skip
126
+
127
+ # Apply the last conv layer
128
+ y = self.cnn_layers[9:](y)
129
+ return y
130
+
131
+ class HierarchicalClassifier(nn.Module):
132
+ def __init__(
133
+ self,
134
+ heads_spec,
135
+ dropout_p=0.1,
136
+ img_size=256,
137
+ patch_size=16,
138
+ num_frames=3,
139
+ bands=[0, 1, 2, 3, 4, 5],
140
+ backbone_weights_path=None,
141
+ freeze_backbone=True,
142
+ use_bottleneck_neck=False,
143
+ bottleneck_reduction_factor=4,
144
+ loss_ignore_background=False,
145
+ debug=False
146
+ ):
147
+ super(HierarchicalClassifier, self).__init__()
148
+
149
+ self.embed_dim = 768
150
+ if num_frames % 3 != 0:
151
+ raise ValueError("The number of frames must be a multiple of 3, it is currently: ", num_frames)
152
+ self.num_frames = num_frames
153
+ self.hp, self.wp = img_size // patch_size, img_size // patch_size
154
+ self.heads_spec = heads_spec
155
+ self.dropout_p = dropout_p
156
+ self.loss_ignore_background = loss_ignore_background
157
+ self.debug = debug
158
+
159
+ if self.debug:
160
+ print('hp and wp: ', self.hp, self.wp)
161
+
162
+ self.prithvi = TemporalViTEncoder(
163
+ img_size=img_size,
164
+ patch_size=patch_size,
165
+ num_frames=3,
166
+ tubelet_size=1,
167
+ in_chans=len(bands),
168
+ embed_dim=self.embed_dim,
169
+ depth=12,
170
+ num_heads=8,
171
+ mlp_ratio=4.0,
172
+ norm_pix_loss=False,
173
+ pretrained=backbone_weights_path,
174
+ debug=self.debug
175
+ )
176
+
177
+ # (Un)freeze the backbone
178
+ for param in self.prithvi.parameters():
179
+ param.requires_grad = not freeze_backbone
180
+
181
+ # Neck to transform the token-based output of the transformer into a spatial feature map
182
+ number_of_necks = self.num_frames // 3
183
+ if use_bottleneck_neck:
184
+ self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingBottleneckNeck(
185
+ embed_dim=self.embed_dim * 3,
186
+ output_embed_dim=self.embed_dim * 3,
187
+ drop_cls_token=True,
188
+ Hp=self.hp,
189
+ Wp=self.wp,
190
+ bottleneck_reduction_factor=bottleneck_reduction_factor
191
+ ) for _ in range(number_of_necks)])
192
+ else:
193
+ self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingNeck(
194
+ embed_dim=self.embed_dim * 3,
195
+ output_embed_dim=self.embed_dim * 3,
196
+ drop_cls_token=True,
197
+ Hp=self.hp,
198
+ Wp=self.wp,
199
+ ) for _ in range(number_of_necks)])
200
+
201
+ # Initialize heads and loss weights based on tiers
202
+ self.heads = nn.ModuleDict()
203
+ self.loss_weights = {}
204
+ self.total_classes = 0
205
+
206
+ # Build HierarchicalFCNHeads
207
+ head_count = 0
208
+ for head_name, head_info in self.heads_spec.items():
209
+ head_type = head_info['type']
210
+ num_classes = head_info['num_classes_to_predict']
211
+ loss_weight = head_info['loss_weight']
212
+
213
+ if head_type == 'HierarchicalFCNHead':
214
+ num_classes = head_info['num_classes_to_predict']
215
+ loss_weight = head_info['loss_weight']
216
+ kernel_size = head_info.get('kernel_size', 3)
217
+ num_convs = head_info.get('num_convs', 1)
218
+ num_channels = head_info.get('num_channels', 256)
219
+ self.total_classes += num_classes
220
+
221
+ self.heads[head_name] = HierarchicalFCNHead(
222
+ in_channels=(self.embed_dim * self.num_frames) if head_count == 0 else num_channels,
223
+ out_channels=num_channels,
224
+ num_classes=num_classes,
225
+ num_convs=num_convs,
226
+ kernel_size=kernel_size,
227
+ dropout_p=self.dropout_p,
228
+ debug=self.debug
229
+ )
230
+ self.loss_weights[head_name] = loss_weight
231
+
232
+ # NOTE: LabelRefinementHead must be the last in the dict, otherwise the total_classes will be incorrect
233
+ if head_type == 'LabelRefinementHead':
234
+ self.refinement_head = LabelRefinementHead(input_channels=self.total_classes, num_classes=num_classes)
235
+ self.refinement_head_name = head_name
236
+ self.loss_weights[head_name] = loss_weight
237
+
238
+ head_count += 1
239
+
240
+ self.loss_func = nn.CrossEntropyLoss(ignore_index=-1)
241
+
242
+ def forward(self, x):
243
+ if self.debug:
244
+ print(f"Input shape: {safe_shape(x)}") # torch.Size([4, 6, 9, 224, 224])
245
+
246
+ # Extract features from the base model
247
+ if len(self.necks) == 1:
248
+ features = [x]
249
+ else:
250
+ features = torch.chunk(x, len(self.necks), dim=2)
251
+ features = [self.prithvi(x) for x in features]
252
+
253
+ if self.debug:
254
+ print(f"Features shape after base model: {', '.join([safe_shape(f) for f in features])}") # (tuple) : torch.Size([4, 589, 768]), , (tuple) : torch.Size
255
+
256
+ # Process through the neck
257
+ features = [neck(feat_) for feat_, neck in zip(features, self.necks)]
258
+
259
+ if self.debug:
260
+ print(f"Features shape after neck: {', '.join([safe_shape(f) for f in features])}") # (tuple) : torch.Size([4, 2304, 224, 224]), , (tuple) : torch.Size
261
+
262
+ # Remove from tuple
263
+ features = [feat[0] for feat in features]
264
+ # stack the features to create a tensor of torch.Size([4, 6912, 224, 224])
265
+ features = torch.concatenate(features, dim=1)
266
+ if self.debug:
267
+ print(f"Features shape after removing tuple: {safe_shape(features)}") # torch.Size([4, 6912, 224, 224])
268
+
269
+ # Process through the heads
270
+ outputs = {}
271
+ for tier_name, head in self.heads.items():
272
+ output, features = head(features)
273
+ outputs[tier_name] = output
274
+
275
+ if self.debug:
276
+ print(f"Features shape after {tier_name} head: {safe_shape(features)}")
277
+ print(f"Output shape after {tier_name} head: {safe_shape(output)}")
278
+
279
+ # Process through the classification refinement head
280
+ output_concatenated = torch.cat(list(outputs.values()), dim=1)
281
+ output_refinement_head = self.refinement_head(output_concatenated)
282
+ outputs[self.refinement_head_name] = output_refinement_head
283
+
284
+ return outputs
285
+
286
+ def calculate_loss(self, outputs, targets):
287
+ total_loss = 0
288
+ loss_per_head = {}
289
+ for head_name, output in outputs.items():
290
+ if self.debug:
291
+ print(f"Target index for {head_name}: {self.heads_spec[head_name]['target_idx']}")
292
+ target = targets[self.heads_spec[head_name]['target_idx']]
293
+ loss_target = target
294
+ if self.loss_ignore_background:
295
+ loss_target = target.clone() # Clone as original target needed in backward pass
296
+ loss_target[loss_target == 0] = -1 # Set background class to ignore_index -1 for loss calculation
297
+ loss = self.loss_func(output, loss_target)
298
+ loss_per_head[f'{head_name}'] = loss
299
+ total_loss += loss * self.loss_weights[head_name]
300
+
301
+ return total_loss, loss_per_head
302
+
303
+ class Messis(pl.LightningModule, PyTorchModelHubMixin):
304
+ def __init__(self, hparams):
305
+ super().__init__()
306
+ self.save_hyperparameters(hparams)
307
+
308
+ self.model = HierarchicalClassifier(
309
+ heads_spec=hparams['heads_spec'],
310
+ dropout_p=hparams.get('dropout_p'),
311
+ img_size=hparams.get('img_size'),
312
+ patch_size=hparams.get('patch_size'),
313
+ num_frames=hparams.get('num_frames'),
314
+ bands=hparams.get('bands'),
315
+ backbone_weights_path=hparams.get('backbone_weights_path'),
316
+ freeze_backbone=hparams['freeze_backbone'],
317
+ use_bottleneck_neck=hparams.get('use_bottleneck_neck'),
318
+ bottleneck_reduction_factor=hparams.get('bottleneck_reduction_factor'),
319
+ loss_ignore_background=hparams.get('loss_ignore_background'),
320
+ debug=hparams.get('debug')
321
+ )
322
+
323
+ def forward(self, x):
324
+ return self.model(x)
325
+
326
+ def training_step(self, batch, batch_idx):
327
+ return self.__step(batch, batch_idx, "train")
328
+
329
+ def validation_step(self, batch, batch_idx):
330
+ return self.__step(batch, batch_idx, "val")
331
+
332
+ def test_step(self, batch, batch_idx):
333
+ return self.__step(batch, batch_idx, "test")
334
+
335
+ def configure_optimizers(self):
336
+ # select case on optimizer
337
+ match self.hparams.get('optimizer', 'Adam'):
338
+ case 'Adam':
339
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.get('lr', 1e-3))
340
+ case 'AdamW':
341
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.get('lr', 1e-3), weight_decay=self.hparams.get('optimizer_weight_decay', 0.01))
342
+ case 'SGD':
343
+ optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.get('lr', 1e-3), momentum=self.hparams.get('optimizer_momentum', 0.9))
344
+ case 'Lion':
345
+ # https://github.com/lucidrains/lion-pytorch | Typically lr 3-10 times lower than Adam and weight_decay 3-10 times higher
346
+ optimizer = Lion(self.parameters(), lr=self.hparams.get('lr', 1e-4), weight_decay=self.hparams.get('optimizer_weight_decay', 0.1))
347
+ case _:
348
+ raise ValueError(f"Optimizer {self.hparams.get('optimizer')} not supported")
349
+ return optimizer
350
+
351
+ def __step(self, batch, batch_idx, stage):
352
+ inputs, targets = batch
353
+ targets = torch.stack(targets[0])
354
+ outputs = self(inputs)
355
+ loss, loss_per_head = self.model.calculate_loss(outputs, targets)
356
+ loss_per_head_named = {f'{stage}_loss_{head}': loss_per_head[head] for head in loss_per_head}
357
+ loss_proportions = { f'{stage}_loss_{head}_proportion': round(loss_per_head[head].item() / loss.item(), 2) for head in loss_per_head}
358
+ loss_detail_dict = {**loss_per_head_named, **loss_proportions}
359
+
360
+ if self.hparams.get('debug'):
361
+ print(f"Step Inputs shape: {safe_shape(inputs)}")
362
+ print(f"Step Targets shape: {safe_shape(targets)}")
363
+ print(f"Step Outputs dict keys: {outputs.keys()}")
364
+
365
+ # NOTE: All metrics other than loss are tracked by callbacks (LogMessisMetrics)
366
+ self.log_dict({f'{stage}_loss': loss, **loss_detail_dict}, on_step=True, on_epoch=True, prog_bar=True, logger=True)
367
+ return {'loss': loss, 'outputs': outputs}
368
+
369
+ class LogConfusionMatrix(pl.Callback):
370
+ def __init__(self, hparams, dataset_info_file, debug=False):
371
+ super().__init__()
372
+
373
+ assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams"
374
+ self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)}
375
+ self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None)
376
+ self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None)
377
+
378
+ assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True"
379
+ assert self.final_head_name is not None, "No head found with 'is_final_head' set to True"
380
+
381
+ self.tiers = list(self.tiers_dict.keys())
382
+ self.phases = ['train', 'val', 'test']
383
+ self.modes = ['pixelwise', 'majority']
384
+ self.debug = debug
385
+
386
+ if debug:
387
+ print(f"Final head identified as: {self.final_head_name}")
388
+ print(f"LogConfusionMatrix Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}")
389
+
390
+ with open(dataset_info_file, 'r') as f:
391
+ self.dataset_info = json.load(f)
392
+
393
+ # Initialize confusion matrices
394
+ self.metrics_to_compute = ['confusion_matrix']
395
+ self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases}
396
+
397
+ def __init_metrics(self, tier, phase):
398
+ num_classes = self.tiers_dict[tier]['num_classes_to_predict']
399
+ confusion_matrix = classification.MulticlassConfusionMatrix(num_classes=num_classes)
400
+
401
+ return {
402
+ 'confusion_matrix': confusion_matrix
403
+ }
404
+
405
+ def setup(self, trainer, pl_module, stage=None):
406
+ # Move all metrics to the correct device at the start of the training/validation
407
+ device = pl_module.device
408
+ for phase_metrics in self.metrics.values():
409
+ for tier_metrics in phase_metrics.values():
410
+ for mode_metrics in tier_metrics.values():
411
+ for metric in self.metrics_to_compute:
412
+ mode_metrics[metric].to(device)
413
+
414
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
415
+ self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'train')
416
+
417
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
418
+ self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'val')
419
+
420
+ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
421
+ self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'test')
422
+
423
+ def __update_confusion_matrices(self, trainer, pl_module, outputs, batch, batch_idx, phase):
424
+ if trainer.sanity_checking:
425
+ return
426
+
427
+ targets = torch.stack(batch[1][0]) # (tiers, batch, H, W)
428
+ outputs = outputs['outputs'][self.final_head_name] # (batch, C, H, W)
429
+ field_ids = batch[1][1].permute(1, 0, 2, 3)[0]
430
+
431
+ pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info)
432
+
433
+ for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes):
434
+ # Update all metrics
435
+ assert len(preds) == len(targets), f"Number of predictions and targets do not match: {len(preds)} vs {len(targets)}"
436
+ assert len(preds) == len(self.tiers), f"Number of predictions and tiers do not match: {len(preds)} vs {len(self.tiers)}"
437
+
438
+ for pred, target, tier in zip(preds, targets, self.tiers):
439
+ if self.debug:
440
+ print(f"Updating confusion matrix for {phase} {tier} {mode}")
441
+ metrics = self.metrics[phase][tier][mode]
442
+ # flatten and remove background class if the mode is majority (such that the background class is not included in the confusion matrix)
443
+ if mode == 'majority':
444
+ pred = pred[target != 0]
445
+ target = target[target != 0]
446
+ metrics['confusion_matrix'].update(pred, target)
447
+
448
+
449
+ @staticmethod
450
+ def get_pixelwise_and_majority_outputs(refinement_head_outputs, tiers, field_ids, dataset_info):
451
+ """
452
+ Get the pixelwise and majority predictions from the model outputs.
453
+ The pixelwise tier predictions are derived from the refinement_head_outputs predictions.
454
+ The majority last tier predictions are derived from the refinement_head_outputs. And then the majority lower-tier predictions are derived from the majority highest-tier predictions.
455
+
456
+ Also sets the background to 0 for all field majority predictions (regardless of what the model predicts for the background class).
457
+ As this is a classification task and not a segmentation task and the field boundaries are known beforehand and not of any interest.
458
+
459
+ Args:
460
+ refinement_head_outputs (torch.Tensor(batch, C, H, W)): The probability outputs from the model for the refined tier.
461
+ tiers (list of str): List of tiers e.g. ['tier1', 'tier2', 'tier3'].
462
+ field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction.
463
+ dataset_info (dict): The dataset information.
464
+
465
+ Returns:
466
+ torch.Tensor(tiers, batch, H, W): The pixelwise predictions.
467
+ torch.Tensor(tiers, batch, H, W): The majority predictions.
468
+ """
469
+
470
+ # Assuming the highest tier is the last one in the list
471
+ highest_tier = tiers[-1]
472
+
473
+ pixelwise_highest_tier = torch.softmax(refinement_head_outputs, dim=1).argmax(dim=1) # (batch, H, W)
474
+ majority_highest_tier = LogConfusionMatrix.get_field_majority_preds(refinement_head_outputs, field_ids)
475
+
476
+ tier_mapping = {tier: dataset_info[f'{highest_tier}_to_{tier}'] for tier in tiers if tier != highest_tier}
477
+
478
+ pixelwise_outputs = {highest_tier: pixelwise_highest_tier}
479
+ majority_outputs = {highest_tier: majority_highest_tier}
480
+
481
+ # Initialize pixelwise and majority outputs for each tier
482
+ for tier in tiers:
483
+ if tier != highest_tier:
484
+ pixelwise_outputs[tier] = torch.zeros_like(pixelwise_highest_tier)
485
+ majority_outputs[tier] = torch.zeros_like(majority_highest_tier)
486
+
487
+ # Map the highest tier to lower tiers
488
+ for i, mappings in enumerate(zip(*tier_mapping.values())):
489
+ for j, tier in enumerate(tier_mapping.keys()):
490
+ pixelwise_outputs[tier][pixelwise_highest_tier == i] = mappings[j]
491
+ majority_outputs[tier][majority_highest_tier == i] = mappings[j]
492
+
493
+ pixelwise_outputs_stacked = torch.stack([pixelwise_outputs[tier] for tier in tiers])
494
+ majority_outputs_stacked = torch.stack([majority_outputs[tier] for tier in tiers])
495
+
496
+ # Ensure these are tensors
497
+ assert isinstance(pixelwise_outputs_stacked, torch.Tensor), "pixelwise_outputs_stacked is not a tensor"
498
+ assert isinstance(majority_outputs_stacked, torch.Tensor), "majority_outputs_stacked is not a tensor"
499
+
500
+ return pixelwise_outputs_stacked, majority_outputs_stacked
501
+
502
+
503
+ @staticmethod
504
+ def get_field_majority_preds(output, field_ids):
505
+ """
506
+ Get the majority prediction for each field in the batch. The majority excludes the background class.
507
+
508
+ Args:
509
+ output (torch.Tensor(batch, C, H, W)): The probability outputs from the model (tier3_refined)
510
+ field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction.
511
+
512
+ Returns:
513
+ torch.Tensor(batch, H, W): The majority predictions.
514
+ """
515
+ # remove the background class
516
+ pixelwise = torch.softmax(output[:, 1:, :, :], dim=1).argmax(dim=1) + 1 # (batch, H, W)
517
+ majority_preds = torch.zeros_like(pixelwise)
518
+ for batch in range(len(pixelwise)):
519
+ field_ids_batch = field_ids[batch]
520
+ for field_id in np.unique(field_ids_batch.cpu().numpy()):
521
+ if field_id == 0:
522
+ continue
523
+ field_mask = field_ids_batch == field_id
524
+ flattened_pred = pixelwise[batch][field_mask].view(-1) # Flatten the prediction
525
+ flattened_pred = flattened_pred[flattened_pred != 0] # Exclude background class
526
+ if len(flattened_pred) == 0:
527
+ continue
528
+ mode_pred, _ = torch.mode(flattened_pred) # Compute mode prediction
529
+ majority_preds[batch][field_mask] = mode_pred.item()
530
+ return majority_preds
531
+
532
+ def on_train_epoch_end(self, trainer, pl_module):
533
+ # Log and then reset the confusion matrices after training epoch
534
+ self.__log_and_reset_confusion_matrices(trainer, pl_module, 'train')
535
+
536
+ def on_validation_epoch_end(self, trainer, pl_module):
537
+ # Log and then reset the confusion matrices after validation epoch
538
+ self.__log_and_reset_confusion_matrices(trainer, pl_module, 'val')
539
+
540
+ def on_test_epoch_end(self, trainer, pl_module):
541
+ # Log and then reset the confusion matrices after test epoch
542
+ self.__log_and_reset_confusion_matrices(trainer, pl_module, 'test')
543
+
544
+ def __log_and_reset_confusion_matrices(self, trainer, pl_module, phase):
545
+ if trainer.sanity_checking:
546
+ return
547
+
548
+ for tier in self.tiers:
549
+ for mode in self.modes:
550
+ metrics = self.metrics[phase][tier][mode]
551
+ confusion_matrix = metrics['confusion_matrix']
552
+ if self.debug:
553
+ print(f"Logging and resetting confusion matrix for {phase} {tier} Update count: {confusion_matrix._update_count}")
554
+ matrix = confusion_matrix.compute() # columns are predictions and rows are targets
555
+
556
+ # Calculate percentages
557
+ matrix = matrix.float()
558
+ row_sums = matrix.sum(dim=1, keepdim=True)
559
+ matrix_percent = matrix / row_sums
560
+
561
+ # Ensure percentages sum to 1 for each row or handle NaNs
562
+ row_sum_check = matrix_percent.sum(dim=1)
563
+ valid_rows = ~torch.isnan(row_sum_check)
564
+ if valid_rows.any():
565
+ assert torch.allclose(row_sum_check[valid_rows], torch.ones_like(row_sum_check[valid_rows]), atol=1e-2), "Percentages do not sum to 1 for some valid rows"
566
+
567
+ # Sort the matrix and labels by the total number of instances
568
+ sorted_indices = row_sums.squeeze().argsort(descending=True)
569
+ matrix_percent = matrix_percent[sorted_indices, :] # sort rows
570
+ matrix_percent = matrix_percent[:, sorted_indices] # sort columns
571
+ class_labels = [self.dataset_info[tier][i] for i in sorted_indices]
572
+ row_sums_sorted = row_sums[sorted_indices]
573
+
574
+ # Check for zero rows after sorting
575
+ zero_rows = (row_sums_sorted == 0).squeeze()
576
+
577
+ fig, ax = plt.subplots(figsize=(matrix.size(0), matrix.size(0)), dpi=140)
578
+
579
+ ax.matshow(matrix_percent.cpu().numpy(), cmap='viridis')
580
+
581
+ ax.xaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(1) + 1)))
582
+ ax.yaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(0) + 1)))
583
+
584
+ ax.set_xticklabels(class_labels + [''], rotation=45)
585
+ ax.set_yticklabels(class_labels + [''])
586
+
587
+ # Add total number of instances to the y-axis labels
588
+ y_labels = [f'{class_labels[i]} [n={int(row_sums_sorted[i].item()):,.0f}]'.replace(',', "'") for i in range(matrix.size(0))]
589
+ ax.set_yticklabels(y_labels + [''])
590
+
591
+ ax.set_xlabel('Predictions')
592
+ ax.set_ylabel('Targets')
593
+
594
+ # Move x-axis label and ticks to the top
595
+ ax.xaxis.set_label_position('top')
596
+ ax.xaxis.set_ticks_position('top')
597
+
598
+ fig.tight_layout()
599
+
600
+ for i in range(matrix.size(0)):
601
+ for j in range(matrix.size(1)):
602
+ if zero_rows[i]:
603
+ ax.text(j, i, 'N/A', ha='center', va='center', color='black')
604
+ else:
605
+ ax.text(j, i, f'{matrix_percent[i, j]:.2f}', ha='center', va='center', color='#F88379', weight='bold') # coral red
606
+ trainer.logger.experiment.log({f"{phase}_{tier}_confusion_matrix_{mode}": wandb.Image(fig)})
607
+ plt.close()
608
+ confusion_matrix.reset()
609
+
610
+ class LogMessisMetrics(pl.Callback):
611
+ def __init__(self, hparams, dataset_info_file, debug=False):
612
+ super().__init__()
613
+
614
+ assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams"
615
+ self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)}
616
+ self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None)
617
+ self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None)
618
+
619
+ assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True"
620
+ assert self.final_head_name is not None, "No head found with 'is_final_head' set to True"
621
+
622
+ self.tiers = list(self.tiers_dict.keys())
623
+ self.phases = ['train', 'val', 'test']
624
+ self.modes = ['pixelwise', 'majority']
625
+ self.debug = debug
626
+
627
+ if debug:
628
+ print(f"Last tier identified as: {self.last_tier_name}")
629
+ print(f"Final head identified as: {self.final_head_name}")
630
+ print(f"LogMessisMetrics Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}")
631
+
632
+ with open(dataset_info_file, 'r') as f:
633
+ self.dataset_info = json.load(f)
634
+
635
+ # Initialize metrics
636
+ self.metrics_to_compute = ['accuracy', 'weighted_accuracy', 'precision', 'weighted_precision', 'recall', 'weighted_recall' ,'f1', 'weighted_f1', 'cohen_kappa']
637
+ self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases}
638
+ self.images_to_log = {phase: {mode: None for mode in self.modes} for phase in self.phases}
639
+ self.images_to_log_targets = {phase: None for phase in self.phases}
640
+ self.field_ids_to_log_targets = {phase: None for phase in self.phases}
641
+ self.inputs_to_log = {phase: None for phase in self.phases}
642
+
643
+ def __init_metrics(self, tier, phase):
644
+ num_classes = self.tiers_dict[tier]['num_classes_to_predict']
645
+
646
+ accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='macro')
647
+ weighted_accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='weighted')
648
+ per_class_accuracies = {
649
+ class_index: classification.BinaryAccuracy() for class_index in range(num_classes)
650
+ }
651
+ precision = classification.MulticlassPrecision(num_classes=num_classes, average='macro')
652
+ weighted_precision = classification.MulticlassPrecision(num_classes=num_classes, average='weighted')
653
+ recall = classification.MulticlassRecall(num_classes=num_classes, average='macro')
654
+ weighted_recall = classification.MulticlassRecall(num_classes=num_classes, average='weighted')
655
+ f1 = classification.MulticlassF1Score(num_classes=num_classes, average='macro')
656
+ weighted_f1 = classification.MulticlassF1Score(num_classes=num_classes, average='weighted')
657
+ cohen_kappa = classification.MulticlassCohenKappa(num_classes=num_classes)
658
+
659
+ return {
660
+ 'accuracy': accuracy,
661
+ 'weighted_accuracy': weighted_accuracy,
662
+ 'per_class_accuracies': per_class_accuracies,
663
+ 'precision': precision,
664
+ 'weighted_precision': weighted_precision,
665
+ 'recall': recall,
666
+ 'weighted_recall': weighted_recall,
667
+ 'f1': f1,
668
+ 'weighted_f1': weighted_f1,
669
+ 'cohen_kappa': cohen_kappa
670
+ }
671
+
672
+ def setup(self, trainer, pl_module, stage=None):
673
+ # Move all metrics to the correct device at the start of the training/validation
674
+ device = pl_module.device
675
+ for phase_metrics in self.metrics.values():
676
+ for tier_metrics in phase_metrics.values():
677
+ for mode_metrics in tier_metrics.values():
678
+ for metric in self.metrics_to_compute:
679
+ mode_metrics[metric].to(device)
680
+ for class_accuracy in mode_metrics['per_class_accuracies'].values():
681
+ class_accuracy.to(device)
682
+
683
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
684
+ self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'train')
685
+
686
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
687
+ self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'val')
688
+
689
+ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
690
+ self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'test')
691
+
692
+ def __on_batch_end(self, trainer: pl.Trainer, pl_module, outputs, batch, batch_idx, phase):
693
+ if trainer.sanity_checking:
694
+ return
695
+ if self.debug:
696
+ print(f"{phase} batch ended. Updating metrics...")
697
+
698
+ targets = torch.stack(batch[1][0]) # (tiers, batch, H, W)
699
+ outputs = outputs['outputs'][self.final_head_name] # (batch, C, H, W)
700
+ field_ids = batch[1][1].permute(1, 0, 2, 3)[0]
701
+
702
+ pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info)
703
+
704
+ for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes):
705
+
706
+ # Update all metrics
707
+ assert preds.shape == targets.shape, f"Shapes of predictions and targets do not match: {preds.shape} vs {targets.shape}"
708
+ assert preds.shape[0] == len(self.tiers), f"Number of tiers in predictions and tiers do not match: {preds.shape[0]} vs {len(self.tiers)}"
709
+
710
+ self.images_to_log[phase][mode] = preds[-1]
711
+
712
+ for pred, target, tier in zip(preds, targets, self.tiers):
713
+ # flatten and remove background class if the mode is majority (such that the background class is not considered in the metrics)
714
+ if mode == 'majority':
715
+ pred = pred[target != 0]
716
+ target = target[target != 0]
717
+ metrics = self.metrics[phase][tier][mode]
718
+ for metric in self.metrics_to_compute:
719
+ metrics[metric].update(pred, target)
720
+ if self.debug:
721
+ print(f"{phase} {tier} {mode} {metric} updated. Update count: {metrics[metric]._update_count}")
722
+ self.__update_per_class_metrics(pred, target, metrics['per_class_accuracies'])
723
+
724
+ self.images_to_log_targets[phase] = targets[-1]
725
+ self.field_ids_to_log_targets[phase] = field_ids
726
+ self.inputs_to_log[phase] = batch[0]
727
+
728
+ def __update_per_class_metrics(self, preds, targets, per_class_accuracies):
729
+ for class_index, class_accuracy in per_class_accuracies.items():
730
+ if not (targets == class_index).any():
731
+ continue
732
+
733
+ if class_index == 0:
734
+ # Mask out non-background elements for background class (0)
735
+ class_mask = targets != 0
736
+ else:
737
+ # Mask out background elements for other classes
738
+ class_mask = targets == 0
739
+
740
+ preds_fields = preds[~class_mask]
741
+ targets_fields = targets[~class_mask]
742
+
743
+ # Prepare for binary classification (needs to be float)
744
+ preds_class = (preds_fields == class_index).float()
745
+ targets_class = (targets_fields == class_index).float()
746
+
747
+ class_accuracy.update(preds_class, targets_class)
748
+
749
+ if self.debug:
750
+ print(f"Shape of preds_fields: {preds_fields.shape}")
751
+ print(f"Shape of targets_fields: {targets_fields.shape}")
752
+ print(f"Unique values in preds_fields: {torch.unique(preds_fields)}")
753
+ print(f"Unique values in targets_fields: {torch.unique(targets_fields)}")
754
+ print(f"Per-class metrics for class {class_index} updated. Update count: {per_class_accuracies[class_index]._update_count}")
755
+
756
+ def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
757
+ self.__on_epoch_end(trainer, pl_module, 'train')
758
+
759
+ def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
760
+ self.__on_epoch_end(trainer, pl_module, 'val')
761
+
762
+ def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
763
+ self.__on_epoch_end(trainer, pl_module, 'test')
764
+
765
+ def __on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, phase):
766
+ if trainer.sanity_checking:
767
+ return # Skip during sanity check (avoid warning about metric compute being called before update)
768
+ for tier in self.tiers:
769
+ for mode in self.modes:
770
+ metrics = self.metrics[phase][tier][mode]
771
+
772
+ # Calculate and reset in tier: Accuracy, WeightedAccuracy, Precision, Recall, F1, Cohen's Kappa
773
+ metrics_dict = {metric: metrics[metric].compute() for metric in self.metrics_to_compute}
774
+ pl_module.log_dict({f"{phase}_{metric}_{tier}_{mode}": v for metric, v in metrics_dict.items()}, on_step=False, on_epoch=True)
775
+ for metric in self.metrics_to_compute:
776
+ metrics[metric].reset()
777
+
778
+ # Per-class metrics
779
+ # NOTE: Some literature reports "per class accuracy" but what they actually mean is "per class recall".
780
+ # Using the accuracy formula per class has no value in our imbalanced multi-class setting (TN's inflate scores!)
781
+ # We calculate all 4 metrics. This allows us to calculate any macro/micro score later if needed.
782
+ class_metrics = []
783
+ class_names_mapping = self.dataset_info[tier.split('_')[0] if '_refined' in tier else tier]
784
+ for class_index, class_accuracy in metrics['per_class_accuracies'].items():
785
+ if class_accuracy._update_count == 0:
786
+ continue # Skip if no updates have been made
787
+ tp, tn, fp, fn = class_accuracy.tp, class_accuracy.tn, class_accuracy.fp, class_accuracy.fn
788
+ recall = (tp / (tp + fn)).item() if tp + fn > 0 else 0
789
+ precision = (tp / (tp + fp)).item() if tp + fp > 0 else 0
790
+ f1 = (2 * (precision * recall) / (precision + recall)) if precision + recall > 0 else 0
791
+ n_of_class = (tp + fn).item()
792
+ class_metrics.append([class_index, class_names_mapping[class_index], precision, recall, f1, class_accuracy.compute().item(), n_of_class])
793
+ class_accuracy.reset()
794
+ wandb_table = wandb.Table(data=class_metrics, columns=["Class Index", "Class Name", "Precision", "Recall", "F1", "Accuracy", "N"])
795
+ trainer.logger.experiment.log({f"{phase}_per_class_metrics_{tier}_{mode}": wandb_table})
796
+
797
+ # use the same n_classes for all images, such that they are comparable
798
+ n_classes = max([
799
+ torch.max(self.images_to_log_targets[phase]),
800
+ torch.max(self.images_to_log[phase]["majority"]),
801
+ torch.max(self.images_to_log[phase]["pixelwise"])
802
+ ])
803
+ images = [LogMessisMetrics.process_images(self.images_to_log[phase][mode], n_classes) for mode in self.modes]
804
+ images.append(LogMessisMetrics.create_positive_negative_image(self.images_to_log[phase]["majority"], self.images_to_log_targets[phase]))
805
+ images.append(LogMessisMetrics.process_images(self.images_to_log_targets[phase], n_classes))
806
+ images.append(LogMessisMetrics.process_images(self.field_ids_to_log_targets[phase].cpu()))
807
+
808
+ examples = []
809
+ for i in range(len(images[0])):
810
+ example = np.concatenate([img[i] for img in images], axis=0)
811
+ examples.append(wandb.Image(example, caption=f"From Top to Bottom: {self.modes[0]}, {self.modes[1]}, right/wrong classifications, target, fields"))
812
+
813
+ trainer.logger.experiment.log({f"{phase}_examples": examples})
814
+
815
+ # Log segmentation masks
816
+ batch_input_data = self.inputs_to_log[phase].cpu() # shape [BS, 6, N_TIMESTEPS, 224, 224]
817
+ ground_truth_masks = self.images_to_log_targets[phase].cpu().numpy()
818
+ pixel_wise_masks = self.images_to_log[phase]["pixelwise"].cpu().numpy()
819
+ field_majority_masks = self.images_to_log[phase]["majority"].cpu().numpy()
820
+ correctness_masks = self.create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks)
821
+ class_labels = {idx: name for idx, name in enumerate(self.dataset_info[self.last_tier_name])}
822
+
823
+ segmentation_masks = []
824
+ for input_data, ground_truth_mask, pixel_wise_mask, field_majority_mask, correctness_mask in zip(batch_input_data, ground_truth_masks, pixel_wise_masks, field_majority_masks, correctness_masks):
825
+ middle_timestep_index = input_data.shape[1] // 2 # Get the middle timestamp index
826
+ gamma = 2.5 # Gamma for brightness adjustment
827
+ rgb_image = input_data[:3, middle_timestep_index, :, :].permute(1, 2, 0).numpy() # Shape [224, 224, 3]
828
+ rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min())
829
+ rgb_image = np.power(rgb_image, 1.0 / gamma)
830
+ rgb_image = (rgb_image * 255).astype(np.uint8)
831
+
832
+ mask_img = wandb.Image(
833
+ rgb_image,
834
+ masks={
835
+ "predictions_pixel_wise": {"mask_data": pixel_wise_mask, "class_labels": class_labels},
836
+ "predictions_field_majority": {"mask_data": field_majority_mask, "class_labels": class_labels},
837
+ "ground_truth": {"mask_data": ground_truth_mask, "class_labels": class_labels},
838
+ "correctness": {"mask_data": correctness_mask, "class_labels": { 0: "Background", 1: "Wrong", 2: "Right" }},
839
+ },
840
+ )
841
+ segmentation_masks.append(mask_img)
842
+
843
+ trainer.logger.experiment.log({f"{phase}_segmentation_mask": segmentation_masks})
844
+
845
+ if self.debug:
846
+ print(f"{phase} epoch ended. Logging & resetting metrics...", trainer.sanity_checking)
847
+
848
+ @staticmethod
849
+ def create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks):
850
+ """
851
+ Create a tensor that shows the positive and negative classifications of the model.
852
+
853
+ Args:
854
+ field_majority_masks (np.ndarray): The field majority masks generated by the model.
855
+ ground_truth_masks (np.ndarray): The ground truth masks.
856
+
857
+ Returns:
858
+ np.ndarray: An array with values:
859
+ - 0 where the target is 0,
860
+ - 2 where the prediction matches the target,
861
+ - 1 where the prediction does not match the target.
862
+ """
863
+ correctness_mask = np.zeros_like(ground_truth_masks, dtype=int)
864
+
865
+ matches = (field_majority_masks == ground_truth_masks) & (ground_truth_masks != 0)
866
+ correctness_mask[matches] = 2
867
+
868
+ mismatches = (field_majority_masks != ground_truth_masks) & (ground_truth_masks != 0)
869
+ correctness_mask[mismatches] = 1
870
+
871
+ return correctness_mask
872
+
873
+ @staticmethod
874
+ def create_positive_negative_image(generated_images, target_images):
875
+ """
876
+ Create an image that shows the positive and negative classifications of the model.
877
+
878
+ Args:
879
+ generated_images (torch.Tensor): The images generated by the model.
880
+ target_images (torch.Tensor): The target images.
881
+
882
+ Returns:
883
+ list: A list of processed images.
884
+ """
885
+ classification_masks = generated_images == target_images
886
+ processed_imgs = []
887
+ for mask, target in zip(classification_masks, target_images):
888
+ # color the background white, right classifications green, wrong classifications red
889
+ colored_img = torch.zeros((mask.shape[0], mask.shape[1], 3), dtype=torch.uint8)
890
+ mask = mask.bool() # Convert to boolean tensor
891
+ colored_img[mask] = torch.tensor([0, 255, 0], dtype=torch.uint8)
892
+ colored_img[~mask] = torch.tensor([255, 0, 0], dtype=torch.uint8)
893
+ colored_img[target == 0] = torch.tensor([0, 0, 0], dtype=torch.uint8)
894
+ processed_imgs.append(colored_img.cpu())
895
+ return processed_imgs
896
+
897
+ @staticmethod
898
+ def process_images(imgs, max=None):
899
+ """
900
+ Process a batch of images to be logged on wandb.
901
+
902
+ Args:
903
+ imgs (torch.Tensor): A batch of images with shape (B, H, W) to be processed.
904
+ max (float, optional): The maximum value to normalize the images. Defaults to None. If None, the maximum value in the batch is used.
905
+ """
906
+ if max is None:
907
+ max = np.max(imgs.cpu().numpy())
908
+ normalized_img = imgs / max
909
+ processed_imgs = []
910
+ for img in normalized_img.cpu().numpy():
911
+ if max < 60:
912
+ cmap = ListedColormap(plt.get_cmap('tab20').colors + plt.get_cmap('tab20b').colors + plt.get_cmap('tab20c').colors)
913
+ else:
914
+ cmap = plt.get_cmap('viridis')
915
+ colored_img = cmap(img)
916
+ colored_img[img == 0] = [0, 0, 0, 1]
917
+ colored_img_uint8 = (colored_img[:, :, :3] * 255).astype(np.uint8)
918
+ processed_imgs.append(colored_img_uint8)
919
+ return processed_imgs
messis/prithvi.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors import safe_open
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+ from timm.models.layers import to_2tuple
7
+ from timm.models.vision_transformer import Block
8
+
9
+ # Taken and adapted from Pritvhi `geospatial_fm.py`, for the purpose of avoiding MMCV/MMSegmentation dependencies
10
+
11
+ def _convTranspose2dOutput(
12
+ input_size: int,
13
+ stride: int,
14
+ padding: int,
15
+ dilation: int,
16
+ kernel_size: int,
17
+ output_padding: int,
18
+ ):
19
+ """
20
+ Calculate the output size of a ConvTranspose2d.
21
+ Taken from: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
22
+ """
23
+ return (
24
+ (input_size - 1) * stride
25
+ - 2 * padding
26
+ + dilation * (kernel_size - 1)
27
+ + output_padding
28
+ + 1
29
+ )
30
+
31
+
32
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor):
33
+ """
34
+ embed_dim: output dimension for each position
35
+ pos: a list of positions to be encoded: size (M,)
36
+ out: (M, D)
37
+ """
38
+ assert embed_dim % 2 == 0
39
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
40
+ omega /= embed_dim / 2.0
41
+ omega = 1.0 / 10000**omega # (D/2,)
42
+
43
+ pos = pos.reshape(-1) # (M,)
44
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
45
+
46
+ emb_sin = np.sin(out) # (M, D/2)
47
+ emb_cos = np.cos(out) # (M, D/2)
48
+
49
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
50
+ return emb
51
+
52
+
53
+ def get_3d_sincos_pos_embed(embed_dim: int, grid_size: tuple, cls_token: bool = False):
54
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
55
+ # All rights reserved.
56
+
57
+ # This source code is licensed under the license found in the
58
+ # LICENSE file in the root directory of this source tree.
59
+ # --------------------------------------------------------
60
+ # Position embedding utils
61
+ # --------------------------------------------------------
62
+ """
63
+ grid_size: 3d tuple of grid size: t, h, w
64
+ return:
65
+ pos_embed: L, D
66
+ """
67
+
68
+ assert embed_dim % 16 == 0
69
+
70
+ t_size, h_size, w_size = grid_size
71
+
72
+ w_embed_dim = embed_dim // 16 * 6
73
+ h_embed_dim = embed_dim // 16 * 6
74
+ t_embed_dim = embed_dim // 16 * 4
75
+
76
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
77
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
78
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
79
+
80
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
81
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
82
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
83
+
84
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
85
+
86
+ if cls_token:
87
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
88
+ return pos_embed
89
+
90
+
91
+ class Norm2d(nn.Module):
92
+ def __init__(self, embed_dim: int):
93
+ super().__init__()
94
+ self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
95
+
96
+ def forward(self, x):
97
+ x = x.permute(0, 2, 3, 1)
98
+ x = self.ln(x)
99
+ x = x.permute(0, 3, 1, 2).contiguous()
100
+ return x
101
+
102
+
103
+ class PatchEmbed(nn.Module):
104
+ """Frames of 2D Images to Patch Embedding
105
+ The 3D version of timm.models.vision_transformer.PatchEmbed
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ img_size: int = 224,
111
+ patch_size: int = 16,
112
+ num_frames: int = 3,
113
+ tubelet_size: int = 1,
114
+ in_chans: int = 3,
115
+ embed_dim: int = 768,
116
+ norm_layer: nn.Module = None,
117
+ flatten: bool = True,
118
+ bias: bool = True,
119
+ ):
120
+ super().__init__()
121
+ img_size = to_2tuple(img_size)
122
+ patch_size = to_2tuple(patch_size)
123
+ self.img_size = img_size
124
+ self.patch_size = patch_size
125
+ self.num_frames = num_frames
126
+ self.tubelet_size = tubelet_size
127
+ self.grid_size = (
128
+ num_frames // tubelet_size,
129
+ img_size[0] // patch_size[0],
130
+ img_size[1] // patch_size[1],
131
+ )
132
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
133
+ self.flatten = flatten
134
+
135
+ self.proj = nn.Conv3d(
136
+ in_chans,
137
+ embed_dim,
138
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
139
+ stride=(tubelet_size, patch_size[0], patch_size[1]),
140
+ bias=bias,
141
+ )
142
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
143
+
144
+ def forward(self, x):
145
+ B, C, T, H, W = x.shape
146
+ assert (
147
+ H == self.img_size[0]
148
+ ), f"Input image height ({H}) doesn't match model ({self.img_size[0]})."
149
+ assert (
150
+ W == self.img_size[1]
151
+ ), f"Input image width ({W}) doesn't match model ({self.img_size[1]})."
152
+ x = self.proj(x)
153
+ Hp, Wp = x.shape[3], x.shape[4]
154
+ if self.flatten:
155
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
156
+ x = self.norm(x)
157
+ return x, Hp, Wp
158
+
159
+
160
+ class ConvTransformerTokensToEmbeddingNeck(nn.Module):
161
+ """
162
+ Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers.
163
+ Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ embed_dim: int,
169
+ output_embed_dim: int,
170
+ # num_frames: int = 1,
171
+ Hp: int = 14,
172
+ Wp: int = 14,
173
+ drop_cls_token: bool = True,
174
+ ):
175
+ """
176
+
177
+ Args:
178
+ embed_dim (int): Input embedding dimension
179
+ output_embed_dim (int): Output embedding dimension
180
+ Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14.
181
+ Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14.
182
+ drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. This assumes the cls token is the first token. Defaults to True.
183
+ """
184
+ super().__init__()
185
+ self.drop_cls_token = drop_cls_token
186
+ self.Hp = Hp
187
+ self.Wp = Wp
188
+ self.H_out = Hp
189
+ self.W_out = Wp
190
+ # self.num_frames = num_frames
191
+
192
+ kernel_size = 2
193
+ stride = 2
194
+ dilation = 1
195
+ padding = 0
196
+ output_padding = 0
197
+ for _ in range(4):
198
+ self.H_out = _convTranspose2dOutput(
199
+ self.H_out, stride, padding, dilation, kernel_size, output_padding
200
+ )
201
+ self.W_out = _convTranspose2dOutput(
202
+ self.W_out, stride, padding, dilation, kernel_size, output_padding
203
+ )
204
+
205
+ self.embed_dim = embed_dim
206
+ self.output_embed_dim = output_embed_dim
207
+ self.fpn1 = nn.Sequential(
208
+ nn.ConvTranspose2d(
209
+ self.embed_dim,
210
+ self.output_embed_dim,
211
+ kernel_size=kernel_size,
212
+ stride=stride,
213
+ dilation=dilation,
214
+ padding=padding,
215
+ output_padding=output_padding,
216
+ ),
217
+ Norm2d(self.output_embed_dim),
218
+ nn.GELU(),
219
+ nn.ConvTranspose2d(
220
+ self.output_embed_dim,
221
+ self.output_embed_dim,
222
+ kernel_size=kernel_size,
223
+ stride=stride,
224
+ dilation=dilation,
225
+ padding=padding,
226
+ output_padding=output_padding,
227
+ ),
228
+ )
229
+ self.fpn2 = nn.Sequential(
230
+ nn.ConvTranspose2d(
231
+ self.output_embed_dim,
232
+ self.output_embed_dim,
233
+ kernel_size=kernel_size,
234
+ stride=stride,
235
+ dilation=dilation,
236
+ padding=padding,
237
+ output_padding=output_padding,
238
+ ),
239
+ Norm2d(self.output_embed_dim),
240
+ nn.GELU(),
241
+ nn.ConvTranspose2d(
242
+ self.output_embed_dim,
243
+ self.output_embed_dim,
244
+ kernel_size=kernel_size,
245
+ stride=stride,
246
+ dilation=dilation,
247
+ padding=padding,
248
+ output_padding=output_padding,
249
+ ),
250
+ )
251
+
252
+ def forward(self, x):
253
+ x = x[0]
254
+ if self.drop_cls_token:
255
+ x = x[:, 1:, :]
256
+ x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp)
257
+
258
+ x = self.fpn1(x)
259
+ x = self.fpn2(x)
260
+
261
+ x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out))
262
+ out = tuple([x])
263
+ return out
264
+
265
+ class ConvTransformerTokensToEmbeddingBottleneckNeck(nn.Module):
266
+ """
267
+ Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers.
268
+ Performs ConvTranspose2d operations with bottleneck layers to reduce channels.
269
+ """
270
+
271
+ def __init__(
272
+ self,
273
+ embed_dim: int,
274
+ output_embed_dim: int,
275
+ Hp: int = 14,
276
+ Wp: int = 14,
277
+ drop_cls_token: bool = True,
278
+ bottleneck_reduction_factor: int = 4,
279
+ ):
280
+ """
281
+ Args:
282
+ embed_dim (int): Input embedding dimension
283
+ output_embed_dim (int): Output embedding dimension
284
+ Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14.
285
+ Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14.
286
+ drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. Defaults to True.
287
+ bottleneck_ratio (int, optional): Ratio to reduce channels in bottleneck layers. Defaults to 4.
288
+ """
289
+ super().__init__()
290
+ self.drop_cls_token = drop_cls_token
291
+ self.Hp = Hp
292
+ self.Wp = Wp
293
+ self.H_out = Hp
294
+ self.W_out = Wp
295
+
296
+ kernel_size = 2
297
+ stride = 2
298
+ dilation = 1
299
+ padding = 0
300
+ output_padding = 0
301
+ for _ in range(4):
302
+ self.H_out = _convTranspose2dOutput(
303
+ self.H_out, stride, padding, dilation, kernel_size, output_padding
304
+ )
305
+ self.W_out = _convTranspose2dOutput(
306
+ self.W_out, stride, padding, dilation, kernel_size, output_padding
307
+ )
308
+
309
+ self.embed_dim = embed_dim
310
+ self.output_embed_dim = output_embed_dim
311
+ bottleneck_dim = self.embed_dim // bottleneck_reduction_factor
312
+
313
+ self.fpn1 = nn.Sequential(
314
+ nn.Conv2d(
315
+ self.embed_dim,
316
+ bottleneck_dim,
317
+ kernel_size=1
318
+ ),
319
+ Norm2d(bottleneck_dim),
320
+ nn.GELU(),
321
+ nn.ConvTranspose2d(
322
+ bottleneck_dim,
323
+ bottleneck_dim,
324
+ kernel_size=kernel_size,
325
+ stride=stride,
326
+ padding=padding,
327
+ output_padding=output_padding
328
+ ),
329
+ Norm2d(bottleneck_dim),
330
+ nn.GELU(),
331
+ nn.ConvTranspose2d(
332
+ bottleneck_dim,
333
+ bottleneck_dim,
334
+ kernel_size=kernel_size,
335
+ stride=stride,
336
+ padding=padding,
337
+ output_padding=output_padding
338
+ ),
339
+ Norm2d(bottleneck_dim),
340
+ nn.GELU(),
341
+ nn.Conv2d(
342
+ bottleneck_dim,
343
+ self.output_embed_dim,
344
+ kernel_size=1
345
+ ),
346
+ Norm2d(self.output_embed_dim),
347
+ nn.GELU(),
348
+ )
349
+
350
+ self.fpn2 = nn.Sequential(
351
+ nn.Conv2d(
352
+ self.output_embed_dim,
353
+ bottleneck_dim,
354
+ kernel_size=1
355
+ ),
356
+ Norm2d(bottleneck_dim),
357
+ nn.GELU(),
358
+ nn.ConvTranspose2d(
359
+ bottleneck_dim,
360
+ bottleneck_dim,
361
+ kernel_size=kernel_size,
362
+ stride=stride,
363
+ padding=padding,
364
+ output_padding=output_padding
365
+ ),
366
+ Norm2d(bottleneck_dim),
367
+ nn.GELU(),
368
+ nn.ConvTranspose2d(
369
+ bottleneck_dim,
370
+ bottleneck_dim,
371
+ kernel_size=kernel_size,
372
+ stride=stride,
373
+ padding=padding,
374
+ output_padding=output_padding
375
+ ),
376
+ Norm2d(bottleneck_dim),
377
+ nn.GELU(),
378
+ nn.Conv2d(
379
+ bottleneck_dim,
380
+ self.output_embed_dim,
381
+ kernel_size=1
382
+ ),
383
+ Norm2d(self.output_embed_dim),
384
+ nn.GELU(),
385
+ )
386
+
387
+ def forward(self, x):
388
+ x = x[0]
389
+ if self.drop_cls_token:
390
+ x = x[:, 1:, :]
391
+ x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp)
392
+
393
+ x = self.fpn1(x)
394
+ x = self.fpn2(x)
395
+
396
+ x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out))
397
+ out = tuple([x])
398
+ return out
399
+
400
+ class TemporalViTEncoder(nn.Module):
401
+ """Encoder from an ViT with capability to take in temporal input.
402
+
403
+ This class defines an encoder taken from a ViT architecture.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ img_size: int = 224,
409
+ patch_size: int = 16,
410
+ num_frames: int = 1,
411
+ tubelet_size: int = 1,
412
+ in_chans: int = 3,
413
+ embed_dim: int = 1024,
414
+ depth: int = 24,
415
+ num_heads: int = 16,
416
+ mlp_ratio: float = 4.0,
417
+ norm_layer: nn.Module = nn.LayerNorm,
418
+ norm_pix_loss: bool = False,
419
+ pretrained: str = None,
420
+ debug=False
421
+ ):
422
+ """
423
+
424
+ Args:
425
+ img_size (int, optional): Input image size. Defaults to 224.
426
+ patch_size (int, optional): Patch size to be used by the transformer. Defaults to 16.
427
+ num_frames (int, optional): Number of frames (temporal dimension) to be input to the encoder. Defaults to 1.
428
+ tubelet_size (int, optional): Tubelet size used in patch embedding. Defaults to 1.
429
+ in_chans (int, optional): Number of input channels. Defaults to 3.
430
+ embed_dim (int, optional): Embedding dimension. Defaults to 1024.
431
+ depth (int, optional): Encoder depth. Defaults to 24.
432
+ num_heads (int, optional): Number of heads used in the encoder blocks. Defaults to 16.
433
+ mlp_ratio (float, optional): Ratio to be used for the size of the MLP in encoder blocks. Defaults to 4.0.
434
+ norm_layer (nn.Module, optional): Norm layer to be used. Defaults to nn.LayerNorm.
435
+ norm_pix_loss (bool, optional): Whether to use Norm Pix Loss. Defaults to False.
436
+ pretrained (str, optional): Path to pretrained encoder weights. Defaults to None.
437
+ """
438
+ super().__init__()
439
+
440
+ # --------------------------------------------------------------------------
441
+ # MAE encoder specifics
442
+ self.embed_dim = embed_dim
443
+ self.patch_embed = PatchEmbed(
444
+ img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim
445
+ )
446
+ num_patches = self.patch_embed.num_patches
447
+ self.num_frames = num_frames
448
+
449
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
450
+ self.pos_embed = nn.Parameter(
451
+ torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
452
+ ) # fixed sin-cos embedding
453
+
454
+ self.blocks = nn.ModuleList(
455
+ [
456
+ Block(
457
+ embed_dim,
458
+ num_heads,
459
+ mlp_ratio,
460
+ qkv_bias=True,
461
+ norm_layer=norm_layer,
462
+ )
463
+ for _ in range(depth)
464
+ ]
465
+ )
466
+ self.norm = norm_layer(embed_dim)
467
+
468
+ self.norm_pix_loss = norm_pix_loss
469
+ self.pretrained = pretrained
470
+ self.debug = debug
471
+
472
+ self.initialize_weights()
473
+
474
+ def initialize_weights(self):
475
+ # initialize (and freeze) pos_embed by sin-cos embedding
476
+ pos_embed = get_3d_sincos_pos_embed(
477
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True
478
+ )
479
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
480
+
481
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
482
+ w = self.patch_embed.proj.weight.data
483
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
484
+
485
+ # TODO: FIX huggingface config
486
+ # load pretrained weights
487
+ # if self.pretrained:
488
+ # if self.pretrained.endswith('.safetensors'):
489
+ # self._load_safetensors_weights()
490
+ # elif self.pretrained == 'huggingface':
491
+ # print("TemporalViTEncoder | Using HuggingFace pretrained weights.")
492
+ # else:
493
+ # self._load_pt_weights()
494
+ # else:
495
+ # self.apply(self._init_weights)
496
+
497
+ def _load_safetensors_weights(self):
498
+ with safe_open(self.pretrained, framework='pt', device='cpu') as f:
499
+ checkpoint_state_dict = {k: torch.tensor(v) for k, v in f.items()}
500
+ missing_keys, unexpected_keys = self.load_state_dict(checkpoint_state_dict, strict=False)
501
+ if missing_keys:
502
+ print("TemporalViTEncoder | Warning: Missing keys in the state dict:", missing_keys)
503
+ if unexpected_keys:
504
+ print("TemporalViTEncoder | Warning: Unexpected keys in the state dict:", unexpected_keys)
505
+ print(f"TemporalViTEncoder | Loaded pretrained weights from '{self.pretrained}' (safetensors).")
506
+
507
+ def _load_pt_weights(self):
508
+ checkpoint = torch.load(self.pretrained, map_location='cpu')
509
+ checkpoint_state_dict = checkpoint.get('state_dict', checkpoint)
510
+ missing_keys, unexpected_keys = self.load_state_dict(checkpoint_state_dict, strict=False)
511
+ if missing_keys:
512
+ print("TemporalViTEncoder | Warning: Missing keys in the state dict:", missing_keys)
513
+ if unexpected_keys:
514
+ print("TemporalViTEncoder | Warning: Unexpected keys in the state dict:", unexpected_keys)
515
+ print(f"TemporalViTEncoder | Loaded pretrained weights from '{self.pretrained}' (pt file).")
516
+
517
+ def _init_weights(self, m):
518
+ print("TemporalViTEncoder | Newly Initializing weights...")
519
+ if isinstance(m, nn.Linear):
520
+ # we use xavier_uniform following official JAX ViT:
521
+ torch.nn.init.xavier_uniform_(m.weight)
522
+ if isinstance(m, nn.Linear) and m.bias is not None:
523
+ nn.init.constant_(m.bias, 0)
524
+ elif isinstance(m, nn.LayerNorm):
525
+ nn.init.constant_(m.bias, 0)
526
+ nn.init.constant_(m.weight, 1.0)
527
+
528
+ def forward(self, x):
529
+ if self.debug:
530
+ print('TemporalViTEncoder IN:', x.shape)
531
+
532
+ # embed patches
533
+ x, _, _ = self.patch_embed(x)
534
+
535
+ if self.debug:
536
+ print('TemporalViTEncoder EMBED:', x.shape)
537
+
538
+ # add pos embed w/o cls token
539
+ x = x + self.pos_embed[:, 1:, :]
540
+
541
+ # append cls token
542
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
543
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
544
+ x = torch.cat((cls_tokens, x), dim=1)
545
+
546
+ # apply Transformer blocks
547
+ for blk in self.blocks:
548
+ x = blk(x)
549
+
550
+ x = self.norm(x)
551
+
552
+ if self.debug:
553
+ print('TemporalViTEncoder OUT:', x.shape)
554
+
555
+ return tuple([x])
pages/1_Select_Location.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_folium import st_folium
3
+ import folium
4
+ from geopy.geocoders import Nominatim
5
+
6
+ # Define the bounding box
7
+ ZUERICH_BBOX = [8.364, 47.240, 9.0405, 47.69894]
8
+
9
+ def within_bbox(lat, lon, bbox):
10
+ """Check if a point is within the given bounding box."""
11
+ return bbox[1] <= lat <= bbox[3] and bbox[0] <= lon <= bbox[2]
12
+
13
+ def select_coordinates():
14
+ st.title("Step 1: Select Location")
15
+
16
+ instructions = """
17
+ 1. Choose a crop classification location. Search for a location or click on the map.
18
+ 2. Proceed to the "Perform Crop Classification" step.
19
+
20
+ _Note:_ The location must be within the green ZüriCrop area.
21
+ """
22
+ st.sidebar.header("Instructions")
23
+ st.sidebar.markdown(instructions)
24
+
25
+ # Initialize a map centered around the midpoint of the bounding box
26
+ midpoint_lat = (ZUERICH_BBOX[1] + ZUERICH_BBOX[3]) / 2
27
+ midpoint_lon = (ZUERICH_BBOX[0] + ZUERICH_BBOX[2]) / 2
28
+ m = folium.Map(location=[midpoint_lat, midpoint_lon], zoom_start=9)
29
+
30
+ # Add the bounding box to the map as a rectangle
31
+ folium.Rectangle(
32
+ bounds=[[ZUERICH_BBOX[1], ZUERICH_BBOX[0]], [ZUERICH_BBOX[3], ZUERICH_BBOX[2]]],
33
+ color="green",
34
+ fill=True,
35
+ fill_opacity=0.1
36
+ ).add_to(m)
37
+
38
+ # Search for a location
39
+ geolocator = Nominatim(user_agent="streamlit-app")
40
+ location_query = st.text_input("Search for a location:")
41
+
42
+ if location_query:
43
+ location = geolocator.geocode(location_query)
44
+ if location:
45
+ lat, lon = location.latitude, location.longitude
46
+ folium.Marker([lat, lon], tooltip=location.address).add_to(m)
47
+ m.location = [lat, lon]
48
+ m.zoom_start = 12
49
+
50
+ if within_bbox(lat, lon, ZUERICH_BBOX):
51
+ st.success(f"Location found: {location.address}. It is within the bounding box.")
52
+ st.session_state["selected_location"] = (lat, lon)
53
+ else:
54
+ st.error(f"Location found: {location.address}. It is outside the bounding box.")
55
+ else:
56
+ st.error("Location not found. Please try again.")
57
+
58
+ # Add a click event listener to capture coordinates
59
+ m.add_child(folium.LatLngPopup())
60
+
61
+ # Display the map using streamlit-folium
62
+ st_data = st_folium(m, height=500, width=800)
63
+
64
+ # Check if the user clicked within the bounding box
65
+ if st_data["last_clicked"]:
66
+ lat, lon = st_data["last_clicked"]["lat"], st_data["last_clicked"]["lng"]
67
+ if within_bbox(lat, lon, ZUERICH_BBOX):
68
+ st.success(f"Selected Location: Latitude {lat}, Longitude {lon}")
69
+ st.session_state["selected_location"] = (lat, lon)
70
+ else:
71
+ st.error(f"Selected Location is outside the allowed area. Please select a location within the bounding box.")
72
+
73
+ # Proceed to the next step
74
+ link_disabled = "selected_location" not in st.session_state
75
+ st.sidebar.page_link("pages/2_Perform_Crop_Classification.py", label="Proceed to Crop Classification", icon="🌾", disabled=link_disabled)
76
+
77
+ if __name__ == "__main__":
78
+ select_coordinates()
pages/2_Perform_Crop_Classification.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import leafmap.foliumap as leafmap
3
+ from transformers import PretrainedConfig
4
+ from folium import Icon
5
+
6
+ from messis.messis import Messis
7
+ from inference import perform_inference
8
+
9
+ st.set_page_config(layout="wide")
10
+
11
+ GEOTIFF_PATH = "./data/stacked_features.tif"
12
+
13
+ # Load the model
14
+ @st.cache_resource
15
+ def load_model():
16
+ config = PretrainedConfig.from_pretrained('crop-classification/messis', revision='47d9ca4')
17
+ model = Messis.from_pretrained('crop-classification/messis', cache_dir='./hf_cache/', revision='47d9ca4')
18
+ return model, config
19
+ model, config = load_model()
20
+
21
+ def perform_inference_step():
22
+ st.title("Step 2: Perform Crop Classification")
23
+
24
+ if "selected_location" not in st.session_state:
25
+ st.error("No location selected. Please select a location first.")
26
+ st.page_link("pages/1_Select_Location.py", label="Select Location", icon="📍")
27
+ return
28
+
29
+ lat, lon = st.session_state["selected_location"]
30
+
31
+ # Sidebar
32
+ st.sidebar.header("Settings")
33
+
34
+ # Timestep Slider
35
+ timestep = st.sidebar.slider("Select Timestep", 1, 9, 5)
36
+
37
+ # Band Dropdown
38
+ band_options = {
39
+ "RGB": [1, 2, 3], # Adjust indices based on the actual bands in your GeoTIFF
40
+ "NIR": [4],
41
+ "SWIR1": [5],
42
+ "SWIR2": [6]
43
+ }
44
+ vmin_vmax = {
45
+ "RGB": (89, 1878),
46
+ "NIR": (165, 5468),
47
+ "SWIR1": (120, 3361),
48
+ "SWIR2": (94, 2700)
49
+ }
50
+ selected_band = st.sidebar.selectbox("Select Satellite Band to Display", options=list(band_options.keys()), index=0)
51
+
52
+ # Calculate the band indices based on the selected timestep
53
+ selected_bands = [band + (timestep - 1) * 6 for band in band_options[selected_band]]
54
+
55
+ instructions = """
56
+ Click the button "Perform Crop Classification".
57
+
58
+ _Note:_
59
+ - Messis will classify the crop types for the fields in your selected location.
60
+ - Hover over the fields to see the predicted and true crop type.
61
+ - The satellite images might take a few seconds to load.
62
+ """
63
+ st.sidebar.header("Instructions")
64
+ st.sidebar.markdown(instructions)
65
+
66
+ # Initialize the map
67
+ m = leafmap.Map(center=(lat, lon), zoom=10, draw_control=False)
68
+
69
+ # Perform inference
70
+ if st.sidebar.button("Perform Crop Classification", type="primary"):
71
+ predictions = perform_inference(lon, lat, model, config, debug=True)
72
+
73
+ m.add_data(predictions,
74
+ layer_name = "Predictions",
75
+ column="Correct",
76
+ add_legend=False,
77
+ style_function=lambda x: {"fillColor": "green" if x["properties"]["Correct"] else "red", "color": "black", "weight": 0, "fillOpacity": 0.25},
78
+ )
79
+ st.success("Inference completed!")
80
+
81
+ # GeoTIFF Satellite Imagery with selected timestep and band
82
+ m.add_raster(
83
+ GEOTIFF_PATH,
84
+ layer_name="Satellite Image",
85
+ bands=selected_bands,
86
+ fit_bounds=True,
87
+ vmin=vmin_vmax[selected_band][0],
88
+ vmax=vmin_vmax[selected_band][1],
89
+ )
90
+
91
+ # Show the POI on the map
92
+ poi_icon = Icon(color="green", prefix="fa", icon="crosshairs")
93
+ m.add_marker(location=[lat, lon], popup="Selected Location", layer_name="POI", icon=poi_icon)
94
+
95
+ # Display the map in the Streamlit app
96
+ m.to_streamlit()
97
+
98
+ if __name__ == "__main__":
99
+ perform_inference_step()
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.0
2
+ PyYAML==6.0.1
3
+ rasterio==1.3.10
4
+ torchvision==0.18.0
5
+ shapely==2.0.4
6
+ geopandas==0.14.4
7
+ pytorch-lightning==2.2.3
8
+ dvc==3.50.1
9
+ streamlit==1.37.0
10
+ leafmap==0.36.6
11
+ transformers==4.41.2
12
+ folium==0.17.0
13
+ streamlit-folium==0.22.0
14
+ geopy==2.4.1
15
+ localtileserver==0.10.3
16
+ xarray==2024.7.0
17
+ scipy==1.14.0
18
+ mapclassify==2.8.0
19
+ wandb==0.16.6
20
+ numpy==1.26.4
21
+ lion-pytorch==0.2.2
22
+ timm==0.9.16
23
+ pyproj